From a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 1 Apr 2024 12:32:23 -0400 Subject: [mlir][TD] Allow op printing flags as `transform.print` attrs (#86846) Introduce 3 new optional attributes to the `transform.print` ops: * `assume_verified` * `use_local_scope` * `skip_regions` The primary motivation is to allow printing on large inputs that otherwise take forever to print and verify. For the full context, see this IREE issue: https://github.com/openxla/iree/issues/16901. Also add some tests and fix the op description. --- .../mlir/Dialect/Transform/IR/TransformOps.td | 19 ++++++-- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 19 ++++++-- mlir/test/Dialect/Transform/ops.mlir | 10 ++-- .../Transform/test-interpreter-printing.mlir | 56 ++++++++++++++++++++++ 4 files changed, 94 insertions(+), 10 deletions(-) create mode 100644 mlir/test/Dialect/Transform/test-interpreter-printing.mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index bf1a801..21c9595 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print", MatchOpInterface]> { let summary = "Dump each payload op"; let description = [{ - This op dumps each payload op that is associated with the `target` operand - to stderr. It also prints the `name` string attribute. If no target is + Prints each payload op that is associated with the `target` operand to + `stdout`. It also prints the `name` string attribute. If no target is specified, the top-level op is dumped. This op is useful for printf-style debugging. + + Supported printing flag attributes: + * `assume_verified` -- skips verification when the unit attribute is + specified. This improves performace but may lead to crashes and + unexpected behavior when the printed payload op is invalid. + * `use_local_scope` -- prints in local scope when the unit attribute is + specified. This improves performance but may not be identical to + printing within the full module. + * `skip_regions` -- does not print regions of operations when the unit + attribute is specified. }]; let arguments = (ins Optional:$target, - OptionalAttr:$name); + OptionalAttr:$name, + OptionalAttr:$assume_verified, + OptionalAttr:$use_local_scope, + OptionalAttr:$skip_regions); let results = (outs); let builders = [ diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index c8d06ba..dc19022 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" @@ -2627,14 +2628,26 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter, if (getName().has_value()) llvm::outs() << *getName() << " "; + OpPrintingFlags printFlags; + if (getAssumeVerified().value_or(false)) + printFlags.assumeVerified(); + if (getUseLocalScope().value_or(false)) + printFlags.useLocalScope(); + if (getSkipRegions().value_or(false)) + printFlags.skipRegions(); + if (!getTarget()) { - llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; + llvm::outs() << "top-level ]]]\n"; + state.getTopLevel()->print(llvm::outs(), printFlags); + llvm::outs() << "\n"; return DiagnosedSilenceableFailure::success(); } llvm::outs() << "]]]\n"; - for (Operation *target : state.getPayloadOps(getTarget())) - llvm::outs() << *target << "\n"; + for (Operation *target : state.getPayloadOps(getTarget())) { + target->print(llvm::outs(), printFlags); + llvm::outs() << "\n"; + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index a718d6a9..ecef7e1 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -86,16 +86,18 @@ transform.sequence failures(propagate) { } // CHECK: transform.sequence -// CHECK: print -// CHECK: print -// CHECK: print -// CHECK: print +// CHECK-COUNT-9: print transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): transform.print %arg0 : !transform.any_op transform.print transform.print %arg0 {name = "test"} : !transform.any_op transform.print {name = "test"} + transform.print {name = "test", assume_verified} + transform.print %arg0 {assume_verified} : !transform.any_op + transform.print %arg0 {use_local_scope} : !transform.any_op + transform.print %arg0 {skip_regions} : !transform.any_op + transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op } // CHECK: transform.sequence diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir new file mode 100644 index 0000000..a54c83d --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s + +// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \ +// RUN: --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC + +func.func @nested_ops() { + "test.qux"() ({ + // expected-error @below{{fail_to_verify is set}} + "test.baz"() ({ + "test.bar"() : () -> () + }) : () -> () + }) : () -> () +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]] + // CHECK-NEXT: module { + // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]] + // CHECK-LOC-NEXT: #{{.+}} = loc( + // CHECK-LOC-NEXT: module { + transform.print {name = "START"} + + // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]] + // CHECK-NEXT: module { + // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]] + // CHECK-LOC-NEXT: module { + transform.print {name = "Local scope", use_local_scope} + + %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: ]]] + // CHECK-NEXT: "test.baz"() ({ + // CHECK-NEXT: "test.bar"() : () -> () + // CHECK-NEXT: }) : () -> () + transform.print %baz : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: Baz ]]] + // CHECK-NEXT: "test.baz"() ({ + transform.print %baz {name = "Baz"} : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: No region ]]] + // CHECK-NEXT: "test.baz"() ({...}) : () -> () + transform.print %baz {name = "No region", skip_regions} : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: No verify ]]] + // CHECK-NEXT: "test.baz"() ({ + // CHECK-NEXT: transform.test_dummy_payload_op {fail_to_verify} : () -> () + transform.test_produce_invalid_ir %baz : !transform.any_op + transform.print %baz {name = "No verify", assume_verified} : !transform.any_op + + // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]] + transform.print {name = "END"} + transform.yield + } +} -- cgit v1.1