aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Shard/folding.mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Dialect/Shard/folding.mlir')
-rw-r--r--mlir/test/Dialect/Shard/folding.mlir22
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
new file mode 100644
index 0000000..5a0f35b
--- /dev/null
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+shard.grid @grid1(shape = 2x3)
+
+// CHECK-LABEL: func.func @grid_shape_op_folding
+func.func @grid_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index
+ %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid
+func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = shard.grid_shape @grid1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}