aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/MainModule.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python/MainModule.cpp')
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp148
1 files changed, 36 insertions, 112 deletions
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a14f09f..88f58d4 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,129 +6,39 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+void populateIRCore(nb::module_ &m);
+void populateRoot(nb::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-
NB_MODULE(_mlir, m) {
- m.doc() = "MLIR Python Native Extension";
-
- nb::class_<PyGlobals>(m, "_Globals")
- .def_prop_rw("dialect_search_modules",
- &PyGlobals::getDialectSearchPrefixes,
- &PyGlobals::setDialectSearchPrefixes)
- .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
- "module_name"_a)
- .def(
- "_check_dialect_module_loaded",
- [](PyGlobals &self, const std::string &dialectNamespace) {
- return self.loadDialectModule(dialectNamespace);
- },
- "dialect_namespace"_a)
- .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- "dialect_namespace"_a, "dialect_class"_a,
- "Testing hook for directly registering a dialect")
- .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a, nb::kw_only(),
- "replace"_a = false,
- "Testing hook for directly registering an operation")
- .def("loc_tracebacks_enabled",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebacksEnabled();
- })
- .def("set_loc_tracebacks_enabled",
- [](PyGlobals &self, bool enabled) {
- self.getTracebackLoc().setLocTracebacksEnabled(enabled);
- })
- .def("loc_tracebacks_frame_limit",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebackFramesLimit();
- })
- .def("set_loc_tracebacks_frame_limit",
- [](PyGlobals &self, std::optional<int> n) {
- self.getTracebackLoc().setLocTracebackFramesLimit(
- n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
- })
- .def("register_traceback_file_inclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileInclusion(filename);
- })
- .def("register_traceback_file_exclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileExclusion(filename);
- });
-
- // Aside from making the globals accessible to python, having python manage
- // it is necessary to make sure it is destroyed (and releases its python
- // resources) properly.
- m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
-
- // Registration decorators.
- m.def(
- "register_dialect",
- [](nb::type_object pyClass) {
- std::string dialectNamespace =
- nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
- PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
- return pyClass;
- },
- "dialect_class"_a,
- "Class decorator for registering a custom Dialect wrapper");
- m.def(
- "register_operation",
- [](const nb::type_object &dialectClass, bool replace) -> nb::object {
- return nb::cpp_function(
- [dialectClass,
- replace](nb::type_object opClass) -> nb::type_object {
- std::string operationName =
- nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
- PyGlobals::get().registerOperationImpl(operationName, opClass,
- replace);
- // Dict-stuff the new opClass by name onto the dialect class.
- nb::object opClassName = opClass.attr("__name__");
- dialectClass.attr(opClassName) = opClass;
- return opClass;
- });
- },
- "dialect_class"_a, nb::kw_only(), "replace"_a = false,
- "Produce a class decorator for registering an Operation class as part of "
- "a dialect");
- m.def(
- MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function([mlirTypeID, replace](
- nb::callable typeCaster) -> nb::object {
- PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
- return typeCaster;
- });
- },
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a type caster for casting MLIR types to custom user types.");
- m.def(
- MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function(
- [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
- PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
- replace);
- return valueCaster;
- });
- },
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a value caster for casting MLIR values to custom user values.");
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ m.doc() = "MLIR Python Native Extension";
+ populateRoot(m);
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRCore(irModule);
@@ -144,4 +54,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}