diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopPeel.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopPeel.cpp | 194 |
1 files changed, 90 insertions, 104 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index 32c2427..d4f14df 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -77,32 +77,7 @@ static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; // Check whether we are capable of peeling this loop. bool llvm::canPeel(Loop *L) { // Make sure the loop is in simplified form - if (!L->isLoopSimplifyForm()) - return false; - - // Don't try to peel loops where the latch is not the exiting block. - // This can be an indication of two different things: - // 1) The loop is not rotated. - // 2) The loop contains irreducible control flow that involves the latch. - const BasicBlock *Latch = L->getLoopLatch(); - if (!L->isLoopExiting(Latch)) - return false; - - // Peeling is only supported if the latch is a branch. - if (!isa<BranchInst>(Latch->getTerminator())) - return false; - - SmallVector<BasicBlock *, 4> Exits; - L->getUniqueNonLatchExitBlocks(Exits); - // The latch must either be the only exiting block or all non-latch exit - // blocks have either a deopt or unreachable terminator or compose a chain of - // blocks where the last one is either deopt or unreachable terminated. Both - // deopt and unreachable terminators are a strong indication they are not - // taken. Note that this is a profitability check, not a legality check. Also - // note that LoopPeeling currently can only update the branch weights of latch - // blocks and branch weights to blocks with deopt or unreachable do not need - // updating. - return llvm::all_of(Exits, IsBlockFollowedByDeoptOrUnreachable); + return L->isLoopSimplifyForm(); } // This function calculates the number of iterations after which the given Phi @@ -487,82 +462,87 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, } } -/// Update the branch weights of the latch of a peeled-off loop +struct WeightInfo { + // Weights for current iteration. + SmallVector<uint32_t> Weights; + // Weights to subtract after each iteration. + const SmallVector<uint32_t> SubWeights; +}; + +/// Update the branch weights of an exiting block of a peeled-off loop /// iteration. -/// This sets the branch weights for the latch of the recently peeled off loop -/// iteration correctly. -/// Let F is a weight of the edge from latch to header. -/// Let E is a weight of the edge from latch to exit. +/// Let F is a weight of the edge to continue (fallthrough) into the loop. +/// Let E is a weight of the edge to an exit. /// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to /// go to exit. -/// Then, Estimated TripCount = F / E. +/// Then, Estimated ExitCount = F / E. /// For I-th (counting from 0) peeled off iteration we set the the weights for -/// the peeled latch as (TC - I, 1). It gives us reasonable distribution, -/// The probability to go to exit 1/(TC-I) increases. At the same time -/// the estimated trip count of remaining loop reduces by I. +/// the peeled exit as (EC - I, 1). It gives us reasonable distribution, +/// The probability to go to exit 1/(EC-I) increases. At the same time +/// the estimated exit count in the remainder loop reduces by I. /// To avoid dealing with division rounding we can just multiple both part /// of weights to E and use weight as (F - I * E, E). -/// -/// \param Header The copy of the header block that belongs to next iteration. -/// \param LatchBR The copy of the latch branch that belongs to this iteration. -/// \param[in,out] FallThroughWeight The weight of the edge from latch to -/// header before peeling (in) and after peeled off one iteration (out). -static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t ExitWeight, - uint64_t &FallThroughWeight) { - // FallThroughWeight is 0 means that there is no branch weights on original - // latch block or estimated trip count is zero. - if (!FallThroughWeight) - return; - - unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(LatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) - : MDB.createBranchWeights(FallThroughWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - FallThroughWeight = - FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1; +static void updateBranchWeights(Instruction *Term, WeightInfo &Info) { + MDBuilder MDB(Term->getContext()); + Term->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Info.Weights)); + for (auto [Idx, SubWeight] : enumerate(Info.SubWeights)) + if (SubWeight != 0) + Info.Weights[Idx] = Info.Weights[Idx] > SubWeight + ? Info.Weights[Idx] - SubWeight + : 1; } -/// Initialize the weights. -/// -/// \param Header The header block. -/// \param LatchBR The latch branch. -/// \param[out] ExitWeight The weight of the edge from Latch to Exit. -/// \param[out] FallThroughWeight The weight of the edge from Latch to Header. -static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t &ExitWeight, - uint64_t &FallThroughWeight) { - uint64_t TrueWeight, FalseWeight; - if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) - return; - unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; - ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; - FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight; -} +/// Initialize the weights for all exiting blocks. +static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos, + Loop *L) { + SmallVector<BasicBlock *> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (BasicBlock *ExitingBlock : ExitingBlocks) { + Instruction *Term = ExitingBlock->getTerminator(); + SmallVector<uint32_t> Weights; + if (!extractBranchWeights(*Term, Weights)) + continue; -/// Update the weights of original Latch block after peeling off all iterations. -/// -/// \param Header The header block. -/// \param LatchBR The latch branch. -/// \param ExitWeight The weight of the edge from Latch to Exit. -/// \param FallThroughWeight The weight of the edge from Latch to Header. -static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t ExitWeight, - uint64_t FallThroughWeight) { - // FallThroughWeight is 0 means that there is no branch weights on original - // latch block or estimated trip count is zero. - if (!FallThroughWeight) - return; + // See the comment on updateBranchWeights() for an explanation of what we + // do here. + uint32_t FallThroughWeights = 0; + uint32_t ExitWeights = 0; + for (auto [Succ, Weight] : zip(successors(Term), Weights)) { + if (L->contains(Succ)) + FallThroughWeights += Weight; + else + ExitWeights += Weight; + } + + // Don't try to update weights for degenerate case. + if (FallThroughWeights == 0) + continue; - // Sets the branch weights on the loop exit. - MDBuilder MDB(LatchBR->getContext()); - unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) - : MDB.createBranchWeights(FallThroughWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + SmallVector<uint32_t> SubWeights; + for (auto [Succ, Weight] : zip(successors(Term), Weights)) { + if (!L->contains(Succ)) { + // Exit weights stay the same. + SubWeights.push_back(0); + continue; + } + + // Subtract exit weights on each iteration, distributed across all + // fallthrough edges. + double W = (double)Weight / (double)FallThroughWeights; + SubWeights.push_back((uint32_t)(ExitWeights * W)); + } + + WeightInfos.insert({Term, {std::move(Weights), std::move(SubWeights)}}); + } +} + +/// Update the weights of original exiting block after peeling off all +/// iterations. +static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) { + MDBuilder MDB(Term->getContext()); + Term->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Info.Weights)); } /// Clones the body of the loop L, putting it between \p InsertTop and \p @@ -644,10 +624,10 @@ static void cloneLoopBlocks( // header (for the last peeled iteration) or the copied header of the next // iteration (for every other iteration) BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); - BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator()); - for (unsigned idx = 0, e = LatchBR->getNumSuccessors(); idx < e; ++idx) - if (LatchBR->getSuccessor(idx) == Header) { - LatchBR->setSuccessor(idx, InsertBot); + auto *LatchTerm = cast<Instruction>(NewLatch->getTerminator()); + for (unsigned idx = 0, e = LatchTerm->getNumSuccessors(); idx < e; ++idx) + if (LatchTerm->getSuccessor(idx) == Header) { + LatchTerm->setSuccessor(idx, InsertBot); break; } if (DT) @@ -835,12 +815,13 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, ValueToValueMapTy LVMap; + Instruction *LatchTerm = + cast<Instruction>(cast<BasicBlock>(Latch)->getTerminator()); + // If we have branch weight information, we'll want to update it for the // newly created branches. - BranchInst *LatchBR = - cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator()); - uint64_t ExitWeight = 0, FallThroughWeight = 0; - initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); + DenseMap<Instruction *, WeightInfo> Weights; + initBranchWeights(Weights, L); // Identify what noalias metadata is inside the loop: if it is inside the // loop, the associated metadata must be cloned for each iteration. @@ -869,11 +850,15 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, assert(DT.verify(DominatorTree::VerificationLevel::Fast)); #endif - auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]); - updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight); + for (auto &[Term, Info] : Weights) { + auto *TermCopy = cast<Instruction>(VMap[Term]); + updateBranchWeights(TermCopy, Info); + } + // Remove Loop metadata from the latch branch instruction // because it is not the Loop's latch branch anymore. - LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr); + auto *LatchTermCopy = cast<Instruction>(VMap[LatchTerm]); + LatchTermCopy->setMetadata(LLVMContext::MD_loop, nullptr); InsertTop = InsertBot; InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), &DT, LI); @@ -896,7 +881,8 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); + for (const auto &[Term, Info] : Weights) + fixupBranchWeights(Term, Info); // Update Metadata for count of peeled off iterations. unsigned AlreadyPeeled = 0; |