diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 48 |
1 files changed, 15 insertions, 33 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a5ddd91..36317e0 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -40,10 +40,10 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, - const DenseSet<Region *> &scope); + const Region &scope); - /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef<Region> regions) &&; + /// Simplify the ops within the given region. + bool simplify(Region ®ion) &&; /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -104,7 +104,7 @@ protected: const GreedyRewriteConfig config; /// Only ops within this scope are simplified. - const DenseSet<Region *> scope; + const Region &scope; private: #ifndef NDEBUG @@ -116,7 +116,7 @@ private: GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, const DenseSet<Region *> &scope) + const GreedyRewriteConfig &config, const Region &scope) : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), scope(scope) { worklist.reserve(64); @@ -125,7 +125,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && { +bool GreedyPatternRewriteDriver::simplify(Region ®ion) && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -167,15 +167,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && { if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. - for (auto ®ion : regions) { region.walk([&](Operation *op) { if (!insertKnownConstant(op)) addToWorklist(op); }); - } } else { // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) { region.walk<WalkOrder::PreOrder>([&](Operation *op) { if (!insertKnownConstant(op)) { worklist.push_back(op); @@ -183,7 +180,6 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && { } return WalkResult::skip(); }); - } // Reverse the list so our pop-back loop processes them in-order. std::reverse(worklist.begin(), worklist.end()); @@ -305,7 +301,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && { // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, regions)); + changed |= succeeded(simplifyRegions(*this, region)); } while (changed); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. @@ -317,7 +313,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { SmallVector<Operation *, 8> ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (scope.contains(region)) { + if (&scope == region) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) addSingleOpToWorklist(op); @@ -429,31 +425,19 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( /// top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions, +mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config) { - if (regions.empty()) - return success(); - // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above // the region containing 'op'. - auto regionIsIsolated = [](Region ®ion) { - return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>(); - }; - (void)regionIsIsolated; - assert(llvm::all_of(regions, regionIsIsolated) && + assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && "patterns can only be applied to operations IsolatedFromAbove"); - // Limit ops on the worklist to this scope. - DenseSet<Region *> scope; - for (Region &r : regions) - scope.insert(&r); - // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, - scope); - bool converged = std::move(driver).simplify(regions); + GreedyPatternRewriteDriver driver(region.getContext(), patterns, config, + region); + bool converged = std::move(driver).simplify(region); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -476,7 +460,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode, + const Region &scope, GreedyRewriteStrictness strictMode, llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr) : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), strictMode(strictMode), survivingOps(survivingOps) {} @@ -680,10 +664,8 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops, // Start the pattern driver. llvm::SmallDenseSet<Operation *, 4> surviving; - DenseSet<Region *> scopeSet; - scopeSet.insert(scope); MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - scopeSet, strictMode, + *scope, strictMode, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) |