aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp390
1 files changed, 252 insertions, 138 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee2..64720bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
- ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
- : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+ ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
+ : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+ return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
@@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
- : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
+ : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
+ increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
+ : increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
@@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
- ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
- return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+ return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
- rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
- Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(1 << 14));
- stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
- /*isDisjoint=*/true);
+ LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
+ Value swizzleBit = LLVM::ConstantOp::create(
+ rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
} else {
- stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
- rewriter.getI16IntegerAttr(0));
+ stride = LLVM::ConstantOp::create(rewriter, loc, i16,
+ rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
@@ -209,20 +209,21 @@ struct FatRawBufferCastLowering
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(0))
+ ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor)
- : Value{};
- Value strides = hasSizes
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor)
- : Value{};
+ Value sizes = hasSizes
+ ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides =
+ hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kStridePosInMemRefDescriptor)
+ : Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
@@ -231,17 +232,17 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
- kOffsetPosInMemRefDescriptor);
+ SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAlignedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
if (hasSizes) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, strides, kStridePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
+ kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
@@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForStore =
- rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
+ Value castForStore = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
@@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForCmp = rewriter.create<LLVM::BitcastOp>(
- loc, llvmBufferValType, atomicCmpData);
+ Value castForCmp = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
@@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
- voffset =
- voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
- : extraOffsetConst;
+ voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
+ extraOffsetConst)
+ : extraOffsetConst;
}
- voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+ voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
+ sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
@@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
- Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
- replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
- replacement);
+ replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
+ replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
@@ -419,6 +420,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
}
};
+// TODO: AMDGPU backend already have all this bitpacking logic, we should move
+// it to some common place.
+/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
+/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
+/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
+/// Vmcnt = Waitcnt[15:10] (gfx11)
+/// Expcnt = Waitcnt[6:4] (pre-gfx11)
+/// Expcnt = Waitcnt[2:0] (gfx11)
+/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
+/// Lgkmcnt = Waitcnt[13:8] (gfx10)
+/// Lgkmcnt = Waitcnt[9:4] (gfx11)
+static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
+ unsigned expcnt, unsigned lgkmcnt) {
+ if (chipset.majorVersion < 9) {
+ vmcnt = std::min(15u, vmcnt);
+ expcnt = std::min(7u, expcnt);
+ lgkmcnt = std::min(15u, lgkmcnt);
+ return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
+ }
+ if (chipset.majorVersion == 9) {
+ vmcnt = std::min(63u, vmcnt);
+ expcnt = std::min(7u, expcnt);
+ lgkmcnt = std::min(15u, lgkmcnt);
+ unsigned lowBits = vmcnt & 0xF;
+ unsigned highBits = (vmcnt >> 4) << 14;
+ unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
+ return lowBits | highBits | otherCnts;
+ }
+ if (chipset.majorVersion == 10) {
+ vmcnt = std::min(63u, vmcnt);
+ expcnt = std::min(7u, expcnt);
+ lgkmcnt = std::min(63u, lgkmcnt);
+ unsigned lowBits = vmcnt & 0xF;
+ unsigned highBits = (vmcnt >> 4) << 14;
+ unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
+ return lowBits | highBits | otherCnts;
+ }
+ if (chipset.majorVersion == 11) {
+ vmcnt = std::min(63u, vmcnt);
+ expcnt = std::min(7u, expcnt);
+ lgkmcnt = std::min(63u, lgkmcnt);
+ return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
+ }
+ return failure();
+}
+
+struct MemoryCounterWaitOpLowering
+ : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
+ MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset.majorVersion >= 12) {
+ Location loc = op.getLoc();
+ if (std::optional<int> ds = adaptor.getDs())
+ ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
+
+ if (std::optional<int> load = adaptor.getLoad())
+ ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
+
+ if (std::optional<int> store = adaptor.getStore())
+ ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
+
+ if (std::optional<int> exp = adaptor.getExp())
+ ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ auto getVal = [](Attribute attr) -> unsigned {
+ if (attr)
+ return cast<IntegerAttr>(attr).getInt();
+
+ // This value will be clamped to the maximum value for the chipset.
+ return 1024;
+ };
+ unsigned ds = getVal(adaptor.getDsAttr());
+ unsigned exp = getVal(adaptor.getExpAttr());
+
+ unsigned vmcnt = 1024;
+ Attribute load = adaptor.getLoadAttr();
+ Attribute store = adaptor.getStoreAttr();
+ if (load && store) {
+ vmcnt = getVal(load) + getVal(store);
+ } else if (load) {
+ vmcnt = getVal(load);
+ } else if (store) {
+ vmcnt = getVal(store);
+ }
+
+ FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
+ if (failed(waitcnt))
+ return op.emitOpError("unsupported chipset");
+
+ rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
+ return success();
+ }
+};
+
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -465,12 +572,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;
Location loc = op->getLoc();
- rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
+ ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
- rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
- rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
+ ROCDL::WaitDscntOp::create(rewriter, loc, 0);
+ ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
@@ -516,19 +623,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
- return rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
- return rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
- return rewriter.create<LLVM::BitcastOp>(
- loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
}
}
return input;
@@ -549,8 +658,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
- return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
- return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -576,8 +685,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- llvmInput = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
+ llvmInput = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -613,7 +722,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
- castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
+ castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
@@ -633,8 +742,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- output = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), output);
+ output = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
@@ -992,7 +1101,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
- lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
@@ -1092,8 +1201,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
- maybeCastBack =
- rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
@@ -1143,22 +1252,22 @@ struct TransposeLoadOpLowering
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
@@ -1316,21 +1425,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
- Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
@@ -1382,21 +1491,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
- Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
@@ -1454,54 +1563,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
- existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
+ existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
- Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
+ Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
- source = rewriter.create<LLVM::ZeroOp>(loc, v2);
- source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
+ source = LLVM::ZeroOp::create(rewriter, loc, v2);
+ source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
- sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
- sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
+ sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
+ sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
@@ -1526,20 +1638,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
- sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1563,17 +1675,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1617,14 +1729,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
- rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
- Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
- operand = rewriter.create<LLVM::InsertElementOp>(
- loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
- operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
+ Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
+ operand =
+ LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
+ createI32Constant(rewriter, loc, 0));
+ operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
@@ -1711,14 +1824,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
- auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
- loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+ auto dppMovOp =
+ ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
+ rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
- result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
- result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
@@ -1752,7 +1866,7 @@ struct AMDGPUSwizzleBitModeLowering
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
- rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
@@ -1825,9 +1939,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
- AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
- MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);