aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xmlir/utils/generate-test-checks.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index f77c968..14a790e 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -220,12 +220,19 @@ def process_source_lines(source_lines, note, args):
source_segments[-1].append(line + "\n")
return source_segments
-def process_attribute_definition(line, attribute_namer, output):
+
+def process_attribute_definition(line, attribute_namer):
m = ATTR_DEF_RE.match(line)
if m:
attribute_name = attribute_namer.generate_name(m.group(1))
- line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
- output.write(line)
+ return (
+ "// CHECK: #[["
+ + attribute_name
+ + ":.+]] ="
+ + line[len(m.group(0)) :]
+ + "\n"
+ )
+ return None
def process_attribute_references(line, attribute_namer):
@@ -340,6 +347,9 @@ def main():
variable_namer = VariableNamer(args.variable_names)
attribute_namer = AttributeNamer(args.attribute_names)
+ # Store attribute definitions to emit at appropriate scope
+ pending_attr_defs = []
+
# Process lines
for input_line in input_lines:
if not input_line:
@@ -350,8 +360,9 @@ def main():
if input_line.startswith("// -----"):
continue
- # Check if this is an attribute definition and process it
- process_attribute_definition(input_line, attribute_namer, output)
+ if ATTR_DEF_RE.match(input_line):
+ pending_attr_defs.append(input_line)
+ continue
# Lines with blocks begin with a ^. These lines have a trailing comment
# that needs to be stripped.
@@ -407,6 +418,13 @@ def main():
output_line += process_line(ssa_split[1:], variable_namer)
else:
+ # Emit any pending attribute definitions at the start of this scope
+ for attr in pending_attr_defs:
+ attr_line = process_attribute_definition(attr, attribute_namer)
+ if attr_line:
+ output_segments[-1].append(attr_line)
+ pending_attr_defs.clear()
+
# Output the first line chunk that does not contain an SSA name for the
# label.
output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"