diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils')
| -rw-r--r-- | llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp | 51 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/BuildLibCalls.cpp | 6 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/BypassSlowDivision.cpp | 32 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/CodeExtractor.cpp | 1 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp | 52 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/LoopSimplify.cpp | 60 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp | 21 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/LoopVersioning.cpp | 9 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 173 |
9 files changed, 288 insertions, 117 deletions
diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 42b1fdf..8aa8aa2 100644 --- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -39,36 +39,36 @@ using namespace llvm; STATISTIC(NumBroken, "Number of blocks inserted"); namespace { - struct BreakCriticalEdges : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - BreakCriticalEdges() : FunctionPass(ID) { - initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); - } +struct BreakCriticalEdges : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BreakCriticalEdges() : FunctionPass(ID) { + initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); + } - bool runOnFunction(Function &F) override { - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + bool runOnFunction(Function &F) override { + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); - auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); + auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; - auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); - auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - unsigned N = - SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT)); - NumBroken += N; - return N > 0; - } + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; + unsigned N = SplitAllCriticalEdges( + F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT)); + NumBroken += N; + return N > 0; + } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); - // No loop canonicalization guarantees are broken by this pass. - AU.addPreservedID(LoopSimplifyID); - } - }; -} + // No loop canonicalization guarantees are broken by this pass. + AU.addPreservedID(LoopSimplifyID); + } +}; +} // namespace char BreakCriticalEdges::ID = 0; INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges", @@ -76,6 +76,7 @@ INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges", // Publicly exposed interface to pass... char &llvm::BreakCriticalEdgesID = BreakCriticalEdges::ID; + FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 573a781..02b73e8 100644 --- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -1283,6 +1283,12 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_ilogbl: case LibFunc_logf: case LibFunc_logl: + case LibFunc_nextafter: + case LibFunc_nextafterf: + case LibFunc_nextafterl: + case LibFunc_nexttoward: + case LibFunc_nexttowardf: + case LibFunc_nexttowardl: case LibFunc_pow: case LibFunc_powf: case LibFunc_powl: diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 7343c79..9f6d89e 100644 --- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -40,22 +40,22 @@ using namespace llvm; namespace { - struct QuotRemPair { - Value *Quotient; - Value *Remainder; - - QuotRemPair(Value *InQuotient, Value *InRemainder) - : Quotient(InQuotient), Remainder(InRemainder) {} - }; - - /// A quotient and remainder, plus a BB from which they logically "originate". - /// If you use Quotient or Remainder in a Phi node, you should use BB as its - /// corresponding predecessor. - struct QuotRemWithBB { - BasicBlock *BB = nullptr; - Value *Quotient = nullptr; - Value *Remainder = nullptr; - }; +struct QuotRemPair { + Value *Quotient; + Value *Remainder; + + QuotRemPair(Value *InQuotient, Value *InRemainder) + : Quotient(InQuotient), Remainder(InRemainder) {} +}; + +/// A quotient and remainder, plus a BB from which they logically "originate". +/// If you use Quotient or Remainder in a Phi node, you should use BB as its +/// corresponding predecessor. +struct QuotRemWithBB { + BasicBlock *BB = nullptr; + Value *Quotient = nullptr; + Value *Remainder = nullptr; +}; using DivCacheTy = DenseMap<DivRemMapKey, QuotRemPair>; using BypassWidthsTy = DenseMap<unsigned, unsigned>; diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 5ba6f95f..6086615 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -933,6 +933,7 @@ Function *CodeExtractor::constructFunctionDeclaration( case Attribute::CoroDestroyOnlyWhenComplete: case Attribute::CoroElideSafe: case Attribute::NoDivergenceSource: + case Attribute::NoCreateUndefOrPoison: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: diff --git a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp index 0642d51..dd8706c 100644 --- a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp +++ b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp @@ -16,22 +16,62 @@ using namespace llvm; +static void mergeAttributes(LLVMContext &Ctx, const Module &M, + const DataLayout &DL, const Triple &TT, + Function *Func, FunctionType *FuncTy, + AttributeList FuncAttrs) { + AttributeList OldAttrs = Func->getAttributes(); + AttributeList NewAttrs = OldAttrs; + + { + AttrBuilder OldBuilder(Ctx, OldAttrs.getFnAttrs()); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getFnAttrs()); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addFnAttributes(Ctx, OldBuilder); + } + + { + AttrBuilder OldBuilder(Ctx, OldAttrs.getRetAttrs()); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getRetAttrs()); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addRetAttributes(Ctx, OldBuilder); + } + + for (unsigned I = 0, E = FuncTy->getNumParams(); I != E; ++I) { + AttrBuilder OldBuilder(Ctx, OldAttrs.getParamAttrs(I)); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getParamAttrs(I)); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addParamAttributes(Ctx, I, OldBuilder); + } + + Func->setAttributes(NewAttrs); +} + PreservedAnalyses DeclareRuntimeLibcallsPass::run(Module &M, ModuleAnalysisManager &MAM) { RTLIB::RuntimeLibcallsInfo RTLCI(M.getTargetTriple()); LLVMContext &Ctx = M.getContext(); + const DataLayout &DL = M.getDataLayout(); + const Triple &TT = M.getTargetTriple(); - for (RTLIB::LibcallImpl Impl : RTLCI.getLibcallImpls()) { - if (Impl == RTLIB::Unsupported) + for (RTLIB::LibcallImpl Impl : RTLIB::libcall_impls()) { + if (!RTLCI.isAvailable(Impl)) continue; - // TODO: Declare with correct type, calling convention, and attributes. + auto [FuncTy, FuncAttrs] = RTLCI.getFunctionTy(Ctx, TT, DL, Impl); - FunctionType *FuncTy = - FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true); + // TODO: Declare with correct type, calling convention, and attributes. + if (!FuncTy) + FuncTy = FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true); StringRef FuncName = RTLCI.getLibcallImplName(Impl); - M.getOrInsertFunction(FuncName, FuncTy); + + Function *Func = + cast<Function>(M.getOrInsertFunction(FuncName, FuncTy).getCallee()); + if (Func->getFunctionType() == FuncTy) { + mergeAttributes(Ctx, M, DL, TT, Func, FuncTy, FuncAttrs); + Func->setCallingConv(RTLCI.getLibcallImplCallingConv(Impl)); + } } return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 61ffb49..8da6a980 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -378,7 +378,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, if (P != Preheader) BackedgeBlocks.push_back(P); } - // Create and insert the new backedge block... + // Create and insert the new backedge block. BasicBlock *BEBlock = BasicBlock::Create(Header->getContext(), Header->getName() + ".backedge", F); BranchInst *BETerminator = BranchInst::Create(Header, BEBlock); @@ -737,39 +737,39 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, } namespace { - struct LoopSimplify : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - LoopSimplify() : FunctionPass(ID) { - initializeLoopSimplifyPass(*PassRegistry::getPassRegistry()); - } +struct LoopSimplify : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + LoopSimplify() : FunctionPass(ID) { + initializeLoopSimplifyPass(*PassRegistry::getPassRegistry()); + } - bool runOnFunction(Function &F) override; + bool runOnFunction(Function &F) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); - // We need loop information to identify the loops... - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); + // We need loop information to identify the loops. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); - AU.addPreservedID(LCSSAID); - AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. - AU.addPreserved<BranchProbabilityInfoWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addPreserved<SCEVAAWrapperPass>(); + AU.addPreservedID(LCSSAID); + AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. + AU.addPreserved<BranchProbabilityInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } - /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. - void verifyAnalysis() const override; - }; -} + /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. + void verifyAnalysis() const override; +}; +} // namespace char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", @@ -780,12 +780,12 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", "Canonicalize natural loops", false, false) -// Publicly exposed interface to pass... +// Publicly exposed interface to pass. char &llvm::LoopSimplifyID = LoopSimplify::ID; Pass *llvm::createLoopSimplifyPass() { return new LoopSimplify(); } /// runOnFunction - Run down all loops in the CFG (recursively, but we could do -/// it in any convenient order) inserting preheaders... +/// it in any convenient order) inserting preheaders. /// bool LoopSimplify::runOnFunction(Function &F) { bool Changed = false; diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 1e8f6cc..6c9467b 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -202,6 +202,27 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, /// probability of executing at least one more iteration? static BranchProbability probOfNextInRemainder(BranchProbability OriginalLoopProb, unsigned N) { + // OriginalLoopProb == 1 would produce a division by zero in the calculation + // below. The problem is that case indicates an always infinite loop, but a + // remainder loop cannot be calculated at run time if the original loop is + // infinite as infinity % UnrollCount is undefined. We then choose + // probabilities indicating that all remainder loop iterations will always + // execute. + // + // Currently, the remainder loop here is an epilogue, which cannot be reached + // if the original loop is infinite, so the aforementioned choice is + // arbitrary. + // + // FIXME: Branch weights still need to be fixed in the case of prologues + // (issue #135812). In that case, the aforementioned choice seems reasonable + // for the goal of maintaining the original loop's block frequencies. That + // is, an infinite loop's initial iterations are not skipped, and the prologue + // loop body might have unique blocks that execute a finite number of times + // if, for example, the original loop body contains conditionals like i < + // UnrollCount. + if (OriginalLoopProb == BranchProbability::getOne()) + return BranchProbability::getOne(); + // Each of these variables holds the original loop's probability that the // number of iterations it will execute is some m in the specified range. BranchProbability ProbOne = OriginalLoopProb; // 1 <= m diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index ec2e6c1..9c8b6ef 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -109,8 +110,12 @@ void LoopVersioning::versionLoop( // Insert the conditional branch based on the result of the memchecks. Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); Builder.SetInsertPoint(OrigTerm); - Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader()); + auto *BI = + Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(), + VersionedLoop->getLoopPreheader()); + // We don't know what the probability of executing the versioned vs the + // unversioned variants is. + setExplicitlyUnknownBranchWeightsIfProfiled(*BI, DEBUG_TYPE); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cbc604e..37c048f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -778,8 +778,10 @@ private: return false; // Add all values from the range to the set - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + APInt Tmp = Span.getLower(); + do Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + while (++Tmp != Span.getUpper()); UsedICmps++; return true; @@ -5212,8 +5214,7 @@ bool SimplifyCFGOpt::simplifyBranchOnICmpChain(BranchInst *BI, // We don't have any info about this condition. auto *Br = TrueWhenEqual ? Builder.CreateCondBr(ExtraCase, EdgeBB, NewBB) : Builder.CreateCondBr(ExtraCase, NewBB, EdgeBB); - setExplicitlyUnknownBranchWeightsIfProfiled(*Br, *NewBB->getParent(), - DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Br, DEBUG_TYPE); OldTI->eraseFromParent(); @@ -6020,6 +6021,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const DataLayout &DL) { Value *Cond = SI->getCondition(); KnownBits Known = computeKnownBits(Cond, DL, AC, SI); + SmallPtrSet<const Constant *, 4> KnownValues; + bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4); // We can also eliminate cases by determining that their values are outside of // the limited range of the condition based on how many significant (non-sign) @@ -6039,15 +6042,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, UniqueSuccessors.push_back(Successor); ++It->second; } - const APInt &CaseVal = Case.getCaseValue()->getValue(); + ConstantInt *CaseC = Case.getCaseValue(); + const APInt &CaseVal = CaseC->getValue(); if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || - (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) { - DeadCases.push_back(Case.getCaseValue()); + (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) || + (IsKnownValuesValid && !KnownValues.contains(CaseC))) { + DeadCases.push_back(CaseC); if (DTU) --NumPerSuccessorCases[Successor]; LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); - } + } else if (IsKnownValuesValid) + KnownValues.erase(CaseC); } // If we can prove that the cases must cover all possible values, the @@ -6058,33 +6064,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const unsigned NumUnknownBits = Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); - if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */) { - uint64_t AllNumCases = 1ULL << NumUnknownBits; - if (SI->getNumCases() == AllNumCases) { + if (HasDefault && DeadCases.empty()) { + if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) { createUnreachableSwitchDefault(SI, DTU); return true; } - // When only one case value is missing, replace default with that case. - // Eliminating the default branch will provide more opportunities for - // optimization, such as lookup tables. - if (SI->getNumCases() == AllNumCases - 1) { - assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); - IntegerType *CondTy = cast<IntegerType>(Cond->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - uint64_t MissingCaseVal = 0; - for (const auto &Case : SI->cases()) - MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); - auto *MissingCase = - cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal)); - SwitchInstProfUpdateWrapper SIW(*SI); - SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0)); - createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false); - SIW.setSuccessorWeight(0, 0); - return true; + if (NumUnknownBits < 64 /* avoid overflow */) { + uint64_t AllNumCases = 1ULL << NumUnknownBits; + if (SI->getNumCases() == AllNumCases) { + createUnreachableSwitchDefault(SI, DTU); + return true; + } + // When only one case value is missing, replace default with that case. + // Eliminating the default branch will provide more opportunities for + // optimization, such as lookup tables. + if (SI->getNumCases() == AllNumCases - 1) { + assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); + IntegerType *CondTy = cast<IntegerType>(Cond->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + uint64_t MissingCaseVal = 0; + for (const auto &Case : SI->cases()) + MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); + auto *MissingCase = cast<ConstantInt>( + ConstantInt::get(Cond->getType(), MissingCaseVal)); + SwitchInstProfUpdateWrapper SIW(*SI); + SIW.addCase(MissingCase, SI->getDefaultDest(), + SIW.getSuccessorWeight(0)); + createUnreachableSwitchDefault(SI, DTU, + /*RemoveOrigDefaultBlock*/ false); + SIW.setSuccessorWeight(0, 0); + return true; + } } } @@ -7570,6 +7584,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform the switch when the condition is umin with a constant. +/// In that case, the default branch can be replaced by the constant's branch. +/// This method also removes dead cases when the simplification cannot replace +/// the default branch. +/// +/// For example: +/// switch(umin(a, 3)) { +/// case 0: +/// case 1: +/// case 2: +/// case 3: +/// case 4: +/// // ... +/// default: +/// unreachable +/// } +/// +/// Transforms into: +/// +/// switch(a) { +/// case 0: +/// case 1: +/// case 2: +/// default: +/// // This is case 3 +/// } +static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) { + Value *A; + ConstantInt *Constant; + + if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant)))) + return false; + + SmallVector<DominatorTree::UpdateType> Updates; + SwitchInstProfUpdateWrapper SIW(*SI); + BasicBlock *BB = SIW->getParent(); + + // Dead cases are removed even when the simplification fails. + // A case is dead when its value is higher than the Constant. + for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) { + if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) { + ++I; + continue; + } + BasicBlock *DeadCaseBB = I->getCaseSuccessor(); + DeadCaseBB->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); + I = SIW->removeCase(I); + E = SIW->case_end(); + } + + auto Case = SI->findCaseValue(Constant); + // If the case value is not found, `findCaseValue` returns the default case. + // In this scenario, since there is no explicit `case 3:`, the simplification + // fails. The simplification also fails when the switch’s default destination + // is reachable. + if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { + if (DTU) + DTU->applyUpdates(Updates); + return !Updates.empty(); + } + + BasicBlock *Unreachable = SI->getDefaultDest(); + SIW.replaceDefaultDest(Case); + SIW.removeCase(Case); + SIW->setCondition(A); + + Updates.push_back({DominatorTree::Delete, BB, Unreachable}); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + /// Tries to transform switch of powers of two to reduce switch range. /// For example, switch like: /// switch (C) { case 1: case 2: case 64: case 128: } @@ -7642,19 +7731,24 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, // label. The other is those powers of 2 that don't appear in the case // statement. We don't know the distribution of the values coming in, so // the safest is to split 50-50 the original probability to `default`. - uint64_t OrigDenominator = sum_of(map_range( - Weights, [](const auto &V) { return static_cast<uint64_t>(V); })); + uint64_t OrigDenominator = + sum_of(map_range(Weights, StaticCastTo<uint64_t>)); SmallVector<uint64_t> NewWeights(2); NewWeights[1] = Weights[0] / 2; NewWeights[0] = OrigDenominator - NewWeights[1]; setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false); - - // For the original switch, we reduce the weight of the default by the - // amount by which the previous branch contributes to getting to default, - // and then make sure the remaining weights have the same relative ratio - // wrt eachother. + // The probability of executing the default block stays constant. It was + // p_d = Weights[0] / OrigDenominator + // we rewrite as W/D + // We want to find the probability of the default branch of the switch + // statement. Let's call it X. We have W/D = W/2D + X * (1-W/2D) + // i.e. the original probability is the probability we go to the default + // branch from the BI branch, or we take the default branch on the SI. + // Meaning X = W / (2D - W), or (W/2) / (D - W/2) + // This matches using W/2 for the default branch probability numerator and + // D-W/2 as the denominator. + Weights[0] = NewWeights[1]; uint64_t CasesDenominator = OrigDenominator - Weights[0]; - Weights[0] /= 2; for (auto &W : drop_begin(Weights)) W = NewWeights[0] * static_cast<double>(W) / CasesDenominator; @@ -8037,6 +8131,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (simplifyDuplicateSwitchArms(SI, DTU)) return requestResimplify(); + if (simplifySwitchWhenUMin(SI, DTU)) + return requestResimplify(); + return false; } |
