aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
authorNicolas Vasilache <nico.vasilache@amd.com>2025-07-04 10:32:39 +0200
committerNicolas Vasilache <nico.vasilache@amd.com>2025-07-04 10:51:43 +0200
commit2b8f82b2bad6b2ada988fb2b874d676aa748a35b (patch)
treecfb669d83bbe5ad73c4378a2a272c254c485bfb7 /mlir/test/python
parent34f124b06ffd3a4e5befafe3cf5daf7753f415ff (diff)
downloadllvm-users/nico/python-1.zip
llvm-users/nico/python-1.tar.gz
llvm-users/nico/python-1.tar.bz2
[mlir][python] Add utils for more pythonic context creation and registration managementusers/nico/python-1
Co-authored-by: Fabian Mora <fmora.dev@gmail.com Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com> Co-authored-by: Tres <tpopp@users.noreply.github.com>
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/utils.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
new file mode 100644
index 0000000..8435fdd
--- /dev/null
+++ b/mlir/test/python/utils.py
@@ -0,0 +1,58 @@
+# RUN: %python %s | FileCheck %s
+
+import unittest
+
+from mlir import ir
+from mlir.dialects import arith, builtin
+from mlir.extras import types as T
+from mlir.utils import (
+ call_with_toplevel_context_create_module,
+ caller_mlir_context,
+ using_mlir_context,
+)
+
+
+class TestRequiredContext(unittest.TestCase):
+ def test_shared_context(self):
+ """Test that the context is reused, so values can be passed/returned between functions."""
+
+ @using_mlir_context()
+ def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return arith.AddFOp(
+ lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+ ).result
+
+ @using_mlir_context()
+ def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return create_add(create_add(lhs, rhs), create_add(lhs, rhs))
+
+ @call_with_toplevel_context_create_module
+ def _(module) -> None:
+ c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+ multiple_adds(c, c)
+
+ # CHECK: constant
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ print(module)
+
+ def test_unregistered_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), using_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+ def test_required_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), caller_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()