diff options
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
-rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 390 |
1 files changed, 252 insertions, 138 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ef35ee2..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, if (i32 == valTy) return val; return valTy.getWidth() > 32 - ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) - : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); + ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) + : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); - return rewriter.create<LLVM::ConstantOp>(loc, i32, value); + return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); - return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); + return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. @@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) - : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); - increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); + : LLVM::ConstantOp::create(rewriter, loc, i32, stride); + increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } - index = - index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; + index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) + : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } @@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); + Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex - ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) + ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); - return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst); + return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, @@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = - rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride); - Value swizzleBit = rewriter.create<LLVM::ConstantOp>( - loc, i16, rewriter.getI16IntegerAttr(1 << 14)); - stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit, - /*isDisjoint=*/true); + LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); + Value swizzleBit = LLVM::ConstantOp::create( + rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); + stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, + /*isDisjoint=*/true); } else { - stride = rewriter.create<LLVM::ConstantOp>(loc, i16, - rewriter.getI16IntegerAttr(0)); + stride = LLVM::ConstantOp::create(rewriter, loc, i16, + rewriter.getI16IntegerAttr(0)); } // Get the number of elements. // Flag word: @@ -209,20 +209,21 @@ struct FatRawBufferCastLowering : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() - ? rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(0)) + ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. - Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kSizePosInMemRefDescriptor) - : Value{}; - Value strides = hasSizes - ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kStridePosInMemRefDescriptor) - : Value{}; + Value sizes = hasSizes + ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kSizePosInMemRefDescriptor) + : Value{}; + Value strides = + hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kStridePosInMemRefDescriptor) + : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), @@ -231,17 +232,17 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset, - kOffsetPosInMemRefDescriptor); + SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, + kAlignedPtrPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, + kOffsetPosInMemRefDescriptor); if (hasSizes) { - result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes, - kSizePosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, strides, kStridePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, + kSizePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, + kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); @@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { SmallVector<Value, 6> args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForStore = - rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); + Value castForStore = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); @@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForCmp = rewriter.create<LLVM::BitcastOp>( - loc, llvmBufferValType, atomicCmpData); + Value castForCmp = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); @@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); - voffset = - voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) - : extraOffsetConst; + voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, + extraOffsetConst) + : extraOffsetConst; } - voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); + voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); + sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) @@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), llvmBufferValType); - Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { - replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, - replacement); + replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, + replacement); } rewriter.replaceOp(gpuOp, replacement); } else { @@ -419,6 +420,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { } }; +// TODO: AMDGPU backend already have all this bitpacking logic, we should move +// it to some common place. +/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: +/// Vmcnt = Waitcnt[3:0] (pre-gfx9) +/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) +/// Vmcnt = Waitcnt[15:10] (gfx11) +/// Expcnt = Waitcnt[6:4] (pre-gfx11) +/// Expcnt = Waitcnt[2:0] (gfx11) +/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) +/// Lgkmcnt = Waitcnt[13:8] (gfx10) +/// Lgkmcnt = Waitcnt[9:4] (gfx11) +static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt, + unsigned expcnt, unsigned lgkmcnt) { + if (chipset.majorVersion < 9) { + vmcnt = std::min(15u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + return vmcnt | (expcnt << 4) | (lgkmcnt << 8); + } + if (chipset.majorVersion == 9) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 10) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 11) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + return (vmcnt << 10) | expcnt | (lgkmcnt << 4); + } + return failure(); +} + +struct MemoryCounterWaitOpLowering + : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> { + MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter), + chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset.majorVersion >= 12) { + Location loc = op.getLoc(); + if (std::optional<int> ds = adaptor.getDs()) + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); + + if (std::optional<int> load = adaptor.getLoad()) + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); + + if (std::optional<int> store = adaptor.getStore()) + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); + + if (std::optional<int> exp = adaptor.getExp()) + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); + + rewriter.eraseOp(op); + return success(); + } + + auto getVal = [](Attribute attr) -> unsigned { + if (attr) + return cast<IntegerAttr>(attr).getInt(); + + // This value will be clamped to the maximum value for the chipset. + return 1024; + }; + unsigned ds = getVal(adaptor.getDsAttr()); + unsigned exp = getVal(adaptor.getExpAttr()); + + unsigned vmcnt = 1024; + Attribute load = adaptor.getLoadAttr(); + Attribute store = adaptor.getStoreAttr(); + if (load && store) { + vmcnt = getVal(load) + getVal(store); + } else if (load) { + vmcnt = getVal(load); + } else if (store) { + vmcnt = getVal(store); + } + + FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); + if (failed(waitcnt)) + return op.emitOpError("unsupported chipset"); + + rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt); + return success(); + } +}; + struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} @@ -465,12 +572,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { << chipset.majorVersion; Location loc = op->getLoc(); - rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits); + ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); } else { Location loc = op->getLoc(); - rewriter.create<ROCDL::WaitDscntOp>(loc, 0); - rewriter.create<ROCDL::BarrierSignalOp>(loc, -1); + ROCDL::WaitDscntOp::create(rewriter, loc, 0); + ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); } @@ -516,19 +623,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); if (auto vectorType = dyn_cast<VectorType>(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) - return rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) - return rewriter.create<LLVM::BitcastOp>( - loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); + return LLVM::BitcastOp::create( + rewriter, loc, + rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa<IntegerType>(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); - return rewriter.create<LLVM::BitcastOp>( - loc, VectorType::get(numWords, rewriter.getI32Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), + input); } } return input; @@ -549,8 +658,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast<IntegerType>(inputType)) - return rewriter.create<LLVM::ZExtOp>(loc, outputType, input); - return rewriter.create<LLVM::BitcastOp>(loc, outputType, input); + return LLVM::ZExtOp::create(rewriter, loc, outputType, input); + return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is @@ -576,8 +685,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - llvmInput = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), llvmInput); + llvmInput = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -613,7 +722,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) - castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput); + castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } @@ -633,8 +742,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - output = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), output); + output = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); @@ -992,7 +1101,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) - lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); + lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } @@ -1092,8 +1201,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { Operation *maybeCastBack = lowered; if (rawOutType != outType) - maybeCastBack = - rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0)); + maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, + lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); @@ -1143,22 +1252,22 @@ struct TransposeLoadOpLowering switch (elementTypeSize) { case 4: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } @@ -1316,21 +1425,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { - Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); + Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source, @@ -1382,21 +1491,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { - Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType); + Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>( @@ -1454,54 +1563,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else - existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType); + existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); - Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); + Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); - source = rewriter.create<LLVM::ZeroOp>(loc, v2); - source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0); + source = LLVM::ZeroOp::create(rewriter, loc, v2); + source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); - sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); - sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1); + sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); + sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); @@ -1526,20 +1638,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) - sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); + sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1563,17 +1675,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrBf8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrFp8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1617,14 +1729,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa<FloatType>(operandType)) { operand = - rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); + LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); - Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); - operand = rewriter.create<LLVM::InsertElementOp>( - loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); - operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand); + Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); + operand = + LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, + createI32Constant(rewriter, loc, 0)); + operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; @@ -1711,14 +1824,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes - auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>( - loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); + auto dppMovOp = + ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, + rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { - result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result); + result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa<IntegerType>(srcType)) { - result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result); + result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } @@ -1752,7 +1866,7 @@ struct AMDGPUSwizzleBitModeLowering SmallVector<Value> swizzled; for (Value v : decomposed) { Value res = - rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue); + ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } @@ -1825,9 +1939,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering<RawBufferAtomicCmpswapOp, ROCDL::RawPtrBufferAtomicCmpSwap>, - AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, - MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, - ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering>(converter, chipset); |