From 68033aaac5c94d1199e6cc3e6406e4d3fa10b040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 15 Sep 2023 21:32:17 +0200 Subject: [mlir][transform] Fix crash in transform.get_parent_op. (#66492) The previous implementation crashed if run on a `builtin.module` using an `op_name` filter (because the initial value of `parent` in the while loop was a `nullptr`). This PR fixes the crash and adds a test. --- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 5 +++-- mlir/test/Dialect/Transform/test-interpreter.mlir | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index de3cd1b..f1d07b8 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, DenseSet resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentOp(); - do { + while (parent) { bool checkIsolatedFromAbove = !getIsolatedFromAbove() || parent->hasTrait(); @@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, parent->getName().getStringRef() == *getOpName(); if (checkIsolatedFromAbove && checkOpName) break; - } while ((parent = parent->getParentOp())); + parent = parent->getParentOp(); + } if (!parent) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 68e3a48..daa179c 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) { test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op } + +// ----- + +// expected-note @below {{target op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below{{could not find a parent op that matches all requirements}} + %3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op + } +} + // ----- func.func @cast(%arg0: f32) -> f64 { -- cgit v1.1