diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/SimplifyCFG.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 149 |
1 files changed, 121 insertions, 28 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cbc604e..3a3e3ad 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -778,8 +778,10 @@ private: return false; // Add all values from the range to the set - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + APInt Tmp = Span.getLower(); + do Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + while (++Tmp != Span.getUpper()); UsedICmps++; return true; @@ -6020,6 +6022,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const DataLayout &DL) { Value *Cond = SI->getCondition(); KnownBits Known = computeKnownBits(Cond, DL, AC, SI); + SmallPtrSet<const Constant *, 4> KnownValues; + bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4); // We can also eliminate cases by determining that their values are outside of // the limited range of the condition based on how many significant (non-sign) @@ -6039,15 +6043,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, UniqueSuccessors.push_back(Successor); ++It->second; } - const APInt &CaseVal = Case.getCaseValue()->getValue(); + ConstantInt *CaseC = Case.getCaseValue(); + const APInt &CaseVal = CaseC->getValue(); if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || - (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) { - DeadCases.push_back(Case.getCaseValue()); + (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) || + (IsKnownValuesValid && !KnownValues.contains(CaseC))) { + DeadCases.push_back(CaseC); if (DTU) --NumPerSuccessorCases[Successor]; LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); - } + } else if (IsKnownValuesValid) + KnownValues.erase(CaseC); } // If we can prove that the cases must cover all possible values, the @@ -6058,33 +6065,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const unsigned NumUnknownBits = Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); - if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */) { - uint64_t AllNumCases = 1ULL << NumUnknownBits; - if (SI->getNumCases() == AllNumCases) { + if (HasDefault && DeadCases.empty()) { + if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) { createUnreachableSwitchDefault(SI, DTU); return true; } - // When only one case value is missing, replace default with that case. - // Eliminating the default branch will provide more opportunities for - // optimization, such as lookup tables. - if (SI->getNumCases() == AllNumCases - 1) { - assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); - IntegerType *CondTy = cast<IntegerType>(Cond->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - uint64_t MissingCaseVal = 0; - for (const auto &Case : SI->cases()) - MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); - auto *MissingCase = - cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal)); - SwitchInstProfUpdateWrapper SIW(*SI); - SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0)); - createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false); - SIW.setSuccessorWeight(0, 0); - return true; + if (NumUnknownBits < 64 /* avoid overflow */) { + uint64_t AllNumCases = 1ULL << NumUnknownBits; + if (SI->getNumCases() == AllNumCases) { + createUnreachableSwitchDefault(SI, DTU); + return true; + } + // When only one case value is missing, replace default with that case. + // Eliminating the default branch will provide more opportunities for + // optimization, such as lookup tables. + if (SI->getNumCases() == AllNumCases - 1) { + assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); + IntegerType *CondTy = cast<IntegerType>(Cond->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + uint64_t MissingCaseVal = 0; + for (const auto &Case : SI->cases()) + MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); + auto *MissingCase = cast<ConstantInt>( + ConstantInt::get(Cond->getType(), MissingCaseVal)); + SwitchInstProfUpdateWrapper SIW(*SI); + SIW.addCase(MissingCase, SI->getDefaultDest(), + SIW.getSuccessorWeight(0)); + createUnreachableSwitchDefault(SI, DTU, + /*RemoveOrigDefaultBlock*/ false); + SIW.setSuccessorWeight(0, 0); + return true; + } } } @@ -7570,6 +7585,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform the switch when the condition is umin with a constant. +/// In that case, the default branch can be replaced by the constant's branch. +/// This method also removes dead cases when the simplification cannot replace +/// the default branch. +/// +/// For example: +/// switch(umin(a, 3)) { +/// case 0: +/// case 1: +/// case 2: +/// case 3: +/// case 4: +/// // ... +/// default: +/// unreachable +/// } +/// +/// Transforms into: +/// +/// switch(a) { +/// case 0: +/// case 1: +/// case 2: +/// default: +/// // This is case 3 +/// } +static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) { + Value *A; + ConstantInt *Constant; + + if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant)))) + return false; + + SmallVector<DominatorTree::UpdateType> Updates; + SwitchInstProfUpdateWrapper SIW(*SI); + BasicBlock *BB = SIW->getParent(); + + // Dead cases are removed even when the simplification fails. + // A case is dead when its value is higher than the Constant. + for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) { + if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) { + ++I; + continue; + } + BasicBlock *DeadCaseBB = I->getCaseSuccessor(); + DeadCaseBB->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); + I = SIW->removeCase(I); + E = SIW->case_end(); + } + + auto Case = SI->findCaseValue(Constant); + // If the case value is not found, `findCaseValue` returns the default case. + // In this scenario, since there is no explicit `case 3:`, the simplification + // fails. The simplification also fails when the switch’s default destination + // is reachable. + if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { + if (DTU) + DTU->applyUpdates(Updates); + return !Updates.empty(); + } + + BasicBlock *Unreachable = SI->getDefaultDest(); + SIW.replaceDefaultDest(Case); + SIW.removeCase(Case); + SIW->setCondition(A); + + Updates.push_back({DominatorTree::Delete, BB, Unreachable}); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + /// Tries to transform switch of powers of two to reduce switch range. /// For example, switch like: /// switch (C) { case 1: case 2: case 64: case 128: } @@ -8037,6 +8127,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (simplifyDuplicateSwitchArms(SI, DTU)) return requestResimplify(); + if (simplifySwitchWhenUMin(SI, DTU)) + return requestResimplify(); + return false; } |
