diff options
author | Guray Ozen <guray.ozen@gmail.com> | 2023-12-08 15:55:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-08 15:55:14 +0100 |
commit | c65d8c71878361d441008a85f0c99305d9e3aff8 (patch) | |
tree | f71f883841557acbfc5e7a3819ca5404823f540e | |
parent | 3894a11acd1c8907c4d8a19a96d1ea398463af65 (diff) | |
download | llvm-c65d8c71878361d441008a85f0c99305d9e3aff8.zip llvm-c65d8c71878361d441008a85f0c99305d9e3aff8.tar.gz llvm-c65d8c71878361d441008a85f0c99305d9e3aff8.tar.bz2 |
[mlir][memref] extract_strided_metadata for zero-sized memref (#74835)
-rw-r--r-- | mlir/lib/Dialect/Utils/IndexingUtils.cpp | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/MemRef/expand-strided-metadata.mlir | 20 |
2 files changed, 21 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index f4e29539..bb8a0d5 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -70,7 +70,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, //===----------------------------------------------------------------------===// SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { - assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) && + assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && "sizes must be nonnegative"); int64_t unit = 1; return ::computeSuffixProductImpl(sizes, unit); diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index ab0c78a..28b7004 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1494,3 +1494,23 @@ func.func @extract_strided_metadata_of_cast_unranked( index, index, index, index } + + +// ----- +memref.global "private" @dynamicShmem : memref<0xf16,3> + +// CHECK-LABEL: func @zero_sized_memred +func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index) { + %c0 = arith.constant 0 : index + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + + // CHECK: %[[BASE:.*]] = memref.get_global @dynamicShmem : memref<0xf16, 3> + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [], strides: [] : memref<0xf16, 3> to memref<f16, 3> + // CHECK: return %[[CAST]] + + %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %dynamicMem : memref<0xf16, 3> -> memref<f16, 3>, index, index, index + return %base_buffer, %offset, + %sizes, %strides : + memref<f16,3>, index, + index, index +}
\ No newline at end of file |