aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp80
1 files changed, 58 insertions, 22 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e181339..e2c2500 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
+ if (Expr1 != Expr2 &&
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
@@ -14857,7 +14858,7 @@ private:
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(P);
+ return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
-bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+bool SCEVComparePredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
@@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+ if (!Op || setFlags(Flags, Op->Flags) != Flags)
+ return false;
+
+ if (Op->AR == AR)
+ return true;
+
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+ Flags != SCEVWrapPredicate::IncrementNUSW)
+ return false;
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+ if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
+ return false;
+
+ // If both steps are positive, this implies N, if N's start and step are
+ // ULE/SLE (for NSUW/NSSW) than this'.
+ Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
+ Step = SE.getNoopOrZeroExtend(Step, WiderTy);
+ OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
+
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
+ : SE.getNoopOrSignExtend(OpStart, WiderTy);
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
+ : SE.getNoopOrSignExtend(Start, WiderTy);
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,10 +15047,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE)
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
- add(P);
+ add(P, SE);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
@@ -15026,13 +15059,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const {
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
-bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
- return all_of(Set->Preds,
- [this](const SCEVPredicate *I) { return this->implies(I); });
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
+ return this->implies(I, SE);
+ });
return any_of(Preds,
- [N](const SCEVPredicate *I) { return I->implies(N); });
+ [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
@@ -15040,15 +15075,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
Pred->print(OS, Depth);
}
-void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
- add(Pred);
+ add(Pred, SE);
return;
}
// Only add predicate if it is not already implied by this union predicate.
- if (!implies(N))
+ if (!implies(N, SE))
Preds.push_back(N);
}
@@ -15056,7 +15091,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
- if (Preds->implies(&Pred))
+ if (Preds->implies(&Pred, SE))
return;
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}
@@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
+ SE)),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}