diff options
Diffstat (limited to 'mlir/lib/Bindings/Python')
-rw-r--r-- | mlir/lib/Bindings/Python/Globals.h | 25 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/Pass.cpp | 20 |
2 files changed, 42 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 71a051c..1e81f53 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -17,6 +17,7 @@ #include "NanobindUtils.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -151,6 +152,29 @@ public: TracebackLoc &getTracebackLoc() { return tracebackLoc; } + class TypeIDAllocator { + public: + TypeIDAllocator() : allocator(mlirTypeIDAllocatorCreate()) {} + ~TypeIDAllocator() { + if (allocator.ptr) + mlirTypeIDAllocatorDestroy(allocator); + } + TypeIDAllocator(const TypeIDAllocator &) = delete; + TypeIDAllocator(TypeIDAllocator &&other) : allocator(other.allocator) { + other.allocator.ptr = nullptr; + } + + MlirTypeIDAllocator get() { return allocator; } + MlirTypeID allocate() { + return mlirTypeIDAllocatorAllocateTypeID(allocator); + } + + private: + MlirTypeIDAllocator allocator; + }; + + MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); } + private: static PyGlobals *instance; @@ -173,6 +197,7 @@ private: llvm::StringSet<> loadedDialectModules; TracebackLoc tracebackLoc; + TypeIDAllocator typeIDAllocator; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e489585..572afa9 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,6 +8,7 @@ #include "Pass.h" +#include "Globals.h" #include "IRModule.h" #include "mlir-c/Pass.h" // clang-format off @@ -57,6 +58,13 @@ private: /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + + //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- nb::class_<MlirExternalPass>(m, "ExternalPass") @@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { name = nb::cast<std::string>( nb::borrow<nb::str>(run.attr("__name__"))); } - MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); - MlirTypeID passID = - mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + MlirTypeID passID = PyGlobals::get().allocateTypeID(); MlirExternalPassCallbacks callbacks; callbacks.construct = [](void *obj) { (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); |