1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
|
# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testSymbolTableInsert
@run
def testSymbolTableInsert():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
m1 = Module.parse(
"""
func.func private @foo()
func.func private @bar()"""
)
m2 = Module.parse(
"""
func.func private @qux()
func.func private @foo()
"foo.bar"() : () -> ()"""
)
symbol_table = SymbolTable(m1.operation)
# CHECK: func private @foo
# CHECK: func private @bar
assert "foo" in symbol_table
print(symbol_table["foo"])
assert "bar" in symbol_table
bar = symbol_table["bar"]
print(symbol_table["bar"])
assert "qux" not in symbol_table
del symbol_table["bar"]
try:
symbol_table.erase(symbol_table["bar"])
except KeyError:
pass
else:
assert False, "expected KeyError"
# CHECK: module
# CHECK: func private @foo()
print(m1)
assert "bar" not in symbol_table
try:
print(bar)
except RuntimeError as e:
if "the operation has been invalidated" not in str(e):
raise
else:
assert False, "expected RuntimeError due to invalidated operation"
qux = m2.body.operations[0]
m1.body.append(qux)
symbol_table.insert(qux)
assert "qux" in symbol_table
# Check that insertion actually renames this symbol in the symbol table.
foo2 = m2.body.operations[0]
m1.body.append(foo2)
updated_name = symbol_table.insert(foo2)
assert foo2.name.value != "foo"
assert foo2.name == updated_name
assert isinstance(updated_name, StringAttr)
# CHECK: module
# CHECK: func private @foo()
# CHECK: func private @qux()
# CHECK: func private @foo{{.*}}
print(m1)
try:
symbol_table.insert(m2.body.operations[0])
except ValueError as e:
if "Expected operation to have a symbol name" not in str(e):
raise
else:
assert False, "exepcted ValueError when adding a non-symbol"
# CHECK-LABEL: testSymbolTableRAUW
@run
def testSymbolTableRAUW():
with Context() as ctx:
m = Module.parse(
"""
func.func private @foo() {
call @bar() : () -> ()
return
}
func.func private @bar()
"""
)
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
# Do renaming just within `foo`.
SymbolTable.set_symbol_name(bar, "bam")
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
# CHECK: call @bam()
# CHECK: func private @bam
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("bam")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
# Do renaming within the module.
SymbolTable.set_symbol_name(bar, "baz")
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
# CHECK: call @baz()
# CHECK: func private @baz
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("baz")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
# CHECK-LABEL: testSymbolTableVisibility
@run
def testSymbolTableVisibility():
with Context() as ctx:
m = Module.parse(
"""
func.func private @foo() {
return
}
"""
)
foo = m.operation.regions[0].blocks[0].operations[0]
# CHECK: Existing visibility: StringAttr("private")
print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
SymbolTable.set_visibility(foo, "public")
# CHECK: func public @foo
print(m)
# CHECK: testWalkSymbolTables
@run
def testWalkSymbolTables():
with Context() as ctx:
m = Module.parse(
"""
module @outer {
module @inner{
}
}
"""
)
def callback(symbol_table_op, uses_visible):
print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
# CHECK: SYMBOL TABLE: True: module @inner
# CHECK: SYMBOL TABLE: True: module @outer
SymbolTable.walk_symbol_tables(m.operation, True, callback)
# Make sure exceptions in the callback are handled.
def error_callback(symbol_table_op, uses_visible):
assert False, "Raised from python"
try:
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
except RuntimeError as e:
# CHECK: GOT EXCEPTION: Exception raised in callback:
# CHECK: AssertionError: Raised from python
print(f"GOT EXCEPTION: {e}")
|