aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRModule.h
diff options
context:
space:
mode:
authormax <maksim.levental@gmail.com>2023-05-22 11:12:53 -0500
committermax <maksim.levental@gmail.com>2023-05-22 13:19:54 -0500
commitd39a7844028bcdd28f72b0e69becc9c49b8fd283 (patch)
tree7ba86484bf023757d0f8a76c76d1b30036a64485 /mlir/lib/Bindings/Python/IRModule.h
parenta7c5cf226024e246501aa2b66350c3f922acc0cb (diff)
downloadllvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.zip
llvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.tar.gz
llvm-d39a7844028bcdd28f72b0e69becc9c49b8fd283.tar.bz2
[MLIR][python bindings] Expose TypeIDs in python
This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150839
Diffstat (limited to 'mlir/lib/Bindings/Python/IRModule.h')
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h50
1 files changed, 49 insertions, 1 deletions
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ade790b..fa529c4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -20,6 +20,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@@ -826,6 +827,29 @@ private:
MlirType type;
};
+/// A TypeID provides an efficient and unique identifier for a specific C++
+/// type. This allows for a C++ type to be compared, hashed, and stored in an
+/// opaque context. This class wraps around the generic MlirTypeID.
+class PyTypeID {
+public:
+ PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the PyTypeID objects are the same (i.e., PyTypeID is a value type).
+ bool operator==(const PyTypeID &other) const;
+ operator MlirTypeID() const { return typeID; }
+ MlirTypeID get() { return typeID; }
+
+ /// Gets a capsule wrapping the void* within the MlirTypeID.
+ pybind11::object getCapsule();
+
+ /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
+ static PyTypeID createFromCapsule(pybind11::object capsule);
+
+private:
+ MlirTypeID typeID;
+};
+
/// CRTP base classes for Python types that subclass Type and should be
/// castable from it (i.e. via something like IntegerType(t)).
/// By default, type class hierarchies are one level deep (i.e. a
@@ -839,10 +863,14 @@ public:
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {}
+ : BaseTy(std::move(contextRef), t) {
+ pybind11::implicitly_convertible<PyType, DerivedTy>();
+ }
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
@@ -866,6 +894,26 @@ public:
return DerivedTy::isaFunction(otherType);
},
pybind11::arg("other"));
+ cls.def_property_readonly_static(
+ "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+ if (DerivedTy::getTypeIdFunction)
+ return DerivedTy::getTypeIdFunction();
+ throw SetPyError(PyExc_AttributeError,
+ DerivedTy::pyClassName +
+ llvm::Twine(" has no typeid."));
+ });
+ cls.def_property_readonly("typeid", [](PyType &self) {
+ return py::cast(self).attr("typeid").cast<MlirTypeID>();
+ });
+ cls.def("__repr__", [](DerivedTy &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append(DerivedTy::pyClassName);
+ printAccum.parts.append("(");
+ mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
+
DerivedTy::bindDerived(cls);
}