diff options
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModules.cpp')
-rw-r--r-- | mlir/lib/Bindings/Python/IRModules.cpp | 178 |
1 files changed, 152 insertions, 26 deletions
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index d7a0bd8..66e975e 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -174,13 +174,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) { // PyMlirContext //------------------------------------------------------------------------------ -PyMlirContext *PyMlirContextRef::release() { - object.release(); - return &referrent; +PyMlirContext::PyMlirContext(MlirContext context) : context(context) { + py::gil_scoped_acquire acquire; + auto &liveContexts = getLiveContexts(); + liveContexts[context.ptr] = this; } -PyMlirContext::PyMlirContext(MlirContext context) : context(context) {} - PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into @@ -190,6 +189,11 @@ PyMlirContext::~PyMlirContext() { mlirContextDestroy(context); } +PyMlirContext *PyMlirContext::createNewContextForInit() { + MlirContext context = mlirContextCreate(); + return new PyMlirContext(context); +} + PyMlirContextRef PyMlirContext::forContext(MlirContext context) { py::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); @@ -198,14 +202,13 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); py::object pyRef = py::cast(unownedContextWrapper); - unownedContextWrapper->handle = pyRef; - liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper); - return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef)); - } else { - // Use existing. - py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); - return PyMlirContextRef(*it->second.second, std::move(pyRef)); + assert(pyRef && "cast to py::object failed"); + liveContexts[context.ptr] = unownedContextWrapper; + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } + // Use existing. + py::object pyRef = py::cast(it->second); + return PyMlirContextRef(it->second, std::move(pyRef)); } PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { @@ -215,8 +218,99 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } +size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } + +//------------------------------------------------------------------------------ +// PyModule +//------------------------------------------------------------------------------ + +PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) { + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedModule, py::return_value_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); +} + +//------------------------------------------------------------------------------ +// PyOperation +//------------------------------------------------------------------------------ + +PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) + : BaseContextObject(std::move(contextRef)), operation(operation) {} + +PyOperation::~PyOperation() { + auto &liveOperations = getContext()->liveOperations; + assert(liveOperations.count(operation.ptr) == 1 && + "destroying operation not in live map"); + liveOperations.erase(operation.ptr); + if (!isAttached()) { + mlirOperationDestroy(operation); + } +} + +PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + // Create. + PyOperation *unownedOperation = + new PyOperation(std::move(contextRef), operation); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedOperation, py::return_value_policy::take_ownership); + unownedOperation->handle = pyRef; + if (parentKeepAlive) { + unownedOperation->parentKeepAlive = std::move(parentKeepAlive); + } + liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); + return PyOperationRef(unownedOperation, std::move(pyRef)); +} + +PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + auto it = liveOperations.find(operation.ptr); + if (it == liveOperations.end()) { + // Create. + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + } + // Use existing. + PyOperation *existing = it->second.second; + assert(existing->parentKeepAlive.is(parentKeepAlive)); + py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); + return PyOperationRef(existing, std::move(pyRef)); +} + +PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + assert(liveOperations.count(operation.ptr) == 0 && + "cannot create detached operation that already exists"); + (void)liveOperations; + + PyOperationRef created = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + created->attached = false; + return created; +} + +void PyOperation::checkValid() { + if (!valid) { + throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); + } +} + //------------------------------------------------------------------------------ -// PyBlock, PyRegion, and PyOperation. +// PyBlock, PyRegion. //------------------------------------------------------------------------------ void PyRegion::attachToParent() { @@ -865,29 +959,27 @@ public: void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of MlirContext py::class_<PyMlirContext>(m, "Context") - .def(py::init<>([]() { - MlirContext context = mlirContextCreate(); - auto contextRef = PyMlirContext::forContext(context); - return contextRef.release(); - })) + .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { - auto ref = PyMlirContext::forContext(self.get()); - return ref.release(); + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); }) + .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) .def( "parse_module", - [](PyMlirContext &self, const std::string module) { - auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str()); + [](PyMlirContext &self, const std::string moduleAsm) { + MlirModule module = + mlirModuleCreateParse(self.get(), moduleAsm.c_str()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. - if (mlirModuleIsNull(moduleRef)) { + if (mlirModuleIsNull(module)) { throw SetPyError( PyExc_ValueError, "Unable to parse module assembly (see diagnostics)"); } - return PyModule(self.getRef(), moduleRef); + return PyModule::create(self.getRef(), module).releaseObject(); }, kContextParseDocstring) .def( @@ -975,16 +1067,25 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of Module py::class_<PyModule>(m, "Module") + .def_property_readonly( + "operation", + [](PyModule &self) { + return PyOperation::forOperation(self.getContext(), + mlirModuleGetOperation(self.get()), + self.getRef().releaseObject()) + .releaseObject(); + }, + "Accesses the module as an operation") .def( "dump", [](PyModule &self) { - mlirOperationDump(mlirModuleGetOperation(self.module)); + mlirOperationDump(mlirModuleGetOperation(self.get())); }, kDumpDocstring) .def( "__str__", [](PyModule &self) { - auto operation = mlirModuleGetOperation(self.module); + MlirOperation operation = mlirModuleGetOperation(self.get()); PyPrintAccumulator printAccum; mlirOperationPrint(operation, printAccum.getCallback(), printAccum.getUserData()); @@ -992,6 +1093,31 @@ void mlir::python::populateIRSubmodule(py::module &m) { }, kOperationStrDunderDocstring); + // Mapping of Operation. + py::class_<PyOperation>(m, "Operation") + .def_property_readonly( + "first_region", + [](PyOperation &self) { + self.checkValid(); + if (mlirOperationGetNumRegions(self.get()) == 0) { + throw SetPyError(PyExc_IndexError, "Operation has no regions"); + } + return PyRegion(self.getContext()->get(), + mlirOperationGetRegion(self.get(), 0), + /*detached=*/false); + }, + py::keep_alive<0, 1>(), "Gets the operation's first region") + .def( + "__str__", + [](PyOperation &self) { + self.checkValid(); + PyPrintAccumulator printAccum; + mlirOperationPrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring); + // Mapping of PyRegion. py::class_<PyRegion>(m, "Region") .def( |