diff options
author | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-07-04 10:32:39 +0200 |
---|---|---|
committer | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-07-04 10:51:43 +0200 |
commit | 2b8f82b2bad6b2ada988fb2b874d676aa748a35b (patch) | |
tree | cfb669d83bbe5ad73c4378a2a272c254c485bfb7 /mlir/test/python | |
parent | 34f124b06ffd3a4e5befafe3cf5daf7753f415ff (diff) | |
download | llvm-users/nico/python-1.zip llvm-users/nico/python-1.tar.gz llvm-users/nico/python-1.tar.bz2 |
[mlir][python] Add utils for more pythonic context creation and registration managementusers/nico/python-1
Co-authored-by: Fabian Mora <fmora.dev@gmail.com
Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
Diffstat (limited to 'mlir/test/python')
-rw-r--r-- | mlir/test/python/utils.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py new file mode 100644 index 0000000..8435fdd --- /dev/null +++ b/mlir/test/python/utils.py @@ -0,0 +1,58 @@ +# RUN: %python %s | FileCheck %s + +import unittest + +from mlir import ir +from mlir.dialects import arith, builtin +from mlir.extras import types as T +from mlir.utils import ( + call_with_toplevel_context_create_module, + caller_mlir_context, + using_mlir_context, +) + + +class TestRequiredContext(unittest.TestCase): + def test_shared_context(self): + """Test that the context is reused, so values can be passed/returned between functions.""" + + @using_mlir_context() + def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value: + return arith.AddFOp( + lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf + ).result + + @using_mlir_context() + def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value: + return create_add(create_add(lhs, rhs), create_add(lhs, rhs)) + + @call_with_toplevel_context_create_module + def _(module) -> None: + c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result + multiple_adds(c, c) + + # CHECK: constant + # CHECK-NEXT: arith.addf + # CHECK-NEXT: arith.addf + # CHECK-NEXT: arith.addf + print(module) + + def test_unregistered_op_asserts(self): + """Confirm that with_mlir_context fails if an operation is still not registered.""" + with self.assertRaises(AssertionError), using_mlir_context( + required_extension_operations=["func.fake_extension_op"], + registration_funcs=[], + ): + pass + + def test_required_op_asserts(self): + """Confirm that with_mlir_context fails if an operation is still not registered.""" + with self.assertRaises(AssertionError), caller_mlir_context( + required_extension_operations=["func.fake_extension_op"], + registration_funcs=[], + ): + pass + + +if __name__ == "__main__": + unittest.main() |