diff options
author | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-07-04 10:32:39 +0200 |
---|---|---|
committer | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-07-04 10:51:43 +0200 |
commit | 2b8f82b2bad6b2ada988fb2b874d676aa748a35b (patch) | |
tree | cfb669d83bbe5ad73c4378a2a272c254c485bfb7 | |
parent | 34f124b06ffd3a4e5befafe3cf5daf7753f415ff (diff) | |
download | llvm-users/nico/python-1.zip llvm-users/nico/python-1.tar.gz llvm-users/nico/python-1.tar.bz2 |
[mlir][python] Add utils for more pythonic context creation and registration managementusers/nico/python-1
Co-authored-by: Fabian Mora <fmora.dev@gmail.com
Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
-rw-r--r-- | mlir/include/mlir-c/IR.h | 4 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/IRCore.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/CAPI/IR/IR.cpp | 4 | ||||
-rw-r--r-- | mlir/python/CMakeLists.txt | 7 | ||||
-rw-r--r-- | mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 1 | ||||
-rw-r--r-- | mlir/python/mlir/utils.py | 211 | ||||
-rw-r--r-- | mlir/test/python/utils.py | 58 |
7 files changed, 291 insertions, 0 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 81299c7..877aa73 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -143,6 +143,10 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable); +/// Retrieve threading mode current value as controlled by +/// mlirContextEnableMultithreading. +MLIR_CAPI_EXPORTED bool mlirContextIsMultithreadingEnabled(MlirContext context); + /// Eagerly loads all available dialects registered with a context, making /// them available for use for IR construction. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d961482..002923b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2939,6 +2939,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { ss << pool.ptr; return ss.str(); }) + .def_prop_ro( + "is_multithreading_enabled", + [](PyMlirContext &self) { + return mlirContextIsMultithreadingEnabled(self.get()); + }, + "Returns true if multithreading is enabled for this context.") .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index fbc66bc..1cc555a 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -101,6 +101,10 @@ bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { return unwrap(context)->isOperationRegistered(unwrap(name)); } +bool mlirContextIsMultithreadingEnabled(MlirContext context) { + return unwrap(context)->isMultithreadingEnabled(); +} + void mlirContextEnableMultithreading(MlirContext context, bool enable) { return unwrap(context)->enableMultithreading(enable); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b2daabb..b4e0ab2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -48,6 +48,13 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine runtime/*.py ) +declare_mlir_python_sources(MLIRPythonSources.Utils + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources + SOURCES + utils.py +) + declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources ROOT_DIR "${MLIR_SOURCE_DIR}/include" SOURCES_GLOB "mlir-c/*.h" diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 70bca3c..56b9f17 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -986,6 +986,7 @@ class ComplexType(Type): class Context: current: ClassVar[Context] = ... # read-only allow_unregistered_dialects: bool + is_multithreading_enabled: bool @staticmethod def _get_live_count() -> int: ... def _CAPICreate(self) -> object: ... diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py new file mode 100644 index 0000000..c6e9b57 --- /dev/null +++ b/mlir/python/mlir/utils.py @@ -0,0 +1,211 @@ +from contextlib import contextmanager, nullcontext +from functools import wraps +from typing import ( + Any, + Callable, + Concatenate, + Iterator, + Optional, + ParamSpec, + Sequence, + TypeVar, +) + +from mlir import ir +from mlir._mlir_libs import get_dialect_registry +from mlir.dialects import func +from mlir.dialects.transform import interpreter +from mlir.passmanager import PassManager + +RT = TypeVar("RT") +Param = ParamSpec("Param") + + +@contextmanager +def using_mlir_context( + *, + required_dialects: Optional[Sequence[str]] = None, + required_extension_operations: Optional[Sequence[str]] = None, + registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None, +) -> Iterator[None]: + """Ensure a valid context exists by creating one if necessary. + + NOTE: If values that are attached to a Context should outlive this + contextmanager, use caller_mlir_context! + + This can be used as a function decorator or managed context in a with statement. + The context will throw an error if the required dialects have not been registered, + and a context is guaranteed to exist in this scope. + + This only works on dialects and not dialect extensions currently. + + Parameters + ------------ + required_dialects: + Dialects that need to be registered in the context + required_extension_operations: + Required operations by their fully specified name. These are a proxy for detecting needed dialect extensions. + registration_funcs: + Functions that should be called to register all missing dialects/operations if they have not been registered. + """ + dialects = required_dialects or [] + extension_operations = required_extension_operations or [] + registrations = registration_funcs or [] + new_context = nullcontext if ir.Context.current else ir.Context + with new_context(), ir.Location.unknown(): + context = ir.Context.current + # Attempt to disable multithreading. This could fail if currently being + # used in multiple threads. This must be done before checking for + # dialects or registering dialects as both will assert fail in a + # multithreaded situation. + multithreading = context.is_multithreading_enabled + if multithreading: + context.enable_multithreading(False) + + def attempt_registration(): + """Register everything from registration_funcs.""" + nonlocal context, registrations + + # Gather dialects and extensions then add them to the context. + registry = ir.DialectRegistry() + for rf in registrations: + rf(registry) + + context.append_dialect_registry(registry) + + # See if any dialects are missing, register if they are, and then assert they are all registered. + try: + for dialect in dialects: + # If the dialect is registered, continue checking + context.get_dialect_descriptor(dialect) + except Exception: + attempt_registration() + + for dialect in dialects: + # If the dialect is registered, continue checking + assert context.get_dialect_descriptor( + dialect + ), f"required dialect {dialect} not registered by registration_funcs" + + # See if any operations are missing and register if they are. We cannot + # assert the operations exist in the registry after for some reason. + # + # TODO: Make this work for dialect extensions specifically + for operation in extension_operations: + # If the operation is registered, attempt to register and then strongly assert it was added + if not context.is_registered_operation(operation): + attempt_registration() + break + for operation in extension_operations: + # First get the dialect descriptior which loads the dialect as a side effect + dialect = operation.split(".")[0] + assert context.get_dialect_descriptor(dialect), f"Never loaded {dialect}" + assert context.is_registered_operation( + operation + ), f"expected {operation} to be registered in its dialect" + context.enable_multithreading(multithreading) + + # Context manager related yield + try: + yield + finally: + pass + + +@contextmanager +def caller_mlir_context( + *, + required_dialects: Optional[Sequence[str]] = None, + required_extension_operations: Optional[Sequence[str]] = None, + registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None, +) -> Iterator[None]: + """Requires an enclosing context from the caller and ensures relevant operations are loaded. + + NOTE: If the Context is only needed inside of this contextmanager and returned values + don't need to the Context, use using_mlir_context! + + A context must already exist before this frame is executed to ensure that any values + continue to live on exit. Conceptually, this prevents use-after-free issues and + makes the intention clear when one intends to return values tied to a Context. + """ + assert ( + ir.Context.current + ), "Caller must have a context so it outlives this function call." + with using_mlir_context( + required_dialects=required_dialects, + required_extension_operations=required_extension_operations, + registration_funcs=registration_funcs, + ): + # Context manager related yield + try: + yield + finally: + pass + + +def with_toplevel_context(f: Callable[Param, RT]) -> Callable[Param, RT]: + """Decorate the function to be executed with a fresh MLIR context. + + This decorator will ensure the function is executed inside a context manager for a + new MLIR context with upstream and IREE dialects registered. Note that each call to + such a function has a new context, meaning that context-owned objects from these + functions will not be equal to each other. All arguments and keyword arguments are + forwarded. + + The context is destroyed before the function exits so any result from the function + must not depend on the context. + """ + + @wraps(f) + def decorator(*args: Param.args, **kwargs: Param.kwargs) -> RT: + # Appending dialect registry and loading all available dialects occur on + # context creation because of the "_site_initialize" call. + with ir.Context(), ir.Location.unknown(): + results = f(*args, **kwargs) + return results + + return decorator + + +def with_toplevel_context_create_module( + f: Callable[Concatenate[ir.Module, Param], RT], +) -> Callable[Param, RT]: + """Decorate function to be executed in a fresh MLIR context and give it a module. + + The decorated function will receive, as its leading argument, a fresh MLIR module. + The context manager is set up to insert operations into this module. All other + arguments and keyword arguments are forwarded. + + The module and context are destroyed before the function exists so any result from + the function must not depend on either. + """ + + @with_toplevel_context + @wraps(f) + def internal(*args: Param.args, **kwargs: Param.kwargs) -> RT: + module = ir.Module.create() + with ir.InsertionPoint(module.body): + results = f(module, *args, **kwargs) + return results + + return internal + + +def call_with_toplevel_context(f: Callable[[], RT]) -> Callable[[], RT]: + """Immediately call the function in a fresh MLIR context.""" + decorated = with_toplevel_context(f) + decorated() + return decorated + + +def call_with_toplevel_context_create_module( + f: Callable[[ir.Module], RT], +) -> Callable[[], RT]: + """Immediately call the function in a fresh MLIR context and give it a module. + + The decorated function will receive, as its only argument, a fresh MLIR module. The + context manager is set up to insert operations into this module. + """ + decorated = with_toplevel_context_create_module(f) + decorated() + return decorated diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py new file mode 100644 index 0000000..8435fdd --- /dev/null +++ b/mlir/test/python/utils.py @@ -0,0 +1,58 @@ +# RUN: %python %s | FileCheck %s + +import unittest + +from mlir import ir +from mlir.dialects import arith, builtin +from mlir.extras import types as T +from mlir.utils import ( + call_with_toplevel_context_create_module, + caller_mlir_context, + using_mlir_context, +) + + +class TestRequiredContext(unittest.TestCase): + def test_shared_context(self): + """Test that the context is reused, so values can be passed/returned between functions.""" + + @using_mlir_context() + def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value: + return arith.AddFOp( + lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf + ).result + + @using_mlir_context() + def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value: + return create_add(create_add(lhs, rhs), create_add(lhs, rhs)) + + @call_with_toplevel_context_create_module + def _(module) -> None: + c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result + multiple_adds(c, c) + + # CHECK: constant + # CHECK-NEXT: arith.addf + # CHECK-NEXT: arith.addf + # CHECK-NEXT: arith.addf + print(module) + + def test_unregistered_op_asserts(self): + """Confirm that with_mlir_context fails if an operation is still not registered.""" + with self.assertRaises(AssertionError), using_mlir_context( + required_extension_operations=["func.fake_extension_op"], + registration_funcs=[], + ): + pass + + def test_required_op_asserts(self): + """Confirm that with_mlir_context fails if an operation is still not registered.""" + with self.assertRaises(AssertionError), caller_mlir_context( + required_extension_operations=["func.fake_extension_op"], + registration_funcs=[], + ): + pass + + +if __name__ == "__main__": + unittest.main() |