diff options
author | srcarroll <50210727+srcarroll@users.noreply.github.com> | 2024-03-21 00:25:07 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 00:25:07 -0500 |
commit | df9ed9cf52f82aed023adc968ca2a0e7f7cccc69 (patch) | |
tree | d8126007d6ca50faaeedc7c97dca92c918aaa44e /mlir | |
parent | 733640d29ede70585e0e3e1dcc47b935981f791e (diff) | |
download | llvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.zip llvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.tar.gz llvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.tar.bz2 |
[mlir][transform] Fix failure in flattening already flattened linalg ops (#86037)
The previous implementation was doing an early successful return on
`rank <= 1` without adding the original op to transform results. This
resulted in errors about number of returns. This patch fixes this by
adding the original op to results. Additionally, we first check if op is
elementwise and return a slienceable failure early if not.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 15 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/flatten-elementwise.mlir | 21 |
2 files changed, 31 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d82a6beb..ecf9983 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); - if (target.getNumLoops() <= 1) + if (!isElementwise(target)) { + failed(rewriter.notifyMatchFailure( + target, "only elementwise flattening is supported")); + return emitDefaultSilenceableFailure(target); + } + // If rank <= 1, do nothing + if (target.getNumLoops() <= 1) { + results.push_back(target); return DiagnosedSilenceableFailure::success(); + } ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); auto maybeFlattened = - (isElementwise(target)) - ? collapseOpIterationDims(target, reassociation, rewriter) - : FailureOr<CollapseResult>(rewriter.notifyMatchFailure( - target, "only elementwise flattening is supported")); + collapseOpIterationDims(target, reassociation, rewriter); if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); results.push_back(maybeFlattened->collapsedOp); diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir index 858c133..5a27fe7 100644 --- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @map_already_flat( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32> +// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>) +func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) { + linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @generic // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> |