aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/mlir-cpu-runner
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/mlir-cpu-runner')
-rw-r--r--mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h15
-rw-r--r--mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp35
-rw-r--r--mlir/test/mlir-cpu-runner/unranked_memref.mlir43
-rw-r--r--mlir/test/mlir-cpu-runner/utils.mlir15
4 files changed, 93 insertions, 15 deletions
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
index 934faf8..8a58493 100644
--- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
@@ -78,6 +78,12 @@ template <typename T> struct StridedMemRefType<T, 0> {
int64_t offset;
};
+// Unranked MemRef
+struct UnrankedMemRefType {
+ int64_t rank;
+ void *descriptor;
+};
+
template <typename StreamType, typename T, int N>
void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
static_assert(N > 0, "Expected N > 0");
@@ -98,6 +104,15 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
}
extern "C" MLIR_RUNNER_UTILS_EXPORT void
+print_memref_f32(UnrankedMemRefType *M);
+
+template <typename StreamType>
+void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) {
+ os << "Unranked Memref rank = " << V.rank << " "
+ << "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " ";
+}
+
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_0d_f32(StridedMemRefType<float, 0> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_1d_f32(StridedMemRefType<float, 1> *M);
diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
index 05c34f8..056ff65 100644
--- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
+++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
@@ -148,15 +148,41 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
std::cout << std::endl;
}
-template <typename T> void printZeroDMemRef(StridedMemRefType<T, 0> &M) {
+template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
std::cout << "\nMemref base@ = " << M.data << " rank = " << 0
<< " offset = " << M.offset << " data = [";
MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
std::cout << "]" << std::endl;
}
+extern "C" void
+print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
+ printMemRef(*M);
+}
+
+extern "C" void print_memref_f32(UnrankedMemRefType *M) {
+ printUnrankedMemRefMetaData(std::cout, *M);
+ int rank = M->rank;
+ void *ptr = M->descriptor;
+
+#define MEMREF_CASE(RANK) \
+ case RANK: \
+ printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr))); \
+ break
+
+ switch (rank) {
+ MEMREF_CASE(0);
+ MEMREF_CASE(1);
+ MEMREF_CASE(2);
+ MEMREF_CASE(3);
+ MEMREF_CASE(4);
+ default:
+ assert(0 && "Unsupported rank to print");
+ }
+}
+
extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) {
- printZeroDMemRef(*M);
+ printMemRef(*M);
}
extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) {
printMemRef(*M);
@@ -170,8 +196,3 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
printMemRef(*M);
}
-
-extern "C" void
-print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) {
- printMemRef(*M);
-}
diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
new file mode 100644
index 0000000..4e721be
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
+
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [10, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-10: [10, 10, 10]
+//
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [10, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-10: [5, 5, 5]
+//
+// CHECK: rank = 2
+// CHECK-SAME: sizes = [10, 3]
+// CHECK-SAME: strides = [3, 1]
+// CHECK-COUNT-10: [2, 2, 2]
+func @main() -> () {
+ %A = alloc() : memref<10x3xf32, 0>
+ %f2 = constant 2.00000e+00 : f32
+ %f5 = constant 5.00000e+00 : f32
+ %f10 = constant 10.00000e+00 : f32
+
+ %V = memref_cast %A : memref<10x3xf32, 0> to memref<?x?xf32>
+ linalg.fill(%V, %f10) : memref<?x?xf32, 0>, f32
+ %U = memref_cast %A : memref<10x3xf32, 0> to memref<*xf32>
+ call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+
+ %V2 = memref_cast %U : memref<*xf32> to memref<?x?xf32>
+ linalg.fill(%V2, %f5) : memref<?x?xf32, 0>, f32
+ %U2 = memref_cast %V2 : memref<?x?xf32, 0> to memref<*xf32>
+ call @print_memref_f32(%U2) : (memref<*xf32>) -> ()
+
+ %V3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
+ %V4 = memref_cast %V3 : memref<*xf32> to memref<?x?xf32>
+ linalg.fill(%V4, %f2) : memref<?x?xf32, 0>, f32
+ %U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
+
+ dealloc %A : memref<10x3xf32, 0>
+ return
+}
+
+func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index ed54b90..099b856 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -7,7 +7,8 @@ func @print_0d() {
%f = constant 2.00000e+00 : f32
%A = alloc() : memref<f32>
store %f, %A[]: memref<f32>
- call @print_memref_0d_f32(%A): (memref<f32>) -> ()
+ %U = memref_cast %A : memref<f32> to memref<*xf32>
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<f32>
return
}
@@ -18,7 +19,8 @@ func @print_1d() {
%A = alloc() : memref<16xf32>
%B = memref_cast %A: memref<16xf32> to memref<?xf32>
linalg.fill(%B, %f) : memref<?xf32>, f32
- call @print_memref_1d_f32(%B): (memref<?xf32>) -> ()
+ %U = memref_cast %B : memref<?xf32> to memref<*xf32>
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<16xf32>
return
}
@@ -34,8 +36,8 @@ func @print_3d() {
%c2 = constant 2 : index
store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>
-
- call @print_memref_3d_f32(%B): (memref<?x?x?xf32>) -> ()
+ %U = memref_cast %B : memref<?x?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
dealloc %A : memref<3x4x5xf32>
return
}
@@ -46,10 +48,7 @@ func @print_3d() {
// PRINT-3D-NEXT: 2, 2, 4, 2, 2
// PRINT-3D-NEXT: 2, 2, 2, 2, 2
-func @print_memref_0d_f32(memref<f32>)
-func @print_memref_1d_f32(memref<?xf32>)
-func @print_memref_3d_f32(memref<?x?x?xf32>)
-
+func @print_memref_f32(memref<*xf32>)
!vector_type_C = type vector<4x4xf32>
!matrix_type_CC = type memref<1x1x!vector_type_C>