aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Instrumentation/RemoveTrapsPass.cpp
blob: 694dd3c04407f7827b16ed3aba3e59bec256b10e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//===- RemoveTrapsPass.cpp --------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Instrumentation/RemoveTrapsPass.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/RandomNumberGenerator.h"
#include <memory>
#include <random>

using namespace llvm;

#define DEBUG_TYPE "remove-traps"

static cl::opt<int> HotPercentileCutoff("remove-traps-percentile-cutoff-hot",
                                        cl::desc("Hot percentile cuttoff."));

static cl::opt<float>
    RandomRate("remove-traps-random-rate",
               cl::desc("Probability value in the range [0.0, 1.0] of "
                        "unconditional pseudo-random checks removal."));

STATISTIC(NumChecksTotal, "Number of checks");
STATISTIC(NumChecksRemoved, "Number of removed checks");

static bool removeUbsanTraps(Function &F, const BlockFrequencyInfo &BFI,
                             const ProfileSummaryInfo *PSI) {
  SmallVector<IntrinsicInst *, 16> Remove;
  std::unique_ptr<RandomNumberGenerator> Rng;

  auto ShouldRemove = [&](bool IsHot) {
    if (!RandomRate.getNumOccurrences())
      return IsHot;
    if (!Rng)
      Rng = F.getParent()->createRNG(F.getName());
    std::bernoulli_distribution D(RandomRate);
    return D(*Rng);
  };

  for (BasicBlock &BB : F) {
    for (Instruction &I : BB) {
      IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
      if (!II)
        continue;
      auto ID = II->getIntrinsicID();
      switch (ID) {
      case Intrinsic::ubsantrap: {
        ++NumChecksTotal;

        bool IsHot = false;
        if (PSI) {
          uint64_t Count = 0;
          for (const auto *PR : predecessors(&BB))
            Count += BFI.getBlockProfileCount(PR).value_or(0);
          IsHot = PSI->isHotCountNthPercentile(HotPercentileCutoff, Count);
        }

        if (ShouldRemove(IsHot)) {
          Remove.push_back(II);
          ++NumChecksRemoved;
        }
        break;
      }
      default:
        break;
      }
    }
  }

  for (IntrinsicInst *I : Remove)
    I->eraseFromParent();

  return !Remove.empty();
}

PreservedAnalyses RemoveTrapsPass::run(Function &F,
                                       FunctionAnalysisManager &AM) {
  if (F.isDeclaration())
    return PreservedAnalyses::all();
  auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
  ProfileSummaryInfo *PSI =
      MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
  BlockFrequencyInfo &BFI = AM.getResult<BlockFrequencyAnalysis>(F);

  return removeUbsanTraps(F, BFI, PSI) ? PreservedAnalyses::none()
                                       : PreservedAnalyses::all();
}