aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/utils.py
blob: 8435fdd363ae31c5c94b33f785e1d457c53f89dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# RUN: %python %s | FileCheck %s

import unittest

from mlir import ir
from mlir.dialects import arith, builtin
from mlir.extras import types as T
from mlir.utils import (
    call_with_toplevel_context_create_module,
    caller_mlir_context,
    using_mlir_context,
)


class TestRequiredContext(unittest.TestCase):
    def test_shared_context(self):
        """Test that the context is reused, so values can be passed/returned between functions."""

        @using_mlir_context()
        def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
            return arith.AddFOp(
                lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
            ).result

        @using_mlir_context()
        def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
            return create_add(create_add(lhs, rhs), create_add(lhs, rhs))

        @call_with_toplevel_context_create_module
        def _(module) -> None:
            c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
            multiple_adds(c, c)

            # CHECK: constant
            # CHECK-NEXT: arith.addf
            # CHECK-NEXT: arith.addf
            # CHECK-NEXT: arith.addf
            print(module)

    def test_unregistered_op_asserts(self):
        """Confirm that with_mlir_context fails if an operation is still not registered."""
        with self.assertRaises(AssertionError), using_mlir_context(
            required_extension_operations=["func.fake_extension_op"],
            registration_funcs=[],
        ):
            pass

    def test_required_op_asserts(self):
        """Confirm that with_mlir_context fails if an operation is still not registered."""
        with self.assertRaises(AssertionError), caller_mlir_context(
            required_extension_operations=["func.fake_extension_op"],
            registration_funcs=[],
        ):
            pass


if __name__ == "__main__":
    unittest.main()