diff options
Diffstat (limited to 'llvm')
| -rw-r--r-- | llvm/include/llvm/IR/ProfDataUtils.h | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index a7bcbf010..f1c2f38 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -18,6 +18,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/Compiler.h" +#include <cstddef> +#include <type_traits> namespace llvm { struct MDProfLabels { @@ -216,9 +218,13 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T); /// branch weights B1 and B2, respectively. In both B1 and B2, the first /// position (index 0) is for the 'true' branch, and the second position (index /// 1) is for the 'false' branch. +template <typename T1, typename T2, + typename = typename std::enable_if< + std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && + sizeof(T1) <= sizeof(uint64_t) && sizeof(T2) <= sizeof(uint64_t)>> inline SmallVector<uint64_t, 2> -getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1, - const SmallVector<uint32_t, 2> &B2) { +getDisjunctionWeights(const SmallVector<T1, 2> &B1, + const SmallVector<T2, 2> &B2) { // For the first conditional branch, the probability the "true" case is taken // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is // p(not b1) = B1[1] / (B1[0] + B1[1]). @@ -235,8 +241,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1, // the product of sums, the subtracted one cancels out). assert(B1.size() == 2); assert(B2.size() == 2); - auto FalseWeight = B1[1] * B2[1]; - auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0]; + uint64_t FalseWeight = B1[1] * B2[1]; + uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0]; return {TrueWeight, FalseWeight}; } } // namespace llvm |
