aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python')
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp3
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h2
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp34
3 files changed, 36 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 32b2b0c..7b17106 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
block((*refOperation)->getBlock()) {}
+PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
+ : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
+
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index edbd73e..e706be3b 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -841,6 +841,8 @@ public:
PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
+ /// Creates an insertion point positioned before a reference operation.
+ PyInsertionPoint(PyOperationRef beforeOperationRef);
/// Shortcut to create an insertion point at the beginning of the block.
static PyInsertionPoint atBlockBegin(PyBlock &block);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd..9e3d970 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -26,6 +26,30 @@ using namespace mlir::python;
namespace {
+class PyPatternRewriter {
+public:
+ PyPatternRewriter(MlirPatternRewriter rewriter)
+ : base(mlirPatternRewriterAsBase(rewriter)),
+ ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+ PyInsertionPoint getInsertionPoint() const {
+ MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+ if (mlirOperationIsNull(op)) {
+ MlirOperation owner = mlirBlockGetParentOperation(block);
+ auto parent = PyOperation::forOperation(ctx, owner);
+ return PyInsertionPoint(PyBlock(parent, block));
+ }
+
+ return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+ }
+
+private:
+ MlirRewriterBase base;
+ PyMlirContextRef ctx;
+};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +108,8 @@ public:
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -98,7 +123,8 @@ public:
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ f(PyPatternRewriter(rewriter), results,
+ objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -143,7 +169,9 @@ private:
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+ nb::class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------