aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorsrcarroll <50210727+srcarroll@users.noreply.github.com>2024-03-21 00:25:07 -0500
committerGitHub <noreply@github.com>2024-03-21 00:25:07 -0500
commitdf9ed9cf52f82aed023adc968ca2a0e7f7cccc69 (patch)
treed8126007d6ca50faaeedc7c97dca92c918aaa44e /mlir
parent733640d29ede70585e0e3e1dcc47b935981f791e (diff)
downloadllvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.zip
llvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.tar.gz
llvm-df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.tar.bz2
[mlir][transform] Fix failure in flattening already flattened linalg ops (#86037)
The previous implementation was doing an early successful return on `rank <= 1` without adding the original op to transform results. This resulted in errors about number of returns. This patch fixes this by adding the original op to results. Additionally, we first check if op is elementwise and return a slienceable failure early if not.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp15
-rw-r--r--mlir/test/Dialect/Linalg/flatten-elementwise.mlir21
2 files changed, 31 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d82a6beb..ecf9983 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (target.getNumLoops() <= 1)
+ if (!isElementwise(target)) {
+ failed(rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported"));
+ return emitDefaultSilenceableFailure(target);
+ }
+ // If rank <= 1, do nothing
+ if (target.getNumLoops() <= 1) {
+ results.push_back(target);
return DiagnosedSilenceableFailure::success();
+ }
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
- (isElementwise(target))
- ? collapseOpIterationDims(target, reassociation, rewriter)
- : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
+ collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133..5a27fe7 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>