aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect/Shard/backward-sharding-propagation.mlir')
-rw-r--r--mlir/test/Dialect/Shard/backward-sharding-propagation.mlir26
1 files changed, 26 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
new file mode 100644
index 0000000..8894c4a
--- /dev/null
+++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> tensor<6x6xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: tensor.empty()
+ %0 = tensor.empty() : tensor<6x6xi32>
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ // CHECK-COUNT-2: shard.shard
+ %sharded = shard.shard %0 to %sharding : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharded : tensor<6x6xi32>) -> tensor<6x6xi32>
+ // CHECK: tensor.empty()
+ // CHECK-NOT: shard.shard @
+ %2 = tensor.empty() : tensor<6x6xi32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ // CHECK: return
+ return %3 : tensor<6x6xi32>
+ }
+}