diff options
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModules.cpp')
-rw-r--r-- | mlir/lib/Bindings/Python/IRModules.cpp | 440 |
1 files changed, 298 insertions, 142 deletions
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 66e975e..8eab7da 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -46,45 +46,6 @@ static const char kContextGetUnknownLocationDocstring[] = static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; -static const char kContextCreateBlockDocstring[] = - R"(Creates a detached block)"; - -static const char kContextCreateRegionDocstring[] = - R"(Creates a detached region)"; - -static const char kRegionAppendBlockDocstring[] = - R"(Appends a block to a region. - -Raises: - ValueError: If the block is already attached to another region. -)"; - -static const char kRegionInsertBlockDocstring[] = - R"(Inserts a block at a postiion in a region. - -Raises: - ValueError: If the block is already attached to another region. -)"; - -static const char kRegionFirstBlockDocstring[] = - R"(Gets the first block in a region. - -Blocks can also be accessed via the `blocks` container. - -Raises: - IndexError: If the region has no blocks. -)"; - -static const char kBlockNextInRegionDocstring[] = - R"(Gets the next block in the enclosing region. - -Blocks can also be accessed via the `blocks` container of the owning region. -This method exists to mirror the lower level API and should not be preferred. - -Raises: - IndexError: If there are no further blocks. -)"; - static const char kOperationStrDunderDocstring[] = R"(Prints the assembly form of the operation with default options. @@ -171,6 +132,241 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) { } // namespace //------------------------------------------------------------------------------ +// Collections. +//------------------------------------------------------------------------------ + +namespace { + +class PyRegionIterator { +public: + PyRegionIterator(PyOperationRef operation) + : operation(std::move(operation)) {} + + PyRegionIterator &dunderIter() { return *this; } + + PyRegion dunderNext() { + operation->checkValid(); + if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { + throw py::stop_iteration(); + } + MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); + return PyRegion(operation, region); + } + + static void bind(py::module &m) { + py::class_<PyRegionIterator>(m, "RegionIterator") + .def("__iter__", &PyRegionIterator::dunderIter) + .def("__next__", &PyRegionIterator::dunderNext); + } + +private: + PyOperationRef operation; + int nextIndex = 0; +}; + +/// Regions of an op are fixed length and indexed numerically so are represented +/// with a sequence-like container. +class PyRegionList { +public: + PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} + + intptr_t dunderLen() { + operation->checkValid(); + return mlirOperationGetNumRegions(operation->get()); + } + + PyRegion dunderGetItem(intptr_t index) { + // dunderLen checks validity. + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds region"); + } + MlirRegion region = mlirOperationGetRegion(operation->get(), index); + return PyRegion(operation, region); + } + + static void bind(py::module &m) { + py::class_<PyRegionList>(m, "ReqionSequence") + .def("__len__", &PyRegionList::dunderLen) + .def("__getitem__", &PyRegionList::dunderGetItem); + } + +private: + PyOperationRef operation; +}; + +class PyBlockIterator { +public: + PyBlockIterator(PyOperationRef operation, MlirBlock next) + : operation(std::move(operation)), next(next) {} + + PyBlockIterator &dunderIter() { return *this; } + + PyBlock dunderNext() { + operation->checkValid(); + if (mlirBlockIsNull(next)) { + throw py::stop_iteration(); + } + + PyBlock returnBlock(operation, next); + next = mlirBlockGetNextInRegion(next); + return returnBlock; + } + + static void bind(py::module &m) { + py::class_<PyBlockIterator>(m, "BlockIterator") + .def("__iter__", &PyBlockIterator::dunderIter) + .def("__next__", &PyBlockIterator::dunderNext); + } + +private: + PyOperationRef operation; + MlirBlock next; +}; + +/// Blocks are exposed by the C-API as a forward-only linked list. In Python, +/// we present them as a more full-featured list-like container but optimzie +/// it for forward iteration. Blocks are always owned by a region. +class PyBlockList { +public: + PyBlockList(PyOperationRef operation, MlirRegion region) + : operation(std::move(operation)), region(region) {} + + PyBlockIterator dunderIter() { + operation->checkValid(); + return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); + } + + intptr_t dunderLen() { + operation->checkValid(); + intptr_t count = 0; + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + count += 1; + block = mlirBlockGetNextInRegion(block); + } + return count; + } + + PyBlock dunderGetItem(intptr_t index) { + operation->checkValid(); + if (index < 0) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds block"); + } + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + if (index == 0) { + return PyBlock(operation, block); + } + block = mlirBlockGetNextInRegion(block); + index -= 1; + } + throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); + } + + static void bind(py::module &m) { + py::class_<PyBlockList>(m, "BlockList") + .def("__getitem__", &PyBlockList::dunderGetItem) + .def("__iter__", &PyBlockList::dunderIter) + .def("__len__", &PyBlockList::dunderLen); + } + +private: + PyOperationRef operation; + MlirRegion region; +}; + +class PyOperationIterator { +public: + PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) + : parentOperation(std::move(parentOperation)), next(next) {} + + PyOperationIterator &dunderIter() { return *this; } + + py::object dunderNext() { + parentOperation->checkValid(); + if (mlirOperationIsNull(next)) { + throw py::stop_iteration(); + } + + PyOperationRef returnOperation = + PyOperation::forOperation(parentOperation->getContext(), next); + next = mlirOperationGetNextInBlock(next); + return returnOperation.releaseObject(); + } + + static void bind(py::module &m) { + py::class_<PyOperationIterator>(m, "OperationIterator") + .def("__iter__", &PyOperationIterator::dunderIter) + .def("__next__", &PyOperationIterator::dunderNext); + } + +private: + PyOperationRef parentOperation; + MlirOperation next; +}; + +/// Operations are exposed by the C-API as a forward-only linked list. In +/// Python, we present them as a more full-featured list-like container but +/// optimzie it for forward iteration. Iterable operations are always owned +/// by a block. +class PyOperationList { +public: + PyOperationList(PyOperationRef parentOperation, MlirBlock block) + : parentOperation(std::move(parentOperation)), block(block) {} + + PyOperationIterator dunderIter() { + parentOperation->checkValid(); + return PyOperationIterator(parentOperation, + mlirBlockGetFirstOperation(block)); + } + + intptr_t dunderLen() { + parentOperation->checkValid(); + intptr_t count = 0; + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + count += 1; + childOp = mlirOperationGetNextInBlock(childOp); + } + return count; + } + + py::object dunderGetItem(intptr_t index) { + parentOperation->checkValid(); + if (index < 0) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds operation"); + } + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + if (index == 0) { + return PyOperation::forOperation(parentOperation->getContext(), childOp) + .releaseObject(); + } + childOp = mlirOperationGetNextInBlock(childOp); + index -= 1; + } + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds operation"); + } + + static void bind(py::module &m) { + py::class_<PyOperationList>(m, "OperationList") + .def("__getitem__", &PyOperationList::dunderGetItem) + .def("__iter__", &PyOperationList::dunderIter) + .def("__len__", &PyOperationList::dunderLen); + } + +private: + PyOperationRef parentOperation; + MlirBlock block; +}; + +} // namespace + +//------------------------------------------------------------------------------ // PyMlirContext //------------------------------------------------------------------------------ @@ -310,24 +506,6 @@ void PyOperation::checkValid() { } //------------------------------------------------------------------------------ -// PyBlock, PyRegion. -//------------------------------------------------------------------------------ - -void PyRegion::attachToParent() { - if (!detached) { - throw SetPyError(PyExc_ValueError, "Region is already attached to an op"); - } - detached = false; -} - -void PyBlock::attachToParent() { - if (!detached) { - throw SetPyError(PyExc_ValueError, "Block is already attached to an op"); - } - detached = false; -} - -//------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -967,6 +1145,14 @@ void mlir::python::populateIRSubmodule(py::module &m) { return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def_property( + "allow_unregistered_dialects", + [](PyMlirContext &self) -> bool { + return mlirContextGetAllowUnregisteredDialects(self.get()); + }, + [](PyMlirContext &self, bool value) { + mlirContextSetAllowUnregisteredDialects(self.get(), value); + }) .def( "parse_module", [](PyMlirContext &self, const std::string moduleAsm) { @@ -1026,37 +1212,7 @@ void mlir::python::populateIRSubmodule(py::module &m) { self.get(), filename.c_str(), line, col)); }, kContextGetFileLocationDocstring, py::arg("filename"), - py::arg("line"), py::arg("col")) - .def( - "create_region", - [](PyMlirContext &self) { - // The creating context is explicitly captured on regions to - // facilitate illegal assemblies of objects from multiple contexts - // that would invalidate the memory model. - return PyRegion(self.get(), mlirRegionCreate(), - /*detached=*/true); - }, - py::keep_alive<0, 1>(), kContextCreateRegionDocstring) - .def( - "create_block", - [](PyMlirContext &self, std::vector<PyType> pyTypes) { - // In order for the keep_alive extend the proper lifetime, all - // types must be from the same context. - for (auto pyType : pyTypes) { - if (!mlirContextEqual(mlirTypeGetContext(pyType.type), - self.get())) { - throw SetPyError( - PyExc_ValueError, - "All types used to construct a block must be from " - "the same context as the block"); - } - } - llvm::SmallVector<MlirType, 4> types(pyTypes.begin(), - pyTypes.end()); - return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]), - /*detached=*/true); - }, - py::keep_alive<0, 1>(), kContextCreateBlockDocstring); + py::arg("line"), py::arg("col")); py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; @@ -1096,17 +1252,10 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of Operation. py::class_<PyOperation>(m, "Operation") .def_property_readonly( - "first_region", - [](PyOperation &self) { - self.checkValid(); - if (mlirOperationGetNumRegions(self.get()) == 0) { - throw SetPyError(PyExc_IndexError, "Operation has no regions"); - } - return PyRegion(self.getContext()->get(), - mlirOperationGetRegion(self.get(), 0), - /*detached=*/false); - }, - py::keep_alive<0, 1>(), "Gets the operation's first region") + "regions", + [](PyOperation &self) { return PyRegionList(self.getRef()); }) + .def("__iter__", + [](PyOperation &self) { return PyRegionIterator(self.getRef()); }) .def( "__str__", [](PyOperation &self) { @@ -1120,63 +1269,62 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of PyRegion. py::class_<PyRegion>(m, "Region") - .def( - "append_block", - [](PyRegion &self, PyBlock &block) { - if (!mlirContextEqual(self.context, block.context)) { - throw SetPyError( - PyExc_ValueError, - "Block must have been created from the same context as " - "this region"); - } - - block.attachToParent(); - mlirRegionAppendOwnedBlock(self.region, block.block); + .def_property_readonly( + "blocks", + [](PyRegion &self) { + return PyBlockList(self.getParentOperation(), self.get()); }, - kRegionAppendBlockDocstring) + "Returns a forward-optimized sequence of blocks.") .def( - "insert_block", - [](PyRegion &self, int pos, PyBlock &block) { - if (!mlirContextEqual(self.context, block.context)) { - throw SetPyError( - PyExc_ValueError, - "Block must have been created from the same context as " - "this region"); - } - block.attachToParent(); - // TODO: Make this return a failure and raise if out of bounds. - mlirRegionInsertOwnedBlock(self.region, pos, block.block); - }, - kRegionInsertBlockDocstring) - .def_property_readonly( - "first_block", + "__iter__", [](PyRegion &self) { - MlirBlock block = mlirRegionGetFirstBlock(self.region); - if (mlirBlockIsNull(block)) { - throw SetPyError(PyExc_IndexError, "Region has no blocks"); - } - return PyBlock(self.context, block, /*detached=*/false); + self.checkValid(); + MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); + return PyBlockIterator(self.getParentOperation(), firstBlock); }, - kRegionFirstBlockDocstring); + "Iterates over blocks in the region.") + .def("__eq__", [](PyRegion &self, py::object &other) { + try { + PyRegion *otherRegion = other.cast<PyRegion *>(); + return self.get().ptr == otherRegion->get().ptr; + } catch (std::exception &e) { + return false; + } + }); // Mapping of PyBlock. py::class_<PyBlock>(m, "Block") .def_property_readonly( - "next_in_region", + "operations", [](PyBlock &self) { - MlirBlock block = mlirBlockGetNextInRegion(self.block); - if (mlirBlockIsNull(block)) { - throw SetPyError(PyExc_IndexError, - "Attempt to read past last block"); - } - return PyBlock(self.context, block, /*detached=*/false); + return PyOperationList(self.getParentOperation(), self.get()); }, - py::keep_alive<0, 1>(), kBlockNextInRegionDocstring) + "Returns a forward-optimized sequence of operations.") + .def( + "__iter__", + [](PyBlock &self) { + self.checkValid(); + MlirOperation firstOperation = + mlirBlockGetFirstOperation(self.get()); + return PyOperationIterator(self.getParentOperation(), + firstOperation); + }, + "Iterates over operations in the block.") + .def("__eq__", + [](PyBlock &self, py::object &other) { + try { + PyBlock *otherBlock = other.cast<PyBlock *>(); + return self.get().ptr == otherBlock->get().ptr; + } catch (std::exception &e) { + return false; + } + }) .def( "__str__", [](PyBlock &self) { + self.checkValid(); PyPrintAccumulator printAccum; - mlirBlockPrint(self.block, printAccum.getCallback(), + mlirBlockPrint(self.get(), printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, @@ -1310,4 +1458,12 @@ void mlir::python::populateIRSubmodule(py::module &m) { PyMemRefType::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); + + // Container bindings. + PyBlockIterator::bind(m); + PyBlockList::bind(m); + PyOperationIterator::bind(m); + PyOperationList::bind(m); + PyRegionIterator::bind(m); + PyRegionList::bind(m); } |