aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/Rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python/Rewrite.cpp')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp120
1 files changed, 116 insertions, 4 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d970..47685567 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,16 @@ public:
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +175,115 @@ private:
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() {
+ if (set.ptr)
+ mlirRewritePatternSetDestroy(set);
+ }
+
+ void add(MlirStringRef rootName, unsigned benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() {
+ MlirRewritePatternSet s = set;
+ set.ptr = nullptr;
+ return mlirFreezeRewritePattern(s);
+ }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
+ nb::
+ class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ )
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+ // clang-format on
+ )
+ .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ nb::arg("op"),
+ // clang-format off
+ nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ );
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1,
+ "Add a new rewrite pattern on the given root operation with the "
+ "callable as the matching and rewriting function and the given "
+ "benefit.")
+ .def("freeze", &PyRewritePatternSet::freeze,
+ "Freeze the pattern set into a frozen one.");
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())