aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/ProfDataUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/IR/ProfDataUtils.cpp')
-rw-r--r--llvm/lib/IR/ProfDataUtils.cpp123
1 files changed, 78 insertions, 45 deletions
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 99029c1..fc2be51 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -12,6 +12,7 @@
#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
@@ -19,11 +20,10 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/CommandLine.h"
using namespace llvm;
-namespace {
-
// MD_prof nodes have the following layout
//
// In general:
@@ -39,14 +39,15 @@ namespace {
// correctly, and can change the behavior in the future if the layout changes
// the minimum number of operands for MD_prof nodes with branch weights
-constexpr unsigned MinBWOps = 3;
+static constexpr unsigned MinBWOps = 3;
// the minimum number of operands for MD_prof nodes with value profiles
-constexpr unsigned MinVPOps = 5;
+static constexpr unsigned MinVPOps = 5;
// We may want to add support for other MD_prof types, so provide an abstraction
// for checking the metadata type.
-bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
+static bool isTargetMD(const MDNode *ProfData, const char *Name,
+ unsigned MinOps) {
// TODO: This routine may be simplified if MD_prof used an enum instead of a
// string to differentiate the types of MD_prof nodes.
if (!ProfData || !Name || MinOps < 2)
@@ -84,10 +85,28 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
}
}
-} // namespace
-
-namespace llvm {
+/// Push the weights right to fit in uint32_t.
+static SmallVector<uint32_t> fitWeights(ArrayRef<uint64_t> Weights) {
+ SmallVector<uint32_t> Ret;
+ Ret.reserve(Weights.size());
+ uint64_t Max = *llvm::max_element(Weights);
+ if (Max > UINT_MAX) {
+ unsigned Offset = 32 - llvm::countl_zero(Max);
+ for (const uint64_t &Value : Weights)
+ Ret.push_back(static_cast<uint32_t>(Value >> Offset));
+ } else {
+ append_range(Ret, Weights);
+ }
+ return Ret;
+}
+static cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
+#if defined(LLVM_ENABLE_PROFCHECK)
+ cl::init(false)
+#else
+ cl::init(true)
+#endif
+);
const char *MDProfLabels::BranchWeights = "branch_weights";
const char *MDProfLabels::ExpectedBranchWeights = "expected";
const char *MDProfLabels::ValueProfile = "VP";
@@ -95,21 +114,21 @@ const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
const char *MDProfLabels::SyntheticFunctionEntryCount =
"synthetic_function_entry_count";
const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
-const char *LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
+const char *llvm::LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
-bool hasProfMD(const Instruction &I) {
+bool llvm::hasProfMD(const Instruction &I) {
return I.hasMetadata(LLVMContext::MD_prof);
}
-bool isBranchWeightMD(const MDNode *ProfileData) {
+bool llvm::isBranchWeightMD(const MDNode *ProfileData) {
return isTargetMD(ProfileData, MDProfLabels::BranchWeights, MinBWOps);
}
-bool isValueProfileMD(const MDNode *ProfileData) {
+bool llvm::isValueProfileMD(const MDNode *ProfileData) {
return isTargetMD(ProfileData, MDProfLabels::ValueProfile, MinVPOps);
}
-bool hasBranchWeightMD(const Instruction &I) {
+bool llvm::hasBranchWeightMD(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return isBranchWeightMD(ProfileData);
}
@@ -124,16 +143,16 @@ static bool hasCountTypeMD(const Instruction &I) {
return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
}
-bool hasValidBranchWeightMD(const Instruction &I) {
+bool llvm::hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}
-bool hasBranchWeightOrigin(const Instruction &I) {
+bool llvm::hasBranchWeightOrigin(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return hasBranchWeightOrigin(ProfileData);
}
-bool hasBranchWeightOrigin(const MDNode *ProfileData) {
+bool llvm::hasBranchWeightOrigin(const MDNode *ProfileData) {
if (!isBranchWeightMD(ProfileData))
return false;
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
@@ -145,54 +164,54 @@ bool hasBranchWeightOrigin(const MDNode *ProfileData) {
return ProfDataName != nullptr;
}
-unsigned getBranchWeightOffset(const MDNode *ProfileData) {
+unsigned llvm::getBranchWeightOffset(const MDNode *ProfileData) {
return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
}
-unsigned getNumBranchWeights(const MDNode &ProfileData) {
+unsigned llvm::getNumBranchWeights(const MDNode &ProfileData) {
return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
}
-MDNode *getBranchWeightMDNode(const Instruction &I) {
+MDNode *llvm::getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
return nullptr;
return ProfileData;
}
-MDNode *getValidBranchWeightMDNode(const Instruction &I) {
+MDNode *llvm::getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
return ProfileData;
return nullptr;
}
-void extractFromBranchWeightMD32(const MDNode *ProfileData,
- SmallVectorImpl<uint32_t> &Weights) {
+void llvm::extractFromBranchWeightMD32(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
extractFromBranchWeightMD(ProfileData, Weights);
}
-void extractFromBranchWeightMD64(const MDNode *ProfileData,
- SmallVectorImpl<uint64_t> &Weights) {
+void llvm::extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights) {
extractFromBranchWeightMD(ProfileData, Weights);
}
-bool extractBranchWeights(const MDNode *ProfileData,
- SmallVectorImpl<uint32_t> &Weights) {
+bool llvm::extractBranchWeights(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
if (!isBranchWeightMD(ProfileData))
return false;
extractFromBranchWeightMD(ProfileData, Weights);
return true;
}
-bool extractBranchWeights(const Instruction &I,
- SmallVectorImpl<uint32_t> &Weights) {
+bool llvm::extractBranchWeights(const Instruction &I,
+ SmallVectorImpl<uint32_t> &Weights) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return extractBranchWeights(ProfileData, Weights);
}
-bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
- uint64_t &FalseVal) {
+bool llvm::extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
+ uint64_t &FalseVal) {
assert((I.getOpcode() == Instruction::Br ||
I.getOpcode() == Instruction::Select) &&
"Looking for branch weights on something besides branch, select, or "
@@ -211,7 +230,8 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
return true;
}
-bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
+bool llvm::extractProfTotalWeight(const MDNode *ProfileData,
+ uint64_t &TotalVal) {
TotalVal = 0;
if (!ProfileData)
return false;
@@ -239,11 +259,12 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;
}
-bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
+bool llvm::extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
-void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) {
+void llvm::setExplicitlyUnknownBranchWeights(Instruction &I,
+ StringRef PassName) {
MDBuilder MDB(I.getContext());
I.setMetadata(
LLVMContext::MD_prof,
@@ -252,14 +273,16 @@ void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) {
MDB.createString(PassName)}));
}
-void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, Function &F,
- StringRef PassName) {
+void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I,
+ Function &F,
+ StringRef PassName) {
if (std::optional<Function::ProfileCount> EC = F.getEntryCount();
EC && EC->getCount() > 0)
setExplicitlyUnknownBranchWeights(I, PassName);
}
-void setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) {
+void llvm::setExplicitlyUnknownFunctionEntryCount(Function &F,
+ StringRef PassName) {
MDBuilder MDB(F.getContext());
F.setMetadata(
LLVMContext::MD_prof,
@@ -268,28 +291,40 @@ void setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) {
MDB.createString(PassName)}));
}
-bool isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
+bool llvm::isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
if (MD.getNumOperands() != 2)
return false;
return MD.getOperand(0).equalsStr(MDProfLabels::UnknownBranchWeightsMarker);
}
-bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
+bool llvm::hasExplicitlyUnknownBranchWeights(const Instruction &I) {
auto *MD = I.getMetadata(LLVMContext::MD_prof);
if (!MD)
return false;
return isExplicitlyUnknownProfileMetadata(*MD);
}
-void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
- bool IsExpected) {
+void llvm::setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
+ bool IsExpected, bool ElideAllZero) {
+ if ((ElideAllZeroBranchWeights && ElideAllZero) &&
+ llvm::all_of(Weights, [](uint32_t V) { return V == 0; })) {
+ I.setMetadata(LLVMContext::MD_prof, nullptr);
+ return;
+ }
+
MDBuilder MDB(I.getContext());
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
-SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
- std::optional<uint64_t> KnownMaxCount) {
+void llvm::setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
+ bool IsExpected, bool ElideAllZero) {
+ setBranchWeights(I, fitWeights(Weights), IsExpected, ElideAllZero);
+}
+
+SmallVector<uint32_t>
+llvm::downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount) {
uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
: *llvm::max_element(Weights);
assert(MaxCount > 0 && "Bad max count");
@@ -300,7 +335,7 @@ SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
return DownscaledWeights;
}
-void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
+void llvm::scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
assert(T != 0 && "Caller should guarantee");
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (ProfileData == nullptr)
@@ -353,5 +388,3 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
}
I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
}
-
-} // namespace llvm