aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIngo Müller <ingomueller@google.com>2023-09-15 21:32:17 +0200
committerGitHub <noreply@github.com>2023-09-15 21:32:17 +0200
commit68033aaac5c94d1199e6cc3e6406e4d3fa10b040 (patch)
treebab77807493224482891a7cb360b3abe6599e941
parent058e9b0374b09a9f70c9964c533caf2d49eb219a (diff)
downloadllvm-68033aaac5c94d1199e6cc3e6406e4d3fa10b040.zip
llvm-68033aaac5c94d1199e6cc3e6406e4d3fa10b040.tar.gz
llvm-68033aaac5c94d1199e6cc3e6406e4d3fa10b040.tar.bz2
[mlir][transform] Fix crash in transform.get_parent_op. (#66492)
The previous implementation crashed if run on a `builtin.module` using an `op_name` filter (because the initial value of `parent` in the while loop was a `nullptr`). This PR fixes the crash and adds a test.
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp5
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir12
2 files changed, 15 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b..f1d07b8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
DenseSet<Operation *> resultSet;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *parent = target->getParentOp();
- do {
+ while (parent) {
bool checkIsolatedFromAbove =
!getIsolatedFromAbove() ||
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
@@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
parent->getName().getStringRef() == *getOpName();
if (checkIsolatedFromAbove && checkOpName)
break;
- } while ((parent = parent->getParentOp()));
+ parent = parent->getParentOp();
+ }
if (!parent) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 68e3a48..daa179c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) {
test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
}
+
+// -----
+
+// expected-note @below {{target op}}
+module {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below{{could not find a parent op that matches all requirements}}
+ %3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
+ }
+}
+
// -----
func.func @cast(%arg0: f32) -> f64 {