aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopPeel.cpp
diff options
context:
space:
mode:
authorNikita Popov <npopov@redhat.com>2022-09-27 18:30:32 +0200
committerNikita Popov <npopov@redhat.com>2022-10-07 12:35:52 +0200
commitb43a4d0850d5690e4399acda7ec6b5aca40b9eff (patch)
tree1ddbf487f9d7750ee2fc98f11e066da97ae85a40 /llvm/lib/Transforms/Utils/LoopPeel.cpp
parentc9b771b9fc2f17cccd9ccbf8f1d52e2642679b8a (diff)
downloadllvm-b43a4d0850d5690e4399acda7ec6b5aca40b9eff.zip
llvm-b43a4d0850d5690e4399acda7ec6b5aca40b9eff.tar.gz
llvm-b43a4d0850d5690e4399acda7ec6b5aca40b9eff.tar.bz2
[LoopPeeling] Support peeling loops with non-latch exits
Loop peeling currently requires that a) the latch is exiting b) a branch and c) other exits are unreachable/deopt. This patch removes all of these limitations, and adds the necessary branch weight updating support. It essentially works the same way as before with latch -> exiting terminator and loop trip count -> per exit trip count. It's worth noting that there are still other limitations in profitability heuristics: This patch enables peeling of loops to make conditions invariant (which is pretty much always highly profitable if possible), while peeling to make loads dereferenceable still checks that non-latch exits are unreachable and PGO-based peeling has even more conditions. Those checks could be relaxed later if we consider those cases profitable. The motivation for this change is that loops using iterator adaptors in Rust often optimize very badly, and end up with a loop phi of the form phi(true, false) in the final result. Peeling eliminates that phi and conditions based on it, which enables a lot of follow-on simplification. Differential Revision: https://reviews.llvm.org/D134803
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopPeel.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopPeel.cpp194
1 files changed, 90 insertions, 104 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 32c2427..d4f14df 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -77,32 +77,7 @@ static const char *PeeledCountMetaData = "llvm.loop.peeled.count";
// Check whether we are capable of peeling this loop.
bool llvm::canPeel(Loop *L) {
// Make sure the loop is in simplified form
- if (!L->isLoopSimplifyForm())
- return false;
-
- // Don't try to peel loops where the latch is not the exiting block.
- // This can be an indication of two different things:
- // 1) The loop is not rotated.
- // 2) The loop contains irreducible control flow that involves the latch.
- const BasicBlock *Latch = L->getLoopLatch();
- if (!L->isLoopExiting(Latch))
- return false;
-
- // Peeling is only supported if the latch is a branch.
- if (!isa<BranchInst>(Latch->getTerminator()))
- return false;
-
- SmallVector<BasicBlock *, 4> Exits;
- L->getUniqueNonLatchExitBlocks(Exits);
- // The latch must either be the only exiting block or all non-latch exit
- // blocks have either a deopt or unreachable terminator or compose a chain of
- // blocks where the last one is either deopt or unreachable terminated. Both
- // deopt and unreachable terminators are a strong indication they are not
- // taken. Note that this is a profitability check, not a legality check. Also
- // note that LoopPeeling currently can only update the branch weights of latch
- // blocks and branch weights to blocks with deopt or unreachable do not need
- // updating.
- return llvm::all_of(Exits, IsBlockFollowedByDeoptOrUnreachable);
+ return L->isLoopSimplifyForm();
}
// This function calculates the number of iterations after which the given Phi
@@ -487,82 +462,87 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize,
}
}
-/// Update the branch weights of the latch of a peeled-off loop
+struct WeightInfo {
+ // Weights for current iteration.
+ SmallVector<uint32_t> Weights;
+ // Weights to subtract after each iteration.
+ const SmallVector<uint32_t> SubWeights;
+};
+
+/// Update the branch weights of an exiting block of a peeled-off loop
/// iteration.
-/// This sets the branch weights for the latch of the recently peeled off loop
-/// iteration correctly.
-/// Let F is a weight of the edge from latch to header.
-/// Let E is a weight of the edge from latch to exit.
+/// Let F is a weight of the edge to continue (fallthrough) into the loop.
+/// Let E is a weight of the edge to an exit.
/// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to
/// go to exit.
-/// Then, Estimated TripCount = F / E.
+/// Then, Estimated ExitCount = F / E.
/// For I-th (counting from 0) peeled off iteration we set the the weights for
-/// the peeled latch as (TC - I, 1). It gives us reasonable distribution,
-/// The probability to go to exit 1/(TC-I) increases. At the same time
-/// the estimated trip count of remaining loop reduces by I.
+/// the peeled exit as (EC - I, 1). It gives us reasonable distribution,
+/// The probability to go to exit 1/(EC-I) increases. At the same time
+/// the estimated exit count in the remainder loop reduces by I.
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
-///
-/// \param Header The copy of the header block that belongs to next iteration.
-/// \param LatchBR The copy of the latch branch that belongs to this iteration.
-/// \param[in,out] FallThroughWeight The weight of the edge from latch to
-/// header before peeling (in) and after peeled off one iteration (out).
-static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- uint64_t ExitWeight,
- uint64_t &FallThroughWeight) {
- // FallThroughWeight is 0 means that there is no branch weights on original
- // latch block or estimated trip count is zero.
- if (!FallThroughWeight)
- return;
-
- unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1);
- MDBuilder MDB(LatchBR->getContext());
- MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
- : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
- LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
- FallThroughWeight =
- FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1;
+static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
+ MDBuilder MDB(Term->getContext());
+ Term->setMetadata(LLVMContext::MD_prof,
+ MDB.createBranchWeights(Info.Weights));
+ for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
+ if (SubWeight != 0)
+ Info.Weights[Idx] = Info.Weights[Idx] > SubWeight
+ ? Info.Weights[Idx] - SubWeight
+ : 1;
}
-/// Initialize the weights.
-///
-/// \param Header The header block.
-/// \param LatchBR The latch branch.
-/// \param[out] ExitWeight The weight of the edge from Latch to Exit.
-/// \param[out] FallThroughWeight The weight of the edge from Latch to Header.
-static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- uint64_t &ExitWeight,
- uint64_t &FallThroughWeight) {
- uint64_t TrueWeight, FalseWeight;
- if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
- return;
- unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
- ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;
- FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight;
-}
+/// Initialize the weights for all exiting blocks.
+static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
+ Loop *L) {
+ SmallVector<BasicBlock *> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+ for (BasicBlock *ExitingBlock : ExitingBlocks) {
+ Instruction *Term = ExitingBlock->getTerminator();
+ SmallVector<uint32_t> Weights;
+ if (!extractBranchWeights(*Term, Weights))
+ continue;
-/// Update the weights of original Latch block after peeling off all iterations.
-///
-/// \param Header The header block.
-/// \param LatchBR The latch branch.
-/// \param ExitWeight The weight of the edge from Latch to Exit.
-/// \param FallThroughWeight The weight of the edge from Latch to Header.
-static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
- uint64_t ExitWeight,
- uint64_t FallThroughWeight) {
- // FallThroughWeight is 0 means that there is no branch weights on original
- // latch block or estimated trip count is zero.
- if (!FallThroughWeight)
- return;
+ // See the comment on updateBranchWeights() for an explanation of what we
+ // do here.
+ uint32_t FallThroughWeights = 0;
+ uint32_t ExitWeights = 0;
+ for (auto [Succ, Weight] : zip(successors(Term), Weights)) {
+ if (L->contains(Succ))
+ FallThroughWeights += Weight;
+ else
+ ExitWeights += Weight;
+ }
+
+ // Don't try to update weights for degenerate case.
+ if (FallThroughWeights == 0)
+ continue;
- // Sets the branch weights on the loop exit.
- MDBuilder MDB(LatchBR->getContext());
- unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
- MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
- : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
- LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
+ SmallVector<uint32_t> SubWeights;
+ for (auto [Succ, Weight] : zip(successors(Term), Weights)) {
+ if (!L->contains(Succ)) {
+ // Exit weights stay the same.
+ SubWeights.push_back(0);
+ continue;
+ }
+
+ // Subtract exit weights on each iteration, distributed across all
+ // fallthrough edges.
+ double W = (double)Weight / (double)FallThroughWeights;
+ SubWeights.push_back((uint32_t)(ExitWeights * W));
+ }
+
+ WeightInfos.insert({Term, {std::move(Weights), std::move(SubWeights)}});
+ }
+}
+
+/// Update the weights of original exiting block after peeling off all
+/// iterations.
+static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
+ MDBuilder MDB(Term->getContext());
+ Term->setMetadata(LLVMContext::MD_prof,
+ MDB.createBranchWeights(Info.Weights));
}
/// Clones the body of the loop L, putting it between \p InsertTop and \p
@@ -644,10 +624,10 @@ static void cloneLoopBlocks(
// header (for the last peeled iteration) or the copied header of the next
// iteration (for every other iteration)
BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]);
- BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator());
- for (unsigned idx = 0, e = LatchBR->getNumSuccessors(); idx < e; ++idx)
- if (LatchBR->getSuccessor(idx) == Header) {
- LatchBR->setSuccessor(idx, InsertBot);
+ auto *LatchTerm = cast<Instruction>(NewLatch->getTerminator());
+ for (unsigned idx = 0, e = LatchTerm->getNumSuccessors(); idx < e; ++idx)
+ if (LatchTerm->getSuccessor(idx) == Header) {
+ LatchTerm->setSuccessor(idx, InsertBot);
break;
}
if (DT)
@@ -835,12 +815,13 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
ValueToValueMapTy LVMap;
+ Instruction *LatchTerm =
+ cast<Instruction>(cast<BasicBlock>(Latch)->getTerminator());
+
// If we have branch weight information, we'll want to update it for the
// newly created branches.
- BranchInst *LatchBR =
- cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator());
- uint64_t ExitWeight = 0, FallThroughWeight = 0;
- initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
+ DenseMap<Instruction *, WeightInfo> Weights;
+ initBranchWeights(Weights, L);
// Identify what noalias metadata is inside the loop: if it is inside the
// loop, the associated metadata must be cloned for each iteration.
@@ -869,11 +850,15 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
assert(DT.verify(DominatorTree::VerificationLevel::Fast));
#endif
- auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]);
- updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight);
+ for (auto &[Term, Info] : Weights) {
+ auto *TermCopy = cast<Instruction>(VMap[Term]);
+ updateBranchWeights(TermCopy, Info);
+ }
+
// Remove Loop metadata from the latch branch instruction
// because it is not the Loop's latch branch anymore.
- LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr);
+ auto *LatchTermCopy = cast<Instruction>(VMap[LatchTerm]);
+ LatchTermCopy->setMetadata(LLVMContext::MD_loop, nullptr);
InsertTop = InsertBot;
InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), &DT, LI);
@@ -896,7 +881,8 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
- fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
+ for (const auto &[Term, Info] : Weights)
+ fixupBranchWeights(Term, Info);
// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;