From bbcfe6f4311af8cf6095a5bc5937fa68a87b4289 Mon Sep 17 00:00:00 2001 From: srcarroll <50210727+srcarroll@users.noreply.github.com> Date: Fri, 22 Mar 2024 12:37:39 -0500 Subject: [mlir][transform] Emit error message with `emitSilenceableFailure` (#86146) The previous implementation used a `notifyMatchFailure` to emit failure message inappropriately and then used the `emitDefaultSilenceableFailure`. This patch changes this to use the more appropriate `emitSilenceableFailure` with error message. Additionally a failure test has been added. --- .../Linalg/TransformOps/LinalgTransformOps.cpp | 14 +++++---- mlir/test/Dialect/Linalg/flatten-unsupported.mlir | 33 ++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/flatten-unsupported.mlir (limited to 'mlir') diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ecf9983..88819cd9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3269,22 +3269,24 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); - if (!isElementwise(target)) { - failed(rewriter.notifyMatchFailure( - target, "only elementwise flattening is supported")); - return emitDefaultSilenceableFailure(target); - } + if (!isElementwise(target)) + return mlir::emitSilenceableFailure(target->getLoc()) + << "only elementwise flattening is supported"; + // If rank <= 1, do nothing if (target.getNumLoops() <= 1) { results.push_back(target); return DiagnosedSilenceableFailure::success(); } + + // Attempt to flatten all dims to one. ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); auto maybeFlattened = collapseOpIterationDims(target, reassociation, rewriter); if (failed(maybeFlattened)) - return emitDefaultSilenceableFailure(target); + return mlir::emitSilenceableFailure(target->getLoc()) + << "attempted to flatten, but failed"; results.push_back(maybeFlattened->collapsedOp); rewriter.replaceOp(target, maybeFlattened->results); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/Dialect/Linalg/flatten-unsupported.mlir b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir new file mode 100644 index 0000000..499db4c --- /dev/null +++ b/mlir/test/Dialect/Linalg/flatten-unsupported.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics + +func.func @non_elementwise(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) { + // expected-error @below {{only elementwise flattening is supported}} + linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>) outs(%arg2: memref<2x4xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @unsupported_memref(%arg0: memref<32x7xf32, strided<[7, 2]>>, %arg1: memref<32x7xf32, strided<[7, 2]>>, %arg2: memref<32x7xf32, strided<[7, 2]>>) { + // expected-error @below {{attempted to flatten, but failed}} + linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32, strided<[7, 2]>>, memref<32x7xf32, strided<[7, 2]>>) outs(%arg2: memref<32x7xf32, strided<[7, 2]>>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} -- cgit v1.1