aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp101
1 files changed, 60 insertions, 41 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 831b487..1c88500 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -56,6 +56,17 @@ static cl::opt<bool> UnrollRuntimeOtherExitPredictable(
"unroll-runtime-other-exit-predictable", cl::init(false), cl::Hidden,
cl::desc("Assume the non latch exit block to be predictable"));
+// Probability that the loop trip count is so small that after the prolog
+// we do not enter the unrolled loop at all.
+// It is unlikely that the loop trip count is smaller than the unroll factor;
+// other than that, the choice of constant is not tuned yet.
+static const uint32_t UnrolledLoopHeaderWeights[] = {1, 127};
+// Probability that the loop trip count is so small that we skip the unrolled
+// loop completely and immediately enter the epilogue loop.
+// It is unlikely that the loop trip count is smaller than the unroll factor;
+// other than that, the choice of constant is not tuned yet.
+static const uint32_t EpilogHeaderWeights[] = {1, 127};
+
/// Connect the unrolling prolog code to the original loop.
/// The unrolling prolog code contains code to execute the
/// 'extra' iterations if the run-time trip count modulo the
@@ -169,7 +180,14 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
SplitBlockPredecessors(OriginalLoopLatchExit, Preds, ".unr-lcssa", DT, LI,
nullptr, PreserveLCSSA);
// Add the branch to the exit block (around the unrolled loop)
- B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume loop is nearly always entered.
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(UnrolledLoopHeaderWeights);
+ }
+ B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader,
+ BranchWeights);
InsertPt->eraseFromParent();
if (DT) {
auto *NewDom = DT->findNearestCommonDominator(OriginalLoopLatchExit,
@@ -194,8 +212,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
BasicBlock *Exit, BasicBlock *PreHeader,
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
ValueToValueMapTy &VMap, DominatorTree *DT,
- LoopInfo *LI, bool PreserveLCSSA,
- ScalarEvolution &SE) {
+ LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
+ unsigned Count) {
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Loop must have a latch");
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -292,7 +310,13 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr,
PreserveLCSSA);
// Add the branch to the exit block (around the unrolling loop)
- B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume equal distribution in interval [0, Count).
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(1, Count - 1);
+ }
+ B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
InsertPt->eraseFromParent();
if (DT) {
auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit);
@@ -316,8 +340,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
const bool UnrollRemainder,
BasicBlock *InsertTop,
BasicBlock *InsertBot, BasicBlock *Preheader,
- std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks,
- ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI) {
+ std::vector<BasicBlock *> &NewBlocks,
+ LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
+ DominatorTree *DT, LoopInfo *LI, unsigned Count) {
StringRef suffix = UseEpilogRemainder ? "epil" : "prol";
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
@@ -371,7 +396,26 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
Value *IdxNext =
Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");
Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp");
- Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*LatchBR)) {
+ uint32_t ExitWeight;
+ uint32_t BackEdgeWeight;
+ if (Count >= 3) {
+ // Note: We do not enter this loop for zero-remainders. The check
+ // is at the end of the loop. We assume equal distribution between
+ // possible remainders in [1, Count).
+ ExitWeight = 1;
+ BackEdgeWeight = (Count - 2) / 2;
+ } else {
+ // Unnecessary backedge, should never be taken. The conditional
+ // jump should be optimized away later.
+ ExitWeight = 1;
+ BackEdgeWeight = 0;
+ }
+ MDBuilder MDB(Builder.getContext());
+ BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
+ }
+ Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
NewIdx->addIncoming(Zero, InsertTop);
NewIdx->addIncoming(IdxNext, NewBB);
LatchBR->eraseFromParent();
@@ -465,32 +509,6 @@ static bool canProfitablyUnrollMultiExitLoop(
// know of kinds of multiexit loops that would benefit from unrolling.
}
-// Assign the maximum possible trip count as the back edge weight for the
-// remainder loop if the original loop comes with a branch weight.
-static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
- Loop *RemainderLoop,
- uint64_t UnrollFactor) {
- uint64_t TrueWeight, FalseWeight;
- BranchInst *LatchBR =
- cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
- if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
- return;
- uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
- ? FalseWeight
- : TrueWeight;
- assert(UnrollFactor > 1);
- uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight;
- BasicBlock *Header = RemainderLoop->getHeader();
- BasicBlock *Latch = RemainderLoop->getLoopLatch();
- auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator());
- unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1);
- MDBuilder MDB(RemainderLatchBR->getContext());
- MDNode *WeightNode =
- HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight)
- : MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
- RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
-}
-
/// Calculate ModVal = (BECount + 1) % Count on the abstract integer domain
/// accounting for the possibility of unsigned overflow in the 2s complement
/// domain. Preconditions:
@@ -776,7 +794,13 @@ bool llvm::UnrollRuntimeLoopRemainder(
BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader;
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
// Branch to either remainder (extra iterations) loop or unrolling loop.
- B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop);
+ MDNode *BranchWeights = nullptr;
+ if (hasBranchWeightMD(*Latch->getTerminator())) {
+ // Assume loop is nearly always entered.
+ MDBuilder MDB(B.getContext());
+ BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights);
+ }
+ B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
PreHeaderBR->eraseFromParent();
if (DT) {
if (UseEpilogRemainder)
@@ -805,12 +829,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
Loop *remainderLoop = CloneLoopBlocks(
L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
- NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);
-
- // Assign the maximum possible trip count as the back edge weight for the
- // remainder loop if the original loop comes with a branch weight.
- if (remainderLoop && !UnrollRemainder)
- updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count);
+ NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
// Insert the cloned blocks into the function.
F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end());
@@ -904,7 +923,7 @@ bool llvm::UnrollRuntimeLoopRemainder(
// Connect the epilog code to the original loop and update the
// PHI functions.
ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE);
+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count);
// Update counter in loop for unrolling.
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.