diff options
author | Nikita Popov <npopov@redhat.com> | 2024-08-22 16:57:09 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-22 16:57:09 +0200 |
commit | 4d85285ff68d11fcb8c6b296799a11074e7ff7d7 (patch) | |
tree | dd4ff1d9e68b6fb5ee2a72cde4c07f87d705705d /llvm/lib/Transforms/Utils/SimplifyCFG.cpp | |
parent | 58ac764b013606a67043cde6a287db3648d87582 (diff) | |
download | llvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.zip llvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.tar.gz llvm-4d85285ff68d11fcb8c6b296799a11074e7ff7d7.tar.bz2 |
[SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (#105636)
If we switch over ucmp/scmp and have two switch cases going to the same
destination, we can convert into icmp+br.
Fixes https://github.com/llvm/llvm-project/issues/105632.
Diffstat (limited to 'llvm/lib/Transforms/Utils/SimplifyCFG.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 00efd3c..da4d57f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have +/// the same destination. +static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder, + DomTreeUpdater *DTU) { + auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition()); + if (!Cmp || !Cmp->hasOneUse()) + return false; + + SmallVector<uint32_t, 4> Weights; + bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights); + if (!HasWeights) + Weights.resize(4); // Avoid checking HasWeights everywhere. + + // Normalize to [us]cmp == Res ? Succ : OtherSucc. + int64_t Res; + BasicBlock *Succ, *OtherSucc; + uint32_t SuccWeight = 0, OtherSuccWeight = 0; + BasicBlock *Unreachable = nullptr; + + if (SI->getNumCases() == 2) { + // Find which of 1, 0 or -1 is missing (handled by default dest). + SmallSet<int64_t, 3> Missing; + Missing.insert(1); + Missing.insert(0); + Missing.insert(-1); + + Succ = SI->getDefaultDest(); + SuccWeight = Weights[0]; + OtherSucc = nullptr; + for (auto &Case : SI->cases()) { + std::optional<int64_t> Val = + Case.getCaseValue()->getValue().trySExtValue(); + if (!Val) + return false; + if (!Missing.erase(*Val)) + return false; + if (OtherSucc && OtherSucc != Case.getCaseSuccessor()) + return false; + OtherSucc = Case.getCaseSuccessor(); + OtherSuccWeight += Weights[Case.getSuccessorIndex()]; + } + + assert(Missing.size() == 1 && "Should have one case left"); + Res = *Missing.begin(); + } else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) { + // Normalize so that Succ is taken once and OtherSucc twice. + Unreachable = SI->getDefaultDest(); + Succ = OtherSucc = nullptr; + for (auto &Case : SI->cases()) { + BasicBlock *NewSucc = Case.getCaseSuccessor(); + uint32_t Weight = Weights[Case.getSuccessorIndex()]; + if (!OtherSucc || OtherSucc == NewSucc) { + OtherSucc = NewSucc; + OtherSuccWeight += Weight; + } else if (!Succ) { + Succ = NewSucc; + SuccWeight = Weight; + } else if (Succ == NewSucc) { + std::swap(Succ, OtherSucc); + std::swap(SuccWeight, OtherSuccWeight); + } else + return false; + } + for (auto &Case : SI->cases()) { + std::optional<int64_t> Val = + Case.getCaseValue()->getValue().trySExtValue(); + if (!Val || (Val != 1 && Val != 0 && Val != -1)) + return false; + if (Case.getCaseSuccessor() == Succ) { + Res = *Val; + break; + } + } + } else { + return false; + } + + // Determine predicate for the missing case. + ICmpInst::Predicate Pred; + switch (Res) { + case 1: + Pred = ICmpInst::ICMP_UGT; + break; + case 0: + Pred = ICmpInst::ICMP_EQ; + break; + case -1: + Pred = ICmpInst::ICMP_ULT; + break; + } + if (Cmp->isSigned()) + Pred = ICmpInst::getSignedPredicate(Pred); + + MDNode *NewWeights = nullptr; + if (HasWeights) + NewWeights = MDBuilder(SI->getContext()) + .createBranchWeights(SuccWeight, OtherSuccWeight); + + BasicBlock *BB = SI->getParent(); + Builder.SetInsertPoint(SI->getIterator()); + Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS()); + Builder.CreateCondBr(ICmp, Succ, OtherSucc, NewWeights, + SI->getMetadata(LLVMContext::MD_unpredictable)); + OtherSucc->removePredecessor(BB); + if (Unreachable) + Unreachable->removePredecessor(BB); + SI->eraseFromParent(); + Cmp->eraseFromParent(); + if (DTU && Unreachable) + DTU->applyUpdates({{DominatorTree::Delete, BB, Unreachable}}); + return true; +} + bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { BasicBlock *BB = SI->getParent(); @@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL)) return requestResimplify(); + if (simplifySwitchOfCmpIntrinsic(SI, Builder, DTU)) + return requestResimplify(); + if (trySwitchToSelect(SI, Builder, DTU, DL, TTI)) return requestResimplify(); |