diff options
Diffstat (limited to 'mlir/lib/Bindings/Python')
-rw-r--r-- | mlir/lib/Bindings/Python/IRCore.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/IRModule.h | 2 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/Rewrite.cpp | 34 |
3 files changed, 36 insertions, 3 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 32b2b0c..7b17106 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) : refOperation(beforeOperationBase.getOperation().getRef()), block((*refOperation)->getBlock()) {} +PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef) + : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {} + void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index edbd73e..e706be3b 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -841,6 +841,8 @@ public: PyInsertionPoint(const PyBlock &block); /// Creates an insertion point positioned before a reference operation. PyInsertionPoint(PyOperationBase &beforeOperationBase); + /// Creates an insertion point positioned before a reference operation. + PyInsertionPoint(PyOperationRef beforeOperationRef); /// Shortcut to create an insertion point at the beginning of the block. static PyInsertionPoint atBlockBegin(PyBlock &block); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 836f44fd..9e3d970 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -26,6 +26,30 @@ using namespace mlir::python; namespace { +class PyPatternRewriter { +public: + PyPatternRewriter(MlirPatternRewriter rewriter) + : base(mlirPatternRewriterAsBase(rewriter)), + ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {} + + PyInsertionPoint getInsertionPoint() const { + MlirBlock block = mlirRewriterBaseGetInsertionBlock(base); + MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base); + + if (mlirOperationIsNull(op)) { + MlirOperation owner = mlirBlockGetParentOperation(block); + auto parent = PyOperation::forOperation(ctx, owner); + return PyInsertionPoint(PyBlock(parent, block)); + } + + return PyInsertionPoint(PyOperation::forOperation(ctx, op)); + } + +private: + MlirRewriterBase base; + PyMlirContextRef ctx; +}; + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH static nb::object objectFromPDLValue(MlirPDLValue value) { if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) @@ -84,7 +108,8 @@ public: void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast<PyObject *>(userData)); return logicalResultFromObject( - f(rewriter, results, objectsFromPDLValues(nValues, values))); + f(PyPatternRewriter(rewriter), results, + objectsFromPDLValues(nValues, values))); }, fn.ptr()); } @@ -98,7 +123,8 @@ public: void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast<PyObject *>(userData)); return logicalResultFromObject( - f(rewriter, results, objectsFromPDLValues(nValues, values))); + f(PyPatternRewriter(rewriter), results, + objectsFromPDLValues(nValues, values))); }, fn.ptr()); } @@ -143,7 +169,9 @@ private: /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_<MlirPatternRewriter>(m, "PatternRewriter"); + nb::class_<PyPatternRewriter>(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter."); //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- |