diff options
author | Matthias Springer <me@m-sp.org> | 2024-04-02 10:53:57 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-02 10:53:57 +0900 |
commit | 38113a083283d2f30a677befaa5fb86dce731c8b (patch) | |
tree | ac8d5821e9110632d36b0195e7a5898078e8d0a3 | |
parent | d7a43a00fe80007de5d7614576b180d3d21d541b (diff) | |
download | llvm-38113a083283d2f30a677befaa5fb86dce731c8b.zip llvm-38113a083283d2f30a677befaa5fb86dce731c8b.tar.gz llvm-38113a083283d2f30a677befaa5fb86dce731c8b.tar.bz2 |
[mlir][IR] Trigger `notifyOperationReplaced` on `replaceAllOpUsesWith` (#84721)
Before this change: `notifyOperationReplaced` was triggered when calling
`RewriteBase::replaceOp`.
After this change: `notifyOperationReplaced` is triggered when
`RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is
called.
Until now, every `notifyOperationReplaced` was always sent together with
a `notifyOperationErased`, which made that `notifyOperationErased`
callback irrelevant. More importantly, when a user called
`RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of
`RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was
sent, even though the two notations are semantically equivalent. As an
example, this can be a problem when applying patterns with the transform
dialect because the `TrackingListener` will only see the
`notifyOperationErased` callback and the payload op is dropped from the
mappings.
Note: It is still possible to write semantically equivalent code that
does not trigger a `notifyOperationReplaced` (e.g., when op results are
replaced one-by-one), but this commit already improves the situation a
lot.
-rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 29 | ||||
-rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 24 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 |
3 files changed, 37 insertions, 21 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 070e6ed..ac2b0d5 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -409,9 +409,9 @@ public: /// Notify the listener that the specified operation was modified in-place. virtual void notifyOperationModified(Operation *op) {} - /// Notify the listener that the specified operation is about to be replaced - /// with another operation. This is called before the uses of the old - /// operation have been changed. + /// Notify the listener that all uses of the specified operation's results + /// are about to be replaced with the results of another operation. This is + /// called before the uses of the old operation have been changed. /// /// By default, this function calls the "operation replaced with values" /// notification. @@ -420,9 +420,10 @@ public: notifyOperationReplaced(op, replacement->getResults()); } - /// Notify the listener that the specified operation is about to be replaced - /// with the a range of values, potentially produced by other operations. - /// This is called before the uses of the operation have been changed. + /// Notify the listener that all uses of the specified operation's results + /// are about to be replaced with the a range of values, potentially + /// produced by other operations. This is called before the uses of the + /// operation have been changed. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} @@ -648,12 +649,16 @@ public: for (auto it : llvm::zip(from, to)) replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); } - // Note: This function cannot be called `replaceAllUsesWith` because the - // overload resolution, when called with an op that can be implicitly - // converted to a Value, would be ambiguous. - void replaceAllOpUsesWith(Operation *from, ValueRange to) { - replaceAllUsesWith(from->getResults(), to); - } + + /// Find uses of `from` and replace them with `to`. Also notify the listener + /// about every in-place op modification (for every use that was replaced) + /// and that the `from` operation is about to be replaced. + /// + /// Note: This function cannot be called `replaceAllUsesWith` because the + /// overload resolution, when called with an op that can be implicitly + /// converted to a Value, would be ambiguous. + void replaceAllOpUsesWith(Operation *from, ValueRange to); + void replaceAllOpUsesWith(Operation *from, Operation *to); /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. Also notify the listener about every in-place op modification (for diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 4079ccc..5944a0e 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() { // Out of line to provide a vtable anchor for the class. } +void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) { + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) + rewriteListener->notifyOperationReplaced(from, to); + + replaceAllUsesWith(from->getResults(), to); +} + +void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) { + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) + rewriteListener->notifyOperationReplaced(from, to); + + replaceAllUsesWith(from->getResults(), to->getResults()); +} + /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of /// the operation. The replaced op is erased. @@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); - // Replace all result uses. Also notifies the listener of modifications. replaceAllOpUsesWith(op, newValues); @@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "ops have different number of results"); - // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener)) - rewriteListener->notifyOperationReplaced(op, newOp); - // Replace all result uses. Also notifies the listener of modifications. replaceAllOpUsesWith(op, newOp->getResults()); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 2da184b..76dc825 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -489,7 +489,10 @@ private: OperationName("test.new_op", op->getContext()).getIdentifier(), op->getOperands(), op->getResultTypes()); } - rewriter.replaceOp(op, newOp->getResults()); + // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". + // A "notifyOperationReplaced" callback is triggered in either case. + rewriter.replaceAllOpUsesWith(op, newOp->getResults()); + rewriter.eraseOp(op); return success(); } }; |