aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 2515b16..ebcd820 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -27,6 +27,7 @@
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
@@ -505,6 +506,32 @@ static bool canProfitablyUnrollMultiExitLoop(
// know of kinds of multiexit loops that would benefit from unrolling.
}
+// Assign the maximum possible trip count as the back edge weight for the
+// remainder loop if the original loop comes with a branch weight.
+static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
+ Loop *RemainderLoop,
+ uint64_t UnrollFactor) {
+ uint64_t TrueWeight, FalseWeight;
+ BranchInst *LatchBR =
+ cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
+ if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) {
+ uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
+ ? FalseWeight
+ : TrueWeight;
+ assert(UnrollFactor > 1);
+ uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight;
+ BasicBlock *Header = RemainderLoop->getHeader();
+ BasicBlock *Latch = RemainderLoop->getLoopLatch();
+ auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator());
+ unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1);
+ MDBuilder MDB(RemainderLatchBR->getContext());
+ MDNode *WeightNode =
+ HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight)
+ : MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
+ RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
+ }
+}
+
/// Insert code in the prolog/epilog code when unrolling a loop with a
/// run-time trip-count.
///
@@ -788,6 +815,11 @@ bool llvm::UnrollRuntimeLoopRemainder(
InsertTop, InsertBot,
NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);
+ // Assign the maximum possible trip count as the back edge weight for the
+ // remainder loop if the original loop comes with a branch weight.
+ if (remainderLoop && !UnrollRemainder)
+ updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count);
+
// Insert the cloned blocks into the function.
F->getBasicBlockList().splice(InsertBot->getIterator(),
F->getBasicBlockList(),