diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/SimplifyCFG.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 173 |
1 files changed, 135 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cbc604e..37c048f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -778,8 +778,10 @@ private: return false; // Add all values from the range to the set - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + APInt Tmp = Span.getLower(); + do Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + while (++Tmp != Span.getUpper()); UsedICmps++; return true; @@ -5212,8 +5214,7 @@ bool SimplifyCFGOpt::simplifyBranchOnICmpChain(BranchInst *BI, // We don't have any info about this condition. auto *Br = TrueWhenEqual ? Builder.CreateCondBr(ExtraCase, EdgeBB, NewBB) : Builder.CreateCondBr(ExtraCase, NewBB, EdgeBB); - setExplicitlyUnknownBranchWeightsIfProfiled(*Br, *NewBB->getParent(), - DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Br, DEBUG_TYPE); OldTI->eraseFromParent(); @@ -6020,6 +6021,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const DataLayout &DL) { Value *Cond = SI->getCondition(); KnownBits Known = computeKnownBits(Cond, DL, AC, SI); + SmallPtrSet<const Constant *, 4> KnownValues; + bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4); // We can also eliminate cases by determining that their values are outside of // the limited range of the condition based on how many significant (non-sign) @@ -6039,15 +6042,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, UniqueSuccessors.push_back(Successor); ++It->second; } - const APInt &CaseVal = Case.getCaseValue()->getValue(); + ConstantInt *CaseC = Case.getCaseValue(); + const APInt &CaseVal = CaseC->getValue(); if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || - (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) { - DeadCases.push_back(Case.getCaseValue()); + (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) || + (IsKnownValuesValid && !KnownValues.contains(CaseC))) { + DeadCases.push_back(CaseC); if (DTU) --NumPerSuccessorCases[Successor]; LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); - } + } else if (IsKnownValuesValid) + KnownValues.erase(CaseC); } // If we can prove that the cases must cover all possible values, the @@ -6058,33 +6064,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const unsigned NumUnknownBits = Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); - if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */) { - uint64_t AllNumCases = 1ULL << NumUnknownBits; - if (SI->getNumCases() == AllNumCases) { + if (HasDefault && DeadCases.empty()) { + if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) { createUnreachableSwitchDefault(SI, DTU); return true; } - // When only one case value is missing, replace default with that case. - // Eliminating the default branch will provide more opportunities for - // optimization, such as lookup tables. - if (SI->getNumCases() == AllNumCases - 1) { - assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); - IntegerType *CondTy = cast<IntegerType>(Cond->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - uint64_t MissingCaseVal = 0; - for (const auto &Case : SI->cases()) - MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); - auto *MissingCase = - cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal)); - SwitchInstProfUpdateWrapper SIW(*SI); - SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0)); - createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false); - SIW.setSuccessorWeight(0, 0); - return true; + if (NumUnknownBits < 64 /* avoid overflow */) { + uint64_t AllNumCases = 1ULL << NumUnknownBits; + if (SI->getNumCases() == AllNumCases) { + createUnreachableSwitchDefault(SI, DTU); + return true; + } + // When only one case value is missing, replace default with that case. + // Eliminating the default branch will provide more opportunities for + // optimization, such as lookup tables. + if (SI->getNumCases() == AllNumCases - 1) { + assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); + IntegerType *CondTy = cast<IntegerType>(Cond->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + uint64_t MissingCaseVal = 0; + for (const auto &Case : SI->cases()) + MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); + auto *MissingCase = cast<ConstantInt>( + ConstantInt::get(Cond->getType(), MissingCaseVal)); + SwitchInstProfUpdateWrapper SIW(*SI); + SIW.addCase(MissingCase, SI->getDefaultDest(), + SIW.getSuccessorWeight(0)); + createUnreachableSwitchDefault(SI, DTU, + /*RemoveOrigDefaultBlock*/ false); + SIW.setSuccessorWeight(0, 0); + return true; + } } } @@ -7570,6 +7584,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform the switch when the condition is umin with a constant. +/// In that case, the default branch can be replaced by the constant's branch. +/// This method also removes dead cases when the simplification cannot replace +/// the default branch. +/// +/// For example: +/// switch(umin(a, 3)) { +/// case 0: +/// case 1: +/// case 2: +/// case 3: +/// case 4: +/// // ... +/// default: +/// unreachable +/// } +/// +/// Transforms into: +/// +/// switch(a) { +/// case 0: +/// case 1: +/// case 2: +/// default: +/// // This is case 3 +/// } +static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) { + Value *A; + ConstantInt *Constant; + + if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant)))) + return false; + + SmallVector<DominatorTree::UpdateType> Updates; + SwitchInstProfUpdateWrapper SIW(*SI); + BasicBlock *BB = SIW->getParent(); + + // Dead cases are removed even when the simplification fails. + // A case is dead when its value is higher than the Constant. + for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) { + if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) { + ++I; + continue; + } + BasicBlock *DeadCaseBB = I->getCaseSuccessor(); + DeadCaseBB->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); + I = SIW->removeCase(I); + E = SIW->case_end(); + } + + auto Case = SI->findCaseValue(Constant); + // If the case value is not found, `findCaseValue` returns the default case. + // In this scenario, since there is no explicit `case 3:`, the simplification + // fails. The simplification also fails when the switch’s default destination + // is reachable. + if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { + if (DTU) + DTU->applyUpdates(Updates); + return !Updates.empty(); + } + + BasicBlock *Unreachable = SI->getDefaultDest(); + SIW.replaceDefaultDest(Case); + SIW.removeCase(Case); + SIW->setCondition(A); + + Updates.push_back({DominatorTree::Delete, BB, Unreachable}); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + /// Tries to transform switch of powers of two to reduce switch range. /// For example, switch like: /// switch (C) { case 1: case 2: case 64: case 128: } @@ -7642,19 +7731,24 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, // label. The other is those powers of 2 that don't appear in the case // statement. We don't know the distribution of the values coming in, so // the safest is to split 50-50 the original probability to `default`. - uint64_t OrigDenominator = sum_of(map_range( - Weights, [](const auto &V) { return static_cast<uint64_t>(V); })); + uint64_t OrigDenominator = + sum_of(map_range(Weights, StaticCastTo<uint64_t>)); SmallVector<uint64_t> NewWeights(2); NewWeights[1] = Weights[0] / 2; NewWeights[0] = OrigDenominator - NewWeights[1]; setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false); - - // For the original switch, we reduce the weight of the default by the - // amount by which the previous branch contributes to getting to default, - // and then make sure the remaining weights have the same relative ratio - // wrt eachother. + // The probability of executing the default block stays constant. It was + // p_d = Weights[0] / OrigDenominator + // we rewrite as W/D + // We want to find the probability of the default branch of the switch + // statement. Let's call it X. We have W/D = W/2D + X * (1-W/2D) + // i.e. the original probability is the probability we go to the default + // branch from the BI branch, or we take the default branch on the SI. + // Meaning X = W / (2D - W), or (W/2) / (D - W/2) + // This matches using W/2 for the default branch probability numerator and + // D-W/2 as the denominator. + Weights[0] = NewWeights[1]; uint64_t CasesDenominator = OrigDenominator - Weights[0]; - Weights[0] /= 2; for (auto &W : drop_begin(Weights)) W = NewWeights[0] * static_cast<double>(W) / CasesDenominator; @@ -8037,6 +8131,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (simplifyDuplicateSwitchArms(SI, DTU)) return requestResimplify(); + if (simplifySwitchWhenUMin(SI, DTU)) + return requestResimplify(); + return false; } |
