aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/Local.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/Local.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/Local.cpp29
1 files changed, 12 insertions, 17 deletions
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index a8a7b64..00cbee9a 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -62,6 +62,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
@@ -210,20 +211,18 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Check to see if this branch is going to the same place as the default
// dest. If so, eliminate it as an explicit compare.
if (i->getCaseSuccessor() == DefaultDest) {
- MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
+ MDNode *MD = getValidBranchWeightMDNode(*SI);
unsigned NCases = SI->getNumCases();
// Fold the case metadata into the default if there will be any branches
// left, unless the metadata doesn't match the switch.
- if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) {
+ if (NCases > 1 && MD) {
// Collect branch weights into a vector.
SmallVector<uint32_t, 8> Weights;
- for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
- ++MD_i) {
- auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i));
- Weights.push_back(CI->getValue().getZExtValue());
- }
+ extractBranchWeights(MD, Weights);
+
// Merge weight of this case to the default weight.
unsigned idx = i->getCaseIndex();
+ // TODO: Add overflow check.
Weights[0] += Weights[idx+1];
// Remove weight for this case.
std::swap(Weights[idx+1], Weights.back());
@@ -313,18 +312,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
BranchInst *NewBr = Builder.CreateCondBr(Cond,
FirstCase.getCaseSuccessor(),
SI->getDefaultDest());
- MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
- if (MD && MD->getNumOperands() == 3) {
- ConstantInt *SICase =
- mdconst::dyn_extract<ConstantInt>(MD->getOperand(2));
- ConstantInt *SIDef =
- mdconst::dyn_extract<ConstantInt>(MD->getOperand(1));
- assert(SICase && SIDef);
+ SmallVector<uint32_t> Weights;
+ if (extractBranchWeights(*SI, Weights) && Weights.size() == 2) {
+ uint32_t DefWeight = Weights[0];
+ uint32_t CaseWeight = Weights[1];
// The TrueWeight should be the weight for the single case of SI.
NewBr->setMetadata(LLVMContext::MD_prof,
- MDBuilder(BB->getContext()).
- createBranchWeights(SICase->getValue().getZExtValue(),
- SIDef->getValue().getZExtValue()));
+ MDBuilder(BB->getContext())
+ .createBranchWeights(CaseWeight, DefWeight));
}
// Update make.implicit metadata to the newly-created conditional branch.