aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects/shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python/dialects/shard.py')
-rw-r--r--mlir/test/python/dialects/shard.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/mlir/test/python/dialects/shard.py b/mlir/test/python/dialects/shard.py
new file mode 100644
index 0000000..c3ba605
--- /dev/null
+++ b/mlir/test/python/dialects/shard.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import shard
+from mlir.dialects import func
+
+
+def constructAndPrintInModule(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ module.operation.verify()
+ return f
+
+
+# CHECK-LABEL: TEST: testShardGrid
+@constructAndPrintInModule
+def testShardGrid():
+ # Test creating shard grids with different shapes
+ grid2d = shard.GridOp("grid_2d", [2, 2])
+ grid1d = shard.GridOp("grid_1d", [4])
+
+ # CHECK: shard.grid @grid_2d(shape = 2x2)
+ # CHECK: shard.grid @grid_1d(shape = 4)
+
+
+# CHECK-LABEL: TEST: testCollectiveOperations
+@constructAndPrintInModule
+def testCollectiveOperations():
+ # Create grid and types
+ grid_op = shard.GridOp("grid_2x2", [2, 2])
+ i32 = IntegerType.get_signless(32)
+ index_type = IndexType.get()
+ input_type = RankedTensorType.get([4, 2], i32)
+ gather_result_type = RankedTensorType.get([4, 4], i32)
+
+ # Create a function to hold the operations
+ func_type = FunctionType.get([input_type], [input_type])
+ test_func = func.FuncOp("test_collectives", func_type)
+
+ with InsertionPoint(test_func.add_entry_block()):
+ arg = test_func.entry_block.arguments[0]
+
+ gather_op = shard.AllGatherOp(
+ input=arg,
+ grid=FlatSymbolRefAttr.get("grid_2x2"),
+ grid_axes=DenseI16ArrayAttr.get([1]),
+ gather_axis=IntegerAttr.get(index_type, 1),
+ result=gather_result_type,
+ )
+
+ reduce_op = shard.AllReduceOp(
+ input=arg,
+ grid=FlatSymbolRefAttr.get("grid_2x2"),
+ reduction=shard.ReductionKind.Sum,
+ result=input_type,
+ )
+
+ func.ReturnOp([reduce_op])
+
+ # CHECK: shard.grid @grid_2x2(shape = 2x2)
+ # CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32>
+ # CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
+ # CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32>