aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
blob: daa1c371f27c927c1d836b44577b815ce01c1669 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
                                                  Operation *rootOp) {
  return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
                                                 Operation *root,
                                                 Operation *rootCopy) {
  return customSingleEntityConstraint(rewriter, rootCopy);
}
static LogicalResult customMultiEntityVariadicConstraint(
    PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
  if (operandValues.size() != 2 || typeValues.size() != 2)
    return failure();
  return success();
}

// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
  return rewriter.create(OperationState(op->getLoc(), "test.success"));
}
static auto customVariadicResultCreate(PatternRewriter &rewriter,
                                       Operation *root) {
  return std::make_pair(root->getOperands(), root->getOperands().getTypes());
}
static Type customCreateType(PatternRewriter &rewriter) {
  return rewriter.getF32Type();
}
static std::string customCreateStrAttr(PatternRewriter &rewriter) {
  return "test.str";
}

/// Custom rewriter invoked from PDL.
static void customRewriter(PatternRewriter &rewriter, Operation *root,
                           Value input) {
  rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
                  input);
  rewriter.eraseOp(root);
}

namespace {
struct TestPDLByteCodePass
    : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)

  StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
  StringRef getDescription() const final {
    return "Test PDL ByteCode functionality";
  }
  void getDependentDialects(DialectRegistry &registry) const override {
    // Mark the pdl_interp dialect as a dependent. This is needed, because we
    // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
    registry.insert<pdl_interp::PDLInterpDialect>();
  }
  void runOnOperation() final {
    ModuleOp module = getOperation();

    // The test cases are encompassed via two modules, one containing the
    // patterns and one containing the operations to rewrite.
    ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
        StringAttr::get(module->getContext(), "patterns"));
    ModuleOp irModule = module.lookupSymbol<ModuleOp>(
        StringAttr::get(module->getContext(), "ir"));
    if (!patternModule || !irModule)
      return;

    RewritePatternSet patternList(module->getContext());

    // Register ahead of time to test when functions are registered without a
    // pattern.
    patternList.getPDLPatterns().registerConstraintFunction(
        "multi_entity_constraint", customMultiEntityConstraint);
    patternList.getPDLPatterns().registerConstraintFunction(
        "single_entity_constraint", customSingleEntityConstraint);

    // Process the pattern module.
    patternModule.getOperation()->remove();
    PDLPatternModule pdlPattern(patternModule);

    // Note: This constraint was already registered, but we re-register here to
    // ensure that duplication registration is allowed (the duplicate mapping
    // will be ignored). This tests that we support separating the registration
    // of library functions from the construction of patterns, and also that we
    // allow multiple patterns to depend on the same library functions (without
    // asserting/crashing).
    pdlPattern.registerConstraintFunction("multi_entity_constraint",
                                          customMultiEntityConstraint);
    pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                          customMultiEntityVariadicConstraint);
    pdlPattern.registerRewriteFunction("creator", customCreate);
    pdlPattern.registerRewriteFunction("var_creator",
                                       customVariadicResultCreate);
    pdlPattern.registerRewriteFunction("type_creator", customCreateType);
    pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
    pdlPattern.registerRewriteFunction("rewriter", customRewriter);
    patternList.add(std::move(pdlPattern));

    // Invoke the pattern driver with the provided patterns.
    (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
                                       std::move(patternList));
  }
};
} // namespace

namespace mlir {
namespace test {
void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
} // namespace test
} // namespace mlir