aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRAttributes.cpp
diff options
context:
space:
mode:
authorPeter Hawkins <phawkins@google.com>2023-07-14 16:08:15 -0700
committerJacques Pienaar <jpienaar@google.com>2023-07-14 16:08:15 -0700
commit71a254543d44a943dfe8790abc60795b87173f0b (patch)
tree605dbcf1b7f51e97639bb45cf2f50727b4a248db /mlir/lib/Bindings/Python/IRAttributes.cpp
parentc9ef33e1d8a8aeb68a18f24af6d9fc9ab4ecf257 (diff)
downloadllvm-71a254543d44a943dfe8790abc60795b87173f0b.zip
llvm-71a254543d44a943dfe8790abc60795b87173f0b.tar.gz
llvm-71a254543d44a943dfe8790abc60795b87173f0b.tar.bz2
[MLIR:Python] Make DenseElementsAttr.get() only request a buffer format if no explicit type was provided.
Not every NumPy type (e.g., the `ml_dtypes.bfloat16` NumPy extension type) has a type in the Python buffer protocol, so exporting such a buffer with `PyBUF_FORMAT` may fail. However, we don't care about the self-reported type of a buffer if the user provides an explicit type. In the case that an explicit type is provided, don't request the format from the buffer protocol, which allows arrays whose element types are unknown to the buffer protocol to be passed. Reviewed By: jpienaar, ftynse Differential Revision: https://reviews.llvm.org/D155209
Diffstat (limited to 'mlir/lib/Bindings/Python/IRAttributes.cpp')
-rw-r--r--mlir/lib/Bindings/Python/IRAttributes.cpp175
1 files changed, 93 insertions, 82 deletions
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 84a48a8..75d743f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -7,12 +7,15 @@
//===----------------------------------------------------------------------===//
#include <optional>
+#include <string_view>
#include <utility>
#include "IRModule.h"
#include "PybindUtils.h"
+#include "llvm/ADT/ScopeExit.h"
+
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -612,19 +615,20 @@ public:
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
- Py_buffer *view = new Py_buffer();
- if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
- delete view;
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
+ }
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
throw py::error_already_set();
}
- py::buffer_info arrayInfo(view);
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
- shape.append(arrayInfo.shape.begin(),
- arrayInfo.shape.begin() + arrayInfo.ndim);
+ shape.append(view.shape, view.shape + view.ndim);
}
MlirAttribute encodingAttr = mlirAttributeGetNull();
@@ -638,85 +642,92 @@ public:
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
- } else if (arrayInfo.format == "f") {
- // f32
- assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (arrayInfo.format == "d") {
- // f64
- assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (arrayInfo.format == "e") {
- // f16
- assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (isSignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // i32
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (arrayInfo.itemsize == 8) {
- // i64
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (arrayInfo.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (arrayInfo.itemsize == 2) {
- // i16
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (arrayInfo.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (arrayInfo.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (arrayInfo.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
- }
- if (bulkLoadElementType) {
- MlirType shapedType;
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
+ }
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
- shapedType = *bulkLoadElementType;
- } else {
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
}
- size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
- MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
- shapedType, rawBufferSize, arrayInfo.ptr);
- if (mlirAttributeIsNull(attr)) {
+ if (!bulkLoadElementType) {
throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
}
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
}
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- arrayInfo.format);
+ MlirType shapedType;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
+ }
+ shapedType = *bulkLoadElementType;
+ } else {
+ shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+ }
+ size_t rawBufferSize = view.len;
+ MlirAttribute attr =
+ mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
}
static PyDenseElementsAttribute getSplat(const PyType &shapedType,
@@ -852,7 +863,7 @@ public:
}
private:
- static bool isUnsignedIntegerFormat(const std::string &format) {
+ static bool isUnsignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
char code = format[0];
@@ -860,7 +871,7 @@ private:
code == 'Q';
}
- static bool isSignedIntegerFormat(const std::string &format) {
+ static bool isSignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
char code = format[0];