aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
blob: ee5c642c943c451b9b27909f9ebea43831450ebf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements mlir::walkAndApplyPatterns.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "walk-rewriter"

namespace mlir {

namespace {
struct WalkAndApplyPatternsAction final
    : tracing::ActionImpl<WalkAndApplyPatternsAction> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
  using ActionImpl::ActionImpl;
  static constexpr StringLiteral tag = "walk-and-apply-patterns";
  void print(raw_ostream &os) const override { os << tag; }
};

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Forwarding listener to guard against unsupported erasures of non-descendant
// ops/blocks. Because we use walk-based pattern application, erasing the
// op/block from the *next* iteration (e.g., a user of the visited op) is not
// valid. Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
  using RewriterBase::ForwardingListener::ForwardingListener;

  void notifyOperationErased(Operation *op) override {
    checkErasure(op);
    ForwardingListener::notifyOperationErased(op);
  }

  void notifyBlockErased(Block *block) override {
    checkErasure(block->getParentOp());
    ForwardingListener::notifyBlockErased(block);
  }

  void checkErasure(Operation *op) const {
    Operation *ancestorOp = op;
    while (ancestorOp && ancestorOp != visitedOp)
      ancestorOp = ancestorOp->getParentOp();

    if (ancestorOp != visitedOp)
      llvm::report_fatal_error(
          "unsupported erasure in WalkPatternRewriter; "
          "erasure is only supported for matched ops and their descendants");
  }

  Operation *visitedOp = nullptr;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
} // namespace

void walkAndApplyPatterns(Operation *op,
                          const FrozenRewritePatternSet &patterns,
                          RewriterBase::Listener *listener) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
  if (failed(verify(op)))
    llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

  MLIRContext *ctx = op->getContext();
  PatternRewriter rewriter(ctx);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
  ErasedOpsListener erasedListener(listener);
  rewriter.setListener(&erasedListener);
#else
  rewriter.setListener(listener);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

  PatternApplicator applicator(patterns);
  applicator.applyDefaultCostModel();

  ctx->executeAction<WalkAndApplyPatternsAction>(
      [&] {
        for (Region &region : op->getRegions()) {
          region.walk([&](Operation *visitedOp) {
            LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
                llvm::dbgs(), OpPrintingFlags().skipRegions());
                       llvm::dbgs() << "\n";);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
            erasedListener.visitedOp = visitedOp;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
            if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
              LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
            }
          });
        }
      },
      {op});

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
  if (failed(verify(op)))
    llvm::report_fatal_error(
        "walk pattern rewriter result IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

} // namespace mlir