aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
authorJoel E. Denny <jdenny.ornl@gmail.com>2025-07-31 12:28:25 -0400
committerGitHub <noreply@github.com>2025-07-31 12:28:25 -0400
commitf7b65011de519b1bd987892475db61f99dde44ce (patch)
tree5550563d8ac0f12271a6cb7d948ba559f4301b78 /llvm/lib/Transforms/Utils/LoopUtils.cpp
parent3e579d93ab50952628a51bda05f3a39f6a5a631c (diff)
downloadllvm-f7b65011de519b1bd987892475db61f99dde44ce.zip
llvm-f7b65011de519b1bd987892475db61f99dde44ce.tar.gz
llvm-f7b65011de519b1bd987892475db61f99dde44ce.tar.bz2
[PGO] Add `llvm.loop.estimated_trip_count` metadata (#148758)
This patch implements the `llvm.loop.estimated_trip_count` metadata discussed in [[RFC] Fix Loop Transformations to Preserve Block Frequencies](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785). As [suggested in the RFC comments](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785/4), it adds the new metadata to all loops at the time of profile ingestion and estimates each trip count from the loop's `branch_weights` metadata. As [suggested in the PR #128785 review](https://github.com/llvm/llvm-project/pull/128785#discussion_r2151091036), it does so via a new `PGOEstimateTripCountsPass` pass, which creates the new metadata for each loop but omits the value if it cannot estimate a trip count due to the loop's form. An important observation not previously discussed is that `PGOEstimateTripCountsPass` *often* cannot estimate a loop's trip count, but later passes can sometimes transform the loop in a way that makes it possible. Currently, such passes do not necessarily update the metadata, but eventually that should be fixed. Until then, if the new metadata has no value, `llvm::getLoopEstimatedTripCount` disregards it and tries again to estimate the trip count from the loop's current `branch_weights` metadata.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp184
1 files changed, 133 insertions, 51 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index e7623aa..9043baa 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -201,34 +201,40 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) {
}
/// Create MDNode for input string.
-static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
+static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name,
+ std::optional<unsigned> V) {
LLVMContext &Context = TheLoop->getHeader()->getContext();
- Metadata *MDs[] = {
- MDString::get(Context, Name),
- ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))};
- return MDNode::get(Context, MDs);
+ if (V) {
+ Metadata *MDs[] = {MDString::get(Context, Name),
+ ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Context), *V))};
+ return MDNode::get(Context, MDs);
+ }
+ return MDNode::get(Context, {MDString::get(Context, Name)});
}
-/// Set input string into loop metadata by keeping other values intact.
-/// If the string is already in loop metadata update value if it is
-/// different.
-void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
- unsigned V) {
+bool llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
+ std::optional<unsigned> V) {
SmallVector<Metadata *, 4> MDs(1);
// If the loop already has metadata, retain it.
MDNode *LoopID = TheLoop->getLoopID();
if (LoopID) {
for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
MDNode *Node = cast<MDNode>(LoopID->getOperand(i));
- // If it is of form key = value, try to parse it.
- if (Node->getNumOperands() == 2) {
+ // If it is of form key [= value], try to parse it.
+ unsigned NumOps = Node->getNumOperands();
+ if (NumOps == 1 || NumOps == 2) {
MDString *S = dyn_cast<MDString>(Node->getOperand(0));
if (S && S->getString() == StringMD) {
- ConstantInt *IntMD =
- mdconst::extract_or_null<ConstantInt>(Node->getOperand(1));
- if (IntMD && IntMD->getSExtValue() == V)
- // It is already in place. Do nothing.
- return;
+ // If the metadata and any value are already as specified, do nothing.
+ if (NumOps == 2 && V) {
+ ConstantInt *IntMD =
+ mdconst::extract_or_null<ConstantInt>(Node->getOperand(1));
+ if (IntMD && IntMD->getSExtValue() == *V)
+ return false;
+ } else if (NumOps == 1 && !V) {
+ return false;
+ }
// We need to update the value, so just skip it here and it will
// be added after copying other existed nodes.
continue;
@@ -245,6 +251,7 @@ void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
// Set operand 0 to refer to the loop id itself.
NewLoopID->replaceOperandWith(0, NewLoopID);
TheLoop->setLoopID(NewLoopID);
+ return true;
}
std::optional<ElementCount>
@@ -804,26 +811,48 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) {
return LatchBR;
}
-/// Return the estimated trip count for any exiting branch which dominates
-/// the loop latch.
-static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
- Loop *L,
- uint64_t &OrigExitWeight) {
+struct DbgLoop {
+ const Loop *L;
+ explicit DbgLoop(const Loop *L) : L(L) {}
+};
+static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
+ OS << "function ";
+ D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
+ return OS << " " << *D.L;
+}
+
+static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
+ // Currently we take the estimate exit count only from the loop latch,
+ // ignoring other exiting blocks. This can overestimate the trip count
+ // if we exit through another exit, but can never underestimate it.
+ // TODO: incorporate information from other exits
+ BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
+ if (!ExitingBranch) {
+ LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
+ << "latch branch of required form in " << DbgLoop(L)
+ << "\n");
+ return std::nullopt;
+ }
+
// To estimate the number of times the loop body was executed, we want to
// know the number of times the backedge was taken, vs. the number of times
// we exited the loop.
uint64_t LoopWeight, ExitWeight;
- if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
+ if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) {
+ LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
+ << "weights for " << DbgLoop(L) << "\n");
return std::nullopt;
+ }
if (L->contains(ExitingBranch->getSuccessor(1)))
std::swap(LoopWeight, ExitWeight);
- if (!ExitWeight)
+ if (!ExitWeight) {
// Don't have a way to return predicated infinite
+ LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
+ << "probability for " << DbgLoop(L) << "\n");
return std::nullopt;
-
- OrigExitWeight = ExitWeight;
+ }
// Estimated exit count is a ratio of the loop weight by the weight of the
// edge exiting the loop, rounded to nearest.
@@ -834,33 +863,86 @@ static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
return std::numeric_limits<unsigned>::max();
// Estimated trip count is one plus estimated exit count.
- return ExitCount + 1;
+ uint64_t TC = ExitCount + 1;
+ LLVM_DEBUG(dbgs() << "estimateLoopTripCount: estimated trip count of " << TC
+ << " for " << DbgLoop(L) << "\n");
+ return TC;
}
-std::optional<unsigned>
-llvm::getLoopEstimatedTripCount(Loop *L,
- unsigned *EstimatedLoopInvocationWeight) {
- // Currently we take the estimate exit count only from the loop latch,
- // ignoring other exiting blocks. This can overestimate the trip count
- // if we exit through another exit, but can never underestimate it.
- // TODO: incorporate information from other exits
- if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L)) {
- uint64_t ExitWeight;
- if (std::optional<uint64_t> EstTripCount =
- getEstimatedTripCount(LatchBranch, L, ExitWeight)) {
- if (EstimatedLoopInvocationWeight)
- *EstimatedLoopInvocationWeight = ExitWeight;
- return *EstTripCount;
+std::optional<unsigned> llvm::getLoopEstimatedTripCount(
+ Loop *L, unsigned *EstimatedLoopInvocationWeight, bool DbgForInit) {
+ // If requested, either compute *EstimatedLoopInvocationWeight or return
+ // nullopt if cannot.
+ //
+ // TODO: Eventually, once all passes have migrated away from setting branch
+ // weights to indicate estimated trip counts, this function will drop the
+ // EstimatedLoopInvocationWeight parameter.
+ if (EstimatedLoopInvocationWeight) {
+ if (BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L)) {
+ uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
+ if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
+ return std::nullopt;
+ if (L->contains(ExitingBranch->getSuccessor(1)))
+ std::swap(LoopWeight, ExitWeight);
+ if (!ExitWeight)
+ return std::nullopt;
+ *EstimatedLoopInvocationWeight = ExitWeight;
}
}
- return std::nullopt;
+
+ // Return the estimated trip count from metadata unless the metadata is
+ // missing or has no value.
+ bool Missing = false; // Initialization is expected to be unused.
+ if (auto TC = getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount,
+ &Missing)) {
+ LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
+ << LLVMLoopEstimatedTripCount << " metadata has trip "
+ << "count of " << *TC << " for " << DbgLoop(L) << "\n");
+ return TC;
+ }
+
+ // Estimate the trip count from latch branch weights.
+ std::optional<unsigned> TC = estimateLoopTripCount(L);
+ if (DbgForInit) {
+ // We expect no existing metadata as we are responsible for creating it.
+ LLVM_DEBUG(dbgs() << (Missing ? "" : "WARNING: ")
+ << "getLoopEstimatedTripCount: "
+ << LLVMLoopEstimatedTripCount << " metadata "
+ << (Missing ? "" : "not ") << "missing as expected "
+ << "during its init for " << DbgLoop(L) << "\n");
+ } else if (Missing) {
+ // We expect that metadata was already created.
+ LLVM_DEBUG(dbgs() << "WARNING: getLoopEstimatedTripCount: "
+ << LLVMLoopEstimatedTripCount << " metadata missing for "
+ << DbgLoop(L) << "\n");
+ } else {
+ // If the trip count is estimable, the value should have been added already.
+ LLVM_DEBUG(dbgs() << (TC ? "WARNING: " : "")
+ << "getLoopEstimatedTripCount: "
+ << LLVMLoopEstimatedTripCount << " metadata "
+ << (TC ? "incorrectly " : "correctly ")
+ << "indicates trip count is inestimable for "
+ << DbgLoop(L) << "\n");
+ }
+ return TC;
}
-bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
- unsigned EstimatedloopInvocationWeight) {
- // At the moment, we currently support changing the estimate trip count of
- // the latch branch only. We could extend this API to manipulate estimated
- // trip counts for any exit.
+bool llvm::setLoopEstimatedTripCount(
+ Loop *L, std::optional<unsigned> EstimatedTripCount,
+ std::optional<unsigned> EstimatedloopInvocationWeight) {
+ // Set the metadata.
+ bool Updated = addStringMetadataToLoop(L, LLVMLoopEstimatedTripCount,
+ EstimatedTripCount);
+ if (!EstimatedTripCount || !EstimatedloopInvocationWeight)
+ return Updated;
+
+ // At the moment, we currently support changing the estimated trip count in
+ // the latch branch's branch weights only. We could extend this API to
+ // manipulate estimated trip counts for any exit.
+ //
+ // TODO: Eventually, once all passes have migrated away from setting branch
+ // weights to indicate estimated trip counts, we will not set branch weights
+ // here at all.
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
if (!LatchBranch)
return false;
@@ -869,9 +951,9 @@ bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
unsigned LatchExitWeight = 0;
unsigned BackedgeTakenWeight = 0;
- if (EstimatedTripCount > 0) {
- LatchExitWeight = EstimatedloopInvocationWeight;
- BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
+ if (*EstimatedTripCount != 0) {
+ LatchExitWeight = *EstimatedloopInvocationWeight;
+ BackedgeTakenWeight = (*EstimatedTripCount - 1) * LatchExitWeight;
}
// Make a swap if back edge is taken when condition is "false".
@@ -885,7 +967,7 @@ bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
LLVMContext::MD_prof,
MDB.createBranchWeights(BackedgeTakenWeight, LatchExitWeight));
- return true;
+ return Updated;
}
bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,