aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <nico.vasilache@amd.com>2025-07-04 10:32:39 +0200
committerNicolas Vasilache <nico.vasilache@amd.com>2025-07-04 10:51:43 +0200
commit2b8f82b2bad6b2ada988fb2b874d676aa748a35b (patch)
treecfb669d83bbe5ad73c4378a2a272c254c485bfb7
parent34f124b06ffd3a4e5befafe3cf5daf7753f415ff (diff)
downloadllvm-users/nico/python-1.zip
llvm-users/nico/python-1.tar.gz
llvm-users/nico/python-1.tar.bz2
[mlir][python] Add utils for more pythonic context creation and registration managementusers/nico/python-1
Co-authored-by: Fabian Mora <fmora.dev@gmail.com Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com> Co-authored-by: Tres <tpopp@users.noreply.github.com>
-rw-r--r--mlir/include/mlir-c/IR.h4
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp6
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp4
-rw-r--r--mlir/python/CMakeLists.txt7
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/ir.pyi1
-rw-r--r--mlir/python/mlir/utils.py211
-rw-r--r--mlir/test/python/utils.py58
7 files changed, 291 insertions, 0 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 81299c7..877aa73 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -143,6 +143,10 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context,
bool enable);
+/// Retrieve threading mode current value as controlled by
+/// mlirContextEnableMultithreading.
+MLIR_CAPI_EXPORTED bool mlirContextIsMultithreadingEnabled(MlirContext context);
+
/// Eagerly loads all available dialects registered with a context, making
/// them available for use for IR construction.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d961482..002923b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2939,6 +2939,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
ss << pool.ptr;
return ss.str();
})
+ .def_prop_ro(
+ "is_multithreading_enabled",
+ [](PyMlirContext &self) {
+ return mlirContextIsMultithreadingEnabled(self.get());
+ },
+ "Returns true if multithreading is enabled for this context.")
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fbc66bc..1cc555a 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -101,6 +101,10 @@ bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
return unwrap(context)->isOperationRegistered(unwrap(name));
}
+bool mlirContextIsMultithreadingEnabled(MlirContext context) {
+ return unwrap(context)->isMultithreadingEnabled();
+}
+
void mlirContextEnableMultithreading(MlirContext context, bool enable) {
return unwrap(context)->enableMultithreading(enable);
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb..b4e0ab2 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -48,6 +48,13 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
runtime/*.py
)
+declare_mlir_python_sources(MLIRPythonSources.Utils
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ ADD_TO_PARENT MLIRPythonSources
+ SOURCES
+ utils.py
+)
+
declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
ROOT_DIR "${MLIR_SOURCE_DIR}/include"
SOURCES_GLOB "mlir-c/*.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 70bca3c..56b9f17 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -986,6 +986,7 @@ class ComplexType(Type):
class Context:
current: ClassVar[Context] = ... # read-only
allow_unregistered_dialects: bool
+ is_multithreading_enabled: bool
@staticmethod
def _get_live_count() -> int: ...
def _CAPICreate(self) -> object: ...
diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py
new file mode 100644
index 0000000..c6e9b57
--- /dev/null
+++ b/mlir/python/mlir/utils.py
@@ -0,0 +1,211 @@
+from contextlib import contextmanager, nullcontext
+from functools import wraps
+from typing import (
+ Any,
+ Callable,
+ Concatenate,
+ Iterator,
+ Optional,
+ ParamSpec,
+ Sequence,
+ TypeVar,
+)
+
+from mlir import ir
+from mlir._mlir_libs import get_dialect_registry
+from mlir.dialects import func
+from mlir.dialects.transform import interpreter
+from mlir.passmanager import PassManager
+
+RT = TypeVar("RT")
+Param = ParamSpec("Param")
+
+
+@contextmanager
+def using_mlir_context(
+ *,
+ required_dialects: Optional[Sequence[str]] = None,
+ required_extension_operations: Optional[Sequence[str]] = None,
+ registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+ """Ensure a valid context exists by creating one if necessary.
+
+ NOTE: If values that are attached to a Context should outlive this
+ contextmanager, use caller_mlir_context!
+
+ This can be used as a function decorator or managed context in a with statement.
+ The context will throw an error if the required dialects have not been registered,
+ and a context is guaranteed to exist in this scope.
+
+ This only works on dialects and not dialect extensions currently.
+
+ Parameters
+ ------------
+ required_dialects:
+ Dialects that need to be registered in the context
+ required_extension_operations:
+ Required operations by their fully specified name. These are a proxy for detecting needed dialect extensions.
+ registration_funcs:
+ Functions that should be called to register all missing dialects/operations if they have not been registered.
+ """
+ dialects = required_dialects or []
+ extension_operations = required_extension_operations or []
+ registrations = registration_funcs or []
+ new_context = nullcontext if ir.Context.current else ir.Context
+ with new_context(), ir.Location.unknown():
+ context = ir.Context.current
+ # Attempt to disable multithreading. This could fail if currently being
+ # used in multiple threads. This must be done before checking for
+ # dialects or registering dialects as both will assert fail in a
+ # multithreaded situation.
+ multithreading = context.is_multithreading_enabled
+ if multithreading:
+ context.enable_multithreading(False)
+
+ def attempt_registration():
+ """Register everything from registration_funcs."""
+ nonlocal context, registrations
+
+ # Gather dialects and extensions then add them to the context.
+ registry = ir.DialectRegistry()
+ for rf in registrations:
+ rf(registry)
+
+ context.append_dialect_registry(registry)
+
+ # See if any dialects are missing, register if they are, and then assert they are all registered.
+ try:
+ for dialect in dialects:
+ # If the dialect is registered, continue checking
+ context.get_dialect_descriptor(dialect)
+ except Exception:
+ attempt_registration()
+
+ for dialect in dialects:
+ # If the dialect is registered, continue checking
+ assert context.get_dialect_descriptor(
+ dialect
+ ), f"required dialect {dialect} not registered by registration_funcs"
+
+ # See if any operations are missing and register if they are. We cannot
+ # assert the operations exist in the registry after for some reason.
+ #
+ # TODO: Make this work for dialect extensions specifically
+ for operation in extension_operations:
+ # If the operation is registered, attempt to register and then strongly assert it was added
+ if not context.is_registered_operation(operation):
+ attempt_registration()
+ break
+ for operation in extension_operations:
+ # First get the dialect descriptior which loads the dialect as a side effect
+ dialect = operation.split(".")[0]
+ assert context.get_dialect_descriptor(dialect), f"Never loaded {dialect}"
+ assert context.is_registered_operation(
+ operation
+ ), f"expected {operation} to be registered in its dialect"
+ context.enable_multithreading(multithreading)
+
+ # Context manager related yield
+ try:
+ yield
+ finally:
+ pass
+
+
+@contextmanager
+def caller_mlir_context(
+ *,
+ required_dialects: Optional[Sequence[str]] = None,
+ required_extension_operations: Optional[Sequence[str]] = None,
+ registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+ """Requires an enclosing context from the caller and ensures relevant operations are loaded.
+
+ NOTE: If the Context is only needed inside of this contextmanager and returned values
+ don't need to the Context, use using_mlir_context!
+
+ A context must already exist before this frame is executed to ensure that any values
+ continue to live on exit. Conceptually, this prevents use-after-free issues and
+ makes the intention clear when one intends to return values tied to a Context.
+ """
+ assert (
+ ir.Context.current
+ ), "Caller must have a context so it outlives this function call."
+ with using_mlir_context(
+ required_dialects=required_dialects,
+ required_extension_operations=required_extension_operations,
+ registration_funcs=registration_funcs,
+ ):
+ # Context manager related yield
+ try:
+ yield
+ finally:
+ pass
+
+
+def with_toplevel_context(f: Callable[Param, RT]) -> Callable[Param, RT]:
+ """Decorate the function to be executed with a fresh MLIR context.
+
+ This decorator will ensure the function is executed inside a context manager for a
+ new MLIR context with upstream and IREE dialects registered. Note that each call to
+ such a function has a new context, meaning that context-owned objects from these
+ functions will not be equal to each other. All arguments and keyword arguments are
+ forwarded.
+
+ The context is destroyed before the function exits so any result from the function
+ must not depend on the context.
+ """
+
+ @wraps(f)
+ def decorator(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+ # Appending dialect registry and loading all available dialects occur on
+ # context creation because of the "_site_initialize" call.
+ with ir.Context(), ir.Location.unknown():
+ results = f(*args, **kwargs)
+ return results
+
+ return decorator
+
+
+def with_toplevel_context_create_module(
+ f: Callable[Concatenate[ir.Module, Param], RT],
+) -> Callable[Param, RT]:
+ """Decorate function to be executed in a fresh MLIR context and give it a module.
+
+ The decorated function will receive, as its leading argument, a fresh MLIR module.
+ The context manager is set up to insert operations into this module. All other
+ arguments and keyword arguments are forwarded.
+
+ The module and context are destroyed before the function exists so any result from
+ the function must not depend on either.
+ """
+
+ @with_toplevel_context
+ @wraps(f)
+ def internal(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ results = f(module, *args, **kwargs)
+ return results
+
+ return internal
+
+
+def call_with_toplevel_context(f: Callable[[], RT]) -> Callable[[], RT]:
+ """Immediately call the function in a fresh MLIR context."""
+ decorated = with_toplevel_context(f)
+ decorated()
+ return decorated
+
+
+def call_with_toplevel_context_create_module(
+ f: Callable[[ir.Module], RT],
+) -> Callable[[], RT]:
+ """Immediately call the function in a fresh MLIR context and give it a module.
+
+ The decorated function will receive, as its only argument, a fresh MLIR module. The
+ context manager is set up to insert operations into this module.
+ """
+ decorated = with_toplevel_context_create_module(f)
+ decorated()
+ return decorated
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
new file mode 100644
index 0000000..8435fdd
--- /dev/null
+++ b/mlir/test/python/utils.py
@@ -0,0 +1,58 @@
+# RUN: %python %s | FileCheck %s
+
+import unittest
+
+from mlir import ir
+from mlir.dialects import arith, builtin
+from mlir.extras import types as T
+from mlir.utils import (
+ call_with_toplevel_context_create_module,
+ caller_mlir_context,
+ using_mlir_context,
+)
+
+
+class TestRequiredContext(unittest.TestCase):
+ def test_shared_context(self):
+ """Test that the context is reused, so values can be passed/returned between functions."""
+
+ @using_mlir_context()
+ def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return arith.AddFOp(
+ lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+ ).result
+
+ @using_mlir_context()
+ def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return create_add(create_add(lhs, rhs), create_add(lhs, rhs))
+
+ @call_with_toplevel_context_create_module
+ def _(module) -> None:
+ c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+ multiple_adds(c, c)
+
+ # CHECK: constant
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ print(module)
+
+ def test_unregistered_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), using_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+ def test_required_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), caller_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()