aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOleksandr "Alex" Zinenko <zinenko@google.com>2024-03-28 18:52:10 +0100
committerGitHub <noreply@github.com>2024-03-28 18:52:10 +0100
commit0b790572b108bf691d11dece07bca65ca457fc88 (patch)
treed8567ca5edf6efb7ad8d48ee67f465bca34eea33
parent2af3b43642017d13de2b6d9802915851517fa0ca (diff)
downloadllvm-0b790572b108bf691d11dece07bca65ca457fc88.zip
llvm-0b790572b108bf691d11dece07bca65ca457fc88.tar.gz
llvm-0b790572b108bf691d11dece07bca65ca457fc88.tar.bz2
[mlir] propagate silenceable failures in transform.foreach_match (#86956)
The original implementation was eagerly reporting silenceable failures from actions as definite failures. Since silenceable failures are intended for cases when the IR has not been irreversibly modified, it's okay to propagate them as silenceable failures of the parent op. Fixes #86834.
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp17
-rw-r--r--mlir/test/Dialect/Transform/foreach-match.mlir80
2 files changed, 95 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 9423410..578b249 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1020,6 +1020,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
}
+ DiagnosedSilenceableFailure overallDiag =
+ DiagnosedSilenceableFailure::success();
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1058,8 +1060,19 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
action.getFunctionBody().front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
- if (failed(result.checkAndReport()))
+ if (result.isDefiniteFailure())
return WalkResult::interrupt();
+ if (result.isSilenceableFailure()) {
+ if (overallDiag.succeeded()) {
+ overallDiag = emitSilenceableError() << "actions failed";
+ }
+ overallDiag.attachNote(action->getLoc())
+ << "failed action: " << result.getMessage();
+ overallDiag.attachNote(op->getLoc())
+ << "when applied to this matching payload";
+ (void)result.silence();
+ continue;
+ }
}
break;
}
@@ -1075,7 +1088,7 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
- return DiagnosedSilenceableFailure::success();
+ return overallDiag;
}
void transform::ForeachMatchOp::getEffects(
diff --git a/mlir/test/Dialect/Transform/foreach-match.mlir b/mlir/test/Dialect/Transform/foreach-match.mlir
new file mode 100644
index 0000000..206625a
--- /dev/null
+++ b/mlir/test/Dialect/Transform/foreach-match.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+// Silenceable diagnostics suppressed.
+module attributes { transform.with_named_sequence } {
+ func.func @test_loop_peeling_not_beneficial() {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 40 : index
+ %step = arith.constant 5 : index
+ scf.for %i = %lb to %ub step %step {
+ arith.addi %i, %i : index
+ }
+ return
+ }
+
+ transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
+ transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// Silenceable diagnostics propagated.
+module attributes { transform.with_named_sequence } {
+ func.func @test_loop_peeling_not_beneficial() {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 40 : index
+ %step = arith.constant 5 : index
+ // expected-note @below {{when applied to this matching payload}}
+ scf.for %i = %lb to %ub step %step {
+ arith.addi %i, %i : index
+ }
+ return
+ }
+
+ // expected-note @below {{failed to peel the last iteration}}
+ transform.named_sequence @peel(%arg0: !transform.op<"scf.for"> {transform.consumed}) {
+ transform.loop.peel %arg0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.named_sequence @match_for(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["scf.for"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+ transform.named_sequence @main_suppress(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.sequence %root : !transform.any_op failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{actions failed}}
+ transform.foreach_match in %arg0
+ @match_for -> @peel
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}