diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils')
-rw-r--r-- | llvm/lib/Transforms/Utils/SCCPSolver.cpp | 39 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 189 |
2 files changed, 163 insertions, 65 deletions
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index af216cd..9693ae6 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -317,24 +317,29 @@ static Value *simplifyInstruction(SCCPSolver &Solver, // Early exit if we know nothing about X. if (LRange.isFullSet()) return nullptr; - // We are allowed to refine the comparison to either true or false for out - // of range inputs. Here we refine the comparison to true, i.e. we relax - // the range check. - auto NewCR = CR->exactUnionWith(LRange.inverse()); - // TODO: Check if we can narrow the range check to an equality test. - // E.g, for X in [0, 4), X - 3 u< 2 -> X == 3 - if (!NewCR) + auto ConvertCRToICmp = + [&](const std::optional<ConstantRange> &NewCR) -> Value * { + ICmpInst::Predicate Pred; + APInt RHS; + // Check if we can represent NewCR as an icmp predicate. + if (NewCR && NewCR->getEquivalentICmp(Pred, RHS)) { + IRBuilder<NoFolder> Builder(&Inst); + Value *NewICmp = + Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS)); + InsertedValues.insert(NewICmp); + return NewICmp; + } return nullptr; - ICmpInst::Predicate Pred; - APInt RHS; - // Check if we can represent NewCR as an icmp predicate. - if (NewCR->getEquivalentICmp(Pred, RHS)) { - IRBuilder<NoFolder> Builder(&Inst); - Value *NewICmp = - Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS)); - InsertedValues.insert(NewICmp); - return NewICmp; - } + }; + // We are allowed to refine the comparison to either true or false for out + // of range inputs. + // Here we refine the comparison to false, and check if we can narrow the + // range check to a simpler test. + if (auto *V = ConvertCRToICmp(CR->exactIntersectWith(LRange))) + return V; + // Here we refine the comparison to true, i.e. we relax the range check. + if (auto *V = ConvertCRToICmp(CR->exactUnionWith(LRange.inverse()))) + return V; } } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 48055ad..b8cfe3a 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm, // We found both of the successors we were looking for. // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); - if (TrueWeight != FalseWeight) - setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, - /*IsExpected=*/false, /*ElideAllZero=*/true); + setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI, BasicBlock *TrueBB = TBA->getBasicBlock(); BasicBlock *FalseBB = FBA->getBasicBlock(); + // The select's profile becomes the profile of the conditional branch that + // replaces the indirect branch. + SmallVector<uint32_t> SelectBranchWeights(2); + if (!ProfcheckDisableMetadataFixes) + extractBranchWeights(*SI, SelectBranchWeights); // Perform the actual simplification. - return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0, - 0); + return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, + SelectBranchWeights[0], + SelectBranchWeights[1]); } /// This is called when we find an icmp instruction @@ -5734,15 +5739,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { return Changed; } -static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) { +struct ContiguousCasesResult { + ConstantInt *Min; + ConstantInt *Max; + BasicBlock *Dest; + BasicBlock *OtherDest; + SmallVectorImpl<ConstantInt *> *Cases; + SmallVectorImpl<ConstantInt *> *OtherCases; +}; + +static std::optional<ContiguousCasesResult> +findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases, + SmallVectorImpl<ConstantInt *> &OtherCases, + BasicBlock *Dest, BasicBlock *OtherDest) { assert(Cases.size() >= 1); array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate); - for (size_t I = 1, E = Cases.size(); I != E; ++I) { - if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1) - return false; + const APInt &Min = Cases.back()->getValue(); + const APInt &Max = Cases.front()->getValue(); + APInt Offset = Max - Min; + size_t ContiguousOffset = Cases.size() - 1; + if (Offset == ContiguousOffset) { + return ContiguousCasesResult{ + /*Min=*/Cases.back(), + /*Max=*/Cases.front(), + /*Dest=*/Dest, + /*OtherDest=*/OtherDest, + /*Cases=*/&Cases, + /*OtherCases=*/&OtherCases, + }; } - return true; + ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false); + // If this is a wrapping contiguous range, that is, [Min, OtherMin] + + // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a + // contiguous range for the other destination. N.B. If CR is not a full range, + // Max+1 is not equal to Min. It's not continuous in arithmetic. + if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) { + assert(Cases.size() >= 2); + auto *It = + std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) { + return L->getValue() != R->getValue() + 1; + }); + if (It == Cases.end()) + return std::nullopt; + auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It)); + if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) == + Cases.size() - 2) { + return ContiguousCasesResult{ + /*Min=*/cast<ConstantInt>( + ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)), + /*Max=*/ + cast<ConstantInt>( + ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)), + /*Dest=*/OtherDest, + /*OtherDest=*/Dest, + /*Cases=*/&OtherCases, + /*OtherCases=*/&Cases, + }; + } + } + return std::nullopt; } static void createUnreachableSwitchDefault(SwitchInst *Switch, @@ -5779,7 +5835,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, bool HasDefault = !SI->defaultDestUnreachable(); auto *BB = SI->getParent(); - // Partition the cases into two sets with different destinations. BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr; BasicBlock *DestB = nullptr; @@ -5813,37 +5868,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, assert(!CasesA.empty() || HasDefault); // Figure out if one of the sets of cases form a contiguous range. - SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr; - BasicBlock *ContiguousDest = nullptr; - BasicBlock *OtherDest = nullptr; - if (!CasesA.empty() && casesAreContiguous(CasesA)) { - ContiguousCases = &CasesA; - ContiguousDest = DestA; - OtherDest = DestB; - } else if (casesAreContiguous(CasesB)) { - ContiguousCases = &CasesB; - ContiguousDest = DestB; - OtherDest = DestA; - } else - return false; + std::optional<ContiguousCasesResult> ContiguousCases; + + // Only one icmp is needed when there is only one case. + if (!HasDefault && CasesA.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesA[0], + /*Max=*/CasesA[0], + /*Dest=*/DestA, + /*OtherDest=*/DestB, + /*Cases=*/&CasesA, + /*OtherCases=*/&CasesB, + }; + else if (CasesB.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesB[0], + /*Max=*/CasesB[0], + /*Dest=*/DestB, + /*OtherDest=*/DestA, + /*Cases=*/&CasesB, + /*OtherCases=*/&CasesA, + }; + // Correctness: Cases to the default destination cannot be contiguous cases. + else if (!HasDefault) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB); - // Start building the compare and branch. + if (!ContiguousCases) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA); + + if (!ContiguousCases) + return false; - Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back()); - Constant *NumCases = - ConstantInt::get(Offset->getType(), ContiguousCases->size()); + auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases; - Value *Sub = SI->getCondition(); - if (!Offset->isNullValue()) - Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + // Start building the compare and branch. - Value *Cmp; + Constant *Offset = ConstantExpr::getNeg(Min); + Constant *NumCases = ConstantInt::get(Offset->getType(), + Max->getValue() - Min->getValue() + 1); + BranchInst *NewBI; + if (NumCases->isOneValue()) { + assert(Max->getValue() == Min->getValue()); + Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // If NumCases overflowed, then all possible values jump to the successor. - if (NumCases->isNullValue() && !ContiguousCases->empty()) - Cmp = ConstantInt::getTrue(SI->getContext()); - else - Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); - BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); + else if (NumCases->isNullValue() && !Cases->empty()) { + NewBI = Builder.CreateBr(Dest); + } else { + Value *Sub = SI->getCondition(); + if (!Offset->isNullValue()) + Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // Update weight for the newly-created conditional branch. if (hasBranchWeightMD(*SI)) { @@ -5853,7 +5933,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, uint64_t TrueWeight = 0; uint64_t FalseWeight = 0; for (size_t I = 0, E = Weights.size(); I != E; ++I) { - if (SI->getSuccessor(I) == ContiguousDest) + if (SI->getSuccessor(I) == Dest) TrueWeight += Weights[I]; else FalseWeight += Weights[I]; @@ -5868,15 +5948,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, } // Prune obsolete incoming values off the successors' PHI nodes. - for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) { - unsigned PreviousEdges = ContiguousCases->size(); - if (ContiguousDest == SI->getDefaultDest()) + for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) { + unsigned PreviousEdges = Cases->size(); + if (Dest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); } for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) { - unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size(); + unsigned PreviousEdges = OtherCases->size(); if (OtherDest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) @@ -7877,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { BasicBlock *BB = IBI->getParent(); bool Changed = false; + SmallVector<uint32_t> BranchWeights; + const bool HasBranchWeights = !ProfcheckDisableMetadataFixes && + extractBranchWeights(*IBI, BranchWeights); + + DenseMap<const BasicBlock *, uint64_t> TargetWeight; + if (HasBranchWeights) + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + TargetWeight[IBI->getDestination(I)] += BranchWeights[I]; // Eliminate redundant destinations. SmallPtrSet<Value *, 8> Succs; SmallSetVector<BasicBlock *, 8> RemovedSuccs; - for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { - BasicBlock *Dest = IBI->getDestination(i); + for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) { + BasicBlock *Dest = IBI->getDestination(I); if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { if (!Dest->hasAddressTaken()) RemovedSuccs.insert(Dest); Dest->removePredecessor(BB); - IBI->removeDestination(i); - --i; - --e; + IBI->removeDestination(I); + --I; + --E; Changed = true; } } @@ -7915,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { eraseTerminatorAndDCECond(IBI); return true; } - + if (HasBranchWeights) { + SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations()); + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second; + setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false); + } if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (simplifyIndirectBrOnSelect(IBI, SI)) return requestResimplify(); |