diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 73 |
1 files changed, 33 insertions, 40 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 4a37730..2ef46b1 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -59,7 +59,7 @@ public: protected: /// Add the given operation to the worklist. - virtual void addSingleOpToWorklist(Operation *op); + void addSingleOpToWorklist(Operation *op); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -102,6 +102,12 @@ protected: /// Configuration information for how to simplify. const GreedyRewriteConfig config; + /// The list of ops we are restricting our rewrites to. These include the + /// supplied set of ops as well as new ops created while rewriting those ops + /// depending on `strictMode`. This set is not maintained when + /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. + llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; + private: #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region ®ion) && { return false; }; + // Populate strict mode ops. + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { + strictModeFilteredOps.clear(); + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } + bool changed = false; int64_t iteration = 0; do { @@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { } void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); + if (config.strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) { + // Check to see if the worklist already contains this op. + if (worklistMap.count(op)) + return; + + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } } Operation *GreedyPatternRewriteDriver::popFromWorklist() { @@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) + strictModeFilteredOps.insert(op); addToWorklist(op); } @@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { removeFromWorklist(operation); folder.notifyRemoval(operation); }); + + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.erase(op); } void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op, @@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config, + const GreedyRewriteConfig &config, llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr) : GreedyPatternRewriteDriver(ctx, patterns, config), - strictMode(strictMode), survivingOps(survivingOps) {} + survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. Only ops @@ -476,38 +496,13 @@ public: LogicalResult simplifyLocally(ArrayRef<Operation *> op, bool *changed = nullptr) &&; -protected: - void addSingleOpToWorklist(Operation *op) override { - if (strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) - GreedyPatternRewriteDriver::addSingleOpToWorklist(op); - } - private: - void notifyOperationInserted(Operation *op) override { - if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps) - strictModeFilteredOps.insert(op); - GreedyPatternRewriteDriver::notifyOperationInserted(op); - } - void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); if (survivingOps) survivingOps->erase(op); - if (strictMode != GreedyRewriteStrictness::AnyOp) - strictModeFilteredOps.erase(op); } - /// `strictMode` control which ops are added to the worklist during - /// simplification. - const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; - - /// The list of ops we are restricting our rewrites to. These include the - /// supplied set of ops as well as new ops created while rewriting those ops - /// depending on `strictMode`. This set is not maintained when `strictMode` - /// is GreedyRewriteStrictness::AnyOp. - llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; - /// An optional set of ops that survived the rewrite. This set is populated /// at the beginning of `simplifyLocally` with the inititally provided list /// of ops. @@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops, survivingOps->insert(ops.begin(), ops.end()); } - if (strictMode != GreedyRewriteStrictness::AnyOp) { + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } @@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops, if (op == nullptr) continue; - assert((strictMode == GreedyRewriteStrictness::AnyOp || + assert((config.strictMode == GreedyRewriteStrictness::AnyOp || strictModeFilteredOps.contains(op)) && "unexpected op was inserted under strict mode"); @@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) { LogicalResult mlir::applyOpPatternsAndFold( ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, GreedyRewriteConfig config, - bool *changed, bool *allErased) { + GreedyRewriteConfig config, bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; @@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold( // Start the pattern driver. llvm::SmallDenseSet<Operation *, 4> surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strictMode, config, - allErased ? &surviving : nullptr); + config, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); |