aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/IR')
-rw-r--r--llvm/lib/IR/Instruction.cpp19
-rw-r--r--llvm/lib/IR/Instructions.cpp6
-rw-r--r--llvm/lib/IR/MDBuilder.cpp14
-rw-r--r--llvm/lib/IR/Metadata.cpp8
-rw-r--r--llvm/lib/IR/ProfDataUtils.cpp40
-rw-r--r--llvm/lib/IR/Verifier.cpp9
6 files changed, 70 insertions, 26 deletions
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 29272e6..aec927a 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -1268,12 +1268,23 @@ Instruction *Instruction::cloneImpl() const {
void Instruction::swapProfMetadata() {
MDNode *ProfileData = getBranchWeightMDNode(*this);
- if (!ProfileData || ProfileData->getNumOperands() != 3)
+ if (!ProfileData)
+ return;
+ unsigned FirstIdx = getBranchWeightOffset(ProfileData);
+ if (ProfileData->getNumOperands() != 2 + FirstIdx)
return;
- // The first operand is the name. Fetch them backwards and build a new one.
- Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
- ProfileData->getOperand(1)};
+ unsigned SecondIdx = FirstIdx + 1;
+ SmallVector<Metadata *, 4> Ops;
+ // If there are more weights past the second, we can't swap them
+ if (ProfileData->getNumOperands() > SecondIdx + 1)
+ return;
+ for (unsigned Idx = 0; Idx < FirstIdx; ++Idx) {
+ Ops.push_back(ProfileData->getOperand(Idx));
+ }
+ // Switch the order of the weights
+ Ops.push_back(ProfileData->getOperand(SecondIdx));
+ Ops.push_back(ProfileData->getOperand(FirstIdx));
setMetadata(LLVMContext::MD_prof,
MDNode::get(ProfileData->getContext(), Ops));
}
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 1213f07..de369bd 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -5199,7 +5199,11 @@ void SwitchInstProfUpdateWrapper::init() {
if (!ProfileData)
return;
- if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
+ // FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
+ // getValidBranchWeightMDNode(), but the need to use llvm_unreachable
+ // makes them slightly different.
+ if (ProfileData->getNumOperands() !=
+ SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
llvm_unreachable("number of prof branch_weights metadata operands does "
"not correspond to number of succesors");
}
diff --git a/llvm/lib/IR/MDBuilder.cpp b/llvm/lib/IR/MDBuilder.cpp
index bd68db3..0000277 100644
--- a/llvm/lib/IR/MDBuilder.cpp
+++ b/llvm/lib/IR/MDBuilder.cpp
@@ -35,8 +35,8 @@ MDNode *MDBuilder::createFPMath(float Accuracy) {
}
MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
- uint32_t FalseWeight) {
- return createBranchWeights({TrueWeight, FalseWeight});
+ uint32_t FalseWeight, bool IsExpected) {
+ return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
}
MDNode *MDBuilder::createLikelyBranchWeights() {
@@ -49,15 +49,19 @@ MDNode *MDBuilder::createUnlikelyBranchWeights() {
return createBranchWeights(1, (1U << 20) - 1);
}
-MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
+MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
+ bool IsExpected) {
assert(Weights.size() >= 1 && "Need at least one branch weights!");
- SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
+ unsigned int Offset = IsExpected ? 2 : 1;
+ SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
Vals[0] = createString("branch_weights");
+ if (IsExpected)
+ Vals[1] = createString("expected");
Type *Int32Ty = Type::getInt32Ty(Context);
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
- Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
+ Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
return MDNode::get(Context, Vals);
}
diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index b6c9324..5f42ce2 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -1196,10 +1196,10 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
StringRef AProfName = AMDS->getString();
StringRef BProfName = BMDS->getString();
if (AProfName == "branch_weights" && BProfName == "branch_weights") {
- ConstantInt *AInstrWeight =
- mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
- ConstantInt *BInstrWeight =
- mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
+ ConstantInt *AInstrWeight = mdconst::dyn_extract<ConstantInt>(
+ A->getOperand(getBranchWeightOffset(A)));
+ ConstantInt *BInstrWeight = mdconst::dyn_extract<ConstantInt>(
+ B->getOperand(getBranchWeightOffset(B)));
assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
return MDNode::get(Ctx,
{MDHelper.createString("branch_weights"),
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 51e78dc..c4b1ed5 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -40,9 +40,6 @@ namespace {
// We maintain some constants here to ensure that we access the branch weights
// correctly, and can change the behavior in the future if the layout changes
-// The index at which the weights vector starts
-constexpr unsigned WeightsIdx = 1;
-
// the minimum number of operands for MD_prof nodes with branch weights
constexpr unsigned MinBWOps = 3;
@@ -75,6 +72,7 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
unsigned NOps = ProfileData->getNumOperands();
+ unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
Weights.resize(NOps - WeightsIdx);
@@ -82,8 +80,8 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
ConstantInt *Weight =
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(Weight && "Malformed branch_weight in MD_prof node");
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
+ assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
+ "Too many bits for MD_prof branch_weight");
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
}
}
@@ -123,6 +121,26 @@ bool hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}
+bool hasBranchWeightOrigin(const Instruction &I) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ return hasBranchWeightOrigin(ProfileData);
+}
+
+bool hasBranchWeightOrigin(const MDNode *ProfileData) {
+ if (!isBranchWeightMD(ProfileData))
+ return false;
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
+ // NOTE: if we ever have more types of branch weight provenance,
+ // we need to check the string value is "expected". For now, we
+ // supply a more generic API, and avoid the spurious comparisons.
+ assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
+ return ProfDataName != nullptr;
+}
+
+unsigned getBranchWeightOffset(const MDNode *ProfileData) {
+ return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
+}
+
MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
@@ -132,7 +150,9 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
- if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
+ auto Offset = getBranchWeightOffset(ProfileData);
+ if (ProfileData &&
+ ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
return ProfileData;
return nullptr;
}
@@ -191,7 +211,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;
if (ProfDataName->getString() == "branch_weights") {
- for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+ unsigned Offset = getBranchWeightOffset(ProfileData);
+ for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(V && "Malformed branch_weight in MD_prof node");
TotalVal += V->getValue().getZExtValue();
@@ -212,9 +233,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
-void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
+ bool IsExpected) {
MDBuilder MDB(I.getContext());
- MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+ MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index e592720..fe2253d 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -104,6 +104,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSlotTracker.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Statepoint.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
@@ -4808,8 +4809,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
// Check consistency of !prof branch_weights metadata.
if (ProfName == "branch_weights") {
+ unsigned int Offset = getBranchWeightOffset(MD);
if (isa<InvokeInst>(&I)) {
- Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
+ Check(MD->getNumOperands() == (1 + Offset) ||
+ MD->getNumOperands() == (2 + Offset),
"Wrong number of InvokeInst branch_weights operands", MD);
} else {
unsigned ExpectedNumOperands = 0;
@@ -4829,10 +4832,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
CheckFailed("!prof branch_weights are not allowed for this instruction",
MD);
- Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
+ Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
"Wrong number of operands", MD);
}
- for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
+ for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
auto &MDO = MD->getOperand(i);
Check(MDO, "second operand should not be null", MD);
Check(mdconst::dyn_extract<ConstantInt>(MDO),