//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "../LLVMCommon/MemRefDescriptor.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::amdgpu; // Define commonly used chipsets versions for convenience. constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val) { IntegerType i32 = rewriter.getI32Type(); // Force check that `val` is of int type. auto valTy = cast(val.getType()); if (i32 == valTy) return val; return valTy.getWidth() > 32 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef strides) { IntegerType i32 = rewriter.getI32Type(); Value index; for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) { if (stride != 1) { // Skip if stride is 1. Value strideValue = ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) : LLVM::ConstantOp::create(rewriter, loc, i32, stride); increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } /// Compute the contents of the `num_records` field for a given memref /// descriptor - that is, the number of bytes that's one element past the /// greatest possible valid index into the memref. static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef strides, uint32_t elementByteWidth) { if (memrefType.hasStaticShape() && !llvm::any_of(strides, ShapedType::isDynamic)) { int64_t size = memrefType.getRank() == 0 ? 1 : 0; ArrayRef shape = memrefType.getShape(); for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) size = std::max(shape[i] * strides[i], size); size = size * elementByteWidth; assert(size < std::numeric_limits::max() && "the memref buffer is too large"); return createI32Constant(rewriter, loc, static_cast(size)); } Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride = nullptr, unsigned addressSpace = 8) { // The stride value is generally 0. However, on MI-300 and onward, you can // enable a cache swizzling mode by setting bit 14 of the stride field // and setting that stride to a cache stride. Type i16 = rewriter.getI16Type(); Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); Value swizzleBit = LLVM::ConstantOp::create( rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, /*isDisjoint=*/true); } else { stride = LLVM::ConstantOp::create(rewriter, loc, i16, rewriter.getI16IntegerAttr(0)); } // Get the number of elements. // Flag word: // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) // bits 15-18: data format (ignored, must be nonzero, 4=32bit) // bit 19: In nested heap (0 here) // bit 20: Behavior on unmap (0 means "return 0 / ignore") // bits 21-22: Index stride for swizzles (N/A) // bit 23: Add thread ID (0) // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) // bits 25-26: Reserved (0) // bit 27: Buffer is non-volatile (CDNA only) // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = // none, 3 = either swizzles or testing against offset field) RDNA only // bits 30-31: Type (must be 0) uint32_t flags = (7 << 12) | (4 << 15); if (chipset.majorVersion >= 10) { flags |= (1 << 24); uint32_t oob = boundsCheck ? 3 : 2; flags |= (oob << 28); } Value flagsConst = createI32Constant(rewriter, loc, flags); Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); Value resource = rewriter.createOrFold( loc, rsrcType, basePointer, stride, numRecords, flagsConst); return resource; } namespace { struct FatRawBufferCastLowering : public ConvertOpToLLVMPattern { FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value memRef = adaptor.getSource(); Value unconvertedMemref = op.getSource(); MemRefType memrefType = cast(unconvertedMemref.getType()); MemRefDescriptor descriptor(memRef); DataLayout dataLayout = DataLayout::closest(op); int64_t elementByteWidth = dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; int64_t unusedOffset = 0; SmallVector strideVals; if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset))) return op.emitOpError("Can't lower non-stride-offset memrefs"); Value numRecords = adaptor.getValidBytes(); if (!numRecords) numRecords = getNumRecords(rewriter, loc, memrefType, descriptor, strideVals, elementByteWidth); Value basePointer = adaptor.getResetOffset() ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType) : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. Value sizes = hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kSizePosInMemRefDescriptor) : Value{}; Value strides = hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, kStridePosInMemRefDescriptor) : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); SmallVector pos{kAllocatedPtrPosInMemRefDescriptor}; result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, kOffsetPosInMemRefDescriptor); if (hasSizes) { result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, kSizePosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); } }; /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; static constexpr uint32_t maxVectorOpWidth = 128; LogicalResult matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); MemRefType memrefType = cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("raw buffer ops require GCN or higher"); Value storeData = adaptor.getODSOperands(0)[0]; if (storeData == memref) // no write component to this op storeData = Value(); Type wantedDataType; if (storeData) wantedDataType = storeData.getType(); else wantedDataType = gpuOp.getODSResults(0)[0].getType(); Value atomicCmpData = Value(); // Operand index 1 of a load is the indices, trying to read them can crash. if (storeData) { Value maybeCmpData = adaptor.getODSOperands(1)[0]; if (maybeCmpData != memref) atomicCmpData = maybeCmpData; } Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); // Get the type size in bytes. DataLayout dataLayout = DataLayout::closest(gpuOp); int64_t elementByteWidth = dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); // If we want to load a vector with total size <= 32 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, // so bitcast any floats to integers. Type llvmBufferValType = llvmWantedDataType; if (atomicCmpData) { if (auto floatType = dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } if (auto dataVector = dyn_cast(wantedDataType)) { uint32_t vecLen = dataVector.getNumElements(); uint32_t elemBits = dataLayout.getTypeSizeInBits(dataVector.getElementType()); uint32_t totalBits = elemBits * vecLen; bool usePackedFp16 = isa_and_present(*gpuOp) && vecLen == 2; if (totalBits > maxVectorOpWidth) return gpuOp.emitOpError( "Total width of loads or stores must be no more than " + Twine(maxVectorOpWidth) + " bits, but we call for " + Twine(totalBits) + " bits. This should've been caught in validation"); if (!usePackedFp16 && elemBits < 32) { if (totalBits > 32) { if (totalBits % 32 != 0) return gpuOp.emitOpError("Load or store of more than 32-bits that " "doesn't fit into words. Can't happen\n"); llvmBufferValType = this->typeConverter->convertType( VectorType::get(totalBits / 32, i32)); } else { llvmBufferValType = this->typeConverter->convertType( rewriter.getIntegerType(totalBits)); } } } if (auto vecType = dyn_cast(llvmBufferValType)) { // Buffer intrinsics doesn't support 1-element vectors, cast them to // scalars. if (vecType.getNumElements() == 1) llvmBufferValType = vecType.getElementType(); } SmallVector args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { Value castForStore = LLVM::BitcastOp::create( rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); } } if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { Value castForCmp = LLVM::BitcastOp::create( rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); } } // Construct buffer descriptor from memref, attributes int64_t offset = 0; SmallVector strides; if (failed(memrefType.getStridesAndOffset(strides, offset))) return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); MemRefDescriptor memrefDescriptor(memref); Value ptr = memrefDescriptor.bufferPtr( rewriter, loc, *this->getTypeConverter(), memrefType); Value numRecords = getNumRecords( rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth); Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords, adaptor.getBoundsCheck(), chipset); args.push_back(resource); // Indexing (voffset) Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, adaptor.getIndices(), strides); if (std::optional indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, extraOffsetConst) : extraOffsetConst; } voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) // bits 1-2: SLC, DLC = 0 (similarly) // bit 3: swizzled (0 for raw) args.push_back(createI32Constant(rewriter, loc, 0)); llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, replacement); } rewriter.replaceOp(gpuOp, replacement); } else { rewriter.eraseOp(gpuOp); } return success(); } }; // TODO: AMDGPU backend already have all this bitpacking logic, we should move // it to some common place. /// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: /// Vmcnt = Waitcnt[3:0] (pre-gfx9) /// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) /// Vmcnt = Waitcnt[15:10] (gfx11) /// Expcnt = Waitcnt[6:4] (pre-gfx11) /// Expcnt = Waitcnt[2:0] (gfx11) /// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) /// Lgkmcnt = Waitcnt[13:8] (gfx10) /// Lgkmcnt = Waitcnt[9:4] (gfx11) static FailureOr encodeWaitcnt(Chipset chipset, unsigned vmcnt, unsigned expcnt, unsigned lgkmcnt) { if (chipset.majorVersion < 9) { vmcnt = std::min(15u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(15u, lgkmcnt); return vmcnt | (expcnt << 4) | (lgkmcnt << 8); } if (chipset.majorVersion == 9) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(15u, lgkmcnt); unsigned lowBits = vmcnt & 0xF; unsigned highBits = (vmcnt >> 4) << 14; unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); return lowBits | highBits | otherCnts; } if (chipset.majorVersion == 10) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(63u, lgkmcnt); unsigned lowBits = vmcnt & 0xF; unsigned highBits = (vmcnt >> 4) << 14; unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); return lowBits | highBits | otherCnts; } if (chipset.majorVersion == 11) { vmcnt = std::min(63u, vmcnt); expcnt = std::min(7u, expcnt); lgkmcnt = std::min(63u, lgkmcnt); return (vmcnt << 10) | expcnt | (lgkmcnt << 4); } return failure(); } struct MemoryCounterWaitOpLowering : public ConvertOpToLLVMPattern { MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional ds = adaptor.getDs()) ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional load = adaptor.getLoad()) ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional store = adaptor.getStore()) ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional exp = adaptor.getExp()) ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); rewriter.eraseOp(op); return success(); } auto getVal = [](Attribute attr) -> unsigned { if (attr) return cast(attr).getInt(); // This value will be clamped to the maximum value for the chipset. return 1024; }; unsigned ds = getVal(adaptor.getDsAttr()); unsigned exp = getVal(adaptor.getExpAttr()); unsigned vmcnt = 1024; Attribute load = adaptor.getLoadAttr(); Attribute store = adaptor.getStoreAttr(); if (load && store) { vmcnt = getVal(load) + getVal(store); } else if (load) { vmcnt = getVal(load); } else if (store) { vmcnt = getVal(store); } FailureOr waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); if (failed(waitcnt)) return op.emitOpError("unsupported chipset"); rewriter.replaceOpWithNewOp(op, *waitcnt); return success(); } }; struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11; if (requiresInlineAsm) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier"; const char *constraints = ""; rewriter.replaceOpWithNewOp( op, /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, /*is_align_stack=*/false, LLVM::TailCallKind::None, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); return success(); } if (chipset.majorVersion < 12) { constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); // Left in place in case someone disables the inline ASM path or future // chipsets use the same bit pattern. constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); int32_t ldsOnlyBits; if (chipset.majorVersion == 11) ldsOnlyBits = ldsOnlyBitsGfx11; else if (chipset.majorVersion == 10) ldsOnlyBits = ldsOnlyBitsGfx10; else if (chipset.majorVersion <= 9) ldsOnlyBits = ldsOnlyBitsGfx6789; else return op.emitOpError( "don't know how to lower this for chipset major version") << chipset.majorVersion; Location loc = op->getLoc(); ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp(op); } else { Location loc = op->getLoc(); ROCDL::WaitDscntOp::create(rewriter, loc, 0); ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp(op, -1); } return success(); } }; struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, (uint32_t)op.getOpts()); return success(); } }; } // namespace /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL /// and LLVM AMDGPU intrinsics convention. /// /// Specifically: /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic /// allows bf16. Newer MFMAs support bf16 types on operand, check /// IntrinsicsAMDGPU.td file for reference. /// 2. If instead we have a more than 64-bit quantity, use a /// instead, which is what the f8f6f4 intrinsics use. /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit /// integer. /// /// Note that the type of `input` has already been LLVM type converted: /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16 = true) { Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) return LLVM::BitcastOp::create( rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) return LLVM::BitcastOp::create( rewriter, loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); return LLVM::BitcastOp::create( rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input); } } return input; } /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention. /// /// Specifically: /// 1. If `input` is a i8 value, zero extend it to i32 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32 /// /// Note that the type of `input` has already been LLVM type converted: /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast(inputType)) return LLVM::ZExtOp::create(rewriter, loc, outputType, input); return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is /// an integer type, then we need to also push its signdness (1 for signed, 0 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). /// We also need to convert bfloat inputs to i16 to account for the bfloat /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector &operands) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); if (!vectorType) { operands.push_back(llvmInput); return; } Type elemType = vectorType.getElementType(); if (elemType.isBF16()) llvmInput = LLVM::BitcastOp::create( rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; } // We need to check the type of the input before conversion to properly test // for int8. This is because, in LLVM, fp8 type is converted to int8, so the // fp8/int8 information is lost during the conversion process. auto mlirInputType = cast(mlirInput.getType()); bool isInputInteger = mlirInputType.getElementType().isInteger(); if (isInputInteger) { // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag bool localIsUnsigned = isUnsigned; if (elemType.isUnsignedInteger()) { localIsUnsigned = true; } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); operands.push_back(sign); } int64_t numBits = vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); Type i32 = rewriter.getI32Type(); Type intrinsicInType = numBits <= 32 ? (Type)rewriter.getIntegerType(numBits) : (Type)VectorType::get(numBits / 32, i32); auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType); Value castInput = rewriter.createOrFold( loc, llvmIntrinsicInType, llvmInput); // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } /// Push the output operand. For many cases this is only pushing the output in /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, /// since the same numbers of VGPRs is used, we need to decide if to store the /// result in the upper 16 bits of the VGPRs or in the lower part. To store the /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will /// be stored it in the upper part. The subwordOffset must not be set for gfx12, /// as the instructions have been changed to return fewer registers instead. static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector &operands) { Type inputType = output.getType(); auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) output = LLVM::BitcastOp::create( rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); } else if (elemType.isInteger(32)) { operands.push_back(createI1Constant(rewriter, loc, clamp)); } } /// Return true if `type` is the E5M2 variant of an 8-bit float that is /// supported by the `_bf8` instructions on the given `chipset`. static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { return (chipset == kGfx942 && isa(type)) || (hasOcpFp8(chipset) && isa(type)); } /// Return true if `type` is the E4M3FN variant of an 8-bit float that is /// supported by the `_fp8` instructions on the given `chipset`. static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { return (chipset == kGfx942 && isa(type)) || (hasOcpFp8(chipset) && isa(type)); } /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) { uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType()); Type destElem = getElementTypeOrSelf(mfma.getDestC().getType()); if (sourceElem.isF32() && destElem.isF32()) { if (mfma.getReducePrecision() && chipset >= kGfx942) { if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); } if (m == 32 && n == 32 && k == 1 && b == 2) return ROCDL::mfma_f32_32x32x1f32::getOperationName(); if (m == 16 && n == 16 && k == 1 && b == 4) return ROCDL::mfma_f32_16x16x1f32::getOperationName(); if (m == 4 && n == 4 && k == 1 && b == 16) return ROCDL::mfma_f32_4x4x1f32::getOperationName(); if (m == 32 && n == 32 && k == 2 && b == 1) return ROCDL::mfma_f32_32x32x2f32::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f32_16x16x4f32::getOperationName(); } if (sourceElem.isF16() && destElem.isF32()) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 16 && b == 1) return ROCDL::mfma_f32_32x32x16_f16::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1) return ROCDL::mfma_f32_16x16x32_f16::getOperationName(); } if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4f16::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4f16::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4f16::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8f16::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16f16::getOperationName(); } if (sourceElem.isBF16() && destElem.isF32()) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 16 && b == 1) return ROCDL::mfma_f32_32x32x16_bf16::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1) return ROCDL::mfma_f32_16x16x32_bf16::getOperationName(); } if (chipset >= kGfx90a) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); } if (m == 32 && n == 32 && k == 2 && b == 2) return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); if (m == 16 && n == 16 && k == 2 && b == 4) return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); if (m == 4 && n == 4 && k == 2 && b == 16) return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } if (sourceElem.isInteger(8) && destElem.isInteger(32)) { if (chipset >= kGfx950) { if (m == 32 && n == 32 && k == 32 && b == 1) return ROCDL::mfma_i32_32x32x32_i8::getOperationName(); if (m == 16 && n == 16 && k == 64 && b == 1) return ROCDL::mfma_i32_16x16x64_i8::getOperationName(); } if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_i32_16x16x4i8::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_i32_4x4x4i8::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_i32_32x32x8i8::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_i32_16x16x16i8::getOperationName(); if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942) return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942) return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); } if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f64_16x16x4f64::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 4) return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } return std::nullopt; } static std::optional mfmaTypeSelectCode(Type mlirElemType) { return llvm::TypeSwitch>(mlirElemType) .Case([](Float8E4M3FNType) { return 0u; }) .Case([](Float8E5M2Type) { return 1u; }) .Case([](Float6E2M3FNType) { return 2u; }) .Case([](Float6E3M2FNType) { return 3u; }) .Case([](Float4E2M1FNType) { return 4u; }) .Default([](Type) { return std::nullopt; }); } /// If there is a scaled MFMA instruction for the input element types `aType` /// and `bType`, output type `destType`, problem size M, N, K, and B (number of /// blocks) on the given `chipset`, return a tuple consisting of the /// OperationName of the intrinsic and the type codes that need to be passed to /// that intrinsic. Note that this is also used to implement some un-scaled /// MFMAs, since the compiler represents the ordinary instruction as a "scaled" /// MFMA with a scale of 0. static std::optional> mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset) { aType = getElementTypeOrSelf(aType); bType = getElementTypeOrSelf(bType); destType = getElementTypeOrSelf(destType); if (chipset < kGfx950) return std::nullopt; if (!isa(destType)) return std::nullopt; std::optional aTypeCode = mfmaTypeSelectCode(aType); std::optional bTypeCode = mfmaTypeSelectCode(bType); if (!aTypeCode || !bTypeCode) return std::nullopt; if (m == 32 && n == 32 && k == 64 && b == 1) return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), *aTypeCode, *bTypeCode}; if (m == 16 && n == 16 && k == 128 && b == 1) return std::tuple{ ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode, *bTypeCode}; return std::nullopt; } static std::optional> mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) { return mfmaOpToScaledIntrinsic( mfma.getSourceA().getType(), mfma.getSourceB().getType(), mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(), mfma.getBlocks(), chipset); } static std::optional> mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(), smfma.getSourceB().getType(), smfma.getDestC().getType(), smfma.getM(), smfma.getN(), smfma.getK(), 1u, chipset); } /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { auto sourceVectorType = dyn_cast(wmma.getSourceA().getType()); auto sourceBVectorType = dyn_cast(wmma.getSourceB().getType()); auto destVectorType = dyn_cast(wmma.getDestC().getType()); auto elemSourceType = sourceVectorType.getElementType(); auto elemBSourceType = sourceBVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); if (elemSourceType.isF16() && elemDestType.isF16()) return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isBF16()) return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); if (chipset.majorVersion == 11) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); } if (chipset.majorVersion >= 12) { if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { bool isWave64 = destVectorType.getNumElements() == 4; // This is the ambiguous case. 8 inputs to the wave64 version means that // we want the 16x16x32 version, but for wave32 they mean the short form. bool has8Inputs = sourceVectorType.getNumElements() == 8; if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); } } return std::nullopt; } namespace { struct MFMAOpLowering : public ConvertOpToLLVMPattern { MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); Type intrinsicOutType = outType; if (auto outVecType = dyn_cast(outType)) if (outVecType.getElementType().isBF16()) intrinsicOutType = outVecType.clone(rewriter.getI16Type()); if (chipset.majorVersion != 9 || chipset < kGfx908) return op->emitOpError("MFMA only supported on gfx908+"); uint32_t getBlgpField = static_cast(op.getBlgp()); if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { if (chipset < kGfx942) return op.emitOpError("negation unsupported on older than gfx942"); getBlgpField |= op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); } std::optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); std::optional> maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value()) return op.emitOpError("no intrinsic matching MFMA size on given chipset"); bool isScaled = !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value(); if (isScaled && (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) { return op.emitOpError( "non-default abid, blgp, and cbsz aren't supported on MFMAs that can " "be scaled as those fields are used for type information"); } StringRef intrinsicName = isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic; // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file. bool allowBf16 = [&]() { if (chipset < kGfx950) return false; if (isScaled) return true; return intrinsicName.contains("16x16x32.bf16") || intrinsicName.contains("32x32x16.bf16"); }(); OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands({convertMFMAVectorOperand( rewriter, loc, adaptor.getSourceA(), allowBf16), convertMFMAVectorOperand( rewriter, loc, adaptor.getSourceB(), allowBf16), adaptor.getDestC()}); if (isScaled) { Value zero = createI32Constant(rewriter, loc, 0); auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode), createI32Constant(rewriter, loc, bTypeCode), /*scale A byte=*/zero, /*scale A=*/zero, /*scale B byte=*/zero, /*scale B=*/zero}); } else { loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()), createI32Constant(rewriter, loc, op.getAbid()), createI32Constant(rewriter, loc, getBlgpField)}); }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } }; struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern { ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType()); if (chipset.majorVersion != 9 || chipset < kGfx950) return op->emitOpError("scaled MFMA only supported on gfx908+"); std::optional> maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); if (!maybeScaledIntrinsic.has_value()) return op.emitOpError( "no intrinsic matching scaled MFMA size on given chipset"); auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands( {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), adaptor.getDestC()}); Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA()); Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB()); loweredOp.addOperands( {createI32Constant(rewriter, loc, aTypeCode), createI32Constant(rewriter, loc, bTypeCode), /*scales idx A=*/scalesIdxA, /*scales A*/ castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()), /*scales idx B=*/scalesIdxB, /*scales B*/ castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())}); Value lowered = rewriter.create(loweredOp)->getResult(0); rewriter.replaceOp(op, lowered); return success(); } }; struct WMMAOpLowering : public ConvertOpToLLVMPattern { WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = typeConverter->convertType(op.getDestD().getType()); if (!outType) return rewriter.notifyMatchFailure(op, "type conversion failed"); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); // The WMMA operations represent vectors of bf16s as vectors of i16s, so we // need to bitcast bfloats to i16 and then bitcast them back. VectorType rawOutType = outType; if (outType.getElementType().isBF16()) rawOutType = outType.clone(rewriter.getI16Type()); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching WMMA on the given chipset"); if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes(rawOutType); SmallVector operands; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), adaptor.getSourceA(), op.getSourceA(), operands); wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), adaptor.getSourceB(), op.getSourceB(), operands); wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), op.getSubwordOffset(), op.getClamp(), operands); loweredOp.addOperands(operands); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; if (rawOutType != outType) maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); } }; struct TransposeLoadOpLowering : public ConvertOpToLLVMPattern { TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset != kGfx950) return op.emitOpError("Non-gfx950 chipset not supported"); Location loc = op.getLoc(); auto srcMemRefType = cast(op.getSrc().getType()); // Elements in subbyte memrefs are stored non-contiguously, // reject if source is sub-byte memref. Use emulated memrefs instead. size_t srcElementSize = srcMemRefType.getElementType().getIntOrFloatBitWidth(); if (srcElementSize < 8) return op.emitOpError("Expect source memref to have at least 8 bits " "element size, got ") << srcElementSize; auto resultType = cast(op.getResult().getType()); Value srcPtr = getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), (adaptor.getSrcIndices())); size_t numElements = resultType.getNumElements(); size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth(); // ROCDL transpose load intrinsics return vectors of 32-bit integers, if // the element size is smaller than 16 bits. Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32, rewriter.getIntegerType(32)); Type llvmResultType = typeConverter->convertType(resultType); switch (elementTypeSize) { case 4: { assert(numElements == 16); auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 16: { assert(numElements == 4); rewriter.replaceOpWithNewOp(op, llvmResultType, srcPtr); break; } default: return op.emitOpError("Unsupported element size for transpose load"); } return success(); } }; struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern { GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (chipset.majorVersion < 9 || chipset.majorVersion > 10) return op.emitOpError("pre-gfx9 and post-gfx10 not supported"); Location loc = op.getLoc(); auto srcMemRefType = cast(op.getSrc().getType()); auto dstMemRefType = cast(op.getDst().getType()); // TODO: instead of only transfering one element per thread, we could // augment it to transfer multiple elements per thread by issuing multiple // `global_load_lds` instructions. Type transferType = op.getTransferType(); int loadWidth = [&]() -> int { if (auto transferVectorType = dyn_cast(transferType)) { return (transferVectorType.getNumElements() * transferVectorType.getElementTypeBitWidth()) / 8; } return transferType.getIntOrFloatBitWidth() / 8; }(); // Currently only 1, 2, 4, 12 and 16 byte loads are supported. if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth)) return op.emitOpError("chipset unsupported element size"); if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth)) return op.emitOpError("Gather to LDS instructions with 12-byte and " "16-byte load widths are only supported on gfx950"); Value srcPtr = getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), (adaptor.getSrcIndices())); Value dstPtr = getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(), (adaptor.getDstIndices())); rewriter.replaceOpWithNewOp( op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth), /*offset=*/rewriter.getI32IntegerAttr(0), /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{}, ArrayAttr{}); return success(); } }; namespace { struct ExtPackedFp8OpLowering final : public ConvertOpToLLVMPattern { ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedStochRoundFp8OpLowering final : public ConvertOpToLLVMPattern { PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct ScaledExtPackedOpLowering final : public ConvertOpToLLVMPattern { ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedScaledTruncOpLowering final : public ConvertOpToLLVMPattern { PackedScaledTruncOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // end namespace LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type v4i8 = getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type f32 = getTypeConverter()->convertType(op.getResult().getType()); Value source = adaptor.getSource(); auto sourceVecType = dyn_cast(op.getSource().getType()); auto resultVecType = dyn_cast(op.getResult().getType()); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { longVec = LLVM::InsertElementOp::create( rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } } else { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, op.getIndex()); } } return success(); } LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset != kGfx950) return rewriter.notifyMatchFailure( loc, "Scaled fp conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Value source = adaptor.getSource(); Value scale = adaptor.getScale(); VectorType sourceVecType = cast(op.getSource().getType()); Type sourceElemType = sourceVecType.getElementType(); VectorType destVecType = cast(op.getResult().getType()); Type destElemType = destVecType.getElementType(); VectorType packedVecType; if (isa(sourceElemType)) { VectorType v4i8 = VectorType::get(4, rewriter.getI8Type()); packedVecType = cast(getTypeConverter()->convertType(v4i8)); } else if (isa(sourceElemType)) { VectorType v8i4 = VectorType::get(8, rewriter.getI4Type()); packedVecType = cast(getTypeConverter()->convertType(v8i4)); } else { llvm_unreachable("invalid element type for scaled ext"); } // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { longVec = LLVM::InsertElementOp::create( rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else if (isa(sourceElemType) && destElemType.isBF16()) rewriter.replaceOpWithNewOp( op, destVecType, i32Source, scale, op.getIndex()); else return failure(); return success(); } LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset != kGfx950) return rewriter.notifyMatchFailure( loc, "Scaled fp conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type v2i16 = getTypeConverter()->convertType( VectorType::get(2, rewriter.getI16Type())); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); VectorType sourceVecType = cast(op.getSource().getType()); Type sourceElemType = sourceVecType.getElementType(); Type intResultType = isa(resultElemType) ? i32 : v2i16; Value source = adaptor.getSource(); Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); source = LLVM::ZeroOp::create(rewriter, loc, v2); source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4F16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (!(chipset == kGfx942 || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value source = adaptor.getSource(); Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp // operation into the corresponding ROCDL instructions. struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the source operand to the corresponding LLVM type Location loc = DppOp.getLoc(); Value src = adaptor.getSrc(); Value old = adaptor.getOld(); Type srcType = src.getType(); Type oldType = old.getType(); Type llvmType = nullptr; if (srcType.getIntOrFloatBitWidth() < 32) { llvmType = rewriter.getI32Type(); } else if (isa(srcType)) { llvmType = (srcType.getIntOrFloatBitWidth() == 32) ? rewriter.getF32Type() : rewriter.getF64Type(); } else if (isa(srcType)) { llvmType = (srcType.getIntOrFloatBitWidth() == 32) ? rewriter.getI32Type() : rewriter.getI64Type(); } auto llvmSrcIntType = typeConverter->convertType( rewriter.getIntegerType(srcType.getIntOrFloatBitWidth())); // If the source type is less of 32, use bitcast to convert it to i32. auto convertOperand = [&](Value operand, Type operandType) { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa(operandType)) { operand = LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); operand = LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; src = convertOperand(src, srcType); old = convertOperand(old, oldType); // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h enum DppCtrl : unsigned { ROW_SHL0 = 0x100, ROW_SHR0 = 0x110, ROW_ROR0 = 0x120, WAVE_SHL1 = 0x130, WAVE_ROL1 = 0x134, WAVE_SHR1 = 0x138, WAVE_ROR1 = 0x13C, ROW_MIRROR = 0x140, ROW_HALF_MIRROR = 0x141, BCAST15 = 0x142, BCAST31 = 0x143, }; auto kind = DppOp.getKind(); auto permArgument = DppOp.getPermArgument(); uint32_t DppCtrl = 0; switch (kind) { case DPPPerm::quad_perm: if (auto quadPermAttr = cast(*permArgument)) { int32_t i = 0; for (auto elem : quadPermAttr.getAsRange()) { uint32_t num = elem.getInt(); DppCtrl |= num << (i * 2); i++; } } break; case DPPPerm::row_shl: if (auto intAttr = cast(*permArgument)) { DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0; } break; case DPPPerm::row_shr: if (auto intAttr = cast(*permArgument)) { DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0; } break; case DPPPerm::row_ror: if (auto intAttr = cast(*permArgument)) { DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0; } break; case DPPPerm::wave_shl: DppCtrl = DppCtrl::WAVE_SHL1; break; case DPPPerm::wave_shr: DppCtrl = DppCtrl::WAVE_SHR1; break; case DPPPerm::wave_rol: DppCtrl = DppCtrl::WAVE_ROL1; break; case DPPPerm::wave_ror: DppCtrl = DppCtrl::WAVE_ROR1; break; case DPPPerm::row_mirror: DppCtrl = DppCtrl::ROW_MIRROR; break; case DPPPerm::row_half_mirror: DppCtrl = DppCtrl::ROW_HALF_MIRROR; break; case DPPPerm::row_bcast_15: DppCtrl = DppCtrl::BCAST15; break; case DPPPerm::row_bcast_31: DppCtrl = DppCtrl::BCAST31; break; } // Check for row_mask, bank_mask, bound_ctrl if they exist and create // constants auto rowMask = DppOp->getAttrOfType("row_mask").getInt(); auto bankMask = DppOp->getAttrOfType("bank_mask").getInt(); bool boundCtrl = DppOp->getAttrOfType("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes auto dppMovOp = ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa(srcType)) { result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } // We are replacing the AMDGPU_DPPOp instruction with the new // ROCDL_DPPMovOp instruction rewriter.replaceOp(DppOp, ValueRange(result)); return success(); } }; struct AMDGPUSwizzleBitModeLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type i32 = rewriter.getI32Type(); Value src = adaptor.getSrc(); SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, src, i32); unsigned andMask = op.getAndMask(); unsigned orMask = op.getOrMask(); unsigned xorMask = op.getXorMask(); // bit 15 is 0 for the BitMode swizzle. // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ unsigned mask = andMask | (orMask << 5) | (xorMask << 10); Value maskValue = createI32Constant(rewriter, loc, mask); SmallVector swizzled; for (Value v : decomposed) { Value res = ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType()); rewriter.replaceOp(op, result); return success(); } }; struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase { using Base::Base; void runOnOperation() override { MLIRContext *ctx = &getContext(); FailureOr maybeChipset = Chipset::parse(chipset); if (failed(maybeChipset)) { emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); return signalPassFailure(); } RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::populateAMDGPUMemorySpaceAttributeConversions( TypeConverter &typeConverter) { typeConverter.addTypeAttributeConversion( [](BaseMemRefType type, amdgpu::AddressSpaceAttr as) -> TypeConverter::AttributeConversionResult { MLIRContext *ctx = as.getContext(); Type i64 = IntegerType::get(ctx, 64); switch (as.getValue()) { case amdgpu::AddressSpace::FatRawBuffer: return IntegerAttr::get(i64, 7); case amdgpu::AddressSpace::BufferRsrc: return IntegerAttr::get(i64, 8); case amdgpu::AddressSpace::FatStructuredBuffer: return IntegerAttr::get(i64, 9); } return TypeConverter::AttributeConversionResult::abort(); }); } void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { populateAMDGPUMemorySpaceAttributeConversions(converter); patterns .add, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering>(converter, chipset); patterns.add(converter); }