aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-12-06 16:32:22 +0900
committerGitHub <noreply@github.com>2023-12-06 16:32:22 +0900
commite8ae0e72b794bc1b879a3a90215a940572306c25 (patch)
tree14ed9dda9498421078841a4ca0230770632d6b5b /mlir
parentc630f95f33e31fe11ec6242560d9bf5d57007673 (diff)
downloadllvm-e8ae0e72b794bc1b879a3a90215a940572306c25.zip
llvm-e8ae0e72b794bc1b879a3a90215a940572306c25.tar.gz
llvm-e8ae0e72b794bc1b879a3a90215a940572306c25.tar.bz2
[mlir][transform] TrackingListener: Improve dead handles detection (#74290)
The tracking listener should not report op replacement errors for payload ops that are not mapped to any live handles. The handle liveless analysis did not work properly with transform IR that has named sequences. A handle is live if it has a user after the transform op that is currently being applied. With named sequences, we need to maintain a stack of currently applied transform ops. That stack already exists (`regionStack`), the only thing that's missing is the current transform op for each stack frame. This commit fixes #72931.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h33
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp104
-rw-r--r--mlir/test/Dialect/Transform/test-pattern-application.mlir30
3 files changed, 112 insertions, 55 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2fdc15d..23871cc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -310,10 +310,8 @@ public:
/// with the type of the handle value.
LogicalResult mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- assert(argument.getParentRegion() == regionStack.back() &&
+ assert(argument.getParentRegion() == regionStack.back()->region &&
"mapping block arguments from a region other than the active one");
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
LogicalResult mapBlockArgument(BlockArgument argument,
@@ -350,9 +348,7 @@ public:
std::make_pair(&region, std::make_unique<Mappings>()));
assert(res.second && "the region scope is already present");
(void)res;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- state.regionStack.push_back(&region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ state.regionStack.push_back(this);
}
/// Back-reference to the transform state.
@@ -361,7 +357,10 @@ public:
/// The region this scope is associated with.
Region *region;
- friend RegionScope TransformState::make_region_scope(Region &);
+ /// The transform op within this region that is currently being applied.
+ TransformOpInterface currentTransform;
+
+ friend class transform::TransformState;
};
friend class RegionScope;
@@ -784,12 +783,14 @@ private:
/// location.
InvalidatedHandleMap invalidatedHandles;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
- SmallVector<Region *> regionStack;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ SmallVector<RegionScope *> regionStack;
+
+ /// The top-level region scope. The first (bottom) element of `regionStack`
+ /// is the top-level region scope object.
+ std::unique_ptr<RegionScope> topLevelRegionScope;
};
/// Local mapping between values defined by a specific op implementing the
@@ -926,8 +927,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
+ /// A function that returns "true" for handles that do not have to be updated.
+ using SkipHandleFn = std::function<bool(Value)>;
+
/// Create a new TrackingListener for usage in the specified transform op.
- TrackingListener(TransformState &state, TransformOpInterface op);
+ /// Optionally, a function can be specified to identify handles that should
+ /// do not have to be updated.
+ TrackingListener(TransformState &state, TransformOpInterface op,
+ SkipHandleFn skipHandleFn = nullptr);
protected:
/// Return a replacement payload op for the given op, which is going to be
@@ -1015,6 +1022,10 @@ private:
/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;
+
+ /// Handles for which this function evaluates to "true" do not have to be
+ /// updated. These are typically dead or consumed handles.
+ SkipHandleFn skipHandleFn;
};
/// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index de5b7a8..cd66a0e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -31,6 +31,23 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+ do {
+ if (a->isProperAncestor(b))
+ return false;
+ if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+ return a->isBeforeInBlock(bAncestor);
+ }
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
topLevelMappedValues.reserve(extraMappings.size());
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);
-
- auto result =
- mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
- assert(result.second && "the region scope is already present");
- (void)result;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- regionStack.push_back(region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ if (region) {
+ RegionScope *scope = new RegionScope(*this, *region);
+ topLevelRegionScope.reset(scope);
+ }
}
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
});
+
+ // Set current transform op.
+ regionStack.back()->currentTransform = transform;
+
+ // Expensive checks to detect invalid transform IR.
if (options.getExpensiveChecksEnabled()) {
FULL_LDBG("ExpensiveChecksEnabled\n");
if (failed(checkAndRecordHandleInvalidation(transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Prepare rewriter and listener.
- transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
+ TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+ // Skip handle if it is dead.
+ auto scopeIt =
+ llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
+ return handle.getParentRegion() == scope->region;
+ });
+ assert(scopeIt != regionStack.rend() &&
+ "could not find region scope for handle");
+ RegionScope *scope = *scopeIt;
+ for (Operation *user : handle.getUsers()) {
+ if (user != scope->currentTransform &&
+ !happensBefore(user, scope->currentTransform))
+ return false;
+ }
+ return true;
+ };
+ transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
+ skipHandleFn);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.mappings.erase(region);
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.pop_back();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
//===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
//===----------------------------------------------------------------------===//
transform::TrackingListener::TrackingListener(TransformState &state,
- TransformOpInterface op)
- : TransformState::Extension(state), transformOp(op) {
+ TransformOpInterface op,
+ SkipHandleFn skipHandleFn)
+ : TransformState::Extension(state), transformOp(op),
+ skipHandleFn(skipHandleFn) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
});
}
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
- do {
- if (a->isProperAncestor(b))
- return false;
- if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
- return a->isBeforeInBlock(bAncestor);
- }
- } while ((a = a->getParentOp()));
- return false;
-}
-
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
[&](Value h) { return consumedHandles.contains(h); });
};
- // Helper function to check if the handle is alive.
- auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
- for (Value v : opHandles) {
- for (OpOperand &use : v.getUses())
- if (use.getOwner() != transformOp &&
- !happensBefore(use.getOwner(), transformOp))
- return &use;
- }
- return std::nullopt;
- }();
-
- if (!firstAliveUser.has_value() || handleWasConsumed()) {
+ // Check if there are any handles that must be updated.
+ Value aliveHandle;
+ if (skipHandleFn) {
+ auto it =
+ llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+ if (it != opHandles.end())
+ aliveHandle = *it;
+ } else if (!opHandles.empty()) {
+ aliveHandle = opHandles.front();
+ }
+ if (!aliveHandle || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
// If the op is tracked but no replacement op was found, send a
// notification.
if (!diag.succeeded()) {
- diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
- << "replacement is required because alive handle(s) exist "
- << "(first use in this op as operand number "
- << (*firstAliveUser)->getOperandNumber() << ")";
+ diag.attachNote(aliveHandle.getLoc())
+ << "replacement is required because this handle must be updated";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 2d57d4a..2fd47c6 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-note @below {{replacement is required because this handle must be updated}}
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
// expected-note @below {{ran out of suitable replacement values}}
@@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
- // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
transform.annotate %1 "annotated" : !transform.any_op
}
@@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
legal_ops = ["func.func", "func.return", "test.new_op"]}
: !transform.any_op
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+func.func @replacement_op_not_found() {
+ // No op replacement can be found, but there are no handles that must be
+ // updated. No error should be reported.
+ "test.container"() ({
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
+ transform.apply_patterns to %container {
+ transform.apply_patterns.transform.test_patterns
+ } : !transform.any_op
+ transform.yield
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.annotate %1 "annotated" : !transform.any_op
+ transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
+}
+}