diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 2515b16..ebcd820 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" @@ -505,6 +506,32 @@ static bool canProfitablyUnrollMultiExitLoop( // know of kinds of multiexit loops that would benefit from unrolling. } +// Assign the maximum possible trip count as the back edge weight for the +// remainder loop if the original loop comes with a branch weight. +static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, + Loop *RemainderLoop, + uint64_t UnrollFactor) { + uint64_t TrueWeight, FalseWeight; + BranchInst *LatchBR = + cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); + if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { + uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() + ? FalseWeight + : TrueWeight; + assert(UnrollFactor > 1); + uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; + BasicBlock *Header = RemainderLoop->getHeader(); + BasicBlock *Latch = RemainderLoop->getLoopLatch(); + auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); + unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(RemainderLatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + } +} + /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// @@ -788,6 +815,11 @@ bool llvm::UnrollRuntimeLoopRemainder( InsertTop, InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); + // Assign the maximum possible trip count as the back edge weight for the + // remainder loop if the original loop comes with a branch weight. + if (remainderLoop && !UnrollRemainder) + updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count); + // Insert the cloned blocks into the function. F->getBasicBlockList().splice(InsertBot->getIterator(), F->getBasicBlockList(), |