aboutsummaryrefslogtreecommitdiff
path: root/llvm/utils/update_llubi_test_checks.py
blob: 6b2600069835439648949931237a337d0fe3325a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""A test case update script.

This script is a utility to update LLVM 'llubi' based test cases with new
FileCheck patterns.
"""

from __future__ import print_function

from sys import stderr
from traceback import print_exc
import argparse
import os
import subprocess
import sys

from UpdateTestChecks import common


# Invoke the tool that is being tested.
def invoke_tool(exe, cmd_args, ir, check_rc):
    with open(ir) as ir_file:
        substitutions = common.getSubstitutions(ir)
        stdout = subprocess.run(
            exe + " " + common.applySubstitutions(cmd_args, substitutions),
            shell=True,
            stdin=ir_file,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            check=check_rc,
        ).stdout.decode()
    # Fix line endings to unix CR style.
    return stdout.replace("\r\n", "\n")


def update_test(ti: common.TestInfo):
    if len(ti.run_lines) == 0:
        common.warn("No RUN lines found in test: " + ti.path)
        return
    if len(ti.run_lines) > 1:
        common.warn("Multiple RUN lines found in test: " + ti.path)
        common.warn("Only the first RUN line will be processed.")

    l = ti.run_lines[0]
    if "|" not in l:
        common.warn("Skipping unparsable RUN line: " + l)
        return

    commands = [cmd.strip() for cmd in l.split("|")]
    assert len(commands) == 2
    llubi_cmd = commands[-2]
    filecheck_cmd = commands[-1]
    args = llubi_cmd.split(" ")
    llubi_tool = args[0]
    check_rc = True
    if len(args) > 1 and args[0] == "not":
        llubi_tool = args[1]
        check_rc = False

    common.verify_filecheck_prefixes(filecheck_cmd)

    if llubi_tool != "llubi":
        common.warn("Skipping non-llubi RUN line: " + l)
        return

    if not filecheck_cmd.startswith("FileCheck "):
        common.warn("Skipping non-FileChecked RUN line: " + l)
        return

    llubi_args = llubi_cmd[llubi_cmd.index(llubi_tool) + len(llubi_tool) :].strip()
    llubi_args = llubi_args.replace("< %s", "").replace("%s", "").strip()
    prefixes = common.get_check_prefixes(filecheck_cmd)

    common.debug("Extracted llubi cmd:", llubi_tool, llubi_args)
    common.debug("Extracted FileCheck prefixes:", str(prefixes))
    prefix_set = set([prefix for prefix in prefixes])

    raw_tool_output = invoke_tool(
        ti.args.llubi_binary or llubi_tool,
        llubi_args,
        ti.path,
        check_rc=check_rc,
    )
    if ti.args.llubi_binary:
        raw_tool_output = raw_tool_output.replace(ti.args.llubi_binary, llubi_tool)

    output_lines = []
    common.dump_input_lines(output_lines, ti, prefix_set, ";")
    tool_output_lines = raw_tool_output.splitlines()
    if len(tool_output_lines) == 0:
        common.warn("No output from llubi.")
    else:
        output_lines.append("; CHECK: " + tool_output_lines[0])
        output_lines.extend(["; CHECK-NEXT: " + line for line in tool_output_lines[1:]])

    common.debug("Writing %d lines to %s..." % (len(output_lines), ti.path))
    with open(ti.path, "wb") as f:
        f.writelines(["{}\n".format(l).encode("utf-8") for l in output_lines])


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--llubi-binary",
        default=None,
        help='The "llubi" binary to use to generate the test case',
    )
    parser.add_argument(
        "--tool",
        default=None,
        help="Treat the given tool name as an llubi-like tool for which check lines should be generated",
    )
    parser.add_argument("tests", nargs="+")
    initial_args = common.parse_commandline_args(parser)

    script_name = os.path.basename(__file__)

    returncode = 0
    for ti in common.itertests(
        initial_args.tests, parser, script_name="utils/" + script_name
    ):
        try:
            update_test(ti)
        except Exception as e:
            stderr.write(f"Error: Failed to update test {ti.path}\n")
            print_exc()
            returncode = 1
    return returncode


if __name__ == "__main__":
    sys.exit(main())