aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects/shard.py
blob: c3ba60529634541f9ec8b51264a605e5ca7840e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import shard
from mlir.dialects import func


def constructAndPrintInModule(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f()
        print(module)
    module.operation.verify()
    return f


# CHECK-LABEL: TEST: testShardGrid
@constructAndPrintInModule
def testShardGrid():
    # Test creating shard grids with different shapes
    grid2d = shard.GridOp("grid_2d", [2, 2])
    grid1d = shard.GridOp("grid_1d", [4])

    # CHECK: shard.grid @grid_2d(shape = 2x2)
    # CHECK: shard.grid @grid_1d(shape = 4)


# CHECK-LABEL: TEST: testCollectiveOperations
@constructAndPrintInModule
def testCollectiveOperations():
    # Create grid and types
    grid_op = shard.GridOp("grid_2x2", [2, 2])
    i32 = IntegerType.get_signless(32)
    index_type = IndexType.get()
    input_type = RankedTensorType.get([4, 2], i32)
    gather_result_type = RankedTensorType.get([4, 4], i32)

    # Create a function to hold the operations
    func_type = FunctionType.get([input_type], [input_type])
    test_func = func.FuncOp("test_collectives", func_type)

    with InsertionPoint(test_func.add_entry_block()):
        arg = test_func.entry_block.arguments[0]

        gather_op = shard.AllGatherOp(
            input=arg,
            grid=FlatSymbolRefAttr.get("grid_2x2"),
            grid_axes=DenseI16ArrayAttr.get([1]),
            gather_axis=IntegerAttr.get(index_type, 1),
            result=gather_result_type,
        )

        reduce_op = shard.AllReduceOp(
            input=arg,
            grid=FlatSymbolRefAttr.get("grid_2x2"),
            reduction=shard.ReductionKind.Sum,
            result=input_type,
        )

        func.ReturnOp([reduce_op])

    # CHECK: shard.grid @grid_2x2(shape = 2x2)
    # CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32>
    # CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
    # CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32>