diff options
Diffstat (limited to 'mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir')
-rw-r--r-- | mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir new file mode 100644 index 0000000..9729d2b --- /dev/null +++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s + +module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { + + // CHECK: shard.grid @grid0 + shard.grid @grid0(shape = 3x4x5) + + // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0 + + // all shards are equal + // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) { + func.func @shard_shape_equal() -> (index, index, index) { + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // last shard in last dim gets an extra element + // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) { + func.func @shard_shape_odd_1() -> (index, index, index) { + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index + %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // In the second dimension the shard sizes are now [3 4 4 4] + // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) { + func.func @shard_shape_odd_2() -> (index, index, index) { + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index + %c9 = arith.constant 9 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // In the first dimension the shard sizes are now [3 4 4] + // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) { + func.func @shard_shape_odd_3() -> (index, index, index) { + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index + // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index + %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } + + // extract from sharded_dims_offsets + // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { + func.func @shard_shape_sharded_dims_offs() -> (index, index, index) { + %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] + sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding + %0:3 = shard.process_multi_index on @grid0 : index, index, index + %c9 = arith.constant 9 : index + %c12 = arith.constant 12 : index + // CHECK: [[vc3:%.*]] = arith.constant 3 : index + // CHECK: [[vc2:%.*]] = arith.constant 2 : index + %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index + // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index + return %1#0, %1#1, %1#2 : index, index, index + } +} |