aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorDavid Sherwood <david.sherwood@arm.com>2024-09-02 14:05:26 +0100
committerGitHub <noreply@github.com>2024-09-02 14:05:26 +0100
commitdf3d70b5a72fee43af3793c8b7a138bd44cac8cf (patch)
tree8b058fa7b659ae9f0206e5b34bfabf99d81431b7 /llvm/lib/Analysis/ScalarEvolution.cpp
parentef26afcb88dcb5f2de79bfc3cf88a8ea10f230ec (diff)
downloadllvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.zip
llvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.tar.gz
llvm-df3d70b5a72fee43af3793c8b7a138bd44cac8cf.tar.bz2
[Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)
Due to a reviewer request on PR #88385 I have created this patch to add a getPredicatedExitCount function, which is similar to getExitCount except that it uses the predicated backedge taken information. With PR #88385 we will start to care about more loops with multiple exits, and want the ability to query exit counts for a particular exiting block. Such loops may require predicates in order to be vectorised. New tests added here: Analysis/ScalarEvolution/predicated-exit-count.ll
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp86
1 files changed, 60 insertions, 26 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 54dde84..6b4a81c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+ const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
+ switch (Kind) {
+ case Exact:
+ return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+ Predicates);
+ case SymbolicMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+ Predicates);
+ case ConstantMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+ Predicates);
+ };
+ llvm_unreachable("Invalid ExitCountKind!");
+}
+
const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
@@ -8574,33 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
}
-/// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const {
- for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ExactNotTaken;
-
- return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+const ScalarEvolution::ExitNotTakenInfo *
+ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
+ const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ConstantMaxNotTaken;
-
- return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
- for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.SymbolicMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return &ENT;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return &ENT;
+ }
+ }
- return SE->getCouldNotCompute();
+ return nullptr;
}
/// getConstantMax - Get the constant max backedge taken count for the loop.
@@ -13642,7 +13648,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (ExitingBlocks.size() > 1)
for (BasicBlock *ExitingBlock : ExitingBlocks) {
OS << " exit count for " << ExitingBlock->getName() << ": ";
- PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
+ const SCEV *EC = SE->getExitCount(L, ExitingBlock);
+ PrintSCEVWithTypeHint(OS, EC);
+ if (isa<SCEVCouldNotCompute>(EC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
+ if (!isa<SCEVCouldNotCompute>(EC)) {
+ OS << "\n predicated exit count for " << ExitingBlock->getName()
+ << ": ";
+ PrintSCEVWithTypeHint(OS, EC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}
@@ -13682,6 +13702,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
ScalarEvolution::SymbolicMaximum);
PrintSCEVWithTypeHint(OS, ExitBTC);
+ if (isa<SCEVCouldNotCompute>(ExitBTC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
+ ScalarEvolution::SymbolicMaximum);
+ if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
+ OS << "\n predicated symbolic max exit count for "
+ << ExitingBlock->getName() << ": ";
+ PrintSCEVWithTypeHint(OS, ExitBTC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}