diff options
author | Matthias Springer <me@m-sp.org> | 2024-01-16 08:55:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-16 08:55:25 +0100 |
commit | a02a0e806fab01f4cf4307443cdaed76a2488752 (patch) | |
tree | a9e10b1b7a3e40cee0852e3a973f6eb7337574ff /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | |
parent | af1463d403182720ae0e3fab07634817dd0f41be (diff) | |
download | llvm-a02a0e806fab01f4cf4307443cdaed76a2488752.zip llvm-a02a0e806fab01f4cf4307443cdaed76a2488752.tar.gz llvm-a02a0e806fab01f4cf4307443cdaed76a2488752.tar.bz2 |
[mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (#78175)
This change moves most IR verification logic (which is part of the
expensive checks) into `DebugFingerPrints` and renames the struct to
`ExpensiveChecks`. This isolates the debugging logic better from the
remaining code.
This commit also removes a redundant check: the IR is no longer verified
after a failed pattern application. We already assert that the IR did
not change. (We know that the IR was valid before the attempted pattern
application.)
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 66 |
1 files changed, 42 insertions, 24 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index d31408d..36d63d6 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -43,12 +43,18 @@ namespace { //===----------------------------------------------------------------------===// #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS -/// A helper struct that stores finger prints of ops in order to detect broken -/// RewritePatterns. A rewrite pattern is broken if it modifies IR without -/// using the rewriter API or if it returns an inconsistent return value. -struct DebugFingerPrints : public RewriterBase::ForwardingListener { - DebugFingerPrints(RewriterBase::Listener *driver) - : RewriterBase::ForwardingListener(driver) {} +/// A helper struct that performs various "expensive checks" to detect broken +/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is +/// broken if: +/// * IR does not verify after pattern application / folding. +/// * Pattern returns "failure" but the IR has changed. +/// * Pattern returns "success" but the IR has not changed. +/// +/// This struct stores finger prints of ops to determine whether the IR has +/// changed or not. +struct ExpensiveChecks : public RewriterBase::ForwardingListener { + ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel) + : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {} /// Compute finger prints of the given op and its nested ops. void computeFingerPrints(Operation *topLevel) { @@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener { } void notifyRewriteSuccess() { + if (!topLevel) + return; + + // Make sure that the IR still verifies. + if (failed(verify(topLevel))) + llvm::report_fatal_error("IR failed to verify after pattern application"); + // Pattern application success => IR must have changed. OperationFingerPrint afterFingerPrint(topLevel); if (*topLevelFingerPrint == afterFingerPrint) { @@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener { } void notifyRewriteFailure() { + if (!topLevel) + return; + // Pattern application failure => IR must not have changed. OperationFingerPrint afterFingerPrint(topLevel); if (*topLevelFingerPrint != afterFingerPrint) { @@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener { } } + void notifyFoldingSuccess() { + if (!topLevel) + return; + + // Make sure that the IR still verifies. + if (failed(verify(topLevel))) + llvm::report_fatal_error("IR failed to verify after folding"); + } + protected: /// Invalidate the finger print of the given op, i.e., remove it from the map. void invalidateFingerPrint(Operation *op) { @@ -362,7 +387,7 @@ private: PatternApplicator matcher; #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - DebugFingerPrints debugFingerPrints; + ExpensiveChecks expensiveChecks; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS }; } // namespace @@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( : PatternRewriter(ctx), config(config), matcher(patterns) #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // clang-format off - , debugFingerPrints(this) + , expensiveChecks( + /*driver=*/this, + /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr) // clang-format on #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS { @@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Send IR notifications to the debug handler. This handler will then forward // all notifications to this GreedyPatternRewriteDriver. - setListener(&debugFingerPrints); + setListener(&expensiveChecks); #else setListener(this); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS @@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { changed = true; LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope && failed(verify(config.scope->getParentOp()))) - llvm::report_fatal_error("IR failed to verify after folding"); + expensiveChecks.notifyFoldingSuccess(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS continue; } @@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { changed = true; LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope && failed(verify(config.scope->getParentOp()))) - llvm::report_fatal_error("IR failed to verify after folding"); + expensiveChecks.notifyFoldingSuccess(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS continue; } @@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() { #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (config.scope) { - debugFingerPrints.computeFingerPrints(config.scope->getParentOp()); + expensiveChecks.computeFingerPrints(config.scope->getParentOp()); } auto clearFingerprints = - llvm::make_scope_exit([&]() { debugFingerPrints.clear(); }); + llvm::make_scope_exit([&]() { expensiveChecks.clear(); }); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope && failed(verify(config.scope->getParentOp()))) - llvm::report_fatal_error("IR failed to verify after pattern application"); -#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope) - debugFingerPrints.notifyRewriteSuccess(); + expensiveChecks.notifyRewriteSuccess(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS changed = true; ++numRewrites; } else { LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (config.scope) - debugFingerPrints.notifyRewriteFailure(); + expensiveChecks.notifyRewriteFailure(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } } |