aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2023-01-27 11:09:13 +0100
committerMatthias Springer <springerm@google.com>2023-01-27 11:23:04 +0100
commita2b837ab0448869c74cc042155dd454833c60d62 (patch)
treed4163edb5c7cabe9b1f2cd94858d0eb0aeb0409a /mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
parentbf5f63e59fd729efd6dd69318b365293f7ce385d (diff)
downloadllvm-a2b837ab0448869c74cc042155dd454833c60d62.zip
llvm-a2b837ab0448869c74cc042155dd454833c60d62.tar.gz
llvm-a2b837ab0448869c74cc042155dd454833c60d62.tar.bz2
[mlir] GreedyPatternRewriteDriver: Entry point takes single region
The rewrite driver is typically applied to a single region or all regions of the same op. There is no longer an overload to apply the rewrite driver to a list of regions. This simplifies the rewrite driver implementation because the scope is now a single region as opposed to a list of regions. Note: This change is not NFC because `config.maxIterations` and `config.maxNumRewrites` is now counted for each region separately. Furthermore, worklist filtering (`scope`) is now applied to each region separately. Differential Revision: https://reviews.llvm.org/D142611
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp48
1 files changed, 15 insertions, 33 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index a5ddd91..36317e0 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -40,10 +40,10 @@ public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
- const DenseSet<Region *> &scope);
+ const Region &scope);
- /// Simplify the operations within the given regions.
- bool simplify(MutableArrayRef<Region> regions) &&;
+ /// Simplify the ops within the given region.
+ bool simplify(Region &region) &&;
/// Add the given operation and its ancestors to the worklist.
void addToWorklist(Operation *op);
@@ -104,7 +104,7 @@ protected:
const GreedyRewriteConfig config;
/// Only ops within this scope are simplified.
- const DenseSet<Region *> scope;
+ const Region &scope;
private:
#ifndef NDEBUG
@@ -116,7 +116,7 @@ private:
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, const DenseSet<Region *> &scope)
+ const GreedyRewriteConfig &config, const Region &scope)
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
scope(scope) {
worklist.reserve(64);
@@ -125,7 +125,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
matcher.applyDefaultCostModel();
}
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
+bool GreedyPatternRewriteDriver::simplify(Region &region) && {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -167,15 +167,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
- for (auto &region : regions) {
region.walk([&](Operation *op) {
if (!insertKnownConstant(op))
addToWorklist(op);
});
- }
} else {
// Add all nested operations to the worklist in preorder.
- for (auto &region : regions) {
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!insertKnownConstant(op)) {
worklist.push_back(op);
@@ -183,7 +180,6 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
}
return WalkResult::skip();
});
- }
// Reverse the list so our pop-back loop processes them in-order.
std::reverse(worklist.begin(), worklist.end());
@@ -305,7 +301,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
// After applying patterns, make sure that the CFG of each of the regions
// is kept up to date.
if (config.enableRegionSimplification)
- changed |= succeeded(simplifyRegions(*this, regions));
+ changed |= succeeded(simplifyRegions(*this, region));
} while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
@@ -317,7 +313,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
SmallVector<Operation *, 8> ancestors;
ancestors.push_back(op);
while (Region *region = op->getParentRegion()) {
- if (scope.contains(region)) {
+ if (&scope == region) {
// All gathered ops are in fact ancestors.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -429,31 +425,19 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
/// top-level operation itself.
///
LogicalResult
-mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+mlir::applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config) {
- if (regions.empty())
- return success();
-
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
- auto regionIsIsolated = [](Region &region) {
- return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
- };
- (void)regionIsIsolated;
- assert(llvm::all_of(regions, regionIsIsolated) &&
+ assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");
- // Limit ops on the worklist to this scope.
- DenseSet<Region *> scope;
- for (Region &r : regions)
- scope.insert(&r);
-
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
- scope);
- bool converged = std::move(driver).simplify(regions);
+ GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
+ region);
+ bool converged = std::move(driver).simplify(region);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
@@ -476,7 +460,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode,
+ const Region &scope, GreedyRewriteStrictness strictMode,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
strictMode(strictMode), survivingOps(survivingOps) {}
@@ -680,10 +664,8 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
- DenseSet<Region *> scopeSet;
- scopeSet.insert(scope);
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- scopeSet, strictMode,
+ *scope, strictMode,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)