diff options
author | Joel E. Denny <jdenny.ornl@gmail.com> | 2025-07-31 12:28:25 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-31 12:28:25 -0400 |
commit | f7b65011de519b1bd987892475db61f99dde44ce (patch) | |
tree | 5550563d8ac0f12271a6cb7d948ba559f4301b78 /llvm/lib/Transforms/Utils/LoopUtils.cpp | |
parent | 3e579d93ab50952628a51bda05f3a39f6a5a631c (diff) | |
download | llvm-f7b65011de519b1bd987892475db61f99dde44ce.zip llvm-f7b65011de519b1bd987892475db61f99dde44ce.tar.gz llvm-f7b65011de519b1bd987892475db61f99dde44ce.tar.bz2 |
[PGO] Add `llvm.loop.estimated_trip_count` metadata (#148758)
This patch implements the `llvm.loop.estimated_trip_count` metadata
discussed in [[RFC] Fix Loop Transformations to Preserve Block
Frequencies](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785).
As [suggested in the RFC
comments](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785/4),
it adds the new metadata to all loops at the time of profile ingestion
and estimates each trip count from the loop's `branch_weights` metadata.
As [suggested in the PR #128785
review](https://github.com/llvm/llvm-project/pull/128785#discussion_r2151091036),
it does so via a new `PGOEstimateTripCountsPass` pass, which creates the
new metadata for each loop but omits the value if it cannot estimate a
trip count due to the loop's form.
An important observation not previously discussed is that
`PGOEstimateTripCountsPass` *often* cannot estimate a loop's trip count,
but later passes can sometimes transform the loop in a way that makes it
possible. Currently, such passes do not necessarily update the metadata,
but eventually that should be fixed. Until then, if the new metadata has
no value, `llvm::getLoopEstimatedTripCount` disregards it and tries
again to estimate the trip count from the loop's current
`branch_weights` metadata.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 184 |
1 files changed, 133 insertions, 51 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index e7623aa..9043baa 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -201,34 +201,40 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) { } /// Create MDNode for input string. -static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { +static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, + std::optional<unsigned> V) { LLVMContext &Context = TheLoop->getHeader()->getContext(); - Metadata *MDs[] = { - MDString::get(Context, Name), - ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; - return MDNode::get(Context, MDs); + if (V) { + Metadata *MDs[] = {MDString::get(Context, Name), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Context), *V))}; + return MDNode::get(Context, MDs); + } + return MDNode::get(Context, {MDString::get(Context, Name)}); } -/// Set input string into loop metadata by keeping other values intact. -/// If the string is already in loop metadata update value if it is -/// different. -void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, - unsigned V) { +bool llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, + std::optional<unsigned> V) { SmallVector<Metadata *, 4> MDs(1); // If the loop already has metadata, retain it. MDNode *LoopID = TheLoop->getLoopID(); if (LoopID) { for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); - // If it is of form key = value, try to parse it. - if (Node->getNumOperands() == 2) { + // If it is of form key [= value], try to parse it. + unsigned NumOps = Node->getNumOperands(); + if (NumOps == 1 || NumOps == 2) { MDString *S = dyn_cast<MDString>(Node->getOperand(0)); if (S && S->getString() == StringMD) { - ConstantInt *IntMD = - mdconst::extract_or_null<ConstantInt>(Node->getOperand(1)); - if (IntMD && IntMD->getSExtValue() == V) - // It is already in place. Do nothing. - return; + // If the metadata and any value are already as specified, do nothing. + if (NumOps == 2 && V) { + ConstantInt *IntMD = + mdconst::extract_or_null<ConstantInt>(Node->getOperand(1)); + if (IntMD && IntMD->getSExtValue() == *V) + return false; + } else if (NumOps == 1 && !V) { + return false; + } // We need to update the value, so just skip it here and it will // be added after copying other existed nodes. continue; @@ -245,6 +251,7 @@ void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, // Set operand 0 to refer to the loop id itself. NewLoopID->replaceOperandWith(0, NewLoopID); TheLoop->setLoopID(NewLoopID); + return true; } std::optional<ElementCount> @@ -804,26 +811,48 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) { return LatchBR; } -/// Return the estimated trip count for any exiting branch which dominates -/// the loop latch. -static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch, - Loop *L, - uint64_t &OrigExitWeight) { +struct DbgLoop { + const Loop *L; + explicit DbgLoop(const Loop *L) : L(L) {} +}; +static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) { + OS << "function "; + D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false); + return OS << " " << *D.L; +} + +static std::optional<unsigned> estimateLoopTripCount(Loop *L) { + // Currently we take the estimate exit count only from the loop latch, + // ignoring other exiting blocks. This can overestimate the trip count + // if we exit through another exit, but can never underestimate it. + // TODO: incorporate information from other exits + BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L); + if (!ExitingBranch) { + LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting " + << "latch branch of required form in " << DbgLoop(L) + << "\n"); + return std::nullopt; + } + // To estimate the number of times the loop body was executed, we want to // know the number of times the backedge was taken, vs. the number of times // we exited the loop. uint64_t LoopWeight, ExitWeight; - if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) + if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) { + LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch " + << "weights for " << DbgLoop(L) << "\n"); return std::nullopt; + } if (L->contains(ExitingBranch->getSuccessor(1))) std::swap(LoopWeight, ExitWeight); - if (!ExitWeight) + if (!ExitWeight) { // Don't have a way to return predicated infinite + LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit " + << "probability for " << DbgLoop(L) << "\n"); return std::nullopt; - - OrigExitWeight = ExitWeight; + } // Estimated exit count is a ratio of the loop weight by the weight of the // edge exiting the loop, rounded to nearest. @@ -834,33 +863,86 @@ static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch, return std::numeric_limits<unsigned>::max(); // Estimated trip count is one plus estimated exit count. - return ExitCount + 1; + uint64_t TC = ExitCount + 1; + LLVM_DEBUG(dbgs() << "estimateLoopTripCount: estimated trip count of " << TC + << " for " << DbgLoop(L) << "\n"); + return TC; } -std::optional<unsigned> -llvm::getLoopEstimatedTripCount(Loop *L, - unsigned *EstimatedLoopInvocationWeight) { - // Currently we take the estimate exit count only from the loop latch, - // ignoring other exiting blocks. This can overestimate the trip count - // if we exit through another exit, but can never underestimate it. - // TODO: incorporate information from other exits - if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L)) { - uint64_t ExitWeight; - if (std::optional<uint64_t> EstTripCount = - getEstimatedTripCount(LatchBranch, L, ExitWeight)) { - if (EstimatedLoopInvocationWeight) - *EstimatedLoopInvocationWeight = ExitWeight; - return *EstTripCount; +std::optional<unsigned> llvm::getLoopEstimatedTripCount( + Loop *L, unsigned *EstimatedLoopInvocationWeight, bool DbgForInit) { + // If requested, either compute *EstimatedLoopInvocationWeight or return + // nullopt if cannot. + // + // TODO: Eventually, once all passes have migrated away from setting branch + // weights to indicate estimated trip counts, this function will drop the + // EstimatedLoopInvocationWeight parameter. + if (EstimatedLoopInvocationWeight) { + if (BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L)) { + uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused. + if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) + return std::nullopt; + if (L->contains(ExitingBranch->getSuccessor(1))) + std::swap(LoopWeight, ExitWeight); + if (!ExitWeight) + return std::nullopt; + *EstimatedLoopInvocationWeight = ExitWeight; } } - return std::nullopt; + + // Return the estimated trip count from metadata unless the metadata is + // missing or has no value. + bool Missing = false; // Initialization is expected to be unused. + if (auto TC = getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount, + &Missing)) { + LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: " + << LLVMLoopEstimatedTripCount << " metadata has trip " + << "count of " << *TC << " for " << DbgLoop(L) << "\n"); + return TC; + } + + // Estimate the trip count from latch branch weights. + std::optional<unsigned> TC = estimateLoopTripCount(L); + if (DbgForInit) { + // We expect no existing metadata as we are responsible for creating it. + LLVM_DEBUG(dbgs() << (Missing ? "" : "WARNING: ") + << "getLoopEstimatedTripCount: " + << LLVMLoopEstimatedTripCount << " metadata " + << (Missing ? "" : "not ") << "missing as expected " + << "during its init for " << DbgLoop(L) << "\n"); + } else if (Missing) { + // We expect that metadata was already created. + LLVM_DEBUG(dbgs() << "WARNING: getLoopEstimatedTripCount: " + << LLVMLoopEstimatedTripCount << " metadata missing for " + << DbgLoop(L) << "\n"); + } else { + // If the trip count is estimable, the value should have been added already. + LLVM_DEBUG(dbgs() << (TC ? "WARNING: " : "") + << "getLoopEstimatedTripCount: " + << LLVMLoopEstimatedTripCount << " metadata " + << (TC ? "incorrectly " : "correctly ") + << "indicates trip count is inestimable for " + << DbgLoop(L) << "\n"); + } + return TC; } -bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount, - unsigned EstimatedloopInvocationWeight) { - // At the moment, we currently support changing the estimate trip count of - // the latch branch only. We could extend this API to manipulate estimated - // trip counts for any exit. +bool llvm::setLoopEstimatedTripCount( + Loop *L, std::optional<unsigned> EstimatedTripCount, + std::optional<unsigned> EstimatedloopInvocationWeight) { + // Set the metadata. + bool Updated = addStringMetadataToLoop(L, LLVMLoopEstimatedTripCount, + EstimatedTripCount); + if (!EstimatedTripCount || !EstimatedloopInvocationWeight) + return Updated; + + // At the moment, we currently support changing the estimated trip count in + // the latch branch's branch weights only. We could extend this API to + // manipulate estimated trip counts for any exit. + // + // TODO: Eventually, once all passes have migrated away from setting branch + // weights to indicate estimated trip counts, we will not set branch weights + // here at all. BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); if (!LatchBranch) return false; @@ -869,9 +951,9 @@ bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount, unsigned LatchExitWeight = 0; unsigned BackedgeTakenWeight = 0; - if (EstimatedTripCount > 0) { - LatchExitWeight = EstimatedloopInvocationWeight; - BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight; + if (*EstimatedTripCount != 0) { + LatchExitWeight = *EstimatedloopInvocationWeight; + BackedgeTakenWeight = (*EstimatedTripCount - 1) * LatchExitWeight; } // Make a swap if back edge is taken when condition is "false". @@ -885,7 +967,7 @@ bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount, LLVMContext::MD_prof, MDB.createBranchWeights(BackedgeTakenWeight, LatchExitWeight)); - return true; + return Updated; } bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, |