diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 84 |
1 files changed, 45 insertions, 39 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 425420f..6f7dd79 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15473,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI( } } +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and less or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return SE.getConstant(*ExprVal - Rem); +} + +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and greater or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + if (Rem.isZero()) + return Expr; + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(*ExprVal + DivisorVal - Rem); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15540,36 +15572,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock( match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. For now, only handle constant - // Expr. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - if (Rem.isZero()) - return Expr; - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(*ExprVal + DivisorVal - Rem); - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. For now, only handle constant - // Expr. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(*ExprVal - Rem); - }; - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. @@ -15591,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal) - : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal); + IsMin + ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE); SmallVector<const SCEV *> Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); @@ -15669,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; default: break; @@ -15737,7 +15740,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - GetNextSCEVDividesByDivisor(One, DividesBy); + getNextSCEVDivisibleByDivisor(One, DividesBy, SE); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), @@ -15963,8 +15966,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { if (MatchBinarySub(S, LHS, RHS)) { if (LHS > RHS) std::swap(LHS, RHS); - if (NotEqual.contains({LHS, RHS})) - return SE.getUMaxExpr(S, SE.getOne(S->getType())); + if (NotEqual.contains({LHS, RHS})) { + const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor( + SE.getOne(S->getType()), SE.getConstantMultiple(S), SE); + return SE.getUMaxExpr(OneAlignedUp, S); + } } return nullptr; }; |