aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/Pass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python/Pass.cpp')
-rw-r--r--mlir/lib/Bindings/Python/Pass.cpp20
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e489585..572afa9 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,6 +8,7 @@
#include "Pass.h"
+#include "Globals.h"
#include "IRModule.h"
#include "mlir-c/Pass.h"
// clang-format off
@@ -57,6 +58,13 @@ private:
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
+ // Mapping of enumerated types
+ //----------------------------------------------------------------------------
+ nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
+ .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
+ .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
+
+ //----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
nb::class_<MlirExternalPass>(m, "ExternalPass")
@@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
mlirPassManagerEnableTiming(passManager.get());
},
"Enable pass timing.")
+ .def(
+ "enable_statistics",
+ [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
+ mlirPassManagerEnableStatistics(passManager.get(), displayMode);
+ },
+ "displayMode"_a =
+ MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
+ "Enable pass statistics.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
name = nb::cast<std::string>(
nb::borrow<nb::str>(run.attr("__name__")));
}
- MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
- MlirTypeID passID =
- mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirTypeID passID = PyGlobals::get().allocateTypeID();
MlirExternalPassCallbacks callbacks;
callbacks.construct = [](void *obj) {
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();