diff options
author | Matthias Springer <springerm@google.com> | 2023-01-27 14:14:41 +0100 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2023-01-27 14:33:54 +0100 |
commit | 977cddb95eac67a6dc6680a7d0fadee81114de11 (patch) | |
tree | 305259835d8452d8823b5adec7d017881437e0fc /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | |
parent | 78fee46d8d8e82158b79a4ad948e8723e89f7f65 (diff) | |
download | llvm-977cddb95eac67a6dc6680a7d0fadee81114de11.zip llvm-977cddb95eac67a6dc6680a7d0fadee81114de11.tar.gz llvm-977cddb95eac67a6dc6680a7d0fadee81114de11.tar.bz2 |
[mlir] GreedyPatternRewriteDriver: All entry points take a config
The multi-op entry point now also takes a GreedyPatternRewriteConfig and respects config.maxNumRewrites. The scope is also a part of the config now.
Differential Revision: https://reviews.llvm.org/D142614
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 45 |
1 files changed, 22 insertions, 23 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 36317e0..4a37730 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - const Region &scope); + const GreedyRewriteConfig &config); /// Simplify the ops within the given region. bool simplify(Region ®ion) &&; @@ -103,9 +102,6 @@ protected: /// Configuration information for how to simplify. const GreedyRewriteConfig config; - /// Only ops within this scope are simplified. - const Region &scope; - private: #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -116,9 +112,9 @@ private: GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, const Region &scope) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), - scope(scope) { + const GreedyRewriteConfig &config) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + assert(config.scope && "scope is not specified"); worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { SmallVector<Operation *, 8> ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (&scope == region) { + if (config.scope == region) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) addSingleOpToWorklist(op); @@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion, assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && "patterns can only be applied to operations IsolatedFromAbove"); + // Set scope if not specified. + if (!config.scope) + config.scope = ®ion; + // Start the pattern driver. - GreedyPatternRewriteDriver driver(region.getContext(), patterns, config, - region); + GreedyPatternRewriteDriver driver(region.getContext(), patterns, config); bool converged = std::move(driver).simplify(region); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " @@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const Region &scope, GreedyRewriteStrictness strictMode, + GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config, llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), + : GreedyPatternRewriteDriver(ctx, patterns, config), strictMode(strictMode), survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these @@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) { return region; } -LogicalResult -mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed, - bool *allErased, Region *scope) { +LogicalResult mlir::applyOpPatternsAndFold( + ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, GreedyRewriteConfig config, + bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; @@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops, return success(); } - if (!scope) { + // Determine scope of rewrite. + if (!config.scope) { // Compute scope if none was provided. - scope = findCommonAncestor(ops); + config.scope = findCommonAncestor(ops); } else { // If a scope was provided, make sure that all ops are in scope. #ifndef NDEBUG bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { - return static_cast<bool>(scope->findAncestorOpInRegion(*op)); + return static_cast<bool>(config.scope->findAncestorOpInRegion(*op)); }); assert(allOpsInScope && "ops must be within the specified scope"); #endif // NDEBUG @@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops, // Start the pattern driver. llvm::SmallDenseSet<Operation *, 4> surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - *scope, strictMode, + strictMode, config, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after " - << GreedyRewriteConfig().maxNumRewrites << " rewrites"; + << config.maxNumRewrites << " rewrites"; }); return converged; } |