diff options
author | Florian Hahn <flo@fhahn.com> | 2024-12-16 15:51:22 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-16 15:51:22 +0000 |
commit | 7bfcf93527782f1ebf83880f35e981665308d89c (patch) | |
tree | 9eca7ec1d2d5990fffce3db38b10c01b3a1dd981 /llvm/lib/Analysis/ScalarEvolution.cpp | |
parent | c53901405a309a414cb731c4b22f32eafccbbd2a (diff) | |
download | llvm-7bfcf93527782f1ebf83880f35e981665308d89c.zip llvm-7bfcf93527782f1ebf83880f35e981665308d89c.tar.gz llvm-7bfcf93527782f1ebf83880f35e981665308d89c.tar.bz2 |
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (#118184)
A SCEVWrapPredicate A implies B, if
* they have the same flag,
* both steps are positive and
* B's start and step are ULE/SLE (for NSUW/NSSW) than A's.
See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants
as strides, second pair with variable strides).
Note that this is limited to steps of the same size, due to NSUW having
slightly different semantics than regular NUW. We should be able to
remove this restriction for NSSW (which matches NSW) in the future.
PR: https://github.com/llvm/llvm-project/pull/118184
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 80 |
1 files changed, 58 insertions, 22 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e181339..e2c2500 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && + !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) && + !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE)) return false; return true; }; @@ -14857,7 +14858,7 @@ private: bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(P); + return Pred && Pred->implies(P, SE); } NewPreds->push_back(P); return true; @@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { +bool SCEVComparePredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast<SCEVComparePredicate>(N); if (!Op) @@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +bool SCEVWrapPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + if (!Op || setFlags(Flags, Op->Flags) != Flags) + return false; + + if (Op->AR == AR) + return true; + + if (Flags != SCEVWrapPredicate::IncrementNSSW && + Flags != SCEVWrapPredicate::IncrementNUSW) + return false; - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpStep = Op->AR->getStepRecurrence(SE); + if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep)) + return false; + + // If both steps are positive, this implies N, if N's start and step are + // ULE/SLE (for NSUW/NSSW) than this'. + Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType()); + Step = SE.getNoopOrZeroExtend(Step, WiderTy); + OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy); + + bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; + const SCEV *OpStart = Op->AR->getStart(); + const SCEV *Start = AR->getStart(); + OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy) + : SE.getNoopOrSignExtend(OpStart, WiderTy); + Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy) + : SE.getNoopOrSignExtend(Start, WiderTy); + CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; + return SE.isKnownPredicate(Pred, OpStep, Step) && + SE.isKnownPredicate(Pred, OpStart, Start); } bool SCEVWrapPredicate::isAlwaysTrue() const { @@ -15015,10 +15047,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, } /// Union predicates don't get cached so create a dummy set ID for it. -SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds) - : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { +SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds, + ScalarEvolution &SE) + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { for (const auto *P : Preds) - add(P); + add(P, SE); } bool SCEVUnionPredicate::isAlwaysTrue() const { @@ -15026,13 +15059,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const { [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +bool SCEVUnionPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) - return all_of(Set->Preds, - [this](const SCEVPredicate *I) { return this->implies(I); }); + return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) { + return this->implies(I, SE); + }); return any_of(Preds, - [N](const SCEVPredicate *I) { return I->implies(N); }); + [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }); } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { @@ -15040,15 +15075,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { Pred->print(OS, Depth); } -void SCEVUnionPredicate::add(const SCEVPredicate *N) { +void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) { if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { for (const auto *Pred : Set->Preds) - add(Pred); + add(Pred, SE); return; } // Only add predicate if it is not already implied by this union predicate. - if (!implies(N)) + if (!implies(N, SE)) Preds.push_back(N); } @@ -15056,7 +15091,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) : SE(SE), L(L) { SmallVector<const SCEVPredicate*, 4> Empty; - Preds = std::make_unique<SCEVUnionPredicate>(Empty); + Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE); } void ScalarEvolution::registerUser(const SCEV *User, @@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() { } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds->implies(&Pred)) + if (Preds->implies(&Pred, SE)) return; SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates()); NewPreds.push_back(&Pred); - Preds = std::make_unique<SCEVUnionPredicate>(NewPreds); + Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE); updateGeneration(); } @@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) - : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), - Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())), - Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), + Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(), + SE)), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (auto I : Init.FlagsMap) FlagsMap.insert(I); } |