diff options
Diffstat (limited to 'mlir/test/mlir-cpu-runner')
-rw-r--r-- | mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h | 15 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp | 35 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/unranked_memref.mlir | 43 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/utils.mlir | 15 |
4 files changed, 93 insertions, 15 deletions
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index 934faf8..8a58493 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -78,6 +78,12 @@ template <typename T> struct StridedMemRefType<T, 0> { int64_t offset; }; +// Unranked MemRef +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + template <typename StreamType, typename T, int N> void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) { static_assert(N > 0, "Expected N > 0"); @@ -98,6 +104,15 @@ void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) { } extern "C" MLIR_RUNNER_UTILS_EXPORT void +print_memref_f32(UnrankedMemRefType *M); + +template <typename StreamType> +void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { + os << "Unranked Memref rank = " << V.rank << " " + << "descriptor@ = " << reinterpret_cast<float *>(V.descriptor) << " "; +} + +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_0d_f32(StridedMemRefType<float, 0> *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_1d_f32(StridedMemRefType<float, 1> *M); diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp index 05c34f8..056ff65 100644 --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -148,15 +148,41 @@ template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) { std::cout << std::endl; } -template <typename T> void printZeroDMemRef(StridedMemRefType<T, 0> &M) { +template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) { std::cout << "\nMemref base@ = " << M.data << " rank = " << 0 << " offset = " << M.offset << " data = ["; MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset); std::cout << "]" << std::endl; } +extern "C" void +print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) { + printMemRef(*M); +} + +extern "C" void print_memref_f32(UnrankedMemRefType *M) { + printUnrankedMemRefMetaData(std::cout, *M); + int rank = M->rank; + void *ptr = M->descriptor; + +#define MEMREF_CASE(RANK) \ + case RANK: \ + printMemRef(*(static_cast<StridedMemRefType<float, RANK> *>(ptr))); \ + break + + switch (rank) { + MEMREF_CASE(0); + MEMREF_CASE(1); + MEMREF_CASE(2); + MEMREF_CASE(3); + MEMREF_CASE(4); + default: + assert(0 && "Unsupported rank to print"); + } +} + extern "C" void print_memref_0d_f32(StridedMemRefType<float, 0> *M) { - printZeroDMemRef(*M); + printMemRef(*M); } extern "C" void print_memref_1d_f32(StridedMemRefType<float, 1> *M) { printMemRef(*M); @@ -170,8 +196,3 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) { extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) { printMemRef(*M); } - -extern "C" void -print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M) { - printMemRef(*M); -} diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir new file mode 100644 index 0000000..4e721be --- /dev/null +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s + +// CHECK: rank = 2 +// CHECK-SAME: sizes = [10, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-10: [10, 10, 10] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [10, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-10: [5, 5, 5] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [10, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-10: [2, 2, 2] +func @main() -> () { + %A = alloc() : memref<10x3xf32, 0> + %f2 = constant 2.00000e+00 : f32 + %f5 = constant 5.00000e+00 : f32 + %f10 = constant 10.00000e+00 : f32 + + %V = memref_cast %A : memref<10x3xf32, 0> to memref<?x?xf32> + linalg.fill(%V, %f10) : memref<?x?xf32, 0>, f32 + %U = memref_cast %A : memref<10x3xf32, 0> to memref<*xf32> + call @print_memref_f32(%U) : (memref<*xf32>) -> () + + %V2 = memref_cast %U : memref<*xf32> to memref<?x?xf32> + linalg.fill(%V2, %f5) : memref<?x?xf32, 0>, f32 + %U2 = memref_cast %V2 : memref<?x?xf32, 0> to memref<*xf32> + call @print_memref_f32(%U2) : (memref<*xf32>) -> () + + %V3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32> + %V4 = memref_cast %V3 : memref<*xf32> to memref<?x?xf32> + linalg.fill(%V4, %f2) : memref<?x?xf32, 0>, f32 + %U3 = memref_cast %V2 : memref<?x?xf32> to memref<*xf32> + call @print_memref_f32(%U3) : (memref<*xf32>) -> () + + dealloc %A : memref<10x3xf32, 0> + return +} + +func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir index ed54b90..099b856 100644 --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -7,7 +7,8 @@ func @print_0d() { %f = constant 2.00000e+00 : f32 %A = alloc() : memref<f32> store %f, %A[]: memref<f32> - call @print_memref_0d_f32(%A): (memref<f32>) -> () + %U = memref_cast %A : memref<f32> to memref<*xf32> + call @print_memref_f32(%U): (memref<*xf32>) -> () dealloc %A : memref<f32> return } @@ -18,7 +19,8 @@ func @print_1d() { %A = alloc() : memref<16xf32> %B = memref_cast %A: memref<16xf32> to memref<?xf32> linalg.fill(%B, %f) : memref<?xf32>, f32 - call @print_memref_1d_f32(%B): (memref<?xf32>) -> () + %U = memref_cast %B : memref<?xf32> to memref<*xf32> + call @print_memref_f32(%U): (memref<*xf32>) -> () dealloc %A : memref<16xf32> return } @@ -34,8 +36,8 @@ func @print_3d() { %c2 = constant 2 : index store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32> - - call @print_memref_3d_f32(%B): (memref<?x?x?xf32>) -> () + %U = memref_cast %B : memref<?x?x?xf32> to memref<*xf32> + call @print_memref_f32(%U): (memref<*xf32>) -> () dealloc %A : memref<3x4x5xf32> return } @@ -46,10 +48,7 @@ func @print_3d() { // PRINT-3D-NEXT: 2, 2, 4, 2, 2 // PRINT-3D-NEXT: 2, 2, 2, 2, 2 -func @print_memref_0d_f32(memref<f32>) -func @print_memref_1d_f32(memref<?xf32>) -func @print_memref_3d_f32(memref<?x?x?xf32>) - +func @print_memref_f32(memref<*xf32>) !vector_type_C = type vector<4x4xf32> !matrix_type_CC = type memref<1x1x!vector_type_C> |