aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRModules.cpp
diff options
context:
space:
mode:
authorStella Laurenzo <stellaraccident@gmail.com>2020-09-18 18:38:21 -0700
committerStella Laurenzo <stellaraccident@gmail.com>2020-09-23 07:57:50 -0700
commit7abb0ff7e0419a9554d77e9108cb7da670b7471c (patch)
tree404fc37e1df9dde9bfb37b48057d3d4bbfb9159c /mlir/lib/Bindings/Python/IRModules.cpp
parentbd8b50cd7f5dd5237ec9187ef2fcea3adc15b61a (diff)
downloadllvm-7abb0ff7e0419a9554d77e9108cb7da670b7471c.zip
llvm-7abb0ff7e0419a9554d77e9108cb7da670b7471c.tar.gz
llvm-7abb0ff7e0419a9554d77e9108cb7da670b7471c.tar.bz2
Add Operation to python bindings.
* Fixes a rather egregious bug with respect to the inability to return arbitrary objects from py::init (was causing aliasing of multiple py::object -> native instance). * Makes Modules and Operations referencable types so that they can be reliably depended on. * Uniques python operation instances within a context. Opens the door for further accounting. * Next I will retrofit region and block to be dependent on the operation, and I will attempt to model the API to avoid detached regions/blocks, which will simplify things a lot (in that world, only operations can be detached). * Added quite a bit of test coverage to check for leaks and reference issues. * Supercedes: https://reviews.llvm.org/D87213 Differential Revision: https://reviews.llvm.org/D87958
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModules.cpp')
-rw-r--r--mlir/lib/Bindings/Python/IRModules.cpp178
1 files changed, 152 insertions, 26 deletions
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index d7a0bd8..66e975e 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -174,13 +174,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
// PyMlirContext
//------------------------------------------------------------------------------
-PyMlirContext *PyMlirContextRef::release() {
- object.release();
- return &referrent;
+PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
+ py::gil_scoped_acquire acquire;
+ auto &liveContexts = getLiveContexts();
+ liveContexts[context.ptr] = this;
}
-PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
-
PyMlirContext::~PyMlirContext() {
// Note that the only public way to construct an instance is via the
// forContext method, which always puts the associated handle into
@@ -190,6 +189,11 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}
+PyMlirContext *PyMlirContext::createNewContextForInit() {
+ MlirContext context = mlirContextCreate();
+ return new PyMlirContext(context);
+}
+
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
py::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
@@ -198,14 +202,13 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
// Create.
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
py::object pyRef = py::cast(unownedContextWrapper);
- unownedContextWrapper->handle = pyRef;
- liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
- return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
- } else {
- // Use existing.
- py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
- return PyMlirContextRef(*it->second.second, std::move(pyRef));
+ assert(pyRef && "cast to py::object failed");
+ liveContexts[context.ptr] = unownedContextWrapper;
+ return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
}
+ // Use existing.
+ py::object pyRef = py::cast(it->second);
+ return PyMlirContextRef(it->second, std::move(pyRef));
}
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
@@ -215,8 +218,99 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+
+//------------------------------------------------------------------------------
+// PyModule
+//------------------------------------------------------------------------------
+
+PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedModule, py::return_value_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ return PyModuleRef(unownedModule, std::move(pyRef));
+}
+
+//------------------------------------------------------------------------------
+// PyOperation
+//------------------------------------------------------------------------------
+
+PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
+ : BaseContextObject(std::move(contextRef)), operation(operation) {}
+
+PyOperation::~PyOperation() {
+ auto &liveOperations = getContext()->liveOperations;
+ assert(liveOperations.count(operation.ptr) == 1 &&
+ "destroying operation not in live map");
+ liveOperations.erase(operation.ptr);
+ if (!isAttached()) {
+ mlirOperationDestroy(operation);
+ }
+}
+
+PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
+ MlirOperation operation,
+ py::object parentKeepAlive) {
+ auto &liveOperations = contextRef->liveOperations;
+ // Create.
+ PyOperation *unownedOperation =
+ new PyOperation(std::move(contextRef), operation);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedOperation, py::return_value_policy::take_ownership);
+ unownedOperation->handle = pyRef;
+ if (parentKeepAlive) {
+ unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
+ }
+ liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
+ return PyOperationRef(unownedOperation, std::move(pyRef));
+}
+
+PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
+ MlirOperation operation,
+ py::object parentKeepAlive) {
+ auto &liveOperations = contextRef->liveOperations;
+ auto it = liveOperations.find(operation.ptr);
+ if (it == liveOperations.end()) {
+ // Create.
+ return createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
+ }
+ // Use existing.
+ PyOperation *existing = it->second.second;
+ assert(existing->parentKeepAlive.is(parentKeepAlive));
+ py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ return PyOperationRef(existing, std::move(pyRef));
+}
+
+PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
+ MlirOperation operation,
+ py::object parentKeepAlive) {
+ auto &liveOperations = contextRef->liveOperations;
+ assert(liveOperations.count(operation.ptr) == 0 &&
+ "cannot create detached operation that already exists");
+ (void)liveOperations;
+
+ PyOperationRef created = createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
+ created->attached = false;
+ return created;
+}
+
+void PyOperation::checkValid() {
+ if (!valid) {
+ throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
+ }
+}
+
//------------------------------------------------------------------------------
-// PyBlock, PyRegion, and PyOperation.
+// PyBlock, PyRegion.
//------------------------------------------------------------------------------
void PyRegion::attachToParent() {
@@ -865,29 +959,27 @@ public:
void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of MlirContext
py::class_<PyMlirContext>(m, "Context")
- .def(py::init<>([]() {
- MlirContext context = mlirContextCreate();
- auto contextRef = PyMlirContext::forContext(context);
- return contextRef.release();
- }))
+ .def(py::init<>(&PyMlirContext::createNewContextForInit))
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
[](PyMlirContext &self) {
- auto ref = PyMlirContext::forContext(self.get());
- return ref.release();
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
})
+ .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def(
"parse_module",
- [](PyMlirContext &self, const std::string module) {
- auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
+ [](PyMlirContext &self, const std::string moduleAsm) {
+ MlirModule module =
+ mlirModuleCreateParse(self.get(), moduleAsm.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
- if (mlirModuleIsNull(moduleRef)) {
+ if (mlirModuleIsNull(module)) {
throw SetPyError(
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
- return PyModule(self.getRef(), moduleRef);
+ return PyModule::create(self.getRef(), module).releaseObject();
},
kContextParseDocstring)
.def(
@@ -975,16 +1067,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Module
py::class_<PyModule>(m, "Module")
+ .def_property_readonly(
+ "operation",
+ [](PyModule &self) {
+ return PyOperation::forOperation(self.getContext(),
+ mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject())
+ .releaseObject();
+ },
+ "Accesses the module as an operation")
.def(
"dump",
[](PyModule &self) {
- mlirOperationDump(mlirModuleGetOperation(self.module));
+ mlirOperationDump(mlirModuleGetOperation(self.get()));
},
kDumpDocstring)
.def(
"__str__",
[](PyModule &self) {
- auto operation = mlirModuleGetOperation(self.module);
+ MlirOperation operation = mlirModuleGetOperation(self.get());
PyPrintAccumulator printAccum;
mlirOperationPrint(operation, printAccum.getCallback(),
printAccum.getUserData());
@@ -992,6 +1093,31 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
kOperationStrDunderDocstring);
+ // Mapping of Operation.
+ py::class_<PyOperation>(m, "Operation")
+ .def_property_readonly(
+ "first_region",
+ [](PyOperation &self) {
+ self.checkValid();
+ if (mlirOperationGetNumRegions(self.get()) == 0) {
+ throw SetPyError(PyExc_IndexError, "Operation has no regions");
+ }
+ return PyRegion(self.getContext()->get(),
+ mlirOperationGetRegion(self.get(), 0),
+ /*detached=*/false);
+ },
+ py::keep_alive<0, 1>(), "Gets the operation's first region")
+ .def(
+ "__str__",
+ [](PyOperation &self) {
+ self.checkValid();
+ PyPrintAccumulator printAccum;
+ mlirOperationPrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ kTypeStrDunderDocstring);
+
// Mapping of PyRegion.
py::class_<PyRegion>(m, "Region")
.def(