diff options
-rwxr-xr-x | mlir/utils/generate-test-checks.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py index f77c968..14a790e 100755 --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -220,12 +220,19 @@ def process_source_lines(source_lines, note, args): source_segments[-1].append(line + "\n") return source_segments -def process_attribute_definition(line, attribute_namer, output): + +def process_attribute_definition(line, attribute_namer): m = ATTR_DEF_RE.match(line) if m: attribute_name = attribute_namer.generate_name(m.group(1)) - line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' - output.write(line) + return ( + "// CHECK: #[[" + + attribute_name + + ":.+]] =" + + line[len(m.group(0)) :] + + "\n" + ) + return None def process_attribute_references(line, attribute_namer): @@ -340,6 +347,9 @@ def main(): variable_namer = VariableNamer(args.variable_names) attribute_namer = AttributeNamer(args.attribute_names) + # Store attribute definitions to emit at appropriate scope + pending_attr_defs = [] + # Process lines for input_line in input_lines: if not input_line: @@ -350,8 +360,9 @@ def main(): if input_line.startswith("// -----"): continue - # Check if this is an attribute definition and process it - process_attribute_definition(input_line, attribute_namer, output) + if ATTR_DEF_RE.match(input_line): + pending_attr_defs.append(input_line) + continue # Lines with blocks begin with a ^. These lines have a trailing comment # that needs to be stripped. @@ -407,6 +418,13 @@ def main(): output_line += process_line(ssa_split[1:], variable_namer) else: + # Emit any pending attribute definitions at the start of this scope + for attr in pending_attr_defs: + attr_line = process_attribute_definition(attr, attribute_namer) + if attr_line: + output_segments[-1].append(attr_line) + pending_attr_defs.clear() + # Output the first line chunk that does not contain an SSA name for the # label. output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" |