diff options
Diffstat (limited to 'mlir/test/python/dialects')
-rw-r--r-- | mlir/test/python/dialects/shard.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/mlir/test/python/dialects/shard.py b/mlir/test/python/dialects/shard.py new file mode 100644 index 0000000..c3ba605 --- /dev/null +++ b/mlir/test/python/dialects/shard.py @@ -0,0 +1,67 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import shard +from mlir.dialects import func + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + module.operation.verify() + return f + + +# CHECK-LABEL: TEST: testShardGrid +@constructAndPrintInModule +def testShardGrid(): + # Test creating shard grids with different shapes + grid2d = shard.GridOp("grid_2d", [2, 2]) + grid1d = shard.GridOp("grid_1d", [4]) + + # CHECK: shard.grid @grid_2d(shape = 2x2) + # CHECK: shard.grid @grid_1d(shape = 4) + + +# CHECK-LABEL: TEST: testCollectiveOperations +@constructAndPrintInModule +def testCollectiveOperations(): + # Create grid and types + grid_op = shard.GridOp("grid_2x2", [2, 2]) + i32 = IntegerType.get_signless(32) + index_type = IndexType.get() + input_type = RankedTensorType.get([4, 2], i32) + gather_result_type = RankedTensorType.get([4, 4], i32) + + # Create a function to hold the operations + func_type = FunctionType.get([input_type], [input_type]) + test_func = func.FuncOp("test_collectives", func_type) + + with InsertionPoint(test_func.add_entry_block()): + arg = test_func.entry_block.arguments[0] + + gather_op = shard.AllGatherOp( + input=arg, + grid=FlatSymbolRefAttr.get("grid_2x2"), + grid_axes=DenseI16ArrayAttr.get([1]), + gather_axis=IntegerAttr.get(index_type, 1), + result=gather_result_type, + ) + + reduce_op = shard.AllReduceOp( + input=arg, + grid=FlatSymbolRefAttr.get("grid_2x2"), + reduction=shard.ReductionKind.Sum, + result=input_type, + ) + + func.ReturnOp([reduce_op]) + + # CHECK: shard.grid @grid_2x2(shape = 2x2) + # CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32> + # CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32> + # CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32> |