aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Pass/TestDynamicPipeline.cpp
blob: b46bf08657d57a02d69398bc7a14103ebe180eaf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to test the dynamic pipeline feature.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

using namespace mlir;

namespace {

class TestDynamicPipelinePass
    : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
public:
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDynamicPipelinePass)

  StringRef getArgument() const final { return "test-dynamic-pipeline"; }
  StringRef getDescription() const final {
    return "Tests the dynamic pipeline feature by applying "
           "a pipeline on a selected set of functions";
  }
  void getDependentDialects(DialectRegistry &registry) const override {
    OpPassManager pm(ModuleOp::getOperationName(),
                     OpPassManager::Nesting::Implicit);
    (void)parsePassPipeline(pipeline, pm, llvm::errs());
    pm.getDependentDialects(registry);
  }

  TestDynamicPipelinePass() = default;
  TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}

  void runOnOperation() override {
    Operation *currentOp = getOperation();

    llvm::errs() << "Dynamic execute '" << pipeline << "' on "
                 << currentOp->getName() << "\n";
    if (pipeline.empty()) {
      llvm::errs() << "Empty pipeline\n";
      return;
    }
    auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
    if (!symbolOp) {
      currentOp->emitWarning()
          << "Ignoring because not implementing SymbolOpInterface\n";
      return;
    }

    auto opName = symbolOp.getName();
    if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
      llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
      return;
    }
    OpPassManager pm(currentOp->getName().getIdentifier(),
                     OpPassManager::Nesting::Implicit);
    (void)parsePassPipeline(pipeline, pm, llvm::errs());

    // Check that running on the parent operation always immediately fails.
    if (runOnParent) {
      if (currentOp->getParentOp())
        if (!failed(runPipeline(pm, currentOp->getParentOp())))
          signalPassFailure();
      return;
    }

    if (runOnNestedOp) {
      llvm::errs() << "Run on nested op\n";
      currentOp->walk([&](Operation *op) {
        if (op == currentOp || !op->hasTrait<OpTrait::IsIsolatedFromAbove>() ||
            op->getName() != currentOp->getName())
          return;
        llvm::errs() << "Run on " << *op << "\n";
        // Run on the current operation
        if (failed(runPipeline(pm, op)))
          signalPassFailure();
      });
    } else {
      // Run on the current operation
      if (failed(runPipeline(pm, currentOp)))
        signalPassFailure();
    }
  }

  Option<bool> runOnNestedOp{
      *this, "run-on-nested-operations",
      llvm::cl::desc("This will apply the pipeline on nested operations under "
                     "the visited operation.")};
  Option<bool> runOnParent{
      *this, "run-on-parent",
      llvm::cl::desc("This will apply the pipeline on the parent operation if "
                     "it exist, this is expected to fail.")};
  Option<std::string> pipeline{
      *this, "dynamic-pipeline",
      llvm::cl::desc("The pipeline description that "
                     "will run on the filtered function.")};
  ListOption<std::string> opNames{
      *this, "op-name",
      llvm::cl::desc("List of function name to apply the pipeline to")};
};
} // namespace

namespace mlir {
namespace test {
void registerTestDynamicPipelinePass() {
  PassRegistration<TestDynamicPipelinePass>();
}
} // namespace test
} // namespace mlir