diff options
author | Fabian Mora <fmora.dev@gmail.com> | 2025-01-13 16:11:33 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-13 16:11:33 -0500 |
commit | 0c1c49f0ff8003aee22c3f26fca03c2f5385f355 (patch) | |
tree | 15cb6fbbe49c8261f39c7c6eed9d2653873f1fca | |
parent | ec3525f7844878767b70b78753affbe44acfa9ed (diff) | |
download | llvm-0c1c49f0ff8003aee22c3f26fca03c2f5385f355.zip llvm-0c1c49f0ff8003aee22c3f26fca03c2f5385f355.tar.gz llvm-0c1c49f0ff8003aee22c3f26fca03c2f5385f355.tar.bz2 |
[mlir][AMDGPU] Fix raw buffer ptr ops lowering (#122293)
This patch fixes several bugs in the lowering of AMDGPU raw buffer
operations. These bugs include:
- Incorrectly handling the offset of the memref, causing errors when
using subviews.
- Using the MaximumOp (float specific op) to calculate the number of
records.
- The number of records in the static shape case.
- The lowering when index bitwidth=i64.
Furthermore this patch also switches to use MLIR's data layout to get
the type size.
---------
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
-rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 123 | ||||
-rw-r--r-- | mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 46 |
2 files changed, 103 insertions, 66 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 4100b08..1564e41 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -30,10 +30,23 @@ namespace mlir { using namespace mlir; using namespace mlir::amdgpu; +/// Convert an unsigned number `val` to i32. +static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, + Location loc, Value val) { + IntegerType i32 = rewriter.getI32Type(); + // Force check that `val` is of int type. + auto valTy = cast<IntegerType>(val.getType()); + if (i32 == valTy) + return val; + return valTy.getWidth() > 32 + ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) + : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); +} + static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { - Type llvmI32 = rewriter.getI32Type(); - return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value); + Type i32 = rewriter.getI32Type(); + return rewriter.create<LLVM::ConstantOp>(loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, @@ -42,6 +55,27 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); } +/// Returns the linear index used to access an element in the memref. +static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, + Location loc, MemRefDescriptor &memRefDescriptor, + ValueRange indices, ArrayRef<int64_t> strides) { + IntegerType i32 = rewriter.getI32Type(); + Value index; + for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) { + if (stride != 1) { // Skip if stride is 1. + Value strideValue = + ShapedType::isDynamic(stride) + ? convertUnsignedToI32(rewriter, loc, + memRefDescriptor.stride(rewriter, loc, i)) + : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); + increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); + } + index = + index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; + } + return index ? index : createI32Constant(rewriter, loc, 0); +} + namespace { // Define commonly used chipsets versions for convenience. constexpr Chipset kGfx908 = Chipset(9, 0, 8); @@ -88,17 +122,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); - Type llvmI32 = this->typeConverter->convertType(i32); - Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type()); + Type i16 = rewriter.getI16Type(); - auto toI32 = [&](Value val) -> Value { - if (val.getType() == llvmI32) - return val; - - return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val); - }; - - int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; + // Get the type size in bytes. + DataLayout dataLayout = DataLayout::closest(gpuOp); + int64_t elementByteWidth = + dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); // If we want to load a vector<NxT> with total size <= 32 @@ -114,7 +143,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { } if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { uint32_t vecLen = dataVector.getNumElements(); - uint32_t elemBits = dataVector.getElementTypeBitWidth(); + uint32_t elemBits = + dataLayout.getTypeSizeInBits(dataVector.getElementType()); uint32_t totalBits = elemBits * vecLen; bool usePackedFp16 = isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2; @@ -167,28 +197,36 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { MemRefDescriptor memrefDescriptor(memref); - Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); + Value ptr = memrefDescriptor.bufferPtr( + rewriter, loc, *this->getTypeConverter(), memrefType); // The stride value is always 0 for raw buffers. This also disables // swizling. Value stride = rewriter.create<LLVM::ConstantOp>( - loc, llvmI16, rewriter.getI16IntegerAttr(0)); + loc, i16, rewriter.getI16IntegerAttr(0)); + // Get the number of elements. Value numRecords; - if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) { - numRecords = createI32Constant( - rewriter, loc, - static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth)); + if (memrefType.hasStaticShape() && + !llvm::any_of(strides, ShapedType::isDynamic)) { + int64_t size = memrefType.getRank() == 0 ? 1 : 0; + ArrayRef<int64_t> shape = memrefType.getShape(); + for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) + size = std::max(shape[i] * strides[i], size); + size = size * elementByteWidth; + assert(size < std::numeric_limits<uint32_t>::max() && + "the memref buffer is too large"); + numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size)); } else { Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { - Value size = toI32(memrefDescriptor.size(rewriter, loc, i)); - Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i)); - stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst); + Value size = memrefDescriptor.size(rewriter, loc, i); + Value stride = memrefDescriptor.stride(rewriter, loc, i); Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); - maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex, - maxThisDim) - : maxThisDim; + maxIndex = + maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) + : maxThisDim; } - numRecords = maxIndex; + numRecords = rewriter.create<LLVM::MulOp>( + loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst); } // Flag word: @@ -218,40 +256,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { args.push_back(resource); // Indexing (voffset) - Value voffset = createI32Constant(rewriter, loc, 0); - for (auto pair : llvm::enumerate(adaptor.getIndices())) { - size_t i = pair.index(); - Value index = pair.value(); - Value strideOp; - if (ShapedType::isDynamic(strides[i])) { - strideOp = rewriter.create<LLVM::MulOp>( - loc, toI32(memrefDescriptor.stride(rewriter, loc, i)), - byteWidthConst); - } else { - strideOp = - createI32Constant(rewriter, loc, strides[i] * elementByteWidth); - } - index = rewriter.create<LLVM::MulOp>(loc, index, strideOp); - voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index); - } - if (adaptor.getIndexOffset()) { - int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; - Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); + Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, + adaptor.getIndices(), strides); + if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); + indexOffset && *indexOffset > 0) { + Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); voffset = voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) : extraOffsetConst; } + voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); args.push_back(voffset); + // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - if (ShapedType::isDynamic(offset)) - sgprOffset = rewriter.create<LLVM::AddOp>( - loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset); - else if (offset > 0) - sgprOffset = rewriter.create<LLVM::AddOp>( - loc, sgprOffset, createI32Constant(rewriter, loc, offset)); + sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 4c7515d..af63316 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -31,21 +31,37 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 { } // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided -func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 { - // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16) - // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32) - // CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32 - // CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32 - // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32 - // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32 - // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) - // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) - // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8> - // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32 - // CHECK: return %[[ret]] - %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32, strided<[?], offset: ?>>, i32 -> i32 +func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[?, ?], offset: ?>>, %i: i32, %j: i32) -> i32 { + // CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<16x16xi32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 + // CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64 + // CHECK: %[[sz_j:.*]] = llvm.extractvalue %[[descriptor]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[stride_j:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64 + // CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64 + // CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32 + // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32 + // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8> + // CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32 + // CHECK: %[[t_0:.*]] = llvm.mul %{{.*}}, %[[stride_i_i32]] : i32 + // CHECK: %[[stride_j_1:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[stride_j_i32:.*]] = llvm.trunc %[[stride_j_1]] : i64 to i32 + // CHECK: %[[t_1:.*]] = llvm.mul %{{.*}}, %[[stride_j_i32]] : i32 + // CHECK: %[[index:.*]] = llvm.add %[[t_0]], %[[t_1]] : i32 + // CHECK: %[[vgpr_off:.*]] = llvm.mul %[[index]], %[[elem_size]] : i32 + // CHECK: %[[zero_0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[sgpr_off:.*]] = llvm.mul %[[zero_0]], %[[elem_size]] : i32 + // CHECK: %[[zero_1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[v:.*]] = rocdl.raw.ptr.buffer.load %[[rsrc]], %[[vgpr_off]], %[[sgpr_off]], %[[zero_1]] : i32 + // CHECK: return %[[v]] : i32 + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%i, %j] : memref<16x16xi32, strided<[?, ?], offset: ?>>, i32, i32 -> i32 func.return %0 : i32 } |