aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/multithreaded_tests.py
blob: 6e1a66834687291e94eb7c164fe187c66511e929 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
# RUN: %PYTHON %s
"""
This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
Tests can be run using pytest:
```bash
python3.13t -mpytest -vvv multithreaded_tests.py
```

IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
and passing if no warnings reported by TSAN and failing otherwise.


Details on the generated tests and execution:
1) Multi-threaded execution: all generated tests are executed independently by
a pool of threads, running each test multiple times, see @multi_threaded for details

2) Tests generation: we use existing tests: test/python/ir/*.py,
test/python/dialects/*.py, etc to generate multi-threaded tests.
In details, we perform the following:
a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
In order to import the test file as python module, we remove all executing functions, like
`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.


Observed warnings reported by TSAN.

CPython and free-threading known data-races:
1) ctypes related races: https://github.com/python/cpython/issues/127945
2) LLVM related data-races, llvm::raw_ostream is not thread-safe
- mlir pass manager
- dialects/transform_interpreter.py
- ir/diagnostic_handler.py
- ir/module.py
3) Dialect gpu module-to-binary method is unsafe
"""
import concurrent.futures
import gc
import importlib.util
import os
import sys
import threading
import tempfile
import unittest

from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional, List

import mlir.dialects.arith as arith
from mlir.dialects import transform
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint


def import_from_path(module_name: str, file_path: Path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def copy_and_update(src_filepath: Path, dst_filepath: Path):
    # We should remove all calls like `run(testMethod)`
    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
        while True:
            src_line = reader.readline()
            if len(src_line) == 0:
                break
            skip_lines = [
                "run(",
                "@run",
                "@constructAndPrintInModule",
                "run_apply_patterns(",
                "@run_apply_patterns",
                "@test_in_context",
                "@construct_and_print_in_module",
            ]
            if any(src_line.startswith(line) for line in skip_lines):
                continue
            writer.write(src_line)


# Helper run functions
def run(f):
    f()


def run_with_context_and_location(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        f()
    return f


def run_with_insertion_point(f):
    print("\nTEST:", f.__name__)
    with Context() as ctx, Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f(ctx)
        print(module)


def run_with_insertion_point_v2(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f()
        print(module)
    return f


def run_with_insertion_point_v3(f):
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            print("\nTEST:", f.__name__)
            f(module)
        print(module)
    return f


def run_with_insertion_point_v4(f):
    print("\nTEST:", f.__name__)
    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        module = Module.create()
        with InsertionPoint(module.body):
            f()
    return f


def run_apply_patterns(f):
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            sequence = transform.SequenceOp(
                transform.FailurePropagationMode.Propagate,
                [],
                transform.AnyOpType.get(),
            )
            with InsertionPoint(sequence.body):
                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
                with InsertionPoint(apply.patterns):
                    f()
                transform.YieldOp()
        print("\nTEST:", f.__name__)
        print(module)
    return f


def run_transform_tensor_ext(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            sequence = transform.SequenceOp(
                transform.FailurePropagationMode.Propagate,
                [],
                transform.AnyOpType.get(),
            )
            with InsertionPoint(sequence.body):
                f(sequence.bodyTarget)
                transform.YieldOp()
        print(module)
    return f


def run_transform_structured_ext(f):
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            print("\nTEST:", f.__name__)
            f()
        module.operation.verify()
        print(module)
    return f


def run_construct_and_print_in_module(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            module = f(module)
        if module is not None:
            print(module)
    return f


TEST_MODULES = [
    ("execution_engine", run),
    ("pass_manager", run),
    ("dialects/affine", run_with_insertion_point_v2),
    ("dialects/func", run_with_insertion_point_v2),
    ("dialects/arith_dialect", run),
    ("dialects/arith_llvm", run),
    ("dialects/async_dialect", run),
    ("dialects/builtin", run),
    ("dialects/cf", run_with_insertion_point_v4),
    ("dialects/complex_dialect", run),
    ("dialects/func", run_with_insertion_point_v2),
    ("dialects/index_dialect", run_with_insertion_point),
    ("dialects/llvm", run_with_insertion_point_v2),
    ("dialects/math_dialect", run),
    ("dialects/memref", run),
    ("dialects/ml_program", run_with_insertion_point_v2),
    ("dialects/nvgpu", run_with_insertion_point_v2),
    ("dialects/nvvm", run_with_insertion_point_v2),
    ("dialects/ods_helpers", run),
    ("dialects/openmp_ops", run_with_insertion_point_v2),
    ("dialects/pdl_ops", run_with_insertion_point_v2),
    # ("dialects/python_test", run),  # TODO: Need to pass pybind11 or nanobind argv
    ("dialects/quant", run),
    ("dialects/rocdl", run_with_insertion_point_v2),
    ("dialects/scf", run_with_insertion_point_v2),
    ("dialects/shape", run),
    ("dialects/spirv_dialect", run),
    ("dialects/tensor", run),
    # ("dialects/tosa", ),  # Nothing to test
    ("dialects/transform_bufferization_ext", run_with_insertion_point_v2),
    # ("dialects/transform_extras", ),  # Needs a more complicated execution schema
    ("dialects/transform_gpu_ext", run_transform_tensor_ext),
    (
        "dialects/transform_interpreter",
        run_with_context_and_location,
        ["print_", "transform_options", "failed", "include"],
    ),
    (
        "dialects/transform_loop_ext",
        run_with_insertion_point_v2,
        ["loopOutline"],
    ),
    ("dialects/transform_memref_ext", run_with_insertion_point_v2),
    ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2),
    ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
    ("dialects/transform_structured_ext", run_transform_structured_ext),
    ("dialects/transform_tensor_ext", run_transform_tensor_ext),
    (
        "dialects/transform_vector_ext",
        run_apply_patterns,
        ["configurable_patterns"],
    ),
    ("dialects/transform", run_with_insertion_point_v3),
    ("dialects/vector", run_with_context_and_location),
    ("dialects/gpu/dialect", run_with_context_and_location),
    ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
    ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
    ("dialects/linalg/ops", run),
    # TO ADD: No proper tests in this dialects/linalg/opsdsl/*
    # ("dialects/linalg/opsdsl/*", ...),
    ("dialects/sparse_tensor/dialect", run),
    ("dialects/sparse_tensor/passes", run),
    ("integration/dialects/pdl", run_construct_and_print_in_module),
    ("integration/dialects/transform", run_construct_and_print_in_module),
    ("integration/dialects/linalg/opsrun", run),
    ("ir/affine_expr", run),
    ("ir/affine_map", run),
    ("ir/array_attributes", run),
    ("ir/attributes", run),
    ("ir/blocks", run),
    ("ir/builtin_types", run),
    ("ir/context_managers", run),
    ("ir/debug", run),
    ("ir/diagnostic_handler", run),
    ("ir/dialects", run),
    ("ir/exception", run),
    ("ir/insertion_point", run),
    ("ir/integer_set", run),
    ("ir/location", run),
    ("ir/module", run),
    ("ir/operation", run),
    ("ir/symbol_table", run),
    ("ir/value", run),
]

TESTS_TO_SKIP = [
    "test_execution_engine__testNanoTime_multi_threaded",  # testNanoTime can't run in multiple threads, even with GIL
    "test_execution_engine__testSharedLibLoad_multi_threaded",  # testSharedLibLoad can't run in multiple threads, even with GIL
    "test_dialects_arith_dialect__testArithValue_multi_threaded",  # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
    "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded",  # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
    "test_ir_value__testValueCasters_multi_threaded",  # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
    # tests indirectly calling thread-unsafe llvm::raw_ostream
    "test_execution_engine__testInvalidModule_multi_threaded",  # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
    "test_pass_manager__testPrintIrAfterAll_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
    "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded",  # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
    "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
    "test_pass_manager__testPrintIrTree_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
    "test_pass_manager__testRunPipeline_multi_threaded",  # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
    "test_dialects_transform_interpreter__include_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
    "test_dialects_transform_interpreter__transform_options_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
    "test_dialects_transform_interpreter__print_self_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
    "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded",  # mlirEmitError calls thread-unsafe llvm::raw_ostream
    "test_ir_module__testParseSuccess_multi_threaded",  # mlirOperationDump calls thread-unsafe llvm::raw_ostream
    # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
    # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
    "test_execution_engine__testCapsule_multi_threaded",
    "test_execution_engine__testDumpToObjectFile_multi_threaded",
]

TESTS_TO_XFAIL = [
    # execution_engine tests:
    # - ctypes related data-races: https://github.com/python/cpython/issues/127945
    "test_execution_engine__testBF16Memref_multi_threaded",
    "test_execution_engine__testBasicCallback_multi_threaded",
    "test_execution_engine__testComplexMemrefAdd_multi_threaded",
    "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
    "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
    "test_execution_engine__testF16MemrefAdd_multi_threaded",
    "test_execution_engine__testF8E5M2Memref_multi_threaded",
    "test_execution_engine__testInvokeFloatAdd_multi_threaded",
    "test_execution_engine__testInvokeVoid_multi_threaded",  # a ctypes race
    "test_execution_engine__testMemrefAdd_multi_threaded",
    "test_execution_engine__testRankedMemRefCallback_multi_threaded",
    "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
    "test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
    "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
    # dialects tests
    "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded",  # Related to ctypes data races
    "test_dialects_transform_interpreter__print_other_multi_threaded",  # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
    "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded",  # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
    "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
    "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
    "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
    # integration tests
    "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded",  # Related to ctypes data races
    "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded",  # Related to ctypes data races
    "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded",  # ctypes
    "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded",  # ctypes
]


def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
    def decorator(test_cls):
        this_folder = Path(__file__).parent.absolute()
        test_cls.output_folder = tempfile.TemporaryDirectory()
        output_folder = Path(test_cls.output_folder.name)

        for test_mod_info in test_modules:
            assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
            if len(test_mod_info) == 2:
                test_module_name, exec_fn = test_mod_info
                test_pattern = None
            else:
                test_module_name, exec_fn, test_pattern = test_mod_info

            src_filepath = this_folder / f"{test_module_name}.py"
            dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
            if not dst_filepath.parent.exists():
                dst_filepath.parent.mkdir(parents=True)
            copy_and_update(src_filepath, dst_filepath)
            test_mod = import_from_path(test_module_name, dst_filepath)
            for attr_name in dir(test_mod):
                is_test_fn = test_pattern is None and attr_name.startswith("test")
                is_test_fn |= test_pattern is not None and any(
                    [p in attr_name for p in test_pattern]
                )
                if is_test_fn:
                    obj = getattr(test_mod, attr_name)
                    if callable(obj):
                        test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"

                        def wrapped_test_fn(
                            self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
                        ):
                            __exec_fn__(__test_fn__)

                        setattr(test_cls, test_name, wrapped_test_fn)
        return test_cls

    return decorator


@contextmanager
def _capture_output(fp):
    # Inspired from jax test_utils.py capture_stderr method
    # ``None`` means nothing has not been captured yet.
    captured = None

    def get_output() -> str:
        if captured is None:
            raise ValueError("get_output() called while the context is active.")
        return captured

    with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
        original_fd = os.dup(fp.fileno())
        os.dup2(f.fileno(), fp.fileno())
        try:
            yield get_output
        finally:
            # Python also has its own buffers, make sure everything is flushed.
            fp.flush()
            os.fsync(fp.fileno())
            f.seek(0)
            captured = f.read()
            os.dup2(original_fd, fp.fileno())


capture_stdout = partial(_capture_output, sys.stdout)
capture_stderr = partial(_capture_output, sys.stderr)


def multi_threaded(
    num_workers: int,
    num_runs: int = 5,
    skip_tests: Optional[List[str]] = None,
    xfail_tests: Optional[List[str]] = None,
    test_prefix: str = "_original_test",
    multithreaded_test_postfix: str = "_multi_threaded",
):
    """Decorator that runs a test in a multi-threaded environment."""

    def decorator(test_cls):
        for name, test_fn in test_cls.__dict__.copy().items():
            if not (name.startswith(test_prefix) and callable(test_fn)):
                continue

            name = f"test{name[len(test_prefix):]}"
            if skip_tests is not None:
                if any(
                    test_name.replace(multithreaded_test_postfix, "") in name
                    for test_name in skip_tests
                ):
                    continue

            def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
                with capture_stdout(), capture_stderr() as get_output:
                    barrier = threading.Barrier(num_workers)

                    def closure():
                        barrier.wait()
                        for _ in range(num_runs):
                            __test_fn__(self, *args, **kwargs)

                    with concurrent.futures.ThreadPoolExecutor(
                        max_workers=num_workers
                    ) as executor:
                        futures = []
                        for _ in range(num_workers):
                            futures.append(executor.submit(closure))
                        # We should call future.result() to re-raise an exception if test has
                        # failed
                        assert len(list(f.result() for f in futures)) == num_workers

                    gc.collect()
                    assert Context._get_live_count() == 0

                captured = get_output()
                if len(captured) > 0 and "ThreadSanitizer" in captured:
                    raise RuntimeError(
                        f"ThreadSanitizer reported warnings:\n{captured}"
                    )

            test_new_name = f"{name}{multithreaded_test_postfix}"
            if xfail_tests is not None and test_new_name in xfail_tests:
                multi_threaded_test_fn = unittest.expectedFailure(
                    multi_threaded_test_fn
                )

            setattr(test_cls, test_new_name, multi_threaded_test_fn)

        return test_cls

    return decorator


@multi_threaded(
    num_workers=10,
    num_runs=20,
    skip_tests=TESTS_TO_SKIP,
    xfail_tests=TESTS_TO_XFAIL,
)
@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
class TestAllMultiThreaded(unittest.TestCase):
    @classmethod
    def tearDownClass(cls):
        if hasattr(cls, "output_folder"):
            cls.output_folder.cleanup()

    def _original_test_create_context(self):
        with Context() as ctx:
            print(ctx._get_live_count())
            print(ctx._get_live_module_count())
            print(ctx._get_live_operation_count())
            print(ctx._get_live_operation_objects())
            print(ctx._get_context_again() is ctx)
            print(ctx._clear_live_operations())

    def _original_test_create_module_with_consts(self):
        py_values = [123, 234, 345]
        with Context() as ctx:
            module = Module.create(loc=Location.file("foo.txt", 0, 0))

            dtype = IntegerType.get_signless(64)
            with InsertionPoint(module.body), Location.name("a"):
                arith.constant(dtype, py_values[0])

            with InsertionPoint(module.body), Location.name("b"):
                arith.constant(dtype, py_values[1])

            with InsertionPoint(module.body), Location.name("c"):
                arith.constant(dtype, py_values[2])


if __name__ == "__main__":
    # Do not run the tests on CPython with GIL
    if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
        unittest.main()