diff options
Diffstat (limited to 'mlir/lib/Bindings/Python')
-rw-r--r-- | mlir/lib/Bindings/Python/Globals.h | 25 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/Pass.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/Rewrite.cpp | 121 |
3 files changed, 159 insertions, 7 deletions
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 71a051c..1e81f53 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -17,6 +17,7 @@ #include "NanobindUtils.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -151,6 +152,29 @@ public: TracebackLoc &getTracebackLoc() { return tracebackLoc; } + class TypeIDAllocator { + public: + TypeIDAllocator() : allocator(mlirTypeIDAllocatorCreate()) {} + ~TypeIDAllocator() { + if (allocator.ptr) + mlirTypeIDAllocatorDestroy(allocator); + } + TypeIDAllocator(const TypeIDAllocator &) = delete; + TypeIDAllocator(TypeIDAllocator &&other) : allocator(other.allocator) { + other.allocator.ptr = nullptr; + } + + MlirTypeIDAllocator get() { return allocator; } + MlirTypeID allocate() { + return mlirTypeIDAllocatorAllocateTypeID(allocator); + } + + private: + MlirTypeIDAllocator allocator; + }; + + MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); } + private: static PyGlobals *instance; @@ -173,6 +197,7 @@ private: llvm::StringSet<> loadedDialectModules; TracebackLoc tracebackLoc; + TypeIDAllocator typeIDAllocator; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e489585..572afa9 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,6 +8,7 @@ #include "Pass.h" +#include "Globals.h" #include "IRModule.h" #include "mlir-c/Pass.h" // clang-format off @@ -57,6 +58,13 @@ private: /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + + //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- nb::class_<MlirExternalPass>(m, "ExternalPass") @@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { name = nb::cast<std::string>( nb::borrow<nb::str>(run.attr("__name__"))); } - MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); - MlirTypeID passID = - mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + MlirTypeID passID = PyGlobals::get().allocateTypeID(); MlirExternalPassCallbacks callbacks; callbacks.construct = [](void *obj) { (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d970..d506b7f 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,16 @@ public: return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +175,116 @@ private: MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } + + void add(MlirStringRef rootName, unsigned benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast<PyObject *>(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter)); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_<PyPatternRewriter>(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- + nb:: + class_<PyPatternRewriter>(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector<MlirValue> &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_<MlirRewritePattern>(m, "RewritePattern"); + nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast<std::string>(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +350,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) |