aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp73
1 files changed, 33 insertions, 40 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4a37730..2ef46b1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -59,7 +59,7 @@ public:
protected:
/// Add the given operation to the worklist.
- virtual void addSingleOpToWorklist(Operation *op);
+ void addSingleOpToWorklist(Operation *op);
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
@@ -102,6 +102,12 @@ protected:
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
+ /// The list of ops we are restricting our rewrites to. These include the
+ /// supplied set of ops as well as new ops created while rewriting those ops
+ /// depending on `strictMode`. This set is not maintained when
+ /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
+ llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
@@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region &region) && {
return false;
};
+ // Populate strict mode ops.
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+ strictModeFilteredOps.clear();
+ region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
+ }
+
bool changed = false;
int64_t iteration = 0;
do {
@@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
- // Check to see if the worklist already contains this op.
- if (worklistMap.count(op))
- return;
-
- worklistMap[op] = worklist.size();
- worklist.push_back(op);
+ if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
+ strictModeFilteredOps.contains(op)) {
+ // Check to see if the worklist already contains this op.
+ if (worklistMap.count(op))
+ return;
+
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+ }
}
Operation *GreedyPatternRewriteDriver::popFromWorklist() {
@@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
+ if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
+ strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
removeFromWorklist(operation);
folder.notifyRemoval(operation);
});
+
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.erase(op);
}
void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
@@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
+ const GreedyRewriteConfig &config,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
: GreedyPatternRewriteDriver(ctx, patterns, config),
- strictMode(strictMode), survivingOps(survivingOps) {}
+ survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
/// ops. `strictMode` controls which other ops are simplified. Only ops
@@ -476,38 +496,13 @@ public:
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
bool *changed = nullptr) &&;
-protected:
- void addSingleOpToWorklist(Operation *op) override {
- if (strictMode == GreedyRewriteStrictness::AnyOp ||
- strictModeFilteredOps.contains(op))
- GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
- }
-
private:
- void notifyOperationInserted(Operation *op) override {
- if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
- strictModeFilteredOps.insert(op);
- GreedyPatternRewriteDriver::notifyOperationInserted(op);
- }
-
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
if (survivingOps)
survivingOps->erase(op);
- if (strictMode != GreedyRewriteStrictness::AnyOp)
- strictModeFilteredOps.erase(op);
}
- /// `strictMode` control which ops are added to the worklist during
- /// simplification.
- const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
-
- /// The list of ops we are restricting our rewrites to. These include the
- /// supplied set of ops as well as new ops created while rewriting those ops
- /// depending on `strictMode`. This set is not maintained when `strictMode`
- /// is GreedyRewriteStrictness::AnyOp.
- llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
-
/// An optional set of ops that survived the rewrite. This set is populated
/// at the beginning of `simplifyLocally` with the inititally provided list
/// of ops.
@@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
survivingOps->insert(ops.begin(), ops.end());
}
- if (strictMode != GreedyRewriteStrictness::AnyOp) {
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
@@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
if (op == nullptr)
continue;
- assert((strictMode == GreedyRewriteStrictness::AnyOp ||
+ assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op)) &&
"unexpected op was inserted under strict mode");
@@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
- bool *changed, bool *allErased) {
+ GreedyRewriteConfig config, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold(
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- strictMode, config,
- allErased ? &surviving : nullptr);
+ config, allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();