diff options
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 75 |
1 files changed, 32 insertions, 43 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 30bcff7..b5b4cd9 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15633,47 +15633,34 @@ void ScalarEvolution::LoopGuards::collectFromBlock( return false; }; - // Checks whether Expr is a non-negative constant, and Divisor is a positive - // constant, and returns their APInt in ExprVal and in DivisorVal. - auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, - APInt &ExprVal, APInt &DivisorVal) { - auto *ConstExpr = dyn_cast<SCEVConstant>(Expr); - auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor); - if (!ConstExpr || !ConstDivisor) - return false; - ExprVal = ConstExpr->getAPInt(); - DivisorVal = ConstDivisor->getAPInt(); - return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); - }; - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. - // For now, only handle constant Expr and Divisor. + // \p Divisor and greater or equal than Expr. For now, only handle constant + // Expr. auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + const APInt &DivisorVal) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) return Expr; - APInt Rem = ExprVal.urem(DivisorVal); - if (!Rem.isZero()) - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(ExprVal + DivisorVal - Rem); - return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + if (Rem.isZero()) + return Expr; + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(*ExprVal + DivisorVal - Rem); }; // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. - // For now, only handle constant Expr and Divisor. + // \p Divisor and less or equal than Expr. For now, only handle constant + // Expr. auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + const APInt &DivisorVal) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) return Expr; - APInt Rem = ExprVal.urem(DivisorVal); + APInt Rem = ExprVal->urem(DivisorVal); // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(ExprVal - Rem); + return SE.getConstant(*ExprVal - Rem); }; // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, @@ -15682,6 +15669,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock( std::function<const SCEV *(const SCEV *, const SCEV *)> ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, const SCEV *Divisor) { + auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor); + if (!ConstDivisor) + return MinMaxExpr; + const APInt &DivisorVal = ConstDivisor->getAPInt(); + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; SCEVTypes SCTy; if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, @@ -15692,8 +15684,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock( assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) - : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); + IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal) + : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal); SmallVector<const SCEV *> Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); @@ -15750,10 +15742,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( }; const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); - const SCEV *DividesBy = nullptr; - const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS); - if (!Multiple.isOne()) - DividesBy = SE.getConstant(Multiple); + const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS); // Collect rewrites for LHS and its transitive operands based on the // condition. @@ -15775,21 +15764,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); break; default: break; @@ -15843,7 +15832,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; + GetNextSCEVDividesByDivisor(One, DividesBy); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } break; |