diff options
-rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 40 |
2 files changed, 23 insertions, 18 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 2562301..ed7b9ec 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -784,6 +784,7 @@ public: /// place. class PatternRewriter : public RewriterBase { public: + explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} using RewriterBase::RewriterBase; /// A hook used to indicate if the pattern rewriter can recover from failure diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index cfd4f9c0..597cb29 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -319,8 +319,7 @@ private: /// This abstract class manages the worklist and contains helper methods for /// rewriting ops on the worklist. Derived classes specify how ops are added /// to the worklist in the beginning. -class GreedyPatternRewriteDriver : public PatternRewriter, - public RewriterBase::Listener { +class GreedyPatternRewriteDriver : public RewriterBase::Listener { protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, @@ -339,7 +338,8 @@ protected: /// Notify the driver that the specified operation was inserted. Update the /// worklist as needed: The operation is enqueued depending on scope and /// strict mode. - void notifyOperationInserted(Operation *op, InsertPoint previous) override; + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; /// Notify the driver that the specified operation was removed. Update the /// worklist as needed: The operation and its children are removed from the @@ -354,6 +354,10 @@ protected: /// reached. Return `true` if any IR was changed. bool processWorklist(); + /// The pattern rewriter that is used for making IR modifications and is + /// passed to rewrite patterns. + PatternRewriter rewriter; + /// The worklist for this transformation keeps track of the operations that /// need to be (re)visited. #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED @@ -407,7 +411,7 @@ private: GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : PatternRewriter(ctx), config(config), matcher(patterns) + : rewriter(ctx), config(config), matcher(patterns) #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // clang-format off , expensiveChecks( @@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Send IR notifications to the debug handler. This handler will then forward // all notifications to this GreedyPatternRewriteDriver. - setListener(&expensiveChecks); + rewriter.setListener(&expensiveChecks); #else - setListener(this); + rewriter.setListener(this); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } @@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { - eraseOp(op); + rewriter.eraseOp(op); changed = true; LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); @@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Op results can be replaced with `foldResults`. assert(foldResults.size() == op->getNumResults() && "folder produced incorrect number of results"); - OpBuilder::InsertionGuard g(*this); - setInsertionPoint(op); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); SmallVector<Value> replacements; bool materializationSucceeded = true; for (auto [ofr, resultType] : @@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { } // Materialize Attributes as SSA values. Operation *constOp = op->getDialect()->materializeConstant( - *this, ofr.get<Attribute>(), resultType, op->getLoc()); + rewriter, ofr.get<Attribute>(), resultType, op->getLoc()); if (!constOp) { // If materialization fails, cleanup any operations generated for @@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { replacementOps.insert(replacement.getDefiningOp()); } for (Operation *op : replacementOps) { - eraseOp(op); + rewriter.eraseOp(op); } materializationSucceeded = false; @@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { } if (materializationSucceeded) { - replaceOp(op, replacements); + rewriter.replaceOp(op, replacements); changed = true; LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS @@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS LogicalResult matchResult = - matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess); if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); @@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) { config.listener->notifyBlockErased(block); } -void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op, - InsertPoint previous) { +void GreedyPatternRewriteDriver::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; @@ -822,7 +826,7 @@ private: LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { bool continueRewrites = false; int64_t iteration = 0; - MLIRContext *ctx = getContext(); + MLIRContext *ctx = rewriter.getContext(); do { // Check if the iteration limit was reached. if (++iteration > config.maxIterations && @@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { // `OperationFolder` CSE's constant ops (and may move them into parents // regions to enable more aggressive CSE'ing). - OperationFolder folder(getContext(), this); + OperationFolder folder(ctx, this); auto insertKnownConstant = [&](Operation *op) { // Check for existing constants when populating the worklist. This avoids // accidentally reversing the constant order during processing. @@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { // After applying patterns, make sure that the CFG of each of the // regions is kept up to date. if (config.enableRegionSimplification) - continueRewrites |= succeeded(simplifyRegions(*this, region)); + continueRewrites |= succeeded(simplifyRegions(rewriter, region)); }, {®ion}, iteration); } while (continueRewrites); |