diff options
Diffstat (limited to 'llvm/utils/UpdateTestChecks/common.py')
-rw-r--r-- | llvm/utils/UpdateTestChecks/common.py | 513 |
1 files changed, 494 insertions, 19 deletions
diff --git a/llvm/utils/UpdateTestChecks/common.py b/llvm/utils/UpdateTestChecks/common.py index a3365fe..f766d54 100644 --- a/llvm/utils/UpdateTestChecks/common.py +++ b/llvm/utils/UpdateTestChecks/common.py @@ -1,6 +1,8 @@ from __future__ import print_function import argparse +import bisect +import collections import copy import glob import itertools @@ -10,7 +12,7 @@ import subprocess import sys import shlex -from typing import List +from typing import List, Mapping, Set ##### Common utilities for update_*test_checks.py @@ -420,6 +422,48 @@ def should_add_line_to_output( return True +def collect_original_check_lines(ti: TestInfo, prefix_set: set): + """ + Collect pre-existing check lines into a dictionary `result` which is + returned. + + result[func_name][prefix] is filled with a list of right-hand-sides of check + lines. + """ + result = {} + + current_function = None + for input_line_info in ti.ro_iterlines(): + input_line = input_line_info.line + if current_function is not None: + if input_line == "": + continue + if input_line.lstrip().startswith(";"): + m = CHECK_RE.match(input_line) + if ( + m is not None + and m.group(1) in prefix_set + and m.group(2) not in ["LABEL", "SAME"] + ): + if m.group(1) not in current_function: + current_function[m.group(1)] = [] + current_function[m.group(1)].append(input_line[m.end() :].strip()) + continue + current_function = None + + m = IR_FUNCTION_RE.match(input_line) + if m is not None: + func_name = m.group(1) + if ti.args.function is not None and func_name != ti.args.function: + # When filtering on a specific function, skip all others. + continue + + assert func_name not in result + current_function = result[func_name] = {} + + return result + + # Perform lit-like substitutions def getSubstitutions(sourcepath): sourcedir = os.path.dirname(sourcepath) @@ -491,7 +535,7 @@ RUN_LINE_RE = re.compile(r"^\s*(?://|[;#])\s*RUN:\s*(.*)$") CHECK_PREFIX_RE = re.compile(r"--?check-prefix(?:es)?[= ](\S+)") PREFIX_RE = re.compile("^[a-zA-Z0-9_-]+$") CHECK_RE = re.compile( - r"^\s*(?://|[;#])\s*([^:]+?)(?:-NEXT|-NOT|-DAG|-LABEL|-SAME|-EMPTY)?:" + r"^\s*(?://|[;#])\s*([^:]+?)(?:-(NEXT|NOT|DAG|LABEL|SAME|EMPTY))?:" ) CHECK_SAME_RE = re.compile(r"^\s*(?://|[;#])\s*([^:]+?)(?:-SAME)?:") @@ -1187,20 +1231,325 @@ def may_clash_with_default_check_prefix_name(check_prefix, var): ) +def find_diff_matching(lhs: List[str], rhs: List[str]) -> List[tuple]: + """ + Find a large ordered matching between strings in lhs and rhs. + + Think of this as finding the *unchanged* lines in a diff, where the entries + of lhs and rhs are lines of the files being diffed. + + Returns a list of matched (lhs_idx, rhs_idx) pairs. + """ + + if not lhs or not rhs: + return [] + + # Collect matches in reverse order. + matches = [] + + # First, collect a set of candidate matching edges. We limit this to a + # constant multiple of the input size to avoid quadratic runtime. + patterns = collections.defaultdict(lambda: ([], [])) + + for idx in range(len(lhs)): + patterns[lhs[idx]][0].append(idx) + for idx in range(len(rhs)): + patterns[rhs[idx]][1].append(idx) + + multiple_patterns = [] + + candidates = [] + for pattern in patterns.values(): + if not pattern[0] or not pattern[1]: + continue + + if len(pattern[0]) == len(pattern[1]) == 1: + candidates.append((pattern[0][0], pattern[1][0])) + else: + multiple_patterns.append(pattern) + + multiple_patterns.sort(key=lambda pattern: len(pattern[0]) * len(pattern[1])) + + for pattern in multiple_patterns: + if len(candidates) + len(pattern[0]) * len(pattern[1]) > 2 * ( + len(lhs) + len(rhs) + ): + break + for lhs_idx in pattern[0]: + for rhs_idx in pattern[1]: + candidates.append((lhs_idx, rhs_idx)) + + if not candidates: + # The LHS and RHS either share nothing in common, or lines are just too + # identical. In that case, let's give up and not match anything. + return [] + + # Compute a maximal crossing-free matching via an algorithm that is + # inspired by a mixture of dynamic programming and line-sweeping in + # discrete geometry. + # + # I would be surprised if this algorithm didn't exist somewhere in the + # literature, but I found it without consciously recalling any + # references, so you'll have to make do with the explanation below. + # Sorry. + # + # The underlying graph is bipartite: + # - nodes on the LHS represent lines in the original check + # - nodes on the RHS represent lines in the new (updated) check + # + # Nodes are implicitly sorted by the corresponding line number. + # Edges (unique_matches) are sorted by the line number on the LHS. + # + # Here's the geometric intuition for the algorithm. + # + # * Plot the edges as points in the plane, with the original line + # number on the X axis and the updated line number on the Y axis. + # * The goal is to find a longest "chain" of points where each point + # is strictly above and to the right of the previous point. + # * The algorithm proceeds by sweeping a vertical line from left to + # right. + # * The algorithm maintains a table where `table[N]` answers the + # question "What is currently the 'best' way to build a chain of N+1 + # points to the left of the vertical line". Here, 'best' means + # that the last point of the chain is a as low as possible (minimal + # Y coordinate). + # * `table[N]` is `(y, point_idx)` where `point_idx` is the index of + # the last point in the chain and `y` is its Y coordinate + # * A key invariant is that the Y values in the table are + # monotonically increasing + # * Thanks to these properties, the table can be used to answer the + # question "What is the longest chain that can be built to the left + # of the vertical line using only points below a certain Y value", + # using a binary search over the table. + # * The algorithm also builds a backlink structure in which every point + # links back to the previous point on a best (longest) chain ending + # at that point + # + # The core loop of the algorithm sweeps the line and updates the table + # and backlink structure for every point that we cross during the sweep. + # Therefore, the algorithm is trivially O(M log M) in the number of + # points. + candidates.sort(key=lambda candidate: (candidate[0], -candidate[1])) + + backlinks = [] + table = [] + for _, rhs_idx in candidates: + candidate_idx = len(backlinks) + ti = bisect.bisect_left(table, rhs_idx, key=lambda entry: entry[0]) + + # Update the table to record a best chain ending in the current point. + # There always is one, and if any of the previously visited points had + # a higher Y coordinate, then there is always a previously recorded best + # chain that can be improved upon by using the current point. + # + # There is only one case where there is some ambiguity. If the + # pre-existing entry table[ti] has the same Y coordinate / rhs_idx as + # the current point (this can only happen if the same line appeared + # multiple times on the LHS), then we could choose to keep the + # previously recorded best chain instead. That would bias the algorithm + # differently but should have no systematic impact on the quality of the + # result. + if ti < len(table): + table[ti] = (rhs_idx, candidate_idx) + else: + table.append((rhs_idx, candidate_idx)) + if ti > 0: + backlinks.append(table[ti - 1][1]) + else: + backlinks.append(None) + + # Commit to names in the matching by walking the backlinks. Recursively + # attempt to fill in more matches in-betweem. + match_idx = table[-1][1] + while match_idx is not None: + current = candidates[match_idx] + matches.append(current) + match_idx = backlinks[match_idx] + + matches.reverse() + return matches + + +VARIABLE_TAG = "[[@@]]" +METAVAR_RE = re.compile(r"\[\[([A-Z0-9_]+)(?::[^]]+)?\]\]") +NUMERIC_SUFFIX_RE = re.compile(r"[0-9]*$") + + +class CheckValueInfo: + def __init__( + self, + nameless_value: NamelessValue, + var: str, + prefix: str, + ): + self.nameless_value = nameless_value + self.var = var + self.prefix = prefix + + +# Represent a check line in a way that allows us to compare check lines while +# ignoring some or all of the FileCheck variable names. +class CheckLineInfo: + def __init__(self, line, values): + # Line with all FileCheck variable name occurrences replaced by VARIABLE_TAG + self.line: str = line + + # Information on each FileCheck variable name occurrences in the line + self.values: List[CheckValueInfo] = values + + def __repr__(self): + return f"CheckLineInfo(line={self.line}, self.values={self.values})" + + +def remap_metavar_names( + old_line_infos: List[CheckLineInfo], + new_line_infos: List[CheckLineInfo], + committed_names: Set[str], +) -> Mapping[str, str]: + """ + Map all FileCheck variable names that appear in new_line_infos to new + FileCheck variable names in an attempt to reduce the diff from old_line_infos + to new_line_infos. + + This is done by: + * Matching old check lines and new check lines using a diffing algorithm + applied after replacing names with wildcards. + * Committing to variable names such that the matched lines become equal + (without wildcards) if possible + * This is done recursively to handle cases where many lines are equal + after wildcard replacement + """ + # Initialize uncommitted identity mappings + new_mapping = {} + for line in new_line_infos: + for value in line.values: + new_mapping[value.var] = value.var + + # Recursively commit to the identity mapping or find a better one + def recurse(old_begin, old_end, new_begin, new_end): + if old_begin == old_end or new_begin == new_end: + return + + # Find a matching of lines where uncommitted names are replaced + # with a placeholder. + def diffify_line(line, mapper): + values = [] + for value in line.values: + mapped = mapper(value.var) + values.append(mapped if mapped in committed_names else "?") + return line.line.strip() + " @@@ " + " @ ".join(values) + + lhs_lines = [ + diffify_line(line, lambda x: x) + for line in old_line_infos[old_begin:old_end] + ] + rhs_lines = [ + diffify_line(line, lambda x: new_mapping[x]) + for line in new_line_infos[new_begin:new_end] + ] + + candidate_matches = find_diff_matching(lhs_lines, rhs_lines) + + # Apply commits greedily on a match-by-match basis + matches = [(-1, -1)] + committed_anything = False + for lhs_idx, rhs_idx in candidate_matches: + lhs_line = old_line_infos[lhs_idx] + rhs_line = new_line_infos[rhs_idx] + + local_commits = {} + + for lhs_value, rhs_value in zip(lhs_line.values, rhs_line.values): + if new_mapping[rhs_value.var] in committed_names: + # The new value has already been committed. If it was mapped + # to the same name as the original value, we can consider + # committing other values from this line. Otherwise, we + # should ignore this line. + if new_mapping[rhs_value.var] == lhs_value.var: + continue + else: + break + + if rhs_value.var in local_commits: + # Same, but for a possible commit happening on the same line + if local_commits[rhs_value.var] == lhs_value.var: + continue + else: + break + + if lhs_value.var in committed_names: + # We can't map this value because the name we would map it to has already been + # committed for something else. Give up on this line. + break + + local_commits[rhs_value.var] = lhs_value.var + else: + # No reason not to add any commitments for this line + for rhs_var, lhs_var in local_commits.items(): + new_mapping[rhs_var] = lhs_var + committed_names.add(lhs_var) + committed_anything = True + + if ( + lhs_var != rhs_var + and lhs_var in new_mapping + and new_mapping[lhs_var] == lhs_var + ): + new_mapping[lhs_var] = "conflict_" + lhs_var + + matches.append((lhs_idx, rhs_idx)) + + matches.append((old_end, new_end)) + + # Recursively handle sequences between matches + if committed_anything: + for (lhs_prev, rhs_prev), (lhs_next, rhs_next) in zip(matches, matches[1:]): + recurse(lhs_prev + 1, lhs_next, rhs_prev + 1, rhs_next) + + recurse(0, len(old_line_infos), 0, len(new_line_infos)) + + # Commit to remaining names and resolve conflicts + for new_name, mapped_name in new_mapping.items(): + if mapped_name in committed_names: + continue + if not mapped_name.startswith("conflict_"): + assert mapped_name == new_name + committed_names.add(mapped_name) + + for new_name, mapped_name in new_mapping.items(): + if mapped_name in committed_names: + continue + assert mapped_name.startswith("conflict_") + + m = NUMERIC_SUFFIX_RE.search(new_name) + base_name = new_name[: m.start()] + suffix = int(new_name[m.start() :]) if m.start() != m.end() else 1 + while True: + candidate = f"{base_name}{suffix}" + if candidate not in committed_names: + new_mapping[new_name] = candidate + committed_names.add(candidate) + break + suffix += 1 + + return new_mapping + + def generalize_check_lines_common( lines, is_analyze, vars_seen, global_vars_seen, nameless_values, - nameless_value_regex, + nameless_value_regex: re.Pattern, is_asm, preserve_names, + original_check_lines=None, ): # This gets called for each match that occurs in # a line. We transform variables we haven't seen # into defs, and variables we have seen into uses. - def transform_line_vars(match): + def transform_line_vars(match, transform_locals=True): var = get_name_from_ir_value_match(match) nameless_value = get_nameless_value_from_match(match, nameless_values) if may_clash_with_default_check_prefix_name(nameless_value.check_prefix, var): @@ -1210,6 +1559,8 @@ def generalize_check_lines_common( ) key = (var, nameless_value.check_key) is_local_def = nameless_value.is_local_def_ir_value() + if is_local_def and not transform_locals: + return None if is_local_def and key in vars_seen: rv = nameless_value.get_value_use(var, match) elif not is_local_def and key in global_vars_seen: @@ -1228,13 +1579,15 @@ def generalize_check_lines_common( # including the commas and spaces. return match.group(1) + rv + match.group(match.lastindex) - lines_with_def = [] + def transform_non_local_line_vars(match): + return transform_line_vars(match, False) + multiple_braces_re = re.compile(r"({{+)|(}}+)") def escape_braces(match_obj): return '{{' + re.escape(match_obj.group(0)) + '}}' - for i, line in enumerate(lines): - if not is_asm and not is_analyze: + if not is_asm and not is_analyze: + for i, line in enumerate(lines): # An IR variable named '%.' matches the FileCheck regex string. line = line.replace("%.", "%dot") for regex in _global_hex_value_regex: @@ -1252,25 +1605,136 @@ def generalize_check_lines_common( # Ignore any comments, since the check lines will too. scrubbed_line = SCRUB_IR_COMMENT_RE.sub(r"", line) lines[i] = scrubbed_line - if not preserve_names: - # It can happen that two matches are back-to-back and for some reason sub - # will not replace both of them. For now we work around this by - # substituting until there is no more match. - changed = True - while changed: - (lines[i], changed) = nameless_value_regex.subn( - transform_line_vars, lines[i], count=1 - ) - if is_analyze: + + if not preserve_names: + if is_asm: + for i, _ in enumerate(lines): + # It can happen that two matches are back-to-back and for some reason sub + # will not replace both of them. For now we work around this by + # substituting until there is no more match. + changed = True + while changed: + (lines[i], changed) = nameless_value_regex.subn( + transform_line_vars, lines[i], count=1 + ) + else: + # LLVM IR case. Start by handling global meta variables (global IR variables, + # metadata, attributes) + for i, _ in enumerate(lines): + start = 0 + while True: + m = nameless_value_regex.search(lines[i][start:]) + if m is None: + break + start += m.start() + sub = transform_non_local_line_vars(m) + if sub is not None: + lines[i] = ( + lines[i][:start] + sub + lines[i][start + len(m.group(0)) :] + ) + start += 1 + + # Collect information about new check lines and original check lines (if any) + new_line_infos = [] + for line in lines: + filtered_line = "" + values = [] + while True: + m = nameless_value_regex.search(line) + if m is None: + filtered_line += line + break + + var = get_name_from_ir_value_match(m) + nameless_value = get_nameless_value_from_match(m, nameless_values) + var = nameless_value.get_value_name( + var, nameless_value.check_prefix + ) + + # Replace with a [[@@]] tag, but be sure to keep the spaces and commas. + filtered_line += ( + line[: m.start()] + + m.group(1) + + VARIABLE_TAG + + m.group(m.lastindex) + ) + line = line[m.end() :] + values.append( + CheckValueInfo( + nameless_value=nameless_value, + var=var, + prefix=nameless_value.get_ir_prefix_from_ir_value_match(m)[ + 0 + ], + ) + ) + new_line_infos.append(CheckLineInfo(filtered_line, values)) + + orig_line_infos = [] + for line in original_check_lines or []: + filtered_line = "" + values = [] + while True: + m = METAVAR_RE.search(line) + if m is None: + filtered_line += line + break + + # Replace with a [[@@]] tag, but be sure to keep the spaces and commas. + filtered_line += line[: m.start()] + VARIABLE_TAG + line = line[m.end() :] + values.append( + CheckValueInfo( + nameless_value=None, + var=m.group(1), + prefix=None, + ) + ) + orig_line_infos.append(CheckLineInfo(filtered_line, values)) + + # Compute the variable name mapping + committed_names = set(vars_seen) + + mapping = remap_metavar_names( + orig_line_infos, new_line_infos, committed_names + ) + + for i, line_info in enumerate(new_line_infos): + line_template = line_info.line + line = "" + + for value in line_info.values: + idx = line_template.find(VARIABLE_TAG) + line += line_template[:idx] + line_template = line_template[idx + len(VARIABLE_TAG) :] + + key = (mapping[value.var], nameless_value.check_key) + is_local_def = nameless_value.is_local_def_ir_value() + if is_local_def: + if mapping[value.var] in vars_seen: + line += f"[[{mapping[value.var]}]]" + else: + line += f"[[{mapping[value.var]}:{value.prefix}{value.nameless_value.get_ir_regex()}]]" + vars_seen.add(mapping[value.var]) + else: + raise RuntimeError("not implemented") + + line += line_template + + lines[i] = line + + if is_analyze: + for i, _ in enumerate(lines): # Escape multiple {{ or }} as {{}} denotes a FileCheck regex. scrubbed_line = multiple_braces_re.sub(escape_braces, lines[i]) lines[i] = scrubbed_line + return lines # Replace IR value defs and uses with FileCheck variables. def generalize_check_lines( - lines, is_analyze, vars_seen, global_vars_seen, preserve_names + lines, is_analyze, vars_seen, global_vars_seen, preserve_names, original_check_lines ): return generalize_check_lines_common( lines, @@ -1281,6 +1745,7 @@ def generalize_check_lines( IR_VALUE_RE, False, preserve_names, + original_check_lines=original_check_lines, ) @@ -1337,6 +1802,7 @@ def add_checks( global_vars_seen_dict, is_filtered, preserve_names=False, + original_check_lines: Mapping[str, List[str]] = {}, ): # prefix_exclusions are prefixes we cannot use to print the function because it doesn't exist in run lines that use these prefixes as well. prefix_exclusions = set() @@ -1409,6 +1875,7 @@ def add_checks( vars_seen, global_vars_seen, preserve_names, + original_check_lines=[], )[0] func_name_separator = func_dict[checkprefix][func_name].func_name_separator if "[[" in args_and_sig: @@ -1516,7 +1983,12 @@ def add_checks( # to variable naming fashions. else: func_body = generalize_check_lines( - func_body, False, vars_seen, global_vars_seen, preserve_names + func_body, + False, + vars_seen, + global_vars_seen, + preserve_names, + original_check_lines=original_check_lines.get(checkprefix), ) # This could be selectively enabled with an optional invocation argument. @@ -1578,6 +2050,7 @@ def add_ir_checks( version, global_vars_seen_dict, is_filtered, + original_check_lines={}, ): # Label format is based on IR string. if function_sig and version > 1: @@ -1602,6 +2075,7 @@ def add_ir_checks( global_vars_seen_dict, is_filtered, preserve_names, + original_check_lines=original_check_lines, ) @@ -1890,6 +2364,7 @@ def get_autogennote_suffix(parser, args): "llvm_bin", "verbose", "force_update", + "reset_variable_names", ): continue value = getattr(args, action.dest) |