aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp193
1 files changed, 96 insertions, 97 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a64b93d..425420f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
// = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
// = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
//
- if (SM->getNumOperands() == 2)
- if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
- if (MulLHS->getAPInt().isPowerOf2())
- if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
- int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
- MulLHS->getAPInt().logBase2();
- Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
- return getMulExpr(
- getZeroExtendExpr(MulLHS, Ty),
- getZeroExtendExpr(
- getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
- SCEV::FlagNUW, Depth + 1);
- }
+ const APInt *C;
+ const SCEV *TruncRHS;
+ if (match(SM,
+ m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
+ C->isPowerOf2()) {
+ int NewTruncBits =
+ getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
+ Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
+ return getMulExpr(
+ getZeroExtendExpr(SM->getOperand(0), Ty),
+ getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
+ SCEV::FlagNUW, Depth + 1);
+ }
}
// zext(umin(x, y)) -> umin(zext(x), zext(y))
@@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
if (Ops.size() == 2) {
// C1*(C2+V) -> C1*C2 + C1*V
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
- // If any of Add's ops are Adds or Muls with a constant, apply this
- // transformation as well.
- //
- // TODO: There are some cases where this transformation is not
- // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
- // this transformation should be narrowed down.
- if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
- const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
- SCEV::FlagAnyWrap, Depth + 1);
- const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
- SCEV::FlagAnyWrap, Depth + 1);
- return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
- }
+ // If any of Add's ops are Adds or Muls with a constant, apply this
+ // transformation as well.
+ //
+ // TODO: There are some cases where this transformation is not
+ // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
+ // this transformation should be narrowed down.
+ const SCEV *Op0, *Op1;
+ if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
+ containsConstantInAddMulChain(Ops[1])) {
+ const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
+ return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
+ }
if (Ops[0]->isAllOnesValue()) {
// If we have a mul by -1 of an add, try distributing the -1 among the
@@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
}
// ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
- if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
- AE && AE->getNumOperands() == 2) {
- if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
- const APInt &NegC = VC->getAPInt();
- if (NegC.isNegative() && !NegC.isMinSignedValue()) {
- const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
- if (MME && MME->getNumOperands() == 2 &&
- isa<SCEVConstant>(MME->getOperand(0)) &&
- cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
- MME->getOperand(1) == RHS)
- return getZero(LHS->getType());
- }
- }
- }
+ const APInt *NegC, *C;
+ if (match(LHS,
+ m_scev_Add(m_scev_APInt(NegC),
+ m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) &&
+ NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
+ return getZero(LHS->getType());
// TODO: Generalize to handle any common factors.
// udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
@@ -4623,17 +4614,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
/// If Expr computes ~A, return A else return nullptr
static const SCEV *MatchNotExpr(const SCEV *Expr) {
- const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
- if (!Add || Add->getNumOperands() != 2 ||
- !Add->getOperand(0)->isAllOnesValue())
- return nullptr;
-
- const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
- if (!AddRHS || AddRHS->getNumOperands() != 2 ||
- !AddRHS->getOperand(0)->isAllOnesValue())
- return nullptr;
-
- return AddRHS->getOperand(1);
+ const SCEV *MulOp;
+ if (match(Expr, m_scev_Add(m_scev_AllOnes(),
+ m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
+ return MulOp;
+ return nullptr;
}
/// Return a SCEV corresponding to ~V = -1-V
@@ -10797,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
}
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
- const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
- if (!Add || Add->getNumOperands() != 2)
+ const SCEV *Op0, *Op1;
+ if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
return false;
- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
- LHS = Add->getOperand(1);
- RHS = ME->getOperand(1);
+ if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+ LHS = Op1;
return true;
}
- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
- LHS = Add->getOperand(0);
- RHS = ME->getOperand(1);
+ if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+ LHS = Op0;
return true;
}
return false;
@@ -12172,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
const SCEV *&L, const SCEV *&R,
SCEV::NoWrapFlags &Flags) {
- const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
- if (!AE || AE->getNumOperands() != 2)
+ if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
return false;
- L = AE->getOperand(0);
- R = AE->getOperand(1);
- Flags = AE->getNoWrapFlags();
+ Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
return true;
}
@@ -12220,12 +12198,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// Try to match a common constant multiply.
auto MatchConstMul =
[](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
- auto *M = dyn_cast<SCEVMulExpr>(S);
- if (!M || M->getNumOperands() != 2 ||
- !isa<SCEVConstant>(M->getOperand(0)))
- return std::nullopt;
- return {
- {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
+ const APInt *C;
+ const SCEV *Op;
+ if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
+ return {{Op, *C}};
+ return std::nullopt;
};
if (auto MatchedMore = MatchConstMul(More)) {
if (auto MatchedLess = MatchConstMul(Less)) {
@@ -15557,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
auto IsMinMaxSCEVWithNonNegativeConstant =
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
const SCEV *&RHS) {
- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
- if (MinMax->getNumOperands() != 2)
- return false;
- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
- if (C->getAPInt().isNegative())
- return false;
- SCTy = MinMax->getSCEVType();
- LHS = MinMax->getOperand(0);
- RHS = MinMax->getOperand(1);
- return true;
- }
- }
- return false;
+ const APInt *C;
+ SCTy = Expr->getSCEVType();
+ return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
+ match(LHS, m_scev_APInt(C)) && C->isNonNegative();
};
// Return a new SCEV that modifies \p Expr to the closest number divides by
@@ -15772,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
GetNextSCEVDividesByDivisor(One, DividesBy);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
} else {
+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
+ // but creating the subtraction eagerly is expensive. Track the
+ // inequalities in a separate map, and materialize the rewrite lazily
+ // when encountering a suitable subtraction while re-writing.
if (LHS->getType()->isPointerTy()) {
LHS = SE.getLosslessPtrToIntExpr(LHS);
RHS = SE.getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
break;
}
- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
- const SCEV *Sub = SE.getMinusSCEV(A, B);
- AddRewrite(Sub, Sub,
- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
- };
- AddSubRewrite(LHS, RHS);
- AddSubRewrite(RHS, LHS);
+ const SCEVConstant *C;
+ const SCEV *A, *B;
+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
+ RHS = A;
+ LHS = B;
+ }
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ Guards.NotEqual.insert({LHS, RHS});
continue;
}
break;
@@ -15918,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
class SCEVLoopGuardRewriter
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
+ const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;
SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
const ScalarEvolution::LoopGuards &Guards)
- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
+ NotEqual(Guards.NotEqual) {
if (Guards.PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (Guards.PreserveNSW)
@@ -15979,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
+ // return UMax(S, 1).
+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
+ const SCEV *LHS, *RHS;
+ if (MatchBinarySub(S, LHS, RHS)) {
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ if (NotEqual.contains({LHS, RHS}))
+ return SE.getUMaxExpr(S, SE.getOne(S->getType()));
+ }
+ return nullptr;
+ };
+
+ // Check if Expr itself is a subtraction pattern with guard info.
+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
+ return Rewritten;
+
// Trip count expressions sometimes consist of adding 3 operands, i.e.
// (Const + A + B). There may be guard info for A + B, and if so, apply
// it.
// TODO: Could more generally apply guards to Add sub-expressions.
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
Expr->getNumOperands() == 3) {
- if (const SCEV *S = Map.lookup(
- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
+ const SCEV *Add =
+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
+ return SE.getAddExpr(
+ Expr->getOperand(0), Rewritten,
+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
+ if (const SCEV *S = Map.lookup(Add))
return SE.getAddExpr(Expr->getOperand(0), S);
}
SmallVector<const SCEV *, 2> Operands;
@@ -16021,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
};
- if (RewriteMap.empty())
+ if (RewriteMap.empty() && NotEqual.empty())
return Expr;
SCEVLoopGuardRewriter Rewriter(SE, *this);