diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 80 |
1 files changed, 58 insertions, 22 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e181339..e2c2500 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && + !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) && + !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE)) return false; return true; }; @@ -14857,7 +14858,7 @@ private: bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(P); + return Pred && Pred->implies(P, SE); } NewPreds->push_back(P); return true; @@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { +bool SCEVComparePredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast<SCEVComparePredicate>(N); if (!Op) @@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +bool SCEVWrapPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + if (!Op || setFlags(Flags, Op->Flags) != Flags) + return false; + + if (Op->AR == AR) + return true; + + if (Flags != SCEVWrapPredicate::IncrementNSSW && + Flags != SCEVWrapPredicate::IncrementNUSW) + return false; - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpStep = Op->AR->getStepRecurrence(SE); + if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep)) + return false; + + // If both steps are positive, this implies N, if N's start and step are + // ULE/SLE (for NSUW/NSSW) than this'. + Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType()); + Step = SE.getNoopOrZeroExtend(Step, WiderTy); + OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy); + + bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; + const SCEV *OpStart = Op->AR->getStart(); + const SCEV *Start = AR->getStart(); + OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy) + : SE.getNoopOrSignExtend(OpStart, WiderTy); + Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy) + : SE.getNoopOrSignExtend(Start, WiderTy); + CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; + return SE.isKnownPredicate(Pred, OpStep, Step) && + SE.isKnownPredicate(Pred, OpStart, Start); } bool SCEVWrapPredicate::isAlwaysTrue() const { @@ -15015,10 +15047,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, } /// Union predicates don't get cached so create a dummy set ID for it. -SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds) - : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { +SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds, + ScalarEvolution &SE) + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { for (const auto *P : Preds) - add(P); + add(P, SE); } bool SCEVUnionPredicate::isAlwaysTrue() const { @@ -15026,13 +15059,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const { [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +bool SCEVUnionPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) - return all_of(Set->Preds, - [this](const SCEVPredicate *I) { return this->implies(I); }); + return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) { + return this->implies(I, SE); + }); return any_of(Preds, - [N](const SCEVPredicate *I) { return I->implies(N); }); + [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }); } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { @@ -15040,15 +15075,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { Pred->print(OS, Depth); } -void SCEVUnionPredicate::add(const SCEVPredicate *N) { +void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) { if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { for (const auto *Pred : Set->Preds) - add(Pred); + add(Pred, SE); return; } // Only add predicate if it is not already implied by this union predicate. - if (!implies(N)) + if (!implies(N, SE)) Preds.push_back(N); } @@ -15056,7 +15091,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) : SE(SE), L(L) { SmallVector<const SCEVPredicate*, 4> Empty; - Preds = std::make_unique<SCEVUnionPredicate>(Empty); + Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE); } void ScalarEvolution::registerUser(const SCEV *User, @@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() { } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds->implies(&Pred)) + if (Preds->implies(&Pred, SE)) return; SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates()); NewPreds.push_back(&Pred); - Preds = std::make_unique<SCEVUnionPredicate>(NewPreds); + Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE); updateGeneration(); } @@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) - : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), - Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())), - Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), + Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(), + SE)), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (auto I : Init.FlagsMap) FlagsMap.insert(I); } |