aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils')
-rw-r--r--llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp51
-rw-r--r--llvm/lib/Transforms/Utils/BuildLibCalls.cpp6
-rw-r--r--llvm/lib/Transforms/Utils/BypassSlowDivision.cpp32
-rw-r--r--llvm/lib/Transforms/Utils/CodeExtractor.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp52
-rw-r--r--llvm/lib/Transforms/Utils/LoopSimplify.cpp60
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp21
-rw-r--r--llvm/lib/Transforms/Utils/LoopVersioning.cpp9
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp173
9 files changed, 288 insertions, 117 deletions
diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
index 42b1fdf..8aa8aa2 100644
--- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
+++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
@@ -39,36 +39,36 @@ using namespace llvm;
STATISTIC(NumBroken, "Number of blocks inserted");
namespace {
- struct BreakCriticalEdges : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- BreakCriticalEdges() : FunctionPass(ID) {
- initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry());
- }
+struct BreakCriticalEdges : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ BreakCriticalEdges() : FunctionPass(ID) {
+ initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry());
+ }
- bool runOnFunction(Function &F) override {
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
+ bool runOnFunction(Function &F) override {
+ auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+ auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
- auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
+ auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
+ auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
- auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
- auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
- unsigned N =
- SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT));
- NumBroken += N;
- return N > 0;
- }
+ auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
+ auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
+ unsigned N = SplitAllCriticalEdges(
+ F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT));
+ NumBroken += N;
+ return N > 0;
+ }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
- // No loop canonicalization guarantees are broken by this pass.
- AU.addPreservedID(LoopSimplifyID);
- }
- };
-}
+ // No loop canonicalization guarantees are broken by this pass.
+ AU.addPreservedID(LoopSimplifyID);
+ }
+};
+} // namespace
char BreakCriticalEdges::ID = 0;
INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges",
@@ -76,6 +76,7 @@ INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges",
// Publicly exposed interface to pass...
char &llvm::BreakCriticalEdgesID = BreakCriticalEdges::ID;
+
FunctionPass *llvm::createBreakCriticalEdgesPass() {
return new BreakCriticalEdges();
}
diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
index 573a781..02b73e8 100644
--- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -1283,6 +1283,12 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F,
case LibFunc_ilogbl:
case LibFunc_logf:
case LibFunc_logl:
+ case LibFunc_nextafter:
+ case LibFunc_nextafterf:
+ case LibFunc_nextafterl:
+ case LibFunc_nexttoward:
+ case LibFunc_nexttowardf:
+ case LibFunc_nexttowardl:
case LibFunc_pow:
case LibFunc_powf:
case LibFunc_powl:
diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
index 7343c79..9f6d89e 100644
--- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
+++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp
@@ -40,22 +40,22 @@ using namespace llvm;
namespace {
- struct QuotRemPair {
- Value *Quotient;
- Value *Remainder;
-
- QuotRemPair(Value *InQuotient, Value *InRemainder)
- : Quotient(InQuotient), Remainder(InRemainder) {}
- };
-
- /// A quotient and remainder, plus a BB from which they logically "originate".
- /// If you use Quotient or Remainder in a Phi node, you should use BB as its
- /// corresponding predecessor.
- struct QuotRemWithBB {
- BasicBlock *BB = nullptr;
- Value *Quotient = nullptr;
- Value *Remainder = nullptr;
- };
+struct QuotRemPair {
+ Value *Quotient;
+ Value *Remainder;
+
+ QuotRemPair(Value *InQuotient, Value *InRemainder)
+ : Quotient(InQuotient), Remainder(InRemainder) {}
+};
+
+/// A quotient and remainder, plus a BB from which they logically "originate".
+/// If you use Quotient or Remainder in a Phi node, you should use BB as its
+/// corresponding predecessor.
+struct QuotRemWithBB {
+ BasicBlock *BB = nullptr;
+ Value *Quotient = nullptr;
+ Value *Remainder = nullptr;
+};
using DivCacheTy = DenseMap<DivRemMapKey, QuotRemPair>;
using BypassWidthsTy = DenseMap<unsigned, unsigned>;
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 5ba6f95f..6086615 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -933,6 +933,7 @@ Function *CodeExtractor::constructFunctionDeclaration(
case Attribute::CoroDestroyOnlyWhenComplete:
case Attribute::CoroElideSafe:
case Attribute::NoDivergenceSource:
+ case Attribute::NoCreateUndefOrPoison:
continue;
// Those attributes should be safe to propagate to the extracted function.
case Attribute::AlwaysInline:
diff --git a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp
index 0642d51..dd8706c 100644
--- a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp
+++ b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp
@@ -16,22 +16,62 @@
using namespace llvm;
+static void mergeAttributes(LLVMContext &Ctx, const Module &M,
+ const DataLayout &DL, const Triple &TT,
+ Function *Func, FunctionType *FuncTy,
+ AttributeList FuncAttrs) {
+ AttributeList OldAttrs = Func->getAttributes();
+ AttributeList NewAttrs = OldAttrs;
+
+ {
+ AttrBuilder OldBuilder(Ctx, OldAttrs.getFnAttrs());
+ AttrBuilder NewBuilder(Ctx, FuncAttrs.getFnAttrs());
+ OldBuilder.merge(NewBuilder);
+ NewAttrs = NewAttrs.addFnAttributes(Ctx, OldBuilder);
+ }
+
+ {
+ AttrBuilder OldBuilder(Ctx, OldAttrs.getRetAttrs());
+ AttrBuilder NewBuilder(Ctx, FuncAttrs.getRetAttrs());
+ OldBuilder.merge(NewBuilder);
+ NewAttrs = NewAttrs.addRetAttributes(Ctx, OldBuilder);
+ }
+
+ for (unsigned I = 0, E = FuncTy->getNumParams(); I != E; ++I) {
+ AttrBuilder OldBuilder(Ctx, OldAttrs.getParamAttrs(I));
+ AttrBuilder NewBuilder(Ctx, FuncAttrs.getParamAttrs(I));
+ OldBuilder.merge(NewBuilder);
+ NewAttrs = NewAttrs.addParamAttributes(Ctx, I, OldBuilder);
+ }
+
+ Func->setAttributes(NewAttrs);
+}
+
PreservedAnalyses DeclareRuntimeLibcallsPass::run(Module &M,
ModuleAnalysisManager &MAM) {
RTLIB::RuntimeLibcallsInfo RTLCI(M.getTargetTriple());
LLVMContext &Ctx = M.getContext();
+ const DataLayout &DL = M.getDataLayout();
+ const Triple &TT = M.getTargetTriple();
- for (RTLIB::LibcallImpl Impl : RTLCI.getLibcallImpls()) {
- if (Impl == RTLIB::Unsupported)
+ for (RTLIB::LibcallImpl Impl : RTLIB::libcall_impls()) {
+ if (!RTLCI.isAvailable(Impl))
continue;
- // TODO: Declare with correct type, calling convention, and attributes.
+ auto [FuncTy, FuncAttrs] = RTLCI.getFunctionTy(Ctx, TT, DL, Impl);
- FunctionType *FuncTy =
- FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true);
+ // TODO: Declare with correct type, calling convention, and attributes.
+ if (!FuncTy)
+ FuncTy = FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true);
StringRef FuncName = RTLCI.getLibcallImplName(Impl);
- M.getOrInsertFunction(FuncName, FuncTy);
+
+ Function *Func =
+ cast<Function>(M.getOrInsertFunction(FuncName, FuncTy).getCallee());
+ if (Func->getFunctionType() == FuncTy) {
+ mergeAttributes(Ctx, M, DL, TT, Func, FuncTy, FuncAttrs);
+ Func->setCallingConv(RTLCI.getLibcallImplCallingConv(Impl));
+ }
}
return PreservedAnalyses::none();
diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
index 61ffb49..8da6a980 100644
--- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp
+++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp
@@ -378,7 +378,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader,
if (P != Preheader) BackedgeBlocks.push_back(P);
}
- // Create and insert the new backedge block...
+ // Create and insert the new backedge block.
BasicBlock *BEBlock = BasicBlock::Create(Header->getContext(),
Header->getName() + ".backedge", F);
BranchInst *BETerminator = BranchInst::Create(Header, BEBlock);
@@ -737,39 +737,39 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI,
}
namespace {
- struct LoopSimplify : public FunctionPass {
- static char ID; // Pass identification, replacement for typeid
- LoopSimplify() : FunctionPass(ID) {
- initializeLoopSimplifyPass(*PassRegistry::getPassRegistry());
- }
+struct LoopSimplify : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+ LoopSimplify() : FunctionPass(ID) {
+ initializeLoopSimplifyPass(*PassRegistry::getPassRegistry());
+ }
- bool runOnFunction(Function &F) override;
+ bool runOnFunction(Function &F) override;
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<AssumptionCacheTracker>();
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AssumptionCacheTracker>();
- // We need loop information to identify the loops...
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
+ // We need loop information to identify the loops.
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
- AU.addPreserved<BasicAAWrapperPass>();
- AU.addPreserved<AAResultsWrapperPass>();
- AU.addPreserved<GlobalsAAWrapperPass>();
- AU.addPreserved<ScalarEvolutionWrapperPass>();
- AU.addPreserved<SCEVAAWrapperPass>();
- AU.addPreservedID(LCSSAID);
- AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added.
- AU.addPreserved<BranchProbabilityInfoWrapperPass>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
+ AU.addPreserved<BasicAAWrapperPass>();
+ AU.addPreserved<AAResultsWrapperPass>();
+ AU.addPreserved<GlobalsAAWrapperPass>();
+ AU.addPreserved<ScalarEvolutionWrapperPass>();
+ AU.addPreserved<SCEVAAWrapperPass>();
+ AU.addPreservedID(LCSSAID);
+ AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added.
+ AU.addPreserved<BranchProbabilityInfoWrapperPass>();
+ AU.addPreserved<MemorySSAWrapperPass>();
+ }
- /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees.
- void verifyAnalysis() const override;
- };
-}
+ /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees.
+ void verifyAnalysis() const override;
+};
+} // namespace
char LoopSimplify::ID = 0;
INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify",
@@ -780,12 +780,12 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", "Canonicalize natural loops",
false, false)
-// Publicly exposed interface to pass...
+// Publicly exposed interface to pass.
char &llvm::LoopSimplifyID = LoopSimplify::ID;
Pass *llvm::createLoopSimplifyPass() { return new LoopSimplify(); }
/// runOnFunction - Run down all loops in the CFG (recursively, but we could do
-/// it in any convenient order) inserting preheaders...
+/// it in any convenient order) inserting preheaders.
///
bool LoopSimplify::runOnFunction(Function &F) {
bool Changed = false;
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 1e8f6cc..6c9467b 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -202,6 +202,27 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
/// probability of executing at least one more iteration?
static BranchProbability
probOfNextInRemainder(BranchProbability OriginalLoopProb, unsigned N) {
+ // OriginalLoopProb == 1 would produce a division by zero in the calculation
+ // below. The problem is that case indicates an always infinite loop, but a
+ // remainder loop cannot be calculated at run time if the original loop is
+ // infinite as infinity % UnrollCount is undefined. We then choose
+ // probabilities indicating that all remainder loop iterations will always
+ // execute.
+ //
+ // Currently, the remainder loop here is an epilogue, which cannot be reached
+ // if the original loop is infinite, so the aforementioned choice is
+ // arbitrary.
+ //
+ // FIXME: Branch weights still need to be fixed in the case of prologues
+ // (issue #135812). In that case, the aforementioned choice seems reasonable
+ // for the goal of maintaining the original loop's block frequencies. That
+ // is, an infinite loop's initial iterations are not skipped, and the prologue
+ // loop body might have unique blocks that execute a finite number of times
+ // if, for example, the original loop body contains conditionals like i <
+ // UnrollCount.
+ if (OriginalLoopProb == BranchProbability::getOne())
+ return BranchProbability::getOne();
+
// Each of these variables holds the original loop's probability that the
// number of iterations it will execute is some m in the specified range.
BranchProbability ProbOne = OriginalLoopProb; // 1 <= m
diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index ec2e6c1..9c8b6ef 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -23,6 +23,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
@@ -109,8 +110,12 @@ void LoopVersioning::versionLoop(
// Insert the conditional branch based on the result of the memchecks.
Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
Builder.SetInsertPoint(OrigTerm);
- Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
- VersionedLoop->getLoopPreheader());
+ auto *BI =
+ Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
+ VersionedLoop->getLoopPreheader());
+ // We don't know what the probability of executing the versioned vs the
+ // unversioned variants is.
+ setExplicitlyUnknownBranchWeightsIfProfiled(*BI, DEBUG_TYPE);
OrigTerm->eraseFromParent();
// The loops merge in the original exit block. This is now dominated by the
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index cbc604e..37c048f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -778,8 +778,10 @@ private:
return false;
// Add all values from the range to the set
- for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp)
+ APInt Tmp = Span.getLower();
+ do
Vals.push_back(ConstantInt::get(I->getContext(), Tmp));
+ while (++Tmp != Span.getUpper());
UsedICmps++;
return true;
@@ -5212,8 +5214,7 @@ bool SimplifyCFGOpt::simplifyBranchOnICmpChain(BranchInst *BI,
// We don't have any info about this condition.
auto *Br = TrueWhenEqual ? Builder.CreateCondBr(ExtraCase, EdgeBB, NewBB)
: Builder.CreateCondBr(ExtraCase, NewBB, EdgeBB);
- setExplicitlyUnknownBranchWeightsIfProfiled(*Br, *NewBB->getParent(),
- DEBUG_TYPE);
+ setExplicitlyUnknownBranchWeightsIfProfiled(*Br, DEBUG_TYPE);
OldTI->eraseFromParent();
@@ -6020,6 +6021,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
const DataLayout &DL) {
Value *Cond = SI->getCondition();
KnownBits Known = computeKnownBits(Cond, DL, AC, SI);
+ SmallPtrSet<const Constant *, 4> KnownValues;
+ bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4);
// We can also eliminate cases by determining that their values are outside of
// the limited range of the condition based on how many significant (non-sign)
@@ -6039,15 +6042,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
UniqueSuccessors.push_back(Successor);
++It->second;
}
- const APInt &CaseVal = Case.getCaseValue()->getValue();
+ ConstantInt *CaseC = Case.getCaseValue();
+ const APInt &CaseVal = CaseC->getValue();
if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) ||
- (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) {
- DeadCases.push_back(Case.getCaseValue());
+ (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) ||
+ (IsKnownValuesValid && !KnownValues.contains(CaseC))) {
+ DeadCases.push_back(CaseC);
if (DTU)
--NumPerSuccessorCases[Successor];
LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal
<< " is dead.\n");
- }
+ } else if (IsKnownValuesValid)
+ KnownValues.erase(CaseC);
}
// If we can prove that the cases must cover all possible values, the
@@ -6058,33 +6064,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
const unsigned NumUnknownBits =
Known.getBitWidth() - (Known.Zero | Known.One).popcount();
assert(NumUnknownBits <= Known.getBitWidth());
- if (HasDefault && DeadCases.empty() &&
- NumUnknownBits < 64 /* avoid overflow */) {
- uint64_t AllNumCases = 1ULL << NumUnknownBits;
- if (SI->getNumCases() == AllNumCases) {
+ if (HasDefault && DeadCases.empty()) {
+ if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) {
createUnreachableSwitchDefault(SI, DTU);
return true;
}
- // When only one case value is missing, replace default with that case.
- // Eliminating the default branch will provide more opportunities for
- // optimization, such as lookup tables.
- if (SI->getNumCases() == AllNumCases - 1) {
- assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
- IntegerType *CondTy = cast<IntegerType>(Cond->getType());
- if (CondTy->getIntegerBitWidth() > 64 ||
- !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
- return false;
- uint64_t MissingCaseVal = 0;
- for (const auto &Case : SI->cases())
- MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
- auto *MissingCase =
- cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
- SwitchInstProfUpdateWrapper SIW(*SI);
- SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0));
- createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false);
- SIW.setSuccessorWeight(0, 0);
- return true;
+ if (NumUnknownBits < 64 /* avoid overflow */) {
+ uint64_t AllNumCases = 1ULL << NumUnknownBits;
+ if (SI->getNumCases() == AllNumCases) {
+ createUnreachableSwitchDefault(SI, DTU);
+ return true;
+ }
+ // When only one case value is missing, replace default with that case.
+ // Eliminating the default branch will provide more opportunities for
+ // optimization, such as lookup tables.
+ if (SI->getNumCases() == AllNumCases - 1) {
+ assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
+ IntegerType *CondTy = cast<IntegerType>(Cond->getType());
+ if (CondTy->getIntegerBitWidth() > 64 ||
+ !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+ return false;
+
+ uint64_t MissingCaseVal = 0;
+ for (const auto &Case : SI->cases())
+ MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
+ auto *MissingCase = cast<ConstantInt>(
+ ConstantInt::get(Cond->getType(), MissingCaseVal));
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ SIW.addCase(MissingCase, SI->getDefaultDest(),
+ SIW.getSuccessorWeight(0));
+ createUnreachableSwitchDefault(SI, DTU,
+ /*RemoveOrigDefaultBlock*/ false);
+ SIW.setSuccessorWeight(0, 0);
+ return true;
+ }
}
}
@@ -7570,6 +7584,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Tries to transform the switch when the condition is umin with a constant.
+/// In that case, the default branch can be replaced by the constant's branch.
+/// This method also removes dead cases when the simplification cannot replace
+/// the default branch.
+///
+/// For example:
+/// switch(umin(a, 3)) {
+/// case 0:
+/// case 1:
+/// case 2:
+/// case 3:
+/// case 4:
+/// // ...
+/// default:
+/// unreachable
+/// }
+///
+/// Transforms into:
+///
+/// switch(a) {
+/// case 0:
+/// case 1:
+/// case 2:
+/// default:
+/// // This is case 3
+/// }
+static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) {
+ Value *A;
+ ConstantInt *Constant;
+
+ if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant))))
+ return false;
+
+ SmallVector<DominatorTree::UpdateType> Updates;
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ BasicBlock *BB = SIW->getParent();
+
+ // Dead cases are removed even when the simplification fails.
+ // A case is dead when its value is higher than the Constant.
+ for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) {
+ if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) {
+ ++I;
+ continue;
+ }
+ BasicBlock *DeadCaseBB = I->getCaseSuccessor();
+ DeadCaseBB->removePredecessor(BB);
+ Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB});
+ I = SIW->removeCase(I);
+ E = SIW->case_end();
+ }
+
+ auto Case = SI->findCaseValue(Constant);
+ // If the case value is not found, `findCaseValue` returns the default case.
+ // In this scenario, since there is no explicit `case 3:`, the simplification
+ // fails. The simplification also fails when the switch’s default destination
+ // is reachable.
+ if (!SI->defaultDestUnreachable() || Case == SI->case_default()) {
+ if (DTU)
+ DTU->applyUpdates(Updates);
+ return !Updates.empty();
+ }
+
+ BasicBlock *Unreachable = SI->getDefaultDest();
+ SIW.replaceDefaultDest(Case);
+ SIW.removeCase(Case);
+ SIW->setCondition(A);
+
+ Updates.push_back({DominatorTree::Delete, BB, Unreachable});
+
+ if (DTU)
+ DTU->applyUpdates(Updates);
+
+ return true;
+}
+
/// Tries to transform switch of powers of two to reduce switch range.
/// For example, switch like:
/// switch (C) { case 1: case 2: case 64: case 128: }
@@ -7642,19 +7731,24 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
// label. The other is those powers of 2 that don't appear in the case
// statement. We don't know the distribution of the values coming in, so
// the safest is to split 50-50 the original probability to `default`.
- uint64_t OrigDenominator = sum_of(map_range(
- Weights, [](const auto &V) { return static_cast<uint64_t>(V); }));
+ uint64_t OrigDenominator =
+ sum_of(map_range(Weights, StaticCastTo<uint64_t>));
SmallVector<uint64_t> NewWeights(2);
NewWeights[1] = Weights[0] / 2;
NewWeights[0] = OrigDenominator - NewWeights[1];
setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false);
-
- // For the original switch, we reduce the weight of the default by the
- // amount by which the previous branch contributes to getting to default,
- // and then make sure the remaining weights have the same relative ratio
- // wrt eachother.
+ // The probability of executing the default block stays constant. It was
+ // p_d = Weights[0] / OrigDenominator
+ // we rewrite as W/D
+ // We want to find the probability of the default branch of the switch
+ // statement. Let's call it X. We have W/D = W/2D + X * (1-W/2D)
+ // i.e. the original probability is the probability we go to the default
+ // branch from the BI branch, or we take the default branch on the SI.
+ // Meaning X = W / (2D - W), or (W/2) / (D - W/2)
+ // This matches using W/2 for the default branch probability numerator and
+ // D-W/2 as the denominator.
+ Weights[0] = NewWeights[1];
uint64_t CasesDenominator = OrigDenominator - Weights[0];
- Weights[0] /= 2;
for (auto &W : drop_begin(Weights))
W = NewWeights[0] * static_cast<double>(W) / CasesDenominator;
@@ -8037,6 +8131,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (simplifyDuplicateSwitchArms(SI, DTU))
return requestResimplify();
+ if (simplifySwitchWhenUMin(SI, DTU))
+ return requestResimplify();
+
return false;
}