diff options
author | max <maksim.levental@gmail.com> | 2023-05-22 11:12:53 -0500 |
---|---|---|
committer | max <maksim.levental@gmail.com> | 2023-05-22 13:19:54 -0500 |
commit | d39a7844028bcdd28f72b0e69becc9c49b8fd283 (patch) | |
tree | 7ba86484bf023757d0f8a76c76d1b30036a64485 /mlir/lib/Bindings/Python/IRModule.h | |
parent | a7c5cf226024e246501aa2b66350c3f922acc0cb (diff) | |
download | llvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.zip llvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.tar.gz llvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.tar.bz2 |
[MLIR][python bindings] Expose TypeIDs in python
This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D150839
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModule.h')
-rw-r--r-- | mlir/lib/Bindings/Python/IRModule.h | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index ade790b..fa529c4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -20,6 +20,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -826,6 +827,29 @@ private: MlirType type; }; +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the PyTypeID objects are the same (i.e., PyTypeID is a value type). + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + pybind11::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(pybind11::object capsule); + +private: + MlirTypeID typeID; +}; + /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). /// By default, type class hierarchies are one level deep (i.e. a @@ -839,10 +863,14 @@ public: // const char *pyClassName using ClassTy = pybind11::class_<DerivedTy, BaseTy>; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} + : BaseTy(std::move(contextRef), t) { + pybind11::implicitly_convertible<PyType, DerivedTy>(); + } PyConcreteType(PyType &orig) : PyConcreteType(orig.getContext(), castFrom(orig)) {} @@ -866,6 +894,26 @@ public: return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw SetPyError(PyExc_AttributeError, + DerivedTy::pyClassName + + llvm::Twine(" has no typeid.")); + }); + cls.def_property_readonly("typeid", [](PyType &self) { + return py::cast(self).attr("typeid").cast<MlirTypeID>(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + DerivedTy::bindDerived(cls); } |