aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-05-24 16:14:47 +0200
committerMatthias Springer <me@m-sp.org>2023-05-24 16:22:08 +0200
commite6d90a0d5e202166a9846f1845196086aa02f35e (patch)
treefa5258afd4bc36356f0f92f21cccef304e01ed51 /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
parent0ea5eb143c7a63105b4479c455642b4c4ef3278c (diff)
downloadllvm-e6d90a0d5e202166a9846f1845196086aa02f35e.zip
llvm-e6d90a0d5e202166a9846f1845196086aa02f35e.tar.gz
llvm-e6d90a0d5e202166a9846f1845196086aa02f35e.tar.bz2
[mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns
Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns. Detect patterns that: * Returned `failure` but changed the IR. * Returned `success` but did not change the IR. * Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected. These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated). Differential Revision: https://reviews.llvm.org/D144552
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp135
1 files changed, 132 insertions, 3 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index f6e7fa1..c05b639 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -30,10 +32,108 @@ using namespace mlir;
#define DEBUG_TYPE "greedy-rewriter"
//===----------------------------------------------------------------------===//
-// GreedyPatternRewriteDriver
+// Debugging Infrastructure
//===----------------------------------------------------------------------===//
namespace {
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+/// A helper struct that stores finger prints of ops in order to detect broken
+/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
+/// using the rewriter API or if it returns an inconsistent return value.
+struct DebugFingerPrints : public RewriterBase::ForwardingListener {
+ DebugFingerPrints(RewriterBase::Listener *driver)
+ : RewriterBase::ForwardingListener(driver) {}
+
+ /// Compute finger prints of the given op and its nested ops.
+ void computeFingerPrints(Operation *topLevel) {
+ this->topLevel = topLevel;
+ this->topLevelFingerPrint.emplace(topLevel);
+ topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
+ }
+
+ /// Clear all finger prints.
+ void clear() {
+ topLevel = nullptr;
+ topLevelFingerPrint.reset();
+ fingerprints.clear();
+ }
+
+ void notifyRewriteSuccess() {
+ // Pattern application success => IR must have changed.
+ OperationFingerPrint afterFingerPrint(topLevel);
+ if (*topLevelFingerPrint == afterFingerPrint) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error(
+ "pattern returned success but IR did not change");
+ }
+ for (const auto &it : fingerprints) {
+ // Skip top-level op, its finger print is never invalidated.
+ if (it.first == topLevel)
+ continue;
+ // Note: Finger print computation may crash when an op was erased
+ // without notifying the rewriter. (Run with ASAN to see where the op was
+ // erased; the op was probably erased directly, bypassing the rewriter
+ // API.) Finger print computation does may not crash if a new op was
+ // created at the same memory location. (But then the finger print should
+ // have changed.)
+ if (it.second != OperationFingerPrint(it.first)) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error("operation finger print changed");
+ }
+ }
+ }
+
+ void notifyRewriteFailure() {
+ // Pattern application failure => IR must not have changed.
+ OperationFingerPrint afterFingerPrint(topLevel);
+ if (*topLevelFingerPrint != afterFingerPrint) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error("pattern returned failure but IR did change");
+ }
+ }
+
+protected:
+ /// Invalidate the finger print of the given op, i.e., remove it from the map.
+ void invalidateFingerPrint(Operation *op) {
+ // Invalidate all finger prints until the top level.
+ while (op && op != topLevel) {
+ fingerprints.erase(op);
+ op = op->getParentOp();
+ }
+ }
+
+ void notifyOperationInserted(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationInserted(op);
+ invalidateFingerPrint(op->getParentOp());
+ }
+
+ void notifyOperationModified(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationModified(op);
+ invalidateFingerPrint(op);
+ }
+
+ void notifyOperationRemoved(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationRemoved(op);
+ op->walk([this](Operation *op) { invalidateFingerPrint(op); });
+ }
+
+ /// Operation finger prints to detect invalid pattern API usage. IR is checked
+ /// against these finger prints after pattern application to detect cases
+ /// where IR was modified directly, bypassing the rewriter API.
+ DenseMap<Operation *, OperationFingerPrint> fingerprints;
+
+ /// Top-level operation of the current greedy rewrite.
+ Operation *topLevel = nullptr;
+
+ /// Finger print of the top-level operation.
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
+};
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+//===----------------------------------------------------------------------===//
+// GreedyPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns.
///
@@ -122,21 +222,36 @@ private:
/// The low-level pattern applicator.
PatternApplicator matcher;
+
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ DebugFingerPrints debugFingerPrints;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), folder(ctx, this), config(config),
- matcher(patterns) {
+ : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // clang-format off
+ , debugFingerPrints(this)
+// clang-format on
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+{
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
// Set up listener.
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Send IR notifications to the debug handler. This handler will then forward
+ // all notifications to this GreedyPatternRewriteDriver.
+ setListener(&debugFingerPrints);
+#else
setListener(this);
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
bool GreedyPatternRewriteDriver::processWorklist() {
@@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() {
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
#endif
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.computeFingerPrints(
+ /*topLevel=*/config.scope ? config.scope->getParentOp() : op);
+ auto clearFingerprints =
+ llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.notifyRewriteSuccess();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.notifyRewriteFailure();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
@@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+ assert(op && "expected valid op");
// Gather potential ancestors while looking for a "scope" parent region.
SmallVector<Operation *, 8> ancestors;
Region *region = nullptr;