aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
diff options
context:
space:
mode:
authorNikita Popov <npopov@redhat.com>2024-08-22 16:57:09 +0200
committerGitHub <noreply@github.com>2024-08-22 16:57:09 +0200
commit4d85285ff68d11fcb8c6b296799a11074e7ff7d7 (patch)
treedd4ff1d9e68b6fb5ee2a72cde4c07f87d705705d /llvm/lib/Transforms/Utils/SimplifyCFG.cpp
parent58ac764b013606a67043cde6a287db3648d87582 (diff)
downloadllvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.zip
llvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.tar.gz
llvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.tar.bz2
[SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (#105636)
If we switch over ucmp/scmp and have two switch cases going to the same destination, we can convert into icmp+br. Fixes https://github.com/llvm/llvm-project/issues/105632.
Diffstat (limited to 'llvm/lib/Transforms/Utils/SimplifyCFG.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp116
1 files changed, 116 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 00efd3c..da4d57f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
+/// the same destination.
+static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
+ DomTreeUpdater *DTU) {
+ auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
+ if (!Cmp || !Cmp->hasOneUse())
+ return false;
+
+ SmallVector<uint32_t, 4> Weights;
+ bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights);
+ if (!HasWeights)
+ Weights.resize(4); // Avoid checking HasWeights everywhere.
+
+ // Normalize to [us]cmp == Res ? Succ : OtherSucc.
+ int64_t Res;
+ BasicBlock *Succ, *OtherSucc;
+ uint32_t SuccWeight = 0, OtherSuccWeight = 0;
+ BasicBlock *Unreachable = nullptr;
+
+ if (SI->getNumCases() == 2) {
+ // Find which of 1, 0 or -1 is missing (handled by default dest).
+ SmallSet<int64_t, 3> Missing;
+ Missing.insert(1);
+ Missing.insert(0);
+ Missing.insert(-1);
+
+ Succ = SI->getDefaultDest();
+ SuccWeight = Weights[0];
+ OtherSucc = nullptr;
+ for (auto &Case : SI->cases()) {
+ std::optional<int64_t> Val =
+ Case.getCaseValue()->getValue().trySExtValue();
+ if (!Val)
+ return false;
+ if (!Missing.erase(*Val))
+ return false;
+ if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
+ return false;
+ OtherSucc = Case.getCaseSuccessor();
+ OtherSuccWeight += Weights[Case.getSuccessorIndex()];
+ }
+
+ assert(Missing.size() == 1 && "Should have one case left");
+ Res = *Missing.begin();
+ } else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) {
+ // Normalize so that Succ is taken once and OtherSucc twice.
+ Unreachable = SI->getDefaultDest();
+ Succ = OtherSucc = nullptr;
+ for (auto &Case : SI->cases()) {
+ BasicBlock *NewSucc = Case.getCaseSuccessor();
+ uint32_t Weight = Weights[Case.getSuccessorIndex()];
+ if (!OtherSucc || OtherSucc == NewSucc) {
+ OtherSucc = NewSucc;
+ OtherSuccWeight += Weight;
+ } else if (!Succ) {
+ Succ = NewSucc;
+ SuccWeight = Weight;
+ } else if (Succ == NewSucc) {
+ std::swap(Succ, OtherSucc);
+ std::swap(SuccWeight, OtherSuccWeight);
+ } else
+ return false;
+ }
+ for (auto &Case : SI->cases()) {
+ std::optional<int64_t> Val =
+ Case.getCaseValue()->getValue().trySExtValue();
+ if (!Val || (Val != 1 && Val != 0 && Val != -1))
+ return false;
+ if (Case.getCaseSuccessor() == Succ) {
+ Res = *Val;
+ break;
+ }
+ }
+ } else {
+ return false;
+ }
+
+ // Determine predicate for the missing case.
+ ICmpInst::Predicate Pred;
+ switch (Res) {
+ case 1:
+ Pred = ICmpInst::ICMP_UGT;
+ break;
+ case 0:
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case -1:
+ Pred = ICmpInst::ICMP_ULT;
+ break;
+ }
+ if (Cmp->isSigned())
+ Pred = ICmpInst::getSignedPredicate(Pred);
+
+ MDNode *NewWeights = nullptr;
+ if (HasWeights)
+ NewWeights = MDBuilder(SI->getContext())
+ .createBranchWeights(SuccWeight, OtherSuccWeight);
+
+ BasicBlock *BB = SI->getParent();
+ Builder.SetInsertPoint(SI->getIterator());
+ Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
+ Builder.CreateCondBr(ICmp, Succ, OtherSucc, NewWeights,
+ SI->getMetadata(LLVMContext::MD_unpredictable));
+ OtherSucc->removePredecessor(BB);
+ if (Unreachable)
+ Unreachable->removePredecessor(BB);
+ SI->eraseFromParent();
+ Cmp->eraseFromParent();
+ if (DTU && Unreachable)
+ DTU->applyUpdates({{DominatorTree::Delete, BB, Unreachable}});
+ return true;
+}
+
bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
BasicBlock *BB = SI->getParent();
@@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
return requestResimplify();
+ if (simplifySwitchOfCmpIntrinsic(SI, Builder, DTU))
+ return requestResimplify();
+
if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
return requestResimplify();