aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Mesh/spmdization.mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect/Mesh/spmdization.mlir')
-rw-r--r--mlir/test/Dialect/Mesh/spmdization.mlir14
1 files changed, 14 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029..572d3eb 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}
+
+// CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}