aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
diff options
context:
space:
mode:
authorMatthias Braun <matze@braunis.de>2023-08-04 15:14:51 -0700
committerMatthias Braun <matze@braunis.de>2023-09-11 10:38:06 -0700
commit5d7f84ee17f3f601c49f6124a3a51e557de3ab53 (patch)
treec4e8407d4a165bb587ff8699a21bf8e0798b837e /llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
parent285e0235f5f649e17eef8f0f25e1680c975dcfe4 (diff)
downloadllvm-5d7f84ee17f3f601c49f6124a3a51e557de3ab53.zip
llvm-5d7f84ee17f3f601c49f6124a3a51e557de3ab53.tar.gz
llvm-5d7f84ee17f3f601c49f6124a3a51e557de3ab53.tar.bz2
LoopRotate: Add code to update branch weights
This adds code to the loop rotation transformation to ensure that the computed block execution counts for the loop bodies are the same before and after the transformation. This isn't always true in practice, but I believe this is because of numeric inaccuracies in the BlockFrequency computation. The invariants this is modeled on and heuristic choice of 0-trip loop amount is explained in a lenghty comment in the new `updateBranchWeights()` function. Differential Revision: https://reviews.llvm.org/D157462
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopRotationUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopRotationUtils.cpp106
1 files changed, 102 insertions, 4 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index d81db56..22effcf 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -25,6 +25,8 @@
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -50,6 +52,9 @@ static cl::opt<bool>
cl::desc("Allow loop rotation multiple times in order to reach "
"a better latch exit"));
+// Probability that a rotated loop has zero trip count / is never entered.
+static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
+
namespace {
/// A simple loop rotation transformation.
class LoopRotate {
@@ -244,6 +249,93 @@ static bool canRotateDeoptimizingLatchExit(Loop *L) {
return false;
}
+static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+ bool HasConditionalPreHeader,
+ bool SuccsSwapped) {
+ MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
+ if (WeightMD == nullptr)
+ return;
+
+ // LoopBI should currently be a clone of PreHeaderBI with the same
+ // metadata. But we double check to make sure we don't have a degenerate case
+ // where instsimplify changed the instructions.
+ if (WeightMD != getBranchWeightMDNode(LoopBI))
+ return;
+
+ SmallVector<uint32_t, 2> Weights;
+ extractFromBranchWeightMD(WeightMD, Weights);
+ if (Weights.size() != 2)
+ return;
+ uint32_t OrigLoopExitWeight = Weights[0];
+ uint32_t OrigLoopBackedgeWeight = Weights[1];
+
+ if (SuccsSwapped)
+ std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+
+ // Update branch weights. Consider the following edge-counts:
+ //
+ // | |-------- |
+ // V V | V
+ // Br i1 ... | Br i1 ...
+ // | | | | |
+ // x| y| | becomes: | y0| |-----
+ // V V | | V V |
+ // Exit Loop | | Loop |
+ // | | | Br i1 ... |
+ // ----- | | | |
+ // x0| x1| y1 | |
+ // V V ----
+ // Exit
+ //
+ // The following must hold:
+ // - x == x0 + x1 # counts to "exit" must stay the same.
+ // - y0 == x - x0 == x1 # how often loop was entered at all.
+ // - y1 == y - y0 # How often loop was repeated (after first iter.).
+ //
+ // We cannot generally deduce how often we had a zero-trip count loop so we
+ // have to make a guess for how to distribute x among the new x0 and x1.
+
+ uint32_t ExitWeight0 = 0; // aka x0
+ if (HasConditionalPreHeader) {
+ // Here we cannot know how many 0-trip count loops we have, so we guess:
+ if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
+ // If the loop count is bigger than the exit count then we set
+ // probabilities as if 0-trip count nearly never happens.
+ ExitWeight0 = ZeroTripCountWeights[0];
+ // Scale up counts if necessary so we can match `ZeroTripCountWeights` for
+ // the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
+ while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
+ // ... but don't overflow.
+ uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
+ if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
+ (OrigLoopExitWeight & HighBit) != 0)
+ break;
+ OrigLoopBackedgeWeight <<= 1;
+ OrigLoopExitWeight <<= 1;
+ }
+ } else {
+ // If there's a higher exit-count than backedge-count then we set
+ // probabilities as if there are only 0-trip and 1-trip cases.
+ ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
+ }
+ }
+ uint32_t ExitWeight1 = OrigLoopExitWeight - ExitWeight0; // aka x1
+ uint32_t EnterWeight = ExitWeight1; // aka y0
+ uint32_t LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; // aka y1
+
+ MDBuilder MDB(LoopBI.getContext());
+ MDNode *LoopWeightMD =
+ MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
+ SuccsSwapped ? ExitWeight1 : LoopBackWeight);
+ LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+ if (HasConditionalPreHeader) {
+ MDNode *PreHeaderWeightMD =
+ MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
+ SuccsSwapped ? ExitWeight0 : EnterWeight);
+ PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+ }
+}
+
/// Rotate loop LP. Return true if the loop is rotated.
///
/// \param SimplifiedLatch is true if the latch was just folded into the final
@@ -363,7 +455,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// loop. Otherwise loop is not suitable for rotation.
BasicBlock *Exit = BI->getSuccessor(0);
BasicBlock *NewHeader = BI->getSuccessor(1);
- if (L->contains(Exit))
+ bool BISuccsSwapped = L->contains(Exit);
+ if (BISuccsSwapped)
std::swap(Exit, NewHeader);
assert(NewHeader && "Unable to determine new loop header");
assert(L->contains(NewHeader) && !L->contains(Exit) &&
@@ -605,9 +698,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
// to split as many edges.
BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator());
assert(PHBI->isConditional() && "Should be clone of BI condbr!");
- if (!isa<ConstantInt>(PHBI->getCondition()) ||
- PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) !=
- NewHeader) {
+ const Value *Cond = PHBI->getCondition();
+ const bool HasConditionalPreHeader =
+ !isa<ConstantInt>(Cond) ||
+ PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
+
+ updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+
+ if (HasConditionalPreHeader) {
// The conditional branch can't be folded, handle the general case.
// Split edges as necessary to preserve LoopSimplify form.