aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/IR/ProfDataUtils.h14
1 files changed, 10 insertions, 4 deletions
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index a7bcbf010..f1c2f38 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -18,6 +18,8 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/Compiler.h"
+#include <cstddef>
+#include <type_traits>
namespace llvm {
struct MDProfLabels {
@@ -216,9 +218,13 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
/// branch weights B1 and B2, respectively. In both B1 and B2, the first
/// position (index 0) is for the 'true' branch, and the second position (index
/// 1) is for the 'false' branch.
+template <typename T1, typename T2,
+ typename = typename std::enable_if<
+ std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
+ sizeof(T1) <= sizeof(uint64_t) && sizeof(T2) <= sizeof(uint64_t)>>
inline SmallVector<uint64_t, 2>
-getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
- const SmallVector<uint32_t, 2> &B2) {
+getDisjunctionWeights(const SmallVector<T1, 2> &B1,
+ const SmallVector<T2, 2> &B2) {
// For the first conditional branch, the probability the "true" case is taken
// is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
// p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +241,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
// the product of sums, the subtracted one cancels out).
assert(B1.size() == 2);
assert(B2.size() == 2);
- auto FalseWeight = B1[1] * B2[1];
- auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+ uint64_t FalseWeight = B1[1] * B2[1];
+ uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0];
return {TrueWeight, FalseWeight};
}
} // namespace llvm