aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorFlorian Hahn <flo@fhahn.com>2024-12-16 15:51:22 +0000
committerGitHub <noreply@github.com>2024-12-16 15:51:22 +0000
commit7bfcf93527782f1ebf83880f35e981665308d89c (patch)
tree9eca7ec1d2d5990fffce3db38b10c01b3a1dd981 /llvm/lib/Analysis/ScalarEvolution.cpp
parentc53901405a309a414cb731c4b22f32eafccbbd2a (diff)
downloadllvm-7bfcf93527782f1ebf83880f35e981665308d89c.zip
llvm-7bfcf93527782f1ebf83880f35e981665308d89c.tar.gz
llvm-7bfcf93527782f1ebf83880f35e981665308d89c.tar.bz2
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (#118184)
A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future. PR: https://github.com/llvm/llvm-project/pull/118184
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp80
1 files changed, 58 insertions, 22 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e181339..e2c2500 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
+ if (Expr1 != Expr2 &&
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
@@ -14857,7 +14858,7 @@ private:
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(P);
+ return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
-bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+bool SCEVComparePredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
@@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+ if (!Op || setFlags(Flags, Op->Flags) != Flags)
+ return false;
+
+ if (Op->AR == AR)
+ return true;
+
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+ Flags != SCEVWrapPredicate::IncrementNUSW)
+ return false;
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+ if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
+ return false;
+
+ // If both steps are positive, this implies N, if N's start and step are
+ // ULE/SLE (for NSUW/NSSW) than this'.
+ Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
+ Step = SE.getNoopOrZeroExtend(Step, WiderTy);
+ OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
+
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
+ : SE.getNoopOrSignExtend(OpStart, WiderTy);
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
+ : SE.getNoopOrSignExtend(Start, WiderTy);
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,10 +15047,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE)
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
- add(P);
+ add(P, SE);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
@@ -15026,13 +15059,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const {
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
-bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
- return all_of(Set->Preds,
- [this](const SCEVPredicate *I) { return this->implies(I); });
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
+ return this->implies(I, SE);
+ });
return any_of(Preds,
- [N](const SCEVPredicate *I) { return I->implies(N); });
+ [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
@@ -15040,15 +15075,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
Pred->print(OS, Depth);
}
-void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
- add(Pred);
+ add(Pred, SE);
return;
}
// Only add predicate if it is not already implied by this union predicate.
- if (!implies(N))
+ if (!implies(N, SE))
Preds.push_back(N);
}
@@ -15056,7 +15091,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
- if (Preds->implies(&Pred))
+ if (Preds->implies(&Pred, SE))
return;
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}
@@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
+ SE)),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}