aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Shard/simplifications.mlir
blob: 33cd490be744af1ff3e0546e47790bac53d156fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s

shard.grid @grid0(shape = 4x2)
shard.grid @grid1(shape = 4)

// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
func.func @all_reduce_arith_addf_endomorphism(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf32> {
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
  // CHECK: return %[[ALL_REDUCE_RES]]
  return %2 : tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result
func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
  // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
  return %2, %2 : tensor<5xf32>, tensor<5xf32>
}

// Do not simplify if there is another use of one of the all-reduces.
// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
  // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]]
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]]
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]]
  return %0, %2 : tensor<5xf32>, tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid
func.func @all_reduce_arith_addf_no_endomorphism_different_grid(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf32> {
  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1
  %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: return %[[ADD_RES]]
  return %2 : tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes
func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf32> {
  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1]
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: return %[[ADD_RES]]
  return %2 : tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf32> {
  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
  %2 = arith.addf %0, %1 : tensor<5xf32>
  // CHECK: return %[[ADD_RES]]
  return %2 : tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf64> {
  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf64>
  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
    : tensor<5xf32> -> tensor<5xf64>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
  %2 = arith.addf %0, %1 : tensor<5xf64>
  // CHECK: return %[[ADD_RES]]
  return %2 : tensor<5xf64>
}

// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
// `all_reduce(min(x, y))`.
// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
func.func @all_reduce_arith_minimumf_endomorphism(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg0: tensor<5xf32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
    %arg1: tensor<5xf32>) -> tensor<5xf32> {
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
    : tensor<5xf32> -> tensor<5xf32>
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
    : tensor<5xf32> -> tensor<5xf32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
  %2 = arith.minimumf %0, %1 : tensor<5xf32>
  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
  // CHECK: return %[[ALL_REDUCE_RES]]
  return %2 : tensor<5xf32>
}

// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
func.func @all_reduce_arith_minsi_endomorphism(
    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
    %arg0: tensor<5xi32>,
    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
    %arg1: tensor<5xi32>) -> tensor<5xi32> {
  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
    : tensor<5xi32> -> tensor<5xi32>
  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
    : tensor<5xi32> -> tensor<5xi32>
  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
  %2 = arith.minsi %0, %1 : tensor<5xi32>
  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
  // CHECK: return %[[ALL_REDUCE_RES]]
  return %2 : tensor<5xi32>
}

// Ensure this case without endomorphism op not crash.
// CHECK-LABEL: func.func @no_endomorphism_op
func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
  %c0 = arith.constant 0 : index
  %c1_i64 = arith.constant 1 : i64
  // CHECK: tensor.extract
  %extracted = tensor.extract %arg0[%c0] : tensor<2xi64>
  // CHECK: arith.maxsi
  %0 = arith.maxsi %extracted, %c1_i64 : i64
  return %0 : i64
}