aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp')
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp291
1 files changed, 146 insertions, 145 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -21,19 +21,17 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -53,7 +51,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
- return b.create<LLVM::TruncOp>(b.getI32Type(), value);
+ return LLVM::TruncOp::create(b, b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@@ -113,8 +111,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
- return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
- rewriter.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
+ rewriter.getI32IntegerAttr(index));
};
if (arrayType) {
@@ -126,7 +124,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
arrayType.getElementType() == f32x1Ty) {
for (unsigned i = 0; i < structType.getBody().size(); i++) {
Value el =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
el = rewriter.createOrFold<LLVM::BitcastOp>(
loc, arrayType.getElementType(), el);
elements.push_back(el);
@@ -143,24 +141,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
Value vec =
- rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
+ LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
Value x1 =
- rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
- Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
- i * 2 + 1);
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x1, makeConst(0));
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x2, makeConst(1));
+ LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
+ Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
+ i * 2 + 1);
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x1, makeConst(0));
+ vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
+ x2, makeConst(1));
elements.push_back(vec);
}
}
// Create the final vectorized result.
- Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
for (const auto &el : llvm::enumerate(elements)) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
- el.index());
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
+ el.index());
}
return result;
}
@@ -187,7 +185,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
+ Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@@ -195,7 +193,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
- result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
+ result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
continue;
}
@@ -208,9 +206,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
- result.push_back(b.create<LLVM::ExtractElementOp>(
- toUse,
- b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
+ result.push_back(LLVM::ExtractElementOp::create(
+ b, toUse,
+ LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@@ -285,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
- Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
- ldMatrixResultType, srcPtr,
+ Value ldMatrixResult = NVVM::LdMatrixOp::create(
+ b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@@ -296,13 +294,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = b.create<LLVM::PoisonOp>(finalResultType);
+ Value result = LLVM::PoisonOp::create(b, finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
- num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
+ num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
: ldMatrixResult;
- Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
- result = b.create<LLVM::InsertValueOp>(result, casted, i);
+ Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
+ result = LLVM::InsertValueOp::create(b, result, casted, i);
}
rewriter.replaceOp(op, result);
@@ -375,16 +373,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = b.create<NVVM::MmaOp>(
- intrinsicResTy, matA, matB, matC,
- /*shape=*/gemmShape,
- /*b1Op=*/std::nullopt,
- /*intOverflow=*/overflow,
- /*multiplicandPtxTypes=*/
- std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
- /*multiplicandLayouts=*/
- std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
- NVVM::MMALayout::col});
+ Value intrinsicResult =
+ NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
+ /*shape=*/gemmShape,
+ /*b1Op=*/std::nullopt,
+ /*intOverflow=*/overflow,
+ /*multiplicandPtxTypes=*/
+ std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
+ /*multiplicandLayouts=*/
+ std::array<NVVM::MMALayout, 2>{
+ NVVM::MMALayout::row, NVVM::MMALayout::col});
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
desiredRetTy, intrinsicResult,
rewriter));
@@ -565,15 +563,16 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- return b.create<LLVM::InlineAsmOp>(
- /*resultTypes=*/intrinsicResultType,
- /*operands=*/asmVals,
- /*asm_string=*/asmStr,
- /*constraints=*/constraintStr,
- /*has_side_effects=*/true,
- /*is_align_stack=*/false, LLVM::TailCallKind::None,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
+ return LLVM::InlineAsmOp::create(b,
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ LLVM::TailCallKind::None,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -631,7 +630,7 @@ struct NVGPUMmaSparseSyncLowering
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
- b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
+ LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
@@ -682,7 +681,7 @@ struct NVGPUAsyncCopyLowering
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
- scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
+ scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@@ -697,13 +696,13 @@ struct NVGPUAsyncCopyLowering
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
Value c3I32 =
- b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
- Value bitwidth = b.create<LLVM::ConstantOp>(
- b.getI32Type(),
+ LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
+ Value bitwidth = LLVM::ConstantOp::create(
+ b, b.getI32Type(),
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
- Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
- srcBytes = b.create<LLVM::LShrOp>(
- b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
+ Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
+ srcBytes = LLVM::LShrOp::create(
+ b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@@ -712,14 +711,15 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
- b.create<NVVM::CpAsyncOp>(
- dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
+ NVVM::CpAsyncOp::create(
+ b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
- Value zero = b.create<LLVM::ConstantOp>(
- IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
+ Value zero =
+ LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -733,11 +733,11 @@ struct NVGPUAsyncCreateGroupLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
+ NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
// Drop the result token.
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), IntegerType::get(op.getContext(), 32),
- rewriter.getI32IntegerAttr(0));
+ Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ IntegerType::get(op.getContext(), 32),
+ rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@@ -753,7 +753,7 @@ struct NVGPUAsyncWaitLowering
ConversionPatternRewriter &rewriter) const override {
// If numGroup is not present pick 0 as a conservative correct value.
int32_t numGroups = adaptor.getNumGroups().value_or(0);
- rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
+ NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
rewriter.eraseOp(op);
return success();
}
@@ -771,8 +771,8 @@ struct NVGPUMBarrierCreateLowering
SymbolTable symbolTable(moduleOp);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp.front());
- auto global = rewriter.create<memref::GlobalOp>(
- funcOp->getLoc(), "__mbarrier",
+ auto global = memref::GlobalOp::create(
+ rewriter, funcOp->getLoc(), "__mbarrier",
/*sym_visibility=*/rewriter.getStringAttr("private"),
/*type=*/barrierType,
/*initial_value=*/ElementsAttr(),
@@ -974,7 +974,7 @@ struct NVGPUMBarrierTryWaitParityLowering
adaptor.getMbarId(), rewriter);
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
- b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+ LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@@ -1063,16 +1063,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
auto ti64 = b.getIntegerType(64);
auto makeConst = [&](uint64_t index) -> Value {
- return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
};
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
+ return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
};
auto shiftRight = [&](Value value, unsigned shift) -> Value {
- return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
+ return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
};
auto insertBit = [&](Value desc, Value val, int startBit) {
- return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
+ return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
};
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@@ -1086,7 +1086,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
Value baseAddr = getStridedElementPtr(
rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {});
- Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
+ Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
@@ -1104,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1118,8 +1118,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
};
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
- return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
- b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, b.getIntegerType(64),
+ b.getI32IntegerAttr(index));
}
/// Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,12 +1182,12 @@ struct NVGPUTmaCreateDescriptorOpLowering
auto promotedOperands = getTypeConverter()->promoteOperands(
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
- Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
- makeI64Const(b, 5));
+ Value boxArrayPtr = LLVM::AllocaOp::create(
+ b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
- Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
- boxArrayPtr, makeI64Const(b, index));
- b.create<LLVM::StoreOp>(value, gep);
+ Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
+ boxArrayPtr, makeI64Const(b, index));
+ LLVM::StoreOp::create(b, value, gep);
}
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@@ -1280,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1337,7 +1337,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// Basic function to generate Add
Value makeAdd(Value lhs, Value rhs) {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1365,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1390,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1399,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1430,29 +1429,30 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- return b.create<NVVM::WgmmaMmaAsyncOp>(
- matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
- itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
+ return NVVM::WgmmaMmaAsyncOp::create(
+ b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
+ itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
Value generateWgmmaGroup() {
Value wgmmaResult =
- b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
+ LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
// Perform GEMM
SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
+ Value matrixC =
+ LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
- wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
- wgmmaResult, matrix, idx);
+ wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
}
return wgmmaResult;
}
@@ -1465,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
@@ -1486,10 +1486,10 @@ struct NVGPUWarpgroupMmaOpLowering
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
Value generateWarpgroupMma() {
- b.create<NVVM::WgmmaFenceAlignedOp>();
+ NVVM::WgmmaFenceAlignedOp::create(b);
Value wgmmaResult = generateWgmmaGroup();
- b.create<NVVM::WgmmaGroupSyncAlignedOp>();
- b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
+ NVVM::WgmmaGroupSyncAlignedOp::create(b);
+ NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
return wgmmaResult;
}
};
@@ -1557,7 +1557,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Type i32 = b.getI32Type();
auto makeConst = [&](int32_t index) -> Value {
- return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
+ return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
};
Value c1 = makeConst(1);
Value c2 = makeConst(2);
@@ -1567,29 +1567,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value warpSize = makeConst(kWarpSize);
auto makeMul = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
+ return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
- return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
+ return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
};
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
- Value idx = b.create<arith::IndexCastOp>(it, x);
- Value idy0 = b.create<arith::IndexCastOp>(it, y);
- Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
- Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
- Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
- b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
- b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
+ Value idx = arith::IndexCastOp::create(b, it, x);
+ Value idy0 = arith::IndexCastOp::create(b, it, y);
+ Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
+ Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
+ Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
+ memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
+ memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
};
- Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
- Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
- Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
- Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
- Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ Value tidx = NVVM::ThreadIdXOp::create(b, i32);
+ Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
+ Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
+ Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
+ Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
@@ -1626,7 +1626,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(matrixD);
- Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ Value innerStructValue =
+ LLVM::ExtractValueOp::create(b, matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
@@ -1648,23 +1649,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
.getBody()
.front();
- Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
- Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
+ Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
+ Value packStruct = LLVM::PoisonOp::create(b, packStructType);
SmallVector<Value> innerStructs;
// Unpack the structs and set all values to zero
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
auto structType = cast<LLVM::LLVMStructType>(s);
- Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+ Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
- structValue = b.create<LLVM::InsertValueOp>(
- structType, structValue, zero, ArrayRef<int64_t>({i}));
+ structValue = LLVM::InsertValueOp::create(b, structType, structValue,
+ zero, ArrayRef<int64_t>({i}));
}
innerStructs.push_back(structValue);
}
// Pack the inner structs into a single struct
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
- packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
- packStruct, matrix, idx);
+ packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
+ packStruct, matrix, idx);
}
rewriter.replaceOp(op, packStruct);
return success();
@@ -1681,7 +1682,7 @@ struct NVGPUTmaFenceOpLowering
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto i32Ty = b.getI32Type();
Value tensormapSize =
- b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
+ LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
auto memscope =
NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
@@ -1716,13 +1717,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
VectorType inTy = op.getIn().getType();
// apply rcp.approx.ftz.f on each element in vector.
auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
- Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
+ Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
for (int i = 0; i < numElems; i++) {
- Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
- Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
- Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
- ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
+ Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
+ Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
+ Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
+ ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
}
return ret1DVec;
};