aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/Rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python/Rewrite.cpp')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp34
1 files changed, 31 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd..9e3d970 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -26,6 +26,30 @@ using namespace mlir::python;
namespace {
+class PyPatternRewriter {
+public:
+ PyPatternRewriter(MlirPatternRewriter rewriter)
+ : base(mlirPatternRewriterAsBase(rewriter)),
+ ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+ PyInsertionPoint getInsertionPoint() const {
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+ if (mlirOperationIsNull(op)) {
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+ }
+
+private:
+ MlirRewriterBase base;
+ PyMlirContextRef ctx;
+};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +108,8 @@ public:
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -98,7 +123,8 @@ public:
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -143,7 +169,9 @@ private:
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+ nb::class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------