diff options
author | Arda Unal <ardau@d-matrix.ai> | 2024-06-13 15:09:47 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-13 15:09:47 -0700 |
commit | 01a429c432620cad6deac99d48cf6ef96c7f86e8 (patch) | |
tree | 9dd35cef60c88ca5170dc1e5ddeb45eca0f0ca73 /mlir/test/Dialect | |
parent | 1ebda1173186c4c0ab776d1f140f903a49ace2a3 (diff) | |
download | llvm-01a429c432620cad6deac99d48cf6ef96c7f86e8.zip llvm-01a429c432620cad6deac99d48cf6ef96c7f86e8.tar.gz llvm-01a429c432620cad6deac99d48cf6ef96c7f86e8.tar.bz2 |
[mlir][mesh] Fix wrong argument passed to targetShardingInUnsplitLastAxis (#95059)
In unsplitLastAxisInResharding, wrong argument was passed when calling
targetShardingInUnsplitLastAxis.There weren't any tests to uncover this.
I added one in mesh-spmdization.mlir for Linalg and one in
resharding-spmdization.mlir for Mesh dialects.
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r-- | mlir/test/Dialect/Linalg/mesh-spmdization.mlir | 35 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/resharding-spmdization.mlir | 13 |
2 files changed, 48 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir index bd56c80..52f352c 100644 --- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir +++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir @@ -162,3 +162,38 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8> return %res_shared2 : tensor<4x8xi8> } + +// ----- + +mesh.mesh @mesh_1d(shape = 4) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis +func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>, + %in1: tensor<4x6xi8>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>, + %in2: tensor<6x8xi8>, + // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> + %dps_out: tensor<4x8xi8> + // CHECK-SAME: -> tensor<4x8xi8> { +) -> tensor<4x8xi8> { + %in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8> + %in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8> + // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1 + %in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8> + %in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8> + // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1 + %dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8> + %dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8> + // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>) + // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>) + // CHECK-SAME: -> tensor<4x2xi8> + %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>) + outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8> + %res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8> + %res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8> + return %res_replicated : tensor<4x8xi8> +} diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir index ba05306..b3e3051 100644 --- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir +++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir @@ -96,6 +96,19 @@ func.func @unshard_static_axis( return %1 : tensor<10x14xf32> } +// CHECK-LABEL: func @unshard_static_last_axis +func.func @unshard_static_last_axis( + // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32> + %arg0: tensor<10x14xf32> +) -> tensor<10x14xf32> { + // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32> + // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32> + %0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32> + %1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32> + // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32> + return %1 : tensor<10x14xf32> +} + // CHECK-LABEL: func @unshard_dynamic_axis func.func @unshard_dynamic_axis( // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32> |