diff options
Diffstat (limited to 'mlir/lib/Bindings/Python/Pass.cpp')
-rw-r--r-- | mlir/lib/Bindings/Python/Pass.cpp | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e489585..572afa9 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,6 +8,7 @@ #include "Pass.h" +#include "Globals.h" #include "IRModule.h" #include "mlir-c/Pass.h" // clang-format off @@ -57,6 +58,13 @@ private: /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + + //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- nb::class_<MlirExternalPass>(m, "ExternalPass") @@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { name = nb::cast<std::string>( nb::borrow<nb::str>(run.attr("__name__"))); } - MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); - MlirTypeID passID = - mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + MlirTypeID passID = PyGlobals::get().allocateTypeID(); MlirExternalPassCallbacks callbacks; callbacks.construct = [](void *obj) { (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); |