diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/Local.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/Local.cpp | 29 |
1 files changed, 12 insertions, 17 deletions
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index a8a7b64..00cbee9a 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -210,20 +211,18 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i->getCaseSuccessor() == DefaultDest) { - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = getValidBranchWeightMDNode(*SI); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. - if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) { + if (NCases > 1 && MD) { // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; - for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; - ++MD_i) { - auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + extractBranchWeights(MD, Weights); + // Merge weight of this case to the default weight. unsigned idx = i->getCaseIndex(); + // TODO: Add overflow check. Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -313,18 +312,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); - if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(2)); - ConstantInt *SIDef = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(1)); - assert(SICase && SIDef); + SmallVector<uint32_t> Weights; + if (extractBranchWeights(*SI, Weights) && Weights.size() == 2) { + uint32_t DefWeight = Weights[0]; + uint32_t CaseWeight = Weights[1]; // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(SICase->getValue().getZExtValue(), - SIDef->getValue().getZExtValue())); + MDBuilder(BB->getContext()) + .createBranchWeights(CaseWeight, DefWeight)); } // Update make.implicit metadata to the newly-created conditional branch. |