aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/IRAttributes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bindings/Python/IRAttributes.cpp')
-rw-r--r--mlir/lib/Bindings/Python/IRAttributes.cpp175
1 files changed, 93 insertions, 82 deletions
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 84a48a8..75d743f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -7,12 +7,15 @@
//===----------------------------------------------------------------------===//
#include <optional>
+#include <string_view>
#include <utility>
#include "IRModule.h"
#include "PybindUtils.h"
+#include "llvm/ADT/ScopeExit.h"
+
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -612,19 +615,20 @@ public:
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
- Py_buffer *view = new Py_buffer();
- if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
- delete view;
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
+ }
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
throw py::error_already_set();
}
- py::buffer_info arrayInfo(view);
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
- shape.append(arrayInfo.shape.begin(),
- arrayInfo.shape.begin() + arrayInfo.ndim);
+ shape.append(view.shape, view.shape + view.ndim);
}
MlirAttribute encodingAttr = mlirAttributeGetNull();
@@ -638,85 +642,92 @@ public:
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
- } else if (arrayInfo.format == "f") {
- // f32
- assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (arrayInfo.format == "d") {
- // f64
- assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (arrayInfo.format == "e") {
- // f16
- assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (isSignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // i32
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (arrayInfo.itemsize == 8) {
- // i64
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (arrayInfo.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (arrayInfo.itemsize == 2) {
- // i16
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (arrayInfo.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (arrayInfo.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (arrayInfo.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
- }
- if (bulkLoadElementType) {
- MlirType shapedType;
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
+ }
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
- shapedType = *bulkLoadElementType;
- } else {
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
}
- size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
- MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
- shapedType, rawBufferSize, arrayInfo.ptr);
- if (mlirAttributeIsNull(attr)) {
+ if (!bulkLoadElementType) {
throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
}
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
}
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- arrayInfo.format);
+ MlirType shapedType;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
+ }
+ shapedType = *bulkLoadElementType;
+ } else {
+ shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+ }
+ size_t rawBufferSize = view.len;
+ MlirAttribute attr =
+ mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
}
static PyDenseElementsAttribute getSplat(const PyType &shapedType,
@@ -852,7 +863,7 @@ public:
}
private:
- static bool isUnsignedIntegerFormat(const std::string &format) {
+ static bool isUnsignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
char code = format[0];
@@ -860,7 +871,7 @@ private:
code == 'Q';
}
- static bool isSignedIntegerFormat(const std::string &format) {
+ static bool isSignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
char code = format[0];