aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects/irdl.py
blob: ed62db9b6996866f47ad41741a513aab3a75adfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# RUN: %PYTHON %s 2>&1 | FileCheck %s

from mlir.ir import *
from mlir.dialects.irdl import *
import sys


def run(f):
    print("\nTEST:", f.__name__, file=sys.stderr)
    f()


# CHECK: TEST: testIRDL
@run
def testIRDL():
    with Context() as ctx, Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            irdl_test = dialect("irdl_test")
            with InsertionPoint(irdl_test.body):
                op = operation_("test_op")
                with InsertionPoint(op.body):
                    f32 = is_(TypeAttr.get(F32Type.get()))
                    operands_([f32], ["input"], [Variadicity.single])
                type1 = type_("type1")
                with InsertionPoint(type1.body):
                    f32 = is_(TypeAttr.get(F32Type.get()))
                    parameters([f32], ["val"])
                attr1 = attribute("attr1")
                with InsertionPoint(attr1.body):
                    test = is_(StringAttr.get("test"))
                    parameters([test], ["val"])

        # CHECK: module {
        # CHECK:   irdl.dialect @irdl_test {
        # CHECK:     irdl.operation @test_op {
        # CHECK:       %0 = irdl.is f32
        # CHECK:       irdl.operands(input: %0)
        # CHECK:     }
        # CHECK:     irdl.type @type1 {
        # CHECK:       %0 = irdl.is f32
        # CHECK:       irdl.parameters(val: %0)
        # CHECK:     }
        # CHECK:     irdl.attribute @attr1 {
        # CHECK:       %0 = irdl.is "test"
        # CHECK:       irdl.parameters(val: %0)
        # CHECK:     }
        # CHECK:   }
        # CHECK: }
        module.operation.verify()
        module.dump()

        load_dialects(module)

        m = Module.parse(
            """
          module {
            %a = arith.constant 1.0 : f32
            "irdl_test.test_op"(%a) : (f32) -> ()
          }
        """
        )
        # CHECK: module {
        # CHECK:   "irdl_test.test_op"(%cst) : (f32) -> ()
        # CHECK: }
        m.dump()