aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp45
1 files changed, 22 insertions, 23 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36317e0..4a37730 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config,
- const Region &scope);
+ const GreedyRewriteConfig &config);
/// Simplify the ops within the given region.
bool simplify(Region &region) &&;
@@ -103,9 +102,6 @@ protected:
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
- /// Only ops within this scope are simplified.
- const Region &scope;
-
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
@@ -116,9 +112,9 @@ private:
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, const Region &scope)
- : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
- scope(scope) {
+ const GreedyRewriteConfig &config)
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+ assert(config.scope && "scope is not specified");
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
@@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
SmallVector<Operation *, 8> ancestors;
ancestors.push_back(op);
while (Region *region = op->getParentRegion()) {
- if (&scope == region) {
+ if (config.scope == region) {
// All gathered ops are in fact ancestors.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");
+ // Set scope if not specified.
+ if (!config.scope)
+ config.scope = &region;
+
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
- region);
+ GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
bool converged = std::move(driver).simplify(region);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const Region &scope, GreedyRewriteStrictness strictMode,
+ GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
- : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
+ : GreedyPatternRewriteDriver(ctx, patterns, config),
strictMode(strictMode), survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
@@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
return region;
}
-LogicalResult
-mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed,
- bool *allErased, Region *scope) {
+LogicalResult mlir::applyOpPatternsAndFold(
+ ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
+ bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
return success();
}
- if (!scope) {
+ // Determine scope of rewrite.
+ if (!config.scope) {
// Compute scope if none was provided.
- scope = findCommonAncestor(ops);
+ config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
- return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+ return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
@@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- *scope, strictMode,
+ strictMode, config,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
- << GreedyRewriteConfig().maxNumRewrites << " rewrites";
+ << config.maxNumRewrites << " rewrites";
});
return converged;
}