aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir')
-rw-r--r--mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir75
1 files changed, 75 insertions, 0 deletions
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
new file mode 100644
index 0000000..9729d2b
--- /dev/null
+++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s
+
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+
+ // CHECK: shard.grid @grid0
+ shard.grid @grid0(shape = 3x4x5)
+
+ // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0
+
+ // all shards are equal
+ // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+ func.func @shard_shape_equal() -> (index, index, index) {
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // last shard in last dim gets an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+ func.func @shard_shape_odd_1() -> (index, index, index) {
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // In the second dimension the shard sizes are now [3 4 4 4]
+ // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+ func.func @shard_shape_odd_2() -> (index, index, index) {
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
+ %c9 = arith.constant 9 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // In the first dimension the shard sizes are now [3 4 4]
+ // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+ func.func @shard_shape_odd_3() -> (index, index, index) {
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // extract from sharded_dims_offsets
+ // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+}