aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuray Ozen <guray.ozen@gmail.com>2023-12-08 15:55:14 +0100
committerGitHub <noreply@github.com>2023-12-08 15:55:14 +0100
commitc65d8c71878361d441008a85f0c99305d9e3aff8 (patch)
treef71f883841557acbfc5e7a3819ca5404823f540e
parent3894a11acd1c8907c4d8a19a96d1ea398463af65 (diff)
downloadllvm-c65d8c71878361d441008a85f0c99305d9e3aff8.zip
llvm-c65d8c71878361d441008a85f0c99305d9e3aff8.tar.gz
llvm-c65d8c71878361d441008a85f0c99305d9e3aff8.tar.bz2
[mlir][memref] extract_strided_metadata for zero-sized memref (#74835)
-rw-r--r--mlir/lib/Dialect/Utils/IndexingUtils.cpp2
-rw-r--r--mlir/test/Dialect/MemRef/expand-strided-metadata.mlir20
2 files changed, 21 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index f4e29539..bb8a0d5 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -70,7 +70,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
//===----------------------------------------------------------------------===//
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
- assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
+ assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
"sizes must be nonnegative");
int64_t unit = 1;
return ::computeSuffixProductImpl(sizes, unit);
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index ab0c78a..28b7004 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1494,3 +1494,23 @@ func.func @extract_strided_metadata_of_cast_unranked(
index, index,
index, index
}
+
+
+// -----
+memref.global "private" @dynamicShmem : memref<0xf16,3>
+
+// CHECK-LABEL: func @zero_sized_memred
+func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index) {
+ %c0 = arith.constant 0 : index
+ %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
+
+ // CHECK: %[[BASE:.*]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
+ // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [], strides: [] : memref<0xf16, 3> to memref<f16, 3>
+ // CHECK: return %[[CAST]]
+
+ %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %dynamicMem : memref<0xf16, 3> -> memref<f16, 3>, index, index, index
+ return %base_buffer, %offset,
+ %sizes, %strides :
+ memref<f16,3>, index,
+ index, index
+} \ No newline at end of file