aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp')
-rw-r--r--llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp39
1 files changed, 16 insertions, 23 deletions
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 096faad..a072ba2 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -29,6 +29,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@@ -575,32 +576,26 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT,
return true;
}
-// Returns true and sets the true probability and false probability of an
-// MD_prof metadata if it's well-formed.
-static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb,
- BranchProbability &FalseProb) {
- if (!MD) return false;
- MDString *MDName = cast<MDString>(MD->getOperand(0));
- if (MDName->getString() != "branch_weights" ||
- MD->getNumOperands() != 3)
+// Constructs the true and false branch probabilities if the the instruction has
+// valid branch weights. Returns true when this was successful, false otherwise.
+static bool extractBranchProbabilities(Instruction *I,
+ BranchProbability &TrueProb,
+ BranchProbability &FalseProb) {
+ uint64_t TrueWeight;
+ uint64_t FalseWeight;
+ if (!extractBranchWeights(*I, TrueWeight, FalseWeight))
return false;
- ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1));
- ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2));
- if (!TrueWeight || !FalseWeight)
- return false;
- uint64_t TrueWt = TrueWeight->getValue().getZExtValue();
- uint64_t FalseWt = FalseWeight->getValue().getZExtValue();
- uint64_t SumWt = TrueWt + FalseWt;
+ uint64_t SumWeight = TrueWeight + FalseWeight;
- assert(SumWt >= TrueWt && SumWt >= FalseWt &&
+ assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight &&
"Overflow calculating branch probabilities.");
// Guard against 0-to-0 branch weights to avoid a division-by-zero crash.
- if (SumWt == 0)
+ if (SumWeight == 0)
return false;
- TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
- FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
+ TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight);
+ FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight);
return true;
}
@@ -639,8 +634,7 @@ static bool checkBiasedBranch(BranchInst *BI, Region *R,
if (!BI->isConditional())
return false;
BranchProbability ThenProb, ElseProb;
- if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof),
- ThenProb, ElseProb))
+ if (!extractBranchProbabilities(BI, ThenProb, ElseProb))
return false;
BasicBlock *IfThen = BI->getSuccessor(0);
BasicBlock *IfElse = BI->getSuccessor(1);
@@ -669,8 +663,7 @@ static bool checkBiasedSelect(
DenseSet<SelectInst *> &FalseBiasedSelectsGlobal,
DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) {
BranchProbability TrueProb, FalseProb;
- if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof),
- TrueProb, FalseProb))
+ if (!extractBranchProbabilities(SI, TrueProb, FalseProb))
return false;
CHR_DEBUG(dbgs() << "SI " << *SI << " ");
CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " ");