aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils')
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp39
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp189
2 files changed, 163 insertions, 65 deletions
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index af216cd..9693ae6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -317,24 +317,29 @@ static Value *simplifyInstruction(SCCPSolver &Solver,
// Early exit if we know nothing about X.
if (LRange.isFullSet())
return nullptr;
- // We are allowed to refine the comparison to either true or false for out
- // of range inputs. Here we refine the comparison to true, i.e. we relax
- // the range check.
- auto NewCR = CR->exactUnionWith(LRange.inverse());
- // TODO: Check if we can narrow the range check to an equality test.
- // E.g, for X in [0, 4), X - 3 u< 2 -> X == 3
- if (!NewCR)
+ auto ConvertCRToICmp =
+ [&](const std::optional<ConstantRange> &NewCR) -> Value * {
+ ICmpInst::Predicate Pred;
+ APInt RHS;
+ // Check if we can represent NewCR as an icmp predicate.
+ if (NewCR && NewCR->getEquivalentICmp(Pred, RHS)) {
+ IRBuilder<NoFolder> Builder(&Inst);
+ Value *NewICmp =
+ Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
+ InsertedValues.insert(NewICmp);
+ return NewICmp;
+ }
return nullptr;
- ICmpInst::Predicate Pred;
- APInt RHS;
- // Check if we can represent NewCR as an icmp predicate.
- if (NewCR->getEquivalentICmp(Pred, RHS)) {
- IRBuilder<NoFolder> Builder(&Inst);
- Value *NewICmp =
- Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
- InsertedValues.insert(NewICmp);
- return NewICmp;
- }
+ };
+ // We are allowed to refine the comparison to either true or false for out
+ // of range inputs.
+ // Here we refine the comparison to false, and check if we can narrow the
+ // range check to a simpler test.
+ if (auto *V = ConvertCRToICmp(CR->exactIntersectWith(LRange)))
+ return V;
+ // Here we refine the comparison to true, i.e. we relax the range check.
+ if (auto *V = ConvertCRToICmp(CR->exactUnionWith(LRange.inverse())))
+ return V;
}
}
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 48055ad..b8cfe3a 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm,
// We found both of the successors we were looking for.
// Create a conditional branch sharing the condition of the select.
BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
- if (TrueWeight != FalseWeight)
- setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
- /*IsExpected=*/false, /*ElideAllZero=*/true);
+ setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
+ /*IsExpected=*/false, /*ElideAllZero=*/true);
}
} else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
// Neither of the selected blocks were successors, so this
@@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI,
BasicBlock *TrueBB = TBA->getBasicBlock();
BasicBlock *FalseBB = FBA->getBasicBlock();
+ // The select's profile becomes the profile of the conditional branch that
+ // replaces the indirect branch.
+ SmallVector<uint32_t> SelectBranchWeights(2);
+ if (!ProfcheckDisableMetadataFixes)
+ extractBranchWeights(*SI, SelectBranchWeights);
// Perform the actual simplification.
- return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0,
- 0);
+ return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB,
+ SelectBranchWeights[0],
+ SelectBranchWeights[1]);
}
/// This is called when we find an icmp instruction
@@ -5734,15 +5739,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}
-static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
+struct ContiguousCasesResult {
+ ConstantInt *Min;
+ ConstantInt *Max;
+ BasicBlock *Dest;
+ BasicBlock *OtherDest;
+ SmallVectorImpl<ConstantInt *> *Cases;
+ SmallVectorImpl<ConstantInt *> *OtherCases;
+};
+
+static std::optional<ContiguousCasesResult>
+findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
+ SmallVectorImpl<ConstantInt *> &OtherCases,
+ BasicBlock *Dest, BasicBlock *OtherDest) {
assert(Cases.size() >= 1);
array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
- for (size_t I = 1, E = Cases.size(); I != E; ++I) {
- if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
- return false;
+ const APInt &Min = Cases.back()->getValue();
+ const APInt &Max = Cases.front()->getValue();
+ APInt Offset = Max - Min;
+ size_t ContiguousOffset = Cases.size() - 1;
+ if (Offset == ContiguousOffset) {
+ return ContiguousCasesResult{
+ /*Min=*/Cases.back(),
+ /*Max=*/Cases.front(),
+ /*Dest=*/Dest,
+ /*OtherDest=*/OtherDest,
+ /*Cases=*/&Cases,
+ /*OtherCases=*/&OtherCases,
+ };
}
- return true;
+ ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
+ // If this is a wrapping contiguous range, that is, [Min, OtherMin] +
+ // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
+ // contiguous range for the other destination. N.B. If CR is not a full range,
+ // Max+1 is not equal to Min. It's not continuous in arithmetic.
+ if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
+ assert(Cases.size() >= 2);
+ auto *It =
+ std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
+ return L->getValue() != R->getValue() + 1;
+ });
+ if (It == Cases.end())
+ return std::nullopt;
+ auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
+ if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
+ Cases.size() - 2) {
+ return ContiguousCasesResult{
+ /*Min=*/cast<ConstantInt>(
+ ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
+ /*Max=*/
+ cast<ConstantInt>(
+ ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
+ /*Dest=*/OtherDest,
+ /*OtherDest=*/Dest,
+ /*Cases=*/&OtherCases,
+ /*OtherCases=*/&Cases,
+ };
+ }
+ }
+ return std::nullopt;
}
static void createUnreachableSwitchDefault(SwitchInst *Switch,
@@ -5779,7 +5835,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
bool HasDefault = !SI->defaultDestUnreachable();
auto *BB = SI->getParent();
-
// Partition the cases into two sets with different destinations.
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
BasicBlock *DestB = nullptr;
@@ -5813,37 +5868,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);
// Figure out if one of the sets of cases form a contiguous range.
- SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
- BasicBlock *ContiguousDest = nullptr;
- BasicBlock *OtherDest = nullptr;
- if (!CasesA.empty() && casesAreContiguous(CasesA)) {
- ContiguousCases = &CasesA;
- ContiguousDest = DestA;
- OtherDest = DestB;
- } else if (casesAreContiguous(CasesB)) {
- ContiguousCases = &CasesB;
- ContiguousDest = DestB;
- OtherDest = DestA;
- } else
- return false;
+ std::optional<ContiguousCasesResult> ContiguousCases;
+
+ // Only one icmp is needed when there is only one case.
+ if (!HasDefault && CasesA.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesA[0],
+ /*Max=*/CasesA[0],
+ /*Dest=*/DestA,
+ /*OtherDest=*/DestB,
+ /*Cases=*/&CasesA,
+ /*OtherCases=*/&CasesB,
+ };
+ else if (CasesB.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesB[0],
+ /*Max=*/CasesB[0],
+ /*Dest=*/DestB,
+ /*OtherDest=*/DestA,
+ /*Cases=*/&CasesB,
+ /*OtherCases=*/&CasesA,
+ };
+ // Correctness: Cases to the default destination cannot be contiguous cases.
+ else if (!HasDefault)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);
- // Start building the compare and branch.
+ if (!ContiguousCases)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);
+
+ if (!ContiguousCases)
+ return false;
- Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
- Constant *NumCases =
- ConstantInt::get(Offset->getType(), ContiguousCases->size());
+ auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
- Value *Sub = SI->getCondition();
- if (!Offset->isNullValue())
- Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ // Start building the compare and branch.
- Value *Cmp;
+ Constant *Offset = ConstantExpr::getNeg(Min);
+ Constant *NumCases = ConstantInt::get(Offset->getType(),
+ Max->getValue() - Min->getValue() + 1);
+ BranchInst *NewBI;
+ if (NumCases->isOneValue()) {
+ assert(Max->getValue() == Min->getValue());
+ Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// If NumCases overflowed, then all possible values jump to the successor.
- if (NumCases->isNullValue() && !ContiguousCases->empty())
- Cmp = ConstantInt::getTrue(SI->getContext());
- else
- Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
- BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
+ else if (NumCases->isNullValue() && !Cases->empty()) {
+ NewBI = Builder.CreateBr(Dest);
+ } else {
+ Value *Sub = SI->getCondition();
+ if (!Offset->isNullValue())
+ Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
@@ -5853,7 +5933,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
uint64_t TrueWeight = 0;
uint64_t FalseWeight = 0;
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
- if (SI->getSuccessor(I) == ContiguousDest)
+ if (SI->getSuccessor(I) == Dest)
TrueWeight += Weights[I];
else
FalseWeight += Weights[I];
@@ -5868,15 +5948,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
}
// Prune obsolete incoming values off the successors' PHI nodes.
- for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = ContiguousCases->size();
- if (ContiguousDest == SI->getDefaultDest())
+ for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
+ unsigned PreviousEdges = Cases->size();
+ if (Dest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
+ unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
@@ -7877,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
BasicBlock *BB = IBI->getParent();
bool Changed = false;
+ SmallVector<uint32_t> BranchWeights;
+ const bool HasBranchWeights = !ProfcheckDisableMetadataFixes &&
+ extractBranchWeights(*IBI, BranchWeights);
+
+ DenseMap<const BasicBlock *, uint64_t> TargetWeight;
+ if (HasBranchWeights)
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ TargetWeight[IBI->getDestination(I)] += BranchWeights[I];
// Eliminate redundant destinations.
SmallPtrSet<Value *, 8> Succs;
SmallSetVector<BasicBlock *, 8> RemovedSuccs;
- for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) {
- BasicBlock *Dest = IBI->getDestination(i);
+ for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) {
+ BasicBlock *Dest = IBI->getDestination(I);
if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) {
if (!Dest->hasAddressTaken())
RemovedSuccs.insert(Dest);
Dest->removePredecessor(BB);
- IBI->removeDestination(i);
- --i;
- --e;
+ IBI->removeDestination(I);
+ --I;
+ --E;
Changed = true;
}
}
@@ -7915,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
eraseTerminatorAndDCECond(IBI);
return true;
}
-
+ if (HasBranchWeights) {
+ SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations());
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second;
+ setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false);
+ }
if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) {
if (simplifyIndirectBrOnSelect(IBI, SI))
return requestResimplify();