diff options
author | Jakub Kuderski <jakub@nod-labs.com> | 2024-04-01 12:32:23 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-01 12:32:23 -0400 |
commit | a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9 (patch) | |
tree | 837c57a9aed0f34e8a2dcc37557df2dc5dfa8c86 | |
parent | 985c1a44f8d49e0afeba907fe29d881c19b319fc (diff) | |
download | llvm-a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9.zip llvm-a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9.tar.gz llvm-a8cfa7cbdf6cc1a94ed25c90897d2e031f77a5a9.tar.bz2 |
[mlir][TD] Allow op printing flags as `transform.print` attrs (#86846)
Introduce 3 new optional attributes to the `transform.print` ops:
* `assume_verified`
* `use_local_scope`
* `skip_regions`
The primary motivation is to allow printing on large inputs that
otherwise take forever to print and verify. For the full context, see
this IREE issue: https://github.com/openxla/iree/issues/16901.
Also add some tests and fix the op description.
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 19 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 19 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/ops.mlir | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/Transform/test-interpreter-printing.mlir | 56 |
4 files changed, 94 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index bf1a801..21c9595 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1098,15 +1098,28 @@ def PrintOp : TransformDialectOp<"print", MatchOpInterface]> { let summary = "Dump each payload op"; let description = [{ - This op dumps each payload op that is associated with the `target` operand - to stderr. It also prints the `name` string attribute. If no target is + Prints each payload op that is associated with the `target` operand to + `stdout`. It also prints the `name` string attribute. If no target is specified, the top-level op is dumped. This op is useful for printf-style debugging. + + Supported printing flag attributes: + * `assume_verified` -- skips verification when the unit attribute is + specified. This improves performace but may lead to crashes and + unexpected behavior when the printed payload op is invalid. + * `use_local_scope` -- prints in local scope when the unit attribute is + specified. This improves performance but may not be identical to + printing within the full module. + * `skip_regions` -- does not print regions of operations when the unit + attribute is specified. }]; let arguments = (ins Optional<TransformHandleTypeInterface>:$target, - OptionalAttr<StrAttr>:$name); + OptionalAttr<StrAttr>:$name, + OptionalAttr<UnitAttr>:$assume_verified, + OptionalAttr<UnitAttr>:$use_local_scope, + OptionalAttr<UnitAttr>:$skip_regions); let results = (outs); let builders = [ diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index c8d06ba..dc19022 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" @@ -2627,14 +2628,26 @@ transform::PrintOp::apply(transform::TransformRewriter &rewriter, if (getName().has_value()) llvm::outs() << *getName() << " "; + OpPrintingFlags printFlags; + if (getAssumeVerified().value_or(false)) + printFlags.assumeVerified(); + if (getUseLocalScope().value_or(false)) + printFlags.useLocalScope(); + if (getSkipRegions().value_or(false)) + printFlags.skipRegions(); + if (!getTarget()) { - llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; + llvm::outs() << "top-level ]]]\n"; + state.getTopLevel()->print(llvm::outs(), printFlags); + llvm::outs() << "\n"; return DiagnosedSilenceableFailure::success(); } llvm::outs() << "]]]\n"; - for (Operation *target : state.getPayloadOps(getTarget())) - llvm::outs() << *target << "\n"; + for (Operation *target : state.getPayloadOps(getTarget())) { + target->print(llvm::outs(), printFlags); + llvm::outs() << "\n"; + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index a718d6a9..ecef7e1 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -86,16 +86,18 @@ transform.sequence failures(propagate) { } // CHECK: transform.sequence -// CHECK: print -// CHECK: print -// CHECK: print -// CHECK: print +// CHECK-COUNT-9: print transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): transform.print %arg0 : !transform.any_op transform.print transform.print %arg0 {name = "test"} : !transform.any_op transform.print {name = "test"} + transform.print {name = "test", assume_verified} + transform.print %arg0 {assume_verified} : !transform.any_op + transform.print %arg0 {use_local_scope} : !transform.any_op + transform.print %arg0 {skip_regions} : !transform.any_op + transform.print %arg0 {assume_verified, use_local_scope, skip_regions} : !transform.any_op } // CHECK: transform.sequence diff --git a/mlir/test/Dialect/Transform/test-interpreter-printing.mlir b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir new file mode 100644 index 0000000..a54c83d --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-printing.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics | FileCheck %s + +// RUN: mlir-opt %s --transform-interpreter --allow-unregistered-dialect --verify-diagnostics \ +// RUN: --mlir-print-debuginfo | FileCheck %s --check-prefix=CHECK-LOC + +func.func @nested_ops() { + "test.qux"() ({ + // expected-error @below{{fail_to_verify is set}} + "test.baz"() ({ + "test.bar"() : () -> () + }) : () -> () + }) : () -> () +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // CHECK-LABEL{LITERAL}: [[[ IR printer: START top-level ]]] + // CHECK-NEXT: module { + // CHECK-LOC-LABEL{LITERAL}: [[[ IR printer: START top-level ]]] + // CHECK-LOC-NEXT: #{{.+}} = loc( + // CHECK-LOC-NEXT: module { + transform.print {name = "START"} + + // CHECK{LITERAL}: [[[ IR printer: Local scope top-level ]]] + // CHECK-NEXT: module { + // CHECK-LOC{LITERAL}: [[[ IR printer: Local scope top-level ]]] + // CHECK-LOC-NEXT: module { + transform.print {name = "Local scope", use_local_scope} + + %baz = transform.structured.match ops{["test.baz"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: ]]] + // CHECK-NEXT: "test.baz"() ({ + // CHECK-NEXT: "test.bar"() : () -> () + // CHECK-NEXT: }) : () -> () + transform.print %baz : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: Baz ]]] + // CHECK-NEXT: "test.baz"() ({ + transform.print %baz {name = "Baz"} : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: No region ]]] + // CHECK-NEXT: "test.baz"() ({...}) : () -> () + transform.print %baz {name = "No region", skip_regions} : !transform.any_op + + // CHECK{LITERAL}: [[[ IR printer: No verify ]]] + // CHECK-NEXT: "test.baz"() ({ + // CHECK-NEXT: transform.test_dummy_payload_op {fail_to_verify} : () -> () + transform.test_produce_invalid_ir %baz : !transform.any_op + transform.print %baz {name = "No verify", assume_verified} : !transform.any_op + + // CHECK-LABEL{LITERAL}: [[[ IR printer: END top-level ]]] + transform.print {name = "END"} + transform.yield + } +} |