aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
blob: 9729d2bfb384e246c57d0082e000673764498980 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s

module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {

  // CHECK: shard.grid @grid0
  shard.grid @grid0(shape = 3x4x5)
  
  // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0

  // all shards are equal
  // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
  func.func @shard_shape_equal() -> (index, index, index) {
    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
    %0:3 = shard.process_multi_index on @grid0 : index, index, index
    %c9 = arith.constant 9 : index
    %c12 = arith.constant 12 : index
    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
    %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
    return %1#0, %1#1, %1#2 : index, index, index
  }

  // last shard in last dim gets an extra element
  // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
  func.func @shard_shape_odd_1() -> (index, index, index) {
    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
    %0:3 = shard.process_multi_index on @grid0 : index, index, index
    %c9 = arith.constant 9 : index
    %c12 = arith.constant 12 : index
    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
    %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
    // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
    return %1#0, %1#1, %1#2 : index, index, index
  }

  // In the second dimension the shard sizes are now [3 4 4 4]
  // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
  func.func @shard_shape_odd_2() -> (index, index, index) {
    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
    %0:3 = shard.process_multi_index on @grid0 : index, index, index
    %c9 = arith.constant 9 : index
    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
    %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
    return %1#0, %1#1, %1#2 : index, index, index
  }

  // In the first dimension the shard sizes are now [3 4 4]
  // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
  func.func @shard_shape_odd_3() -> (index, index, index) {
    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
    %0:3 = shard.process_multi_index on @grid0 : index, index, index
    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
    %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
    // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
    return %1#0, %1#1, %1#2 : index, index, index
  }

  // extract from sharded_dims_offsets
  // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
  func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]]
        sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding
    %0:3 = shard.process_multi_index on @grid0 : index, index, index
    %c9 = arith.constant 9 : index
    %c12 = arith.constant 12 : index
    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
    // CHECK: [[vc2:%.*]] = arith.constant 2 : index
    %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
    // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
    return %1#0, %1#1, %1#2 : index, index, index
  }
}