aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp75
1 files changed, 32 insertions, 43 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 30bcff7..b5b4cd9 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15633,47 +15633,34 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
return false;
};
- // Checks whether Expr is a non-negative constant, and Divisor is a positive
- // constant, and returns their APInt in ExprVal and in DivisorVal.
- auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
- APInt &ExprVal, APInt &DivisorVal) {
- auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
- if (!ConstExpr || !ConstDivisor)
- return false;
- ExprVal = ConstExpr->getAPInt();
- DivisorVal = ConstDivisor->getAPInt();
- return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
- };
-
// Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and greater or equal than Expr.
- // For now, only handle constant Expr and Divisor.
+ // \p Divisor and greater or equal than Expr. For now, only handle constant
+ // Expr.
auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
- const SCEV *Divisor) {
- APInt ExprVal;
- APInt DivisorVal;
- if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
+ const APInt &DivisorVal) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
return Expr;
- APInt Rem = ExprVal.urem(DivisorVal);
- if (!Rem.isZero())
- // return the SCEV: Expr + Divisor - Expr % Divisor
- return SE.getConstant(ExprVal + DivisorVal - Rem);
- return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ if (Rem.isZero())
+ return Expr;
+ // return the SCEV: Expr + Divisor - Expr % Divisor
+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
};
// Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and less or equal than Expr.
- // For now, only handle constant Expr and Divisor.
+ // \p Divisor and less or equal than Expr. For now, only handle constant
+ // Expr.
auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
- const SCEV *Divisor) {
- APInt ExprVal;
- APInt DivisorVal;
- if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
+ const APInt &DivisorVal) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
return Expr;
- APInt Rem = ExprVal.urem(DivisorVal);
+ APInt Rem = ExprVal->urem(DivisorVal);
// return the SCEV: Expr - Expr % Divisor
- return SE.getConstant(ExprVal - Rem);
+ return SE.getConstant(*ExprVal - Rem);
};
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
@@ -15682,6 +15669,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
std::function<const SCEV *(const SCEV *, const SCEV *)>
ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
const SCEV *Divisor) {
+ auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
+ if (!ConstDivisor)
+ return MinMaxExpr;
+ const APInt &DivisorVal = ConstDivisor->getAPInt();
+
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
SCEVTypes SCTy;
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
@@ -15692,8 +15684,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
assert(SE.isKnownNonNegative(MinMaxLHS) &&
"Expected non-negative operand!");
auto *DivisibleExpr =
- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
- : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
+ IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
+ : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal);
SmallVector<const SCEV *> Ops = {
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
@@ -15750,10 +15742,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
};
const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
- const SCEV *DividesBy = nullptr;
- const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS);
- if (!Multiple.isOne())
- DividesBy = SE.getConstant(Multiple);
+ const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
// Collect rewrites for LHS and its transitive operands based on the
// condition.
@@ -15775,21 +15764,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
[[fallthrough]];
case CmpInst::ICMP_SLT: {
RHS = SE.getMinusSCEV(RHS, One);
- RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = SE.getAddExpr(RHS, One);
- RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
break;
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
- RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
break;
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
- RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
break;
default:
break;
@@ -15843,7 +15832,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
case CmpInst::ICMP_NE:
if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
- DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
+ GetNextSCEVDividesByDivisor(One, DividesBy);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
}
break;