aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h')
-rw-r--r--mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h13
1 files changed, 8 insertions, 5 deletions
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
index 7671db9..d4b6e1f 100644
--- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
@@ -89,7 +89,7 @@ template <typename T> struct UnrankedMemRefType {
template <typename StreamType, typename T, int N>
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
static_assert(N > 0, "Expected N > 0");
- os << "Memref base@ = " << V.data << " rank = " << N
+ os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N
<< " offset = " << V.offset << " sizes = [" << V.sizes[0];
for (unsigned i = 1; i < N; ++i)
os << ", " << V.sizes[i];
@@ -101,14 +101,14 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
template <typename StreamType, typename T>
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
- os << "Memref base@ = " << V.data << " rank = 0"
+ os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0"
<< " offset = " << V.offset;
}
template <typename T, typename StreamType>
void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
os << "Unranked Memref rank = " << V.rank << " "
- << "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
+ << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n";
}
template <typename T, int Dim, int... Dims> struct Vector {
@@ -258,8 +258,8 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
}
template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
- std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
- << " offset = " << M.offset << " data = " << std::endl;
+ printMemRefMetaData(std::cout, M);
+ std::cout << " data = " << std::endl;
std::cout << "[";
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
std::cout << "]" << std::endl;
@@ -270,7 +270,10 @@ template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
// Currently exposed C API.
////////////////////////////////////////////////////////////////////////////////
extern "C" MLIR_RUNNER_UTILS_EXPORT void
+print_memref_i8(UnrankedMemRefType<int8_t> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_f32(UnrankedMemRefType<float> *M);
+
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void