aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
blob: 19906f15ae85f6eacffa9f135117893badcc54ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

using namespace mlir;

namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INTERPRETERPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir

/// Returns the payload operation to be used as payload root:
///   - the operation nested under `passRoot` that has the given tag attribute,
///     must be unique;
///   - the `passRoot` itself if the tag is empty.
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
  // Fast return.
  if (tag.empty())
    return passRoot;

  // Walk to do a lookup.
  Operation *target = nullptr;
  auto tagAttrName = StringAttr::get(
      passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
  WalkResult walkResult = passRoot->walk([&](Operation *op) {
    auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
    if (!attr || attr.getValue() != tag)
      return WalkResult::advance();

    if (!target) {
      target = op;
      return WalkResult::advance();
    }

    InFlightDiagnostic diag = op->emitError()
                              << "repeated operation with the target tag '"
                              << tag << "'";
    diag.attachNote(target->getLoc()) << "previously seen operation";
    return WalkResult::interrupt();
  });

  if (!target) {
    passRoot->emitError()
        << "could not find the operation with transform.target_tag=\"" << tag
        << "\" attribute";
    return nullptr;
  }

  return walkResult.wasInterrupted() ? nullptr : target;
}

namespace {
class InterpreterPass
    : public transform::impl::InterpreterPassBase<InterpreterPass> {
  // Parses the pass arguments to bind trailing arguments of the entry point.
  std::optional<RaggedArray<transform::MappedValue>>
  parseArguments(Operation *payloadRoot) {
    MLIRContext *context = payloadRoot->getContext();

    SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
    trailingBindings.resize(debugBindTrailingArgs.size());

    // Construct lists of op names to match.
    SmallVector<std::optional<OperationName>> debugBindNames;
    debugBindNames.reserve(debugBindTrailingArgs.size());
    for (auto &&[position, nameString] :
         llvm::enumerate(debugBindTrailingArgs)) {
      StringRef name = nameString;

      // Parse the integer literals.
      if (name.starts_with("#")) {
        debugBindNames.push_back(std::nullopt);
        StringRef lhs = "";
        StringRef rhs = name.drop_front();
        do {
          std::tie(lhs, rhs) = rhs.split(';');
          int64_t value;
          if (lhs.getAsInteger(10, value)) {
            emitError(UnknownLoc::get(context))
                << "couldn't parse integer pass argument " << name;
            return std::nullopt;
          }
          trailingBindings[position].push_back(
              Builder(context).getI64IntegerAttr(value));
        } while (!rhs.empty());
      } else if (name.starts_with("^")) {
        debugBindNames.emplace_back(OperationName(name.drop_front(), context));
      } else {
        debugBindNames.emplace_back(OperationName(name, context));
      }
    }

    // Collect operations or results for extra bindings.
    payloadRoot->walk([&](Operation *payload) {
      for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
        if (!name || payload->getName() != *name)
          continue;

        if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
                .starts_with("^")) {
          llvm::append_range(trailingBindings[position], payload->getResults());
        } else {
          trailingBindings[position].push_back(payload);
        }
      }
    });

    RaggedArray<transform::MappedValue> bindings;
    bindings.push_back(ArrayRef<Operation *>{payloadRoot});
    for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
      bindings.push_back(std::move(trailing));
    return bindings;
  }

public:
  using Base::Base;

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    ModuleOp transformModule =
        transform::detail::getPreloadedTransformModule(context);
    Operation *payloadRoot =
        findPayloadRoot(getOperation(), debugPayloadRootTag);
    if (!payloadRoot)
      return signalPassFailure();

    Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
        getOperation(), transformModule, entryPoint);
    if (!transformEntryPoint)
      return signalPassFailure();

    std::optional<RaggedArray<transform::MappedValue>> bindings =
        parseArguments(payloadRoot);
    if (!bindings)
      return signalPassFailure();
    if (failed(transform::applyTransformNamedSequence(
            *bindings,
            cast<transform::TransformOpInterface>(transformEntryPoint),
            transformModule,
            options.enableExpensiveChecks(!disableExpensiveChecks)))) {
      return signalPassFailure();
    }
  }

private:
  /// Transform interpreter options.
  transform::TransformOptions options;
};
} // namespace