aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarkus Böck <markus.boeck02@gmail.com>2023-02-05 14:58:06 +0100
committerMarkus Böck <markus.boeck02@gmail.com>2023-02-06 12:10:07 +0100
commiteafca2303769800f5da4bc4cbf9e842c6a8cde9f (patch)
treee0b9252e01660c9769b431ef7f32f61fc3344b8e
parent6fa2abf90886f18472c87bc9bffbcdf4f73c465e (diff)
downloadllvm-eafca2303769800f5da4bc4cbf9e842c6a8cde9f.zip
llvm-eafca2303769800f5da4bc4cbf9e842c6a8cde9f.tar.gz
llvm-eafca2303769800f5da4bc4cbf9e842c6a8cde9f.tar.bz2
[mlir][MemRef] Add required address space cast when lowering alloc to LLVM
alloc uses either `malloc` or a plugable allocation function for allocating the required memory. Both of these functions always return a `llvm.ptr<i8>`, aka a pointer in the default address space. When allocating for a memref in a different memory space however, no address space cast is created, leading to invalid LLVM IR being generated. This is currently not caught by the verifier since the pointer to the memory is always bitcast which currently lacks a verifier disallowing address space casts. Translating to actual LLVM IR would cause the verifier to go off, since bitcast cannot translate from one address space to another: https://godbolt.org/z/3a1z97rc9 This patch fixes that issue by generating an address space cast if the address space of the allocation function does not match the address space of the resulting memref. Not sure whether this is actually a real life problem. I found this issue while converting the pass to using opaque pointers which gets rid of all the bitcasts and hence caused type errors without the address space cast. Differential Revision: https://reviews.llvm.org/D143341
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp28
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir5
2 files changed, 28 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 4a5be48..8d99e1f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -50,6 +50,23 @@ Value AllocationOpLLVMLowering::createAligned(
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+ Location loc, Value allocatedPtr,
+ MemRefType memRefType, Type elementPtrType,
+ LLVMTypeConverter &typeConverter) {
+ auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
+ if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt())
+ allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(allocatedPtrTy.getElementType(),
+ memRefType.getMemorySpaceAsInt()),
+ allocatedPtr);
+
+ allocatedPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
+ return allocatedPtr;
+}
+
std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
Operation *op, Value alignment) const {
@@ -64,8 +81,10 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
- Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
- results.getResult());
+
+ Value allocatedPtr =
+ castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+ elementPtrType, *getTypeConverter());
Value alignedPtr = allocatedPtr;
if (alignment) {
@@ -126,10 +145,9 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
- Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
- results.getResult());
- return allocatedPtr;
+ return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+ elementPtrType, *getTypeConverter());
}
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index b6c73e9..3f61f6d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -182,6 +182,11 @@ func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
// CHECK-LABEL: func @address_space(
func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
+ // CHECK: %[[MEMORY:.*]] = llvm.call @malloc(%{{.*}})
+ // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[MEMORY]] : !llvm.ptr<i8> to !llvm.ptr<i8, 5>
+ // CHECK: %[[BCAST:.*]] = llvm.bitcast %[[CAST]]
+ // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[0]
+ // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[1]
%0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5>
%1 = arith.constant 7 : index
// CHECK: llvm.load %{{.*}} : !llvm.ptr<f32, 5>