aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
blob: 56af3c15b905f1c913faa35cebefca18dd23962a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
                                                  Operation *rootOp) {
  return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
                                                 Operation *root,
                                                 Operation *rootCopy) {
  return customSingleEntityConstraint(rewriter, rootCopy);
}
static LogicalResult customMultiEntityVariadicConstraint(
    PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
  if (operandValues.size() != 2 || typeValues.size() != 2)
    return failure();
  return success();
}

// Custom constraint that returns a value if the op is named test.success_op
static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
                                                 PDLResultList &results,
                                                 ArrayRef<PDLValue> args) {
  auto *op = args[0].cast<Operation *>();
  if (op->getName().getStringRef() == "test.success_op") {
    StringAttr customAttr = rewriter.getStringAttr("test.success");
    results.push_back(customAttr);
    return success();
  }
  return failure();
}

// Custom constraint that returns a type if the op is named test.success_op
static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
                                                PDLResultList &results,
                                                ArrayRef<PDLValue> args) {
  auto *op = args[0].cast<Operation *>();
  if (op->getName().getStringRef() == "test.success_op") {
    results.push_back(rewriter.getF32Type());
    return success();
  }
  return failure();
}

// Custom constraint that returns a type range of variable length if the op is
// named test.success_op
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
                                                     PDLResultList &results,
                                                     ArrayRef<PDLValue> args) {
  auto *op = args[0].cast<Operation *>();
  int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();

  if (op->getName().getStringRef() == "test.success_op") {
    SmallVector<Type> types;
    for (int i = 0; i < numTypes; i++) {
      types.push_back(rewriter.getF32Type());
    }
    results.push_back(TypeRange(types));
    return success();
  }
  return failure();
}

// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
  return rewriter.create(OperationState(op->getLoc(), "test.success"));
}
static auto customVariadicResultCreate(PatternRewriter &rewriter,
                                       Operation *root) {
  return std::make_pair(root->getOperands(), root->getOperands().getTypes());
}
static Type customCreateType(PatternRewriter &rewriter) {
  return rewriter.getF32Type();
}
static std::string customCreateStrAttr(PatternRewriter &rewriter) {
  return "test.str";
}

/// Custom rewriter invoked from PDL.
static void customRewriter(PatternRewriter &rewriter, Operation *root,
                           Value input) {
  rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
                  input);
  rewriter.eraseOp(root);
}

namespace {
struct TestPDLByteCodePass
    : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)

  StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
  StringRef getDescription() const final {
    return "Test PDL ByteCode functionality";
  }
  void getDependentDialects(DialectRegistry &registry) const override {
    // Mark the pdl_interp dialect as a dependent. This is needed, because we
    // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
    registry.insert<pdl_interp::PDLInterpDialect>();
  }
  void runOnOperation() final {
    ModuleOp module = getOperation();

    // The test cases are encompassed via two modules, one containing the
    // patterns and one containing the operations to rewrite.
    ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
        StringAttr::get(module->getContext(), "patterns"));
    ModuleOp irModule = module.lookupSymbol<ModuleOp>(
        StringAttr::get(module->getContext(), "ir"));
    if (!patternModule || !irModule)
      return;

    RewritePatternSet patternList(module->getContext());

    // Register ahead of time to test when functions are registered without a
    // pattern.
    patternList.getPDLPatterns().registerConstraintFunction(
        "multi_entity_constraint", customMultiEntityConstraint);
    patternList.getPDLPatterns().registerConstraintFunction(
        "single_entity_constraint", customSingleEntityConstraint);

    // Process the pattern module.
    patternModule.getOperation()->remove();
    PDLPatternModule pdlPattern(patternModule);

    // Note: This constraint was already registered, but we re-register here to
    // ensure that duplication registration is allowed (the duplicate mapping
    // will be ignored). This tests that we support separating the registration
    // of library functions from the construction of patterns, and also that we
    // allow multiple patterns to depend on the same library functions (without
    // asserting/crashing).
    pdlPattern.registerConstraintFunction("multi_entity_constraint",
                                          customMultiEntityConstraint);
    pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                          customMultiEntityVariadicConstraint);
    pdlPattern.registerConstraintFunction("op_constr_return_attr",
                                          customValueResultConstraint);
    pdlPattern.registerConstraintFunction("op_constr_return_type",
                                          customTypeResultConstraint);
    pdlPattern.registerConstraintFunction("op_constr_return_type_range",
                                          customTypeRangeResultConstraint);
    pdlPattern.registerRewriteFunction("creator", customCreate);
    pdlPattern.registerRewriteFunction("var_creator",
                                       customVariadicResultCreate);
    pdlPattern.registerRewriteFunction("type_creator", customCreateType);
    pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
    pdlPattern.registerRewriteFunction("rewriter", customRewriter);
    patternList.add(std::move(pdlPattern));

    // Invoke the pattern driver with the provided patterns.
    (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
                                       std::move(patternList));
  }
};
} // namespace

namespace mlir {
namespace test {
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
} // namespace test
} // namespace mlir