aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/InlinerPass.cpp
blob: 43ca5cac8b76f3ac9b60358f8937e5f645135703 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
//===- InlinerPass.cpp - Pass to inline function calls --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a basic inlining algorithm that operates bottom up over
// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
// incremental propagation of inlining decisions from the leafs to the roots of
// the callgraph.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/Passes.h"

#include "mlir/Analysis/CallGraph.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Inliner.h"

namespace mlir {
#define GEN_PASS_DEF_INLINER
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "inliner-pass"

using namespace mlir;

/// This function implements the inliner optimization pipeline.
static void defaultInlinerOptPipeline(OpPassManager &pm) {
  pm.addPass(createCanonicalizerPass());
}

//===----------------------------------------------------------------------===//
// InlinerPass
//===----------------------------------------------------------------------===//

namespace {
class InlinerPass : public impl::InlinerBase<InlinerPass> {
public:
  InlinerPass();
  InlinerPass(const InlinerPass &) = default;
  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
              llvm::StringMap<OpPassManager> opPipelines);
  void runOnOperation() override;

  /// A callback provided to the inliner driver to execute
  /// the specified pass pipeline on the given operation
  /// within the context of the current inliner pass,
  /// which is passed as the first argument.
  /// runPipeline API is protected within the Pass class,
  /// so this helper is required to call it from the foreign
  /// inliner driver.
  static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
                                         Operation *op) {
    return mlir::cast<InlinerPass>(pass).runPipeline(pipeline, op);
  }

private:
  /// Attempt to initialize the options of this pass from the given string.
  /// Derived classes may override this method to hook into the point at which
  /// options are initialized, but should generally always invoke this base
  /// class variant.
  LogicalResult initializeOptions(
      StringRef options,
      function_ref<LogicalResult(const Twine &)> errorHandler) override;

  /// Inliner configuration parameters created from the pass options.
  InlinerConfig config;
};
} // namespace

InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}

InlinerPass::InlinerPass(
    std::function<void(OpPassManager &)> defaultPipelineArg)
    : InlinerPass(std::move(defaultPipelineArg),
                  llvm::StringMap<OpPassManager>{}) {}

InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
                         llvm::StringMap<OpPassManager> opPipelines)
    : config(std::move(defaultPipeline), maxInliningIterations) {
  if (opPipelines.empty())
    return;

  // Update the option for the op specific optimization pipelines.
  for (auto &it : opPipelines)
    opPipelineList.addValue(it.second);
  config.setOpPipelines(std::move(opPipelines));
}

// Return true if the inlining ratio does not exceed the threshold.
static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
                                 unsigned inliningThreshold) {
  // Return early, ratio <= 0U will always be false.
  if (inliningThreshold == 0U)
    return false;
  // Return early, ratio <= -1U will always be true.
  if (inliningThreshold == -1U)
    return true;

  Region *callerRegion = resolvedCall.sourceNode->getCallableRegion();
  Region *calleeRegion = resolvedCall.targetNode->getCallableRegion();

  // We should not get external nodes here, but just return true
  // for now to preserve the original behavior of the inliner pass.
  if (!callerRegion || !calleeRegion)
    return true;

  auto countOps = [](Region *region) {
    unsigned count = 0;
    region->walk([&](Operation *) { ++count; });
    return count;
  };

  unsigned callerOps = countOps(callerRegion);

  // Always inline empty callees (if it is possible at all).
  if (callerOps == 0)
    return true;

  unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
  LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
                          << inliningThreshold << "%): " << ratio << "%\n");
  return ratio <= inliningThreshold;
}

void InlinerPass::runOnOperation() {
  CallGraph &cg = getAnalysis<CallGraph>();

  // The inliner should only be run on operations that define a symbol table,
  // as the callgraph will need to resolve references.
  Operation *op = getOperation();
  if (!op->hasTrait<OpTrait::SymbolTable>()) {
    op->emitOpError() << " was scheduled to run under the inliner, but does "
                         "not define a symbol table";
    return signalPassFailure();
  }

  // By default, assume that any inlining is profitable.
  auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
    return isProfitableToInline(call, inliningThreshold);
  };

  // Get an instance of the inliner.
  Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
                  config, profitabilityCb);

  // Run the inlining.
  if (failed(inliner.doInlining()))
    signalPassFailure();
  return;
}

LogicalResult InlinerPass::initializeOptions(
    StringRef options,
    function_ref<LogicalResult(const Twine &)> errorHandler) {
  if (failed(Pass::initializeOptions(options, errorHandler)))
    return failure();

  // Initialize the pipeline builder for operations without the dedicated
  // optimization pipeline in opPipelineList to use the option string.
  // TODO: Use a generic pass manager for the pre-inline pipeline, and remove
  // this.
  if (!defaultPipelineStr.empty()) {
    std::string defaultPipelineCopy = defaultPipelineStr;
    config.setDefaultPipeline([=](OpPassManager &pm) {
      (void)parsePassPipeline(defaultPipelineCopy, pm);
    });
  } else if (defaultPipelineStr.getNumOccurrences()) {
    config.setDefaultPipeline(nullptr);
  }

  // Initialize the op specific pass pipelines.
  llvm::StringMap<OpPassManager> pipelines;
  for (OpPassManager pipeline : opPipelineList)
    if (!pipeline.empty())
      pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
  config.setOpPipelines(std::move(pipelines));

  config.setMaxInliningIterations(maxInliningIterations);

  return success();
}

std::unique_ptr<Pass> mlir::createInlinerPass() {
  return std::make_unique<InlinerPass>();
}
std::unique_ptr<Pass>
mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
                                       std::move(opPipelines));
}
std::unique_ptr<Pass> mlir::createInlinerPass(
    llvm::StringMap<OpPassManager> opPipelines,
    std::function<void(OpPassManager &)> defaultPipelineBuilder) {
  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
                                       std::move(opPipelines));
}