diff options
Diffstat (limited to 'llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp')
-rw-r--r-- | llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp | 39 |
1 files changed, 16 insertions, 23 deletions
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 096faad..a072ba2 100644 --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -575,32 +576,26 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, return true; } -// Returns true and sets the true probability and false probability of an -// MD_prof metadata if it's well-formed. -static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb, - BranchProbability &FalseProb) { - if (!MD) return false; - MDString *MDName = cast<MDString>(MD->getOperand(0)); - if (MDName->getString() != "branch_weights" || - MD->getNumOperands() != 3) +// Constructs the true and false branch probabilities if the the instruction has +// valid branch weights. Returns true when this was successful, false otherwise. +static bool extractBranchProbabilities(Instruction *I, + BranchProbability &TrueProb, + BranchProbability &FalseProb) { + uint64_t TrueWeight; + uint64_t FalseWeight; + if (!extractBranchWeights(*I, TrueWeight, FalseWeight)) return false; - ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1)); - ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2)); - if (!TrueWeight || !FalseWeight) - return false; - uint64_t TrueWt = TrueWeight->getValue().getZExtValue(); - uint64_t FalseWt = FalseWeight->getValue().getZExtValue(); - uint64_t SumWt = TrueWt + FalseWt; + uint64_t SumWeight = TrueWeight + FalseWeight; - assert(SumWt >= TrueWt && SumWt >= FalseWt && + assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight && "Overflow calculating branch probabilities."); // Guard against 0-to-0 branch weights to avoid a division-by-zero crash. - if (SumWt == 0) + if (SumWeight == 0) return false; - TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); - FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight); + FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight); return true; } @@ -639,8 +634,7 @@ static bool checkBiasedBranch(BranchInst *BI, Region *R, if (!BI->isConditional()) return false; BranchProbability ThenProb, ElseProb; - if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof), - ThenProb, ElseProb)) + if (!extractBranchProbabilities(BI, ThenProb, ElseProb)) return false; BasicBlock *IfThen = BI->getSuccessor(0); BasicBlock *IfElse = BI->getSuccessor(1); @@ -669,8 +663,7 @@ static bool checkBiasedSelect( DenseSet<SelectInst *> &FalseBiasedSelectsGlobal, DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) { BranchProbability TrueProb, FalseProb; - if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof), - TrueProb, FalseProb)) + if (!extractBranchProbabilities(SI, TrueProb, FalseProb)) return false; CHR_DEBUG(dbgs() << "SI " << *SI << " "); CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " "); |