aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/ir/module.py
blob: 6065e59fd6ed9008bab6dbada9b208296489f53e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# RUN: %PYTHON %s | FileCheck %s

import gc
from tempfile import NamedTemporaryFile
from mlir.ir import *


def run(f):
    print("\nTEST:", f.__name__)
    f()
    gc.collect()
    assert Context._get_live_count() == 0
    return f


# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
@run
def testParseSuccess():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert module.context is ctx
    print("CLEAR CONTEXT")
    ctx = None  # Ensure that module captures the context.
    gc.collect()
    module.dump()  # Just outputs to stderr. Verifies that it functions.
    print(str(module))


# Verify successful parse from file.
# CHECK-LABEL: TEST: testParseFromFileSuccess
# CHECK: module @successfulParse
@run
def testParseFromFileSuccess():
    ctx = Context()
    with NamedTemporaryFile(mode="w") as tmp_file:
        tmp_file.write(r"""module @successfulParse {}""")
        tmp_file.flush()
        module = Module.parseFile(tmp_file.name, ctx)
        assert module.context is ctx
        print("CLEAR CONTEXT")
        ctx = None  # Ensure that module captures the context.
        gc.collect()
        module.operation.verify()
        print(str(module))


# Verify parse error.
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: <
# CHECK:   Unable to parse module assembly:
# CHECK:   error: "-":1:1: expected operation name in quotes
# CHECK: >
@run
def testParseError():
    ctx = Context()
    try:
        module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
    except MLIRError as e:
        print(f"testParseError: <{e}>")
    else:
        print("Exception not produced")


# Verify successful parse.
# CHECK-LABEL: TEST: testCreateEmpty
# CHECK: module {
@run
def testCreateEmpty():
    ctx = Context()
    loc = Location.unknown(ctx)
    module = Module.create(loc)
    print("CLEAR CONTEXT")
    ctx = None  # Ensure that module captures the context.
    gc.collect()
    print(str(module))


# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
# CHECK-LABEL: TEST: testRoundtripUnicode
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripUnicode():
    ctx = Context()
    module = Module.parse(
        r"""
    func.func private @roundtripUnicode() attributes { foo = "😊" }
  """,
        ctx,
    )
    print(str(module))


# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
# CHECK-LABEL: TEST: testRoundtripBinary
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripBinary():
    with Context():
        module = Module.parse(
            r"""
      func.func private @roundtripUnicode() attributes { foo = "😊" }
    """
        )
        binary_asm = module.operation.get_asm(binary=True)
        assert isinstance(binary_asm, bytes)
        module = Module.parse(binary_asm)
        print(module)


# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
@run
def testModuleOperation():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert ctx._get_live_module_count() == 1
    op1 = module.operation
    assert ctx._get_live_operation_count() == 1
    live_ops = ctx._get_live_operation_objects()
    assert len(live_ops) == 1
    assert live_ops[0] is op1
    live_ops = None
    # CHECK: module @successfulParse
    print(op1)

    # Ensure that operations are the same on multiple calls.
    op2 = module.operation
    assert ctx._get_live_operation_count() == 1
    assert op1 is op2

    # Test live operation clearing.
    op1 = module.operation
    assert ctx._get_live_operation_count() == 1
    num_invalidated = ctx._clear_live_operations()
    assert num_invalidated == 1
    assert ctx._get_live_operation_count() == 0
    op1 = None
    gc.collect()
    op1 = module.operation

    # Ensure that if module is de-referenced, the operations are still valid.
    module = None
    gc.collect()
    print(op1)

    # Collect and verify lifetime.
    op1 = None
    op2 = None
    gc.collect()
    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
    assert ctx._get_live_operation_count() == 0
    assert ctx._get_live_module_count() == 0


# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert ctx._get_live_module_count() == 1
    # CHECK: "mlir.ir.Module._CAPIPtr"
    module_capsule = module._CAPIPtr
    print(module_capsule)
    module_dup = Module._CAPICreate(module_capsule)
    assert module is module_dup
    assert module_dup.context is ctx
    # Gc and verify destructed.
    module = None
    module_capsule = None
    module_dup = None
    gc.collect()
    assert ctx._get_live_module_count() == 0