aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp133
1 files changed, 85 insertions, 48 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e06b095..6f7dd79 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15473,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
}
}
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and less or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ // return the SCEV: Expr - Expr % Divisor
+ return SE.getConstant(*ExprVal - Rem);
+}
+
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and greater or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ if (Rem.isZero())
+ return Expr;
+ // return the SCEV: Expr + Divisor - Expr % Divisor
+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
+}
+
void ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
@@ -15540,36 +15572,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
match(LHS, m_scev_APInt(C)) && C->isNonNegative();
};
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and greater or equal than Expr. For now, only handle constant
- // Expr.
- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- if (Rem.isZero())
- return Expr;
- // return the SCEV: Expr + Divisor - Expr % Divisor
- return SE.getConstant(*ExprVal + DivisorVal - Rem);
- };
-
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and less or equal than Expr. For now, only handle constant
- // Expr.
- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- // return the SCEV: Expr - Expr % Divisor
- return SE.getConstant(*ExprVal - Rem);
- };
-
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
// recursively. This is done by aligning up/down the constant value to the
// Divisor.
@@ -15591,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
assert(SE.isKnownNonNegative(MinMaxLHS) &&
"Expected non-negative operand!");
auto *DivisibleExpr =
- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
- : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal);
+ IsMin
+ ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
SmallVector<const SCEV *> Ops = {
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
@@ -15669,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
[[fallthrough]];
case CmpInst::ICMP_SLT: {
RHS = SE.getMinusSCEV(RHS, One);
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = SE.getAddExpr(RHS, One);
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
default:
break;
@@ -15737,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
case CmpInst::ICMP_NE:
if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
- GetNextSCEVDividesByDivisor(One, DividesBy);
+ getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
} else {
+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
+ // but creating the subtraction eagerly is expensive. Track the
+ // inequalities in a separate map, and materialize the rewrite lazily
+ // when encountering a suitable subtraction while re-writing.
if (LHS->getType()->isPointerTy()) {
LHS = SE.getLosslessPtrToIntExpr(LHS);
RHS = SE.getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
break;
}
- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
- const SCEV *Sub = SE.getMinusSCEV(A, B);
- AddRewrite(Sub, Sub,
- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
- };
- AddSubRewrite(LHS, RHS);
- AddSubRewrite(RHS, LHS);
+ const SCEVConstant *C;
+ const SCEV *A, *B;
+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
+ RHS = A;
+ LHS = B;
+ }
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ Guards.NotEqual.insert({LHS, RHS});
continue;
}
break;
@@ -15886,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
class SCEVLoopGuardRewriter
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
+ const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;
SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
const ScalarEvolution::LoopGuards &Guards)
- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
+ NotEqual(Guards.NotEqual) {
if (Guards.PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (Guards.PreserveNSW)
@@ -15947,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
+ // return UMax(S, 1).
+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
+ const SCEV *LHS, *RHS;
+ if (MatchBinarySub(S, LHS, RHS)) {
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ if (NotEqual.contains({LHS, RHS})) {
+ const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
+ SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
+ return SE.getUMaxExpr(OneAlignedUp, S);
+ }
+ }
+ return nullptr;
+ };
+
+ // Check if Expr itself is a subtraction pattern with guard info.
+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
+ return Rewritten;
+
// Trip count expressions sometimes consist of adding 3 operands, i.e.
// (Const + A + B). There may be guard info for A + B, and if so, apply
// it.
// TODO: Could more generally apply guards to Add sub-expressions.
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
Expr->getNumOperands() == 3) {
- if (const SCEV *S = Map.lookup(
- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
+ const SCEV *Add =
+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
+ return SE.getAddExpr(
+ Expr->getOperand(0), Rewritten,
+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
+ if (const SCEV *S = Map.lookup(Add))
return SE.getAddExpr(Expr->getOperand(0), S);
}
SmallVector<const SCEV *, 2> Operands;
@@ -15989,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
};
- if (RewriteMap.empty())
+ if (RewriteMap.empty() && NotEqual.empty())
return Expr;
SCEVLoopGuardRewriter Rewriter(SE, *this);