diff options
author | David Sherwood <david.sherwood@arm.com> | 2024-09-02 14:05:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-02 14:05:26 +0100 |
commit | df3d70b5a72fee43af3793c8b7a138bd44cac8cf (patch) | |
tree | 8b058fa7b659ae9f0206e5b34bfabf99d81431b7 /llvm/lib/Analysis/ScalarEvolution.cpp | |
parent | ef26afcb88dcb5f2de79bfc3cf88a8ea10f230ec (diff) | |
download | llvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.zip llvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.tar.gz llvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.tar.bz2 |
[Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)
Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.
New tests added here:
Analysis/ScalarEvolution/predicated-exit-count.ll
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 86 |
1 files changed, 60 insertions, 26 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 54dde84..6b4a81c 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, llvm_unreachable("Invalid ExitCountKind!"); } +const SCEV *ScalarEvolution::getPredicatedExitCount( + const Loop *L, const BasicBlock *ExitingBlock, + SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) { + switch (Kind) { + case Exact: + return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this, + Predicates); + case SymbolicMaximum: + return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this, + Predicates); + case ConstantMaximum: + return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this, + Predicates); + }; + llvm_unreachable("Invalid ExitCountKind!"); +} + const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount( const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) { return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); @@ -8574,33 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true); } -/// Get the exact not taken count for this loop exit. -const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock, - ScalarEvolution *SE) const { - for (const auto &ENT : ExitNotTaken) - if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) - return ENT.ExactNotTaken; - - return SE->getCouldNotCompute(); -} - -const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( - const BasicBlock *ExitingBlock, ScalarEvolution *SE) const { +const ScalarEvolution::ExitNotTakenInfo * +ScalarEvolution::BackedgeTakenInfo::getExitNotTaken( + const BasicBlock *ExitingBlock, + SmallVectorImpl<const SCEVPredicate *> *Predicates) const { for (const auto &ENT : ExitNotTaken) - if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) - return ENT.ConstantMaxNotTaken; - - return SE->getCouldNotCompute(); -} - -const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( - const BasicBlock *ExitingBlock, ScalarEvolution *SE) const { - for (const auto &ENT : ExitNotTaken) - if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) - return ENT.SymbolicMaxNotTaken; + if (ENT.ExitingBlock == ExitingBlock) { + if (ENT.hasAlwaysTruePredicate()) + return &ENT; + else if (Predicates) { + for (const auto *P : ENT.Predicates) + Predicates->push_back(P); + return &ENT; + } + } - return SE->getCouldNotCompute(); + return nullptr; } /// getConstantMax - Get the constant max backedge taken count for the loop. @@ -13642,7 +13648,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (ExitingBlocks.size() > 1) for (BasicBlock *ExitingBlock : ExitingBlocks) { OS << " exit count for " << ExitingBlock->getName() << ": "; - PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock)); + const SCEV *EC = SE->getExitCount(L, ExitingBlock); + PrintSCEVWithTypeHint(OS, EC); + if (isa<SCEVCouldNotCompute>(EC)) { + // Retry with predicates. + SmallVector<const SCEVPredicate *, 4> Predicates; + EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates); + if (!isa<SCEVCouldNotCompute>(EC)) { + OS << "\n predicated exit count for " << ExitingBlock->getName() + << ": "; + PrintSCEVWithTypeHint(OS, EC); + OS << "\n Predicates:\n"; + for (const auto *P : Predicates) + P->print(OS, 4); + } + } OS << "\n"; } @@ -13682,6 +13702,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, auto *ExitBTC = SE->getExitCount(L, ExitingBlock, ScalarEvolution::SymbolicMaximum); PrintSCEVWithTypeHint(OS, ExitBTC); + if (isa<SCEVCouldNotCompute>(ExitBTC)) { + // Retry with predicates. + SmallVector<const SCEVPredicate *, 4> Predicates; + ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates, + ScalarEvolution::SymbolicMaximum); + if (!isa<SCEVCouldNotCompute>(ExitBTC)) { + OS << "\n predicated symbolic max exit count for " + << ExitingBlock->getName() << ": "; + PrintSCEVWithTypeHint(OS, ExitBTC); + OS << "\n Predicates:\n"; + for (const auto *P : Predicates) + P->print(OS, 4); + } + } OS << "\n"; } |