aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRModules.cpp
diff options
context:
space:
mode:
authorStella Laurenzo <stellaraccident@gmail.com>2020-09-19 22:02:32 -0700
committerStella Laurenzo <stellaraccident@gmail.com>2020-09-23 07:57:50 -0700
commit4cf754c4bca94e957b634a854f57f4c7ec9151fb (patch)
tree951b1973397c61a6918148b32a57768a8135fd57 /mlir/lib/Bindings/Python/IRModules.cpp
parent7abb0ff7e0419a9554d77e9108cb7da670b7471c (diff)
downloadllvm-4cf754c4bca94e957b634a854f57f4c7ec9151fb.zip
llvm-4cf754c4bca94e957b634a854f57f4c7ec9151fb.tar.gz
llvm-4cf754c4bca94e957b634a854f57f4c7ec9151fb.tar.bz2
Implement python iteration over the operation/region/block hierarchy.
* Removes the half-completed prior attempt at region/block mutation in favor of new approach to ownership. * Will re-add mutation more correctly in a follow-on. * Eliminates the detached state on blocks and regions, simplifying the ownership hierarchy. * Adds both iterator and index based access at each level. Differential Revision: https://reviews.llvm.org/D87982
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModules.cpp')
-rw-r--r--mlir/lib/Bindings/Python/IRModules.cpp440
1 files changed, 298 insertions, 142 deletions
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 66e975e..8eab7da 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -46,45 +46,6 @@ static const char kContextGetUnknownLocationDocstring[] =
static const char kContextGetFileLocationDocstring[] =
R"(Gets a Location representing a file, line and column)";
-static const char kContextCreateBlockDocstring[] =
- R"(Creates a detached block)";
-
-static const char kContextCreateRegionDocstring[] =
- R"(Creates a detached region)";
-
-static const char kRegionAppendBlockDocstring[] =
- R"(Appends a block to a region.
-
-Raises:
- ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionInsertBlockDocstring[] =
- R"(Inserts a block at a postiion in a region.
-
-Raises:
- ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionFirstBlockDocstring[] =
- R"(Gets the first block in a region.
-
-Blocks can also be accessed via the `blocks` container.
-
-Raises:
- IndexError: If the region has no blocks.
-)";
-
-static const char kBlockNextInRegionDocstring[] =
- R"(Gets the next block in the enclosing region.
-
-Blocks can also be accessed via the `blocks` container of the owning region.
-This method exists to mirror the lower level API and should not be preferred.
-
-Raises:
- IndexError: If there are no further blocks.
-)";
-
static const char kOperationStrDunderDocstring[] =
R"(Prints the assembly form of the operation with default options.
@@ -171,6 +132,241 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
} // namespace
//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+namespace {
+
+class PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw py::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter)
+ .def("__next__", &PyRegionIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef operation;
+ int nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList {
+public:
+ PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+ }
+
+ PyRegion dunderGetItem(intptr_t index) {
+ // dunderLen checks validity.
+ if (index < 0 || index >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds region");
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), index);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyRegionList>(m, "ReqionSequence")
+ .def("__len__", &PyRegionList::dunderLen)
+ .def("__getitem__", &PyRegionList::dunderGetItem);
+ }
+
+private:
+ PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw py::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter)
+ .def("__next__", &PyBlockIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimzie
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ }
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+ }
+
+ PyBlock dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem)
+ .def("__iter__", &PyBlockList::dunderIter)
+ .def("__len__", &PyBlockList::dunderLen);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ py::object dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw py::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation.releaseObject();
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter)
+ .def("__next__", &PyOperationIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimzie it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+ }
+
+ intptr_t dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+ }
+
+ py::object dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ .releaseObject();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds operation");
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem)
+ .def("__iter__", &PyOperationList::dunderIter)
+ .def("__len__", &PyOperationList::dunderLen);
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+} // namespace
+
+//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -310,24 +506,6 @@ void PyOperation::checkValid() {
}
//------------------------------------------------------------------------------
-// PyBlock, PyRegion.
-//------------------------------------------------------------------------------
-
-void PyRegion::attachToParent() {
- if (!detached) {
- throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
- }
- detached = false;
-}
-
-void PyBlock::attachToParent() {
- if (!detached) {
- throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
- }
- detached = false;
-}
-
-//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
@@ -967,6 +1145,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def_property(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ })
.def(
"parse_module",
[](PyMlirContext &self, const std::string moduleAsm) {
@@ -1026,37 +1212,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
self.get(), filename.c_str(), line, col));
},
kContextGetFileLocationDocstring, py::arg("filename"),
- py::arg("line"), py::arg("col"))
- .def(
- "create_region",
- [](PyMlirContext &self) {
- // The creating context is explicitly captured on regions to
- // facilitate illegal assemblies of objects from multiple contexts
- // that would invalidate the memory model.
- return PyRegion(self.get(), mlirRegionCreate(),
- /*detached=*/true);
- },
- py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
- .def(
- "create_block",
- [](PyMlirContext &self, std::vector<PyType> pyTypes) {
- // In order for the keep_alive extend the proper lifetime, all
- // types must be from the same context.
- for (auto pyType : pyTypes) {
- if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
- self.get())) {
- throw SetPyError(
- PyExc_ValueError,
- "All types used to construct a block must be from "
- "the same context as the block");
- }
- }
- llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
- pyTypes.end());
- return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
- /*detached=*/true);
- },
- py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
+ py::arg("line"), py::arg("col"));
py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
@@ -1096,17 +1252,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Operation.
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
- "first_region",
- [](PyOperation &self) {
- self.checkValid();
- if (mlirOperationGetNumRegions(self.get()) == 0) {
- throw SetPyError(PyExc_IndexError, "Operation has no regions");
- }
- return PyRegion(self.getContext()->get(),
- mlirOperationGetRegion(self.get(), 0),
- /*detached=*/false);
- },
- py::keep_alive<0, 1>(), "Gets the operation's first region")
+ "regions",
+ [](PyOperation &self) { return PyRegionList(self.getRef()); })
+ .def("__iter__",
+ [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
.def(
"__str__",
[](PyOperation &self) {
@@ -1120,63 +1269,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyRegion.
py::class_<PyRegion>(m, "Region")
- .def(
- "append_block",
- [](PyRegion &self, PyBlock &block) {
- if (!mlirContextEqual(self.context, block.context)) {
- throw SetPyError(
- PyExc_ValueError,
- "Block must have been created from the same context as "
- "this region");
- }
-
- block.attachToParent();
- mlirRegionAppendOwnedBlock(self.region, block.block);
+ .def_property_readonly(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
},
- kRegionAppendBlockDocstring)
+ "Returns a forward-optimized sequence of blocks.")
.def(
- "insert_block",
- [](PyRegion &self, int pos, PyBlock &block) {
- if (!mlirContextEqual(self.context, block.context)) {
- throw SetPyError(
- PyExc_ValueError,
- "Block must have been created from the same context as "
- "this region");
- }
- block.attachToParent();
- // TODO: Make this return a failure and raise if out of bounds.
- mlirRegionInsertOwnedBlock(self.region, pos, block.block);
- },
- kRegionInsertBlockDocstring)
- .def_property_readonly(
- "first_block",
+ "__iter__",
[](PyRegion &self) {
- MlirBlock block = mlirRegionGetFirstBlock(self.region);
- if (mlirBlockIsNull(block)) {
- throw SetPyError(PyExc_IndexError, "Region has no blocks");
- }
- return PyBlock(self.context, block, /*detached=*/false);
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
},
- kRegionFirstBlockDocstring);
+ "Iterates over blocks in the region.")
+ .def("__eq__", [](PyRegion &self, py::object &other) {
+ try {
+ PyRegion *otherRegion = other.cast<PyRegion *>();
+ return self.get().ptr == otherRegion->get().ptr;
+ } catch (std::exception &e) {
+ return false;
+ }
+ });
// Mapping of PyBlock.
py::class_<PyBlock>(m, "Block")
.def_property_readonly(
- "next_in_region",
+ "operations",
[](PyBlock &self) {
- MlirBlock block = mlirBlockGetNextInRegion(self.block);
- if (mlirBlockIsNull(block)) {
- throw SetPyError(PyExc_IndexError,
- "Attempt to read past last block");
- }
- return PyBlock(self.context, block, /*detached=*/false);
+ return PyOperationList(self.getParentOperation(), self.get());
},
- py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
+ "Returns a forward-optimized sequence of operations.")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def("__eq__",
+ [](PyBlock &self, py::object &other) {
+ try {
+ PyBlock *otherBlock = other.cast<PyBlock *>();
+ return self.get().ptr == otherBlock->get().ptr;
+ } catch (std::exception &e) {
+ return false;
+ }
+ })
.def(
"__str__",
[](PyBlock &self) {
+ self.checkValid();
PyPrintAccumulator printAccum;
- mlirBlockPrint(self.block, printAccum.getCallback(),
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@@ -1310,4 +1458,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
+
+ // Container bindings.
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
}