aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h1
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp40
2 files changed, 23 insertions, 18 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2562301..ed7b9ec 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -784,6 +784,7 @@ public:
/// place.
class PatternRewriter : public RewriterBase {
public:
+ explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
using RewriterBase::RewriterBase;
/// A hook used to indicate if the pattern rewriter can recover from failure
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c0..597cb29 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -319,8 +319,7 @@ private:
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter,
- public RewriterBase::Listener {
+class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ protected:
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
- void notifyOperationInserted(Operation *op, InsertPoint previous) override;
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ protected:
/// reached. Return `true` if any IR was changed.
bool processWorklist();
+ /// The pattern rewriter that is used for making IR modifications and is
+ /// passed to rewrite patterns.
+ PatternRewriter rewriter;
+
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ private:
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), config(config), matcher(patterns)
+ : rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
- setListener(&expensiveChecks);
+ rewriter.setListener(&expensiveChecks);
#else
- setListener(this);
+ rewriter.setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
- eraseOp(op);
+ rewriter.eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
- OpBuilder::InsertionGuard g(*this);
- setInsertionPoint(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
- *this, ofr.get<Attribute>(), resultType, op->getLoc());
+ rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
if (!constOp) {
// If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
- eraseOp(op);
+ rewriter.eraseOp(op);
}
materializationSucceeded = false;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
if (materializationSucceeded) {
- replaceOp(op, replacements);
+ rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
- matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
+ matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
config.listener->notifyBlockErased(block);
}
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
- InsertPoint previous) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
@@ -822,7 +826,7 @@ private:
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
bool continueRewrites = false;
int64_t iteration = 0;
- MLIRContext *ctx = getContext();
+ MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
- OperationFolder folder(getContext(), this);
+ OperationFolder folder(ctx, this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
- continueRewrites |= succeeded(simplifyRegions(*this, region));
+ continueRewrites |= succeeded(simplifyRegions(rewriter, region));
},
{&region}, iteration);
} while (continueRewrites);