diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/IPO/SampleProfile.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 29 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstructionCombining.cpp | 3 | ||||
-rw-r--r-- | llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp | 10 | ||||
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 134 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp | 3 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 110 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 44 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlan.cpp | 30 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlan.h | 11 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 17 |
14 files changed, 205 insertions, 212 deletions
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index 5bc7e34..99b8b88 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -1664,8 +1664,9 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { else if (OverwriteExistingWeights) I.setMetadata(LLVMContext::MD_prof, nullptr); } else if (!isa<IntrinsicInst>(&I)) { - setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])}, - /*IsExpected=*/false); + setBranchWeights( + I, ArrayRef<uint32_t>{static_cast<uint32_t>(BlockWeights[BB])}, + /*IsExpected=*/false); } } } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { @@ -1676,7 +1677,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (cast<CallBase>(I).isIndirectCall()) { I.setMetadata(LLVMContext::MD_prof, nullptr); } else { - setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false); + setBranchWeights(I, ArrayRef<uint32_t>{uint32_t(0)}, + /*IsExpected=*/false); } } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 4b7793f..9b272c4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3080,6 +3080,13 @@ InstCombinerImpl::convertOrOfShiftsToFunnelShift(Instruction &Or) { assert(ZextLowShlAmt->uge(HighSize) && ZextLowShlAmt->ule(Width - LowSize) && "Invalid concat"); + // We cannot reuse the result if it may produce poison. + // Drop poison generating flags in the expression tree. + // Or + cast<Instruction>(U)->dropPoisonGeneratingFlags(); + // Shl + cast<Instruction>(X)->dropPoisonGeneratingFlags(); + FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)}; break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 6ef3066..18a45c6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -319,20 +319,20 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { return nullptr; } -/// Find elements of V demanded by UserInstr. -static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { +/// Find elements of V demanded by UserInstr. If returns false, we were not able +/// to determine all elements. +static bool findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr, + APInt &UnionUsedElts) { unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); - // Conservatively assume that all elements are needed. - APInt UsedElts(APInt::getAllOnes(VWidth)); - switch (UserInstr->getOpcode()) { case Instruction::ExtractElement: { ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr); assert(EEI->getVectorOperand() == V); ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand()); if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { - UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + UnionUsedElts.setBit(EEIIndexC->getZExtValue()); + return true; } break; } @@ -341,23 +341,23 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { unsigned MaskNumElts = cast<FixedVectorType>(UserInstr->getType())->getNumElements(); - UsedElts = APInt(VWidth, 0); - for (unsigned i = 0; i < MaskNumElts; i++) { - unsigned MaskVal = Shuffle->getMaskValue(i); + for (auto I : llvm::seq(MaskNumElts)) { + unsigned MaskVal = Shuffle->getMaskValue(I); if (MaskVal == -1u || MaskVal >= 2 * VWidth) continue; if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) - UsedElts.setBit(MaskVal); + UnionUsedElts.setBit(MaskVal); if (Shuffle->getOperand(1) == V && ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) - UsedElts.setBit(MaskVal - VWidth); + UnionUsedElts.setBit(MaskVal - VWidth); } - break; + return true; } default: break; } - return UsedElts; + + return false; } /// Find union of elements of V demanded by all its users. @@ -370,7 +370,8 @@ static APInt findDemandedEltsByAllUsers(Value *V) { APInt UnionUsedElts(VWidth, 0); for (const Use &U : V->uses()) { if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { - UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + if (!findDemandedEltsBySingleUser(V, I, UnionUsedElts)) + return APInt::getAllOnes(VWidth); } else { UnionUsedElts = APInt::getAllOnes(VWidth); break; diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index ff063f9..5d2d79e 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -5212,7 +5212,7 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { else if (match(U, m_Select(m_Specific(&I), m_Constant(), m_Value()))) V = ConstantInt::getTrue(Ty); else if (match(U, m_c_Select(m_Specific(&I), m_Value(V)))) { - if (!isGuaranteedNotToBeUndefOrPoison(V, &AC, &I, &DT)) + if (V == &I || !isGuaranteedNotToBeUndefOrPoison(V, &AC, &I, &DT)) V = NullValue; } else if (auto *PHI = dyn_cast<PHINode>(U)) { if (Value *MaybeV = pickCommonConstantFromPHI(*PHI)) @@ -5225,6 +5225,7 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { BestValue = NullValue; } assert(BestValue && "Must have at least one use"); + assert(BestValue != &I && "Cannot replace with itself"); return BestValue; }; diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index f451c2b..0249f21 100644 --- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -672,8 +672,8 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, createBranchWeights(CB.getContext(), Count, TotalCount - Count)); if (AttachProfToDirectCall) - setBranchWeights(NewInst, {static_cast<uint32_t>(Count)}, - /*IsExpected=*/false); + setFittedBranchWeights(NewInst, {Count}, + /*IsExpected=*/false); using namespace ore; diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 944b253..e9a3e98 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -190,12 +190,12 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, std::vector<BasicBlock *> *NewBBs) { SelectInst *SI = SIToUnfold.getInst(); PHINode *SIUse = SIToUnfold.getUse(); - BasicBlock *StartBlock = SI->getParent(); + assert(SI->hasOneUse()); + // The select may come indirectly, instead of from where it is defined. + BasicBlock *StartBlock = SIUse->getIncomingBlock(*SI->use_begin()); BranchInst *StartBlockTerm = dyn_cast<BranchInst>(StartBlock->getTerminator()); - assert(StartBlockTerm); - assert(SI->hasOneUse()); if (StartBlockTerm->isUnconditional()) { BasicBlock *EndBlock = StartBlock->getUniqueSuccessor(); @@ -332,7 +332,7 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, } // Preserve loop info - if (Loop *L = LI->getLoopFor(SI->getParent())) { + if (Loop *L = LI->getLoopFor(StartBlock)) { for (BasicBlock *NewBB : *NewBBs) L->addBasicBlockToLoop(NewBB, *LI); } @@ -533,6 +533,8 @@ private: return false; // Only fold the select coming from directly where it is defined. + // TODO: We have dealt with the select coming indirectly now. This + // constraint can be relaxed. PHINode *PHIUser = dyn_cast<PHINode>(SIUse); if (PHIUser && PHIUser->getIncomingBlock(*SI->use_begin()) != SIBB) return false; diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 0874b29..019536ca 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -1598,11 +1598,8 @@ bool LoopIdiomRecognize::optimizeCRCLoop(const PolynomialInfo &Info) { // crc = (crc << 8) ^ tbl[(iv'th byte of data) ^ (top byte of crc)] { auto LoByte = [](IRBuilderBase &Builder, Value *Op, const Twine &Name) { - Type *OpTy = Op->getType(); - unsigned OpBW = OpTy->getIntegerBitWidth(); - return OpBW > 8 - ? Builder.CreateAnd(Op, ConstantInt::get(OpTy, 0XFF), Name) - : Op; + return Builder.CreateZExtOrTrunc( + Op, IntegerType::getInt8Ty(Op->getContext()), Name); }; auto HiIdx = [LoByte, CRCBW](IRBuilderBase &Builder, Value *Op, const Twine &Name) { diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 216bdf4..4d1f768 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -955,33 +955,6 @@ static bool valuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, return false; } -// Set branch weights on SwitchInst. This sets the metadata if there is at -// least one non-zero weight. -static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights, - bool IsExpected) { - // Check that there is at least one non-zero weight. Otherwise, pass - // nullptr to setMetadata which will erase the existing metadata. - MDNode *N = nullptr; - if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; })) - N = MDBuilder(SI->getParent()->getContext()) - .createBranchWeights(Weights, IsExpected); - SI->setMetadata(LLVMContext::MD_prof, N); -} - -// Similar to the above, but for branch and select instructions that take -// exactly 2 weights. -static void setBranchWeights(Instruction *I, uint32_t TrueWeight, - uint32_t FalseWeight, bool IsExpected) { - assert(isa<BranchInst>(I) || isa<SelectInst>(I)); - // Check that there is at least one non-zero weight. Otherwise, pass - // nullptr to setMetadata which will erase the existing metadata. - MDNode *N = nullptr; - if (TrueWeight || FalseWeight) - N = MDBuilder(I->getParent()->getContext()) - .createBranchWeights(TrueWeight, FalseWeight, IsExpected); - I->setMetadata(LLVMContext::MD_prof, N); -} - /// If TI is known to be a terminator instruction and its block is known to /// only have a single predecessor block, check to see if that predecessor is /// also a value comparison with the same value, and if that comparison @@ -1181,16 +1154,6 @@ static void getBranchWeights(Instruction *TI, } } -/// Keep halving the weights until all can fit in uint32_t. -static void fitWeights(MutableArrayRef<uint64_t> Weights) { - uint64_t Max = *llvm::max_element(Weights); - if (Max > UINT_MAX) { - unsigned Offset = 32 - llvm::countl_zero(Max); - for (uint64_t &I : Weights) - I >>= Offset; - } -} - static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( BasicBlock *BB, BasicBlock *PredBlock, ValueToValueMapTy &VMap) { Instruction *PTI = PredBlock->getTerminator(); @@ -1446,14 +1409,9 @@ bool SimplifyCFGOpt::performValueComparisonIntoPredecessorFolding( for (ValueEqualityComparisonCase &V : PredCases) NewSI->addCase(V.Value, V.Dest); - if (PredHasWeights || SuccHasWeights) { - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(Weights); - - SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - - setBranchWeights(NewSI, MDWeights, /*IsExpected=*/false); - } + if (PredHasWeights || SuccHasWeights) + setFittedBranchWeights(*NewSI, Weights, /*IsExpected=*/false, + /*ElideAllZero=*/true); eraseTerminatorAndDCECond(PTI); @@ -4053,39 +4011,34 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, // Try to update branch weights. uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; - SmallVector<uint32_t, 2> MDWeights; + SmallVector<uint64_t, 2> MDWeights; if (extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight)) { - SmallVector<uint64_t, 8> NewWeights; if (PBI->getSuccessor(0) == BB) { // PBI: br i1 %x, BB, FalseDest // BI: br i1 %y, UniqueSucc, FalseDest // TrueWeight is TrueWeight for PBI * TrueWeight for BI. - NewWeights.push_back(PredTrueWeight * SuccTrueWeight); + MDWeights.push_back(PredTrueWeight * SuccTrueWeight); // FalseWeight is FalseWeight for PBI * TotalWeight for BI + // TrueWeight for PBI * FalseWeight for BI. // We assume that total weights of a BranchInst can fit into 32 bits. // Therefore, we will not have overflow using 64-bit arithmetic. - NewWeights.push_back(PredFalseWeight * - (SuccFalseWeight + SuccTrueWeight) + - PredTrueWeight * SuccFalseWeight); + MDWeights.push_back(PredFalseWeight * (SuccFalseWeight + SuccTrueWeight) + + PredTrueWeight * SuccFalseWeight); } else { // PBI: br i1 %x, TrueDest, BB // BI: br i1 %y, TrueDest, UniqueSucc // TrueWeight is TrueWeight for PBI * TotalWeight for BI + // FalseWeight for PBI * TrueWeight for BI. - NewWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) + - PredFalseWeight * SuccTrueWeight); + MDWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) + + PredFalseWeight * SuccTrueWeight); // FalseWeight is FalseWeight for PBI * FalseWeight for BI. - NewWeights.push_back(PredFalseWeight * SuccFalseWeight); + MDWeights.push_back(PredFalseWeight * SuccFalseWeight); } - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(NewWeights); - - append_range(MDWeights, NewWeights); - setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false); + setFittedBranchWeights(*PBI, MDWeights, /*IsExpected=*/false, + /*ElideAllZero=*/true); // TODO: If BB is reachable from all paths through PredBlock, then we // could replace PBI's branch probabilities with BI's. @@ -4125,8 +4078,8 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, if (auto *SI = dyn_cast<SelectInst>(PBI->getCondition())) if (!MDWeights.empty()) { assert(isSelectInRoleOfConjunctionOrDisjunction(SI)); - setBranchWeights(SI, MDWeights[0], MDWeights[1], - /*IsExpected=*/false); + setFittedBranchWeights(*SI, {MDWeights[0], MDWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } ++NumFoldBranchToCommonDest; @@ -4478,9 +4431,9 @@ static bool mergeConditionalStoreToAddress( if (InvertQCond) std::swap(QWeights[0], QWeights[1]); auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights); - setBranchWeights(PostBB->getTerminator(), CombinedWeights[0], - CombinedWeights[1], - /*IsExpected=*/false); + setFittedBranchWeights(*PostBB->getTerminator(), + {CombinedWeights[0], CombinedWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } QB.SetInsertPoint(T); @@ -4836,10 +4789,9 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther) + PredOther * SuccCommon, PredOther * SuccOther}; - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(NewWeights); - setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false); + setFittedBranchWeights(*PBI, NewWeights, /*IsExpected=*/false, + /*ElideAllZero=*/true); // Cond may be a select instruction with the first operand set to "true", or // the second to "false" (see how createLogicalOp works for `and` and `or`) if (!ProfcheckDisableMetadataFixes) @@ -4849,8 +4801,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, assert(dyn_cast<SelectInst>(SI)->getCondition() == PBICond); // The corresponding probabilities are what was referred to above as // PredCommon and PredOther. - setBranchWeights(SI, PredCommon, PredOther, - /*IsExpected=*/false); + setFittedBranchWeights(*SI, {PredCommon, PredOther}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } @@ -4876,8 +4828,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (HasWeights) { uint64_t TrueWeight = PBIOp ? PredFalseWeight : PredTrueWeight; uint64_t FalseWeight = PBIOp ? PredTrueWeight : PredFalseWeight; - setBranchWeights(NV, TrueWeight, FalseWeight, - /*IsExpected=*/false); + setFittedBranchWeights(*NV, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } } @@ -4940,7 +4892,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm, // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); if (TrueWeight != FalseWeight) - setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); + setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -5889,7 +5842,8 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, TrueWeight /= 2; FalseWeight /= 2; } - setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); + setFittedBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } @@ -6364,9 +6318,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, // BranchWeights. We want the probability and negative probability of // Condition == SecondCase. assert(BranchWeights.size() == 3); - setBranchWeights(SI, BranchWeights[2], - BranchWeights[0] + BranchWeights[1], - /*IsExpected=*/false); + setBranchWeights( + *SI, {BranchWeights[2], BranchWeights[0] + BranchWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } Value *ValueCompare = @@ -6381,9 +6335,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, size_t FirstCasePos = (Condition != nullptr); size_t SecondCasePos = FirstCasePos + 1; uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0; - setBranchWeights(SI, BranchWeights[FirstCasePos], - DefaultCase + BranchWeights[SecondCasePos], - /*IsExpected=*/false); + setBranchWeights(*SI, + {BranchWeights[FirstCasePos], + DefaultCase + BranchWeights[SecondCasePos]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6427,8 +6382,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, // We know there's a Default case. We base the resulting branch // weights off its probability. assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, + {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6451,8 +6408,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, + {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6469,8 +6428,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -8152,8 +8112,8 @@ static bool mergeNestedCondBranch(BranchInst *BI, DomTreeUpdater *DTU) { if (HasWeight) { uint64_t Weights[2] = {BBTWeight * BB1FWeight + BBFWeight * BB2TWeight, BBTWeight * BB1TWeight + BBFWeight * BB2FWeight}; - fitWeights(Weights); - setBranchWeights(BI, Weights[0], Weights[1], /*IsExpected=*/false); + setFittedBranchWeights(*BI, Weights, /*IsExpected=*/false, + /*ElideAllZero=*/true); } return true; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index ff35db1..7d376c3 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -293,9 +293,8 @@ void LoopVectorizeHints::getHintsFromMetadata() { } void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { - if (!Name.starts_with(Prefix())) + if (!Name.consume_front(Prefix())) return; - Name = Name.substr(Prefix().size(), StringRef::npos); const ConstantInt *C = mdconst::dyn_extract<ConstantInt>(Arg); if (!C) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 12fb46d..e5d6c81 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5699,6 +5699,20 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { Worklist.push_back(InstOp); } + auto UpdateMemOpUserCost = [this, VF](LoadInst *LI) { + // If there are direct memory op users of the newly scalarized load, + // their cost may have changed because there's no scalarization + // overhead for the operand. Update it. + for (User *U : LI->users()) { + if (!isa<LoadInst, StoreInst>(U)) + continue; + if (getWideningDecision(cast<Instruction>(U), VF) != CM_Scalarize) + continue; + setWideningDecision( + cast<Instruction>(U), VF, CM_Scalarize, + getMemInstScalarizationCost(cast<Instruction>(U), VF)); + } + }; for (auto *I : AddrDefs) { if (isa<LoadInst>(I)) { // Setting the desired widening decision should ideally be handled in @@ -5708,21 +5722,24 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { InstWidening Decision = getWideningDecision(I, VF); if (Decision == CM_Widen || Decision == CM_Widen_Reverse || (!isPredicatedInst(I) && !Legal->isUniformMemOp(*I, VF) && - Decision == CM_Scalarize)) + Decision == CM_Scalarize)) { // Scalarize a widened load of address or update the cost of a scalar // load of an address. setWideningDecision( I, VF, CM_Scalarize, (VF.getKnownMinValue() * getMemoryInstructionCost(I, ElementCount::getFixed(1)))); - else if (const auto *Group = getInterleavedAccessGroup(I)) { + UpdateMemOpUserCost(cast<LoadInst>(I)); + } else if (const auto *Group = getInterleavedAccessGroup(I)) { // Scalarize an interleave group of address loads. for (unsigned I = 0; I < Group->getFactor(); ++I) { - if (Instruction *Member = Group->getMember(I)) + if (Instruction *Member = Group->getMember(I)) { setWideningDecision( Member, VF, CM_Scalarize, (VF.getKnownMinValue() * getMemoryInstructionCost(Member, ElementCount::getFixed(1)))); + UpdateMemOpUserCost(cast<LoadInst>(Member)); + } } } } else { @@ -9521,55 +9538,52 @@ static SmallVector<Instruction *> preparePlanForEpilogueVectorLoop( VPBasicBlock *Header = VectorLoop->getEntryBasicBlock(); Header->setName("vec.epilog.vector.body"); - DenseMap<Value *, Value *> ToFrozen; - SmallVector<Instruction *> InstsToMove; // Ensure that the start values for all header phi recipes are updated before // vectorizing the epilogue loop. - for (VPRecipeBase &R : Header->phis()) { - if (auto *IV = dyn_cast<VPCanonicalIVPHIRecipe>(&R)) { - // When vectorizing the epilogue loop, the canonical induction start - // value needs to be changed from zero to the value after the main - // vector loop. Find the resume value created during execution of the main - // VPlan. It must be the first phi in the loop preheader. - // FIXME: Improve modeling for canonical IV start values in the epilogue - // loop. - using namespace llvm::PatternMatch; - PHINode *EPResumeVal = &*L->getLoopPreheader()->phis().begin(); - for (Value *Inc : EPResumeVal->incoming_values()) { - if (match(Inc, m_SpecificInt(0))) - continue; - assert(!EPI.VectorTripCount && - "Must only have a single non-zero incoming value"); - EPI.VectorTripCount = Inc; - } - // If we didn't find a non-zero vector trip count, all incoming values - // must be zero, which also means the vector trip count is zero. Pick the - // first zero as vector trip count. - // TODO: We should not choose VF * UF so the main vector loop is known to - // be dead. - if (!EPI.VectorTripCount) { - assert( - EPResumeVal->getNumIncomingValues() > 0 && - all_of(EPResumeVal->incoming_values(), - [](Value *Inc) { return match(Inc, m_SpecificInt(0)); }) && - "all incoming values must be 0"); - EPI.VectorTripCount = EPResumeVal->getOperand(0); - } - VPValue *VPV = Plan.getOrAddLiveIn(EPResumeVal); - assert(all_of(IV->users(), - [](const VPUser *U) { - return isa<VPScalarIVStepsRecipe>(U) || - isa<VPDerivedIVRecipe>(U) || - cast<VPRecipeBase>(U)->isScalarCast() || - cast<VPInstruction>(U)->getOpcode() == - Instruction::Add; - }) && - "the canonical IV should only be used by its increment or " - "ScalarIVSteps when resetting the start value"); - IV->setOperand(0, VPV); + VPCanonicalIVPHIRecipe *IV = Plan.getCanonicalIV(); + // When vectorizing the epilogue loop, the canonical induction start + // value needs to be changed from zero to the value after the main + // vector loop. Find the resume value created during execution of the main + // VPlan. It must be the first phi in the loop preheader. + // FIXME: Improve modeling for canonical IV start values in the epilogue + // loop. + using namespace llvm::PatternMatch; + PHINode *EPResumeVal = &*L->getLoopPreheader()->phis().begin(); + for (Value *Inc : EPResumeVal->incoming_values()) { + if (match(Inc, m_SpecificInt(0))) continue; - } + assert(!EPI.VectorTripCount && + "Must only have a single non-zero incoming value"); + EPI.VectorTripCount = Inc; + } + // If we didn't find a non-zero vector trip count, all incoming values + // must be zero, which also means the vector trip count is zero. Pick the + // first zero as vector trip count. + // TODO: We should not choose VF * UF so the main vector loop is known to + // be dead. + if (!EPI.VectorTripCount) { + assert(EPResumeVal->getNumIncomingValues() > 0 && + all_of(EPResumeVal->incoming_values(), + [](Value *Inc) { return match(Inc, m_SpecificInt(0)); }) && + "all incoming values must be 0"); + EPI.VectorTripCount = EPResumeVal->getOperand(0); + } + VPValue *VPV = Plan.getOrAddLiveIn(EPResumeVal); + assert(all_of(IV->users(), + [](const VPUser *U) { + return isa<VPScalarIVStepsRecipe>(U) || + isa<VPDerivedIVRecipe>(U) || + cast<VPRecipeBase>(U)->isScalarCast() || + cast<VPInstruction>(U)->getOpcode() == + Instruction::Add; + }) && + "the canonical IV should only be used by its increment or " + "ScalarIVSteps when resetting the start value"); + IV->setOperand(0, VPV); + DenseMap<Value *, Value *> ToFrozen; + SmallVector<Instruction *> InstsToMove; + for (VPRecipeBase &R : drop_begin(Header->phis())) { Value *ResumeV = nullptr; // TODO: Move setting of resume values to prepareToExecute. if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) { diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index f77d587..fedca65 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2241,10 +2241,9 @@ public: /// TODO: If load combining is allowed in the IR optimizer, this analysis /// may not be necessary. bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const; - bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, - ArrayRef<unsigned> Order, const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - const int64_t Diff, StridedPtrInfo &SPtrInfo) const; + bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, + Align Alignment, const int64_t Diff, Value *Ptr0, + Value *PtrN, StridedPtrInfo &SPtrInfo) const; /// Checks if the given array of loads can be represented as a vectorized, /// scatter or just simple gather. @@ -6824,13 +6823,10 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, /// 4. Any pointer operand is an instruction with the users outside of the /// current graph (for masked gathers extra extractelement instructions /// might be required). -bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, - ArrayRef<unsigned> Order, - const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - const int64_t Diff, - StridedPtrInfo &SPtrInfo) const { - const size_t Sz = VL.size(); +bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, + Align Alignment, const int64_t Diff, Value *Ptr0, + Value *PtrN, StridedPtrInfo &SPtrInfo) const { + const size_t Sz = PointerOps.size(); if (Diff % (Sz - 1) != 0) return false; @@ -6842,7 +6838,6 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, }); const uint64_t AbsoluteDiff = std::abs(Diff); - Type *ScalarTy = VL.front()->getType(); auto *VecTy = getWidenedType(ScalarTy, Sz); if (IsAnyPointerUsedOutGraph || (AbsoluteDiff > Sz && @@ -6853,20 +6848,9 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, int64_t Stride = Diff / static_cast<int64_t>(Sz - 1); if (Diff != Stride * static_cast<int64_t>(Sz - 1)) return false; - Align Alignment = - cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) - ->getAlign(); - if (!TTI.isLegalStridedLoadStore(VecTy, Alignment)) + if (!TTI->isLegalStridedLoadStore(VecTy, Alignment)) return false; - Value *Ptr0; - Value *PtrN; - if (Order.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); - } else { - Ptr0 = PointerOps[Order.front()]; - PtrN = PointerOps[Order.back()]; - } + // Iterate through all pointers and check if all distances are // unique multiple of Dist. SmallSet<int64_t, 4> Dists; @@ -6875,14 +6859,14 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, if (Ptr == PtrN) Dist = Diff; else if (Ptr != Ptr0) - Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE); + Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); // If the strides are not the same or repeated, we can't // vectorize. if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) break; } if (Dists.size() == Sz) { - Type *StrideTy = DL.getIndexType(Ptr0->getType()); + Type *StrideTy = DL->getIndexType(Ptr0->getType()); SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); return true; @@ -6971,7 +6955,11 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( cast<Instruction>(V), UserIgnoreList); })) return LoadsState::CompressVectorize; - if (isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, *Diff, SPtrInfo)) + Align Alignment = + cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) + ->getAlign(); + if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN, + SPtrInfo)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 81f1956..02eb637 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -968,24 +968,36 @@ void VPlan::execute(VPTransformState *State) { // logic generic during VPlan execution. State->CFG.DTU.applyUpdates( {{DominatorTree::Delete, ScalarPh, ScalarPh->getSingleSuccessor()}}); - } else { + } + ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( + Entry); + // Generate code for the VPlan, in parts of the vector skeleton, loop body and + // successor blocks including the middle, exit and scalar preheader blocks. + for (VPBlockBase *Block : RPOT) + Block->execute(State); + + // If the original loop is unreachable, delete it and all its blocks. + if (!ScalarPhVPBB->hasPredecessors()) { + // DeleteDeadBlocks will remove single-entry phis. Remove them from the exit + // VPIRBBs in VPlan as well, otherwise we would retain references to deleted + // IR instructions. + for (VPIRBasicBlock *EB : getExitBlocks()) { + for (VPRecipeBase &R : make_early_inc_range(EB->phis())) { + if (R.getNumOperands() == 1) + R.eraseFromParent(); + } + } + Loop *OrigLoop = State->LI->getLoopFor(getScalarHeader()->getIRBasicBlock()); - // If the original loop is unreachable, we need to delete it. auto Blocks = OrigLoop->getBlocksVector(); Blocks.push_back(cast<VPIRBasicBlock>(ScalarPhVPBB)->getIRBasicBlock()); for (auto *BB : Blocks) State->LI->removeBlock(BB); + DeleteDeadBlocks(Blocks, &State->CFG.DTU); State->LI->erase(OrigLoop); } - ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( - Entry); - // Generate code for the VPlan, in parts of the vector skeleton, loop body and - // successor blocks including the middle, exit and scalar preheader blocks. - for (VPBlockBase *Block : RPOT) - Block->execute(State); - State->CFG.DTU.flush(); VPBasicBlock *Header = vputils::getFirstLoopHeader(*this, State->VPDT); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 10d704d..c167dd7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -29,6 +29,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist.h" @@ -2977,7 +2978,8 @@ public: /// the expression is elevated to connect the non-expression recipe with the /// VPExpressionRecipe itself. class VPExpressionRecipe : public VPSingleDefRecipe { - /// Recipes included in this VPExpressionRecipe. + /// Recipes included in this VPExpressionRecipe. This could contain + /// duplicates. SmallVector<VPSingleDefRecipe *> ExpressionRecipes; /// Temporary VPValues used for external operands of the expression, i.e. @@ -3039,8 +3041,11 @@ public: } ~VPExpressionRecipe() override { - for (auto *R : reverse(ExpressionRecipes)) - delete R; + SmallPtrSet<VPSingleDefRecipe *, 4> ExpressionRecipesSeen; + for (auto *R : reverse(ExpressionRecipes)) { + if (ExpressionRecipesSeen.insert(R).second) + delete R; + } for (VPValue *T : LiveInPlaceholders) delete T; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 3a55710..46909a5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2755,10 +2755,7 @@ VPExpressionRecipe::VPExpressionRecipe( ExpressionTypes ExpressionType, ArrayRef<VPSingleDefRecipe *> ExpressionRecipes) : VPSingleDefRecipe(VPDef::VPExpressionSC, {}, {}), - ExpressionRecipes(SetVector<VPSingleDefRecipe *>( - ExpressionRecipes.begin(), ExpressionRecipes.end()) - .takeVector()), - ExpressionType(ExpressionType) { + ExpressionRecipes(ExpressionRecipes), ExpressionType(ExpressionType) { assert(!ExpressionRecipes.empty() && "Nothing to combine?"); assert( none_of(ExpressionRecipes, @@ -2802,14 +2799,22 @@ VPExpressionRecipe::VPExpressionRecipe( continue; addOperand(Op); LiveInPlaceholders.push_back(new VPValue()); - R->setOperand(Idx, LiveInPlaceholders.back()); } } + + // Replace each external operand with the first one created for it in + // LiveInPlaceholders. + for (auto *R : ExpressionRecipes) + for (auto const &[LiveIn, Tmp] : zip(operands(), LiveInPlaceholders)) + R->replaceUsesOfWith(LiveIn, Tmp); } void VPExpressionRecipe::decompose() { for (auto *R : ExpressionRecipes) - R->insertBefore(this); + // Since the list could contain duplicates, make sure the recipe hasn't + // already been inserted. + if (!R->getParent()) + R->insertBefore(this); for (const auto &[Idx, Op] : enumerate(operands())) LiveInPlaceholders[Idx]->replaceAllUsesWith(Op); |