aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Mesh/ops.mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect/Mesh/ops.mlir')
-rw-r--r--mlir/test/Dialect/Mesh/ops.mlir49
1 files changed, 49 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309..a7c3b3d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>