diff options
Diffstat (limited to 'mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 291 |
1 files changed, 146 insertions, 145 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -21,19 +21,17 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -53,7 +51,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { assert(llvm::isa<IntegerType>(type) && "expected an integer Value"); if (type.getIntOrFloatBitWidth() <= 32) return value; - return b.create<LLVM::TruncOp>(b.getI32Type(), value); + return LLVM::TruncOp::create(b, b.getI32Type(), value); } /// Returns the type for the intrinsic given the vectorResultType of the @@ -113,8 +111,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { - return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), - rewriter.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); }; if (arrayType) { @@ -126,7 +124,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = - rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i); el = rewriter.createOrFold<LLVM::BitcastOp>( loc, arrayType.getElementType(), el); elements.push_back(el); @@ -143,24 +141,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = - rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType()); + LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType()); Value x1 = - rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); - Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, - i * 2 + 1); - vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, - x1, makeConst(0)); - vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, - x2, makeConst(1)); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2); + Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, + i * 2 + 1); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x1, makeConst(0)); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. - Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType); + Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { - result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), - el.index()); + result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(), + el.index()); } return result; } @@ -187,7 +185,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { - Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); + Value toUse = LLVM::ExtractValueOp::create(b, operand, i); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. @@ -195,7 +193,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { - result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); + result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse)); continue; } @@ -208,9 +206,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { - result.push_back(b.create<LLVM::ExtractElementOp>( - toUse, - b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); + result.push_back(LLVM::ExtractElementOp::create( + b, toUse, + LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx)))); } continue; } @@ -285,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); - Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( - ldMatrixResultType, srcPtr, + Value ldMatrixResult = NVVM::LdMatrixOp::create( + b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); @@ -296,13 +294,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); - Value result = b.create<LLVM::PoisonOp>(finalResultType); + Value result = LLVM::PoisonOp::create(b, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = - num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) + num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i) : ldMatrixResult; - Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); - result = b.create<LLVM::InsertValueOp>(result, casted, i); + Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register); + result = LLVM::InsertValueOp::create(b, result, casted, i); } rewriter.replaceOp(op, result); @@ -375,16 +373,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); - Value intrinsicResult = b.create<NVVM::MmaOp>( - intrinsicResTy, matA, matB, matC, - /*shape=*/gemmShape, - /*b1Op=*/std::nullopt, - /*intOverflow=*/overflow, - /*multiplicandPtxTypes=*/ - std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, - /*multiplicandLayouts=*/ - std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, - NVVM::MMALayout::col}); + Value intrinsicResult = + NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/std::nullopt, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, + /*multiplicandLayouts=*/ + std::array<NVVM::MMALayout, 2>{ + NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); @@ -565,15 +563,16 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( llvm::append_range(asmVals, args); asmVals.push_back(indexData); - return b.create<LLVM::InlineAsmOp>( - /*resultTypes=*/intrinsicResultType, - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/constraintStr, - /*has_side_effects=*/true, - /*is_align_stack=*/false, LLVM::TailCallKind::None, - /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); + return LLVM::InlineAsmOp::create(b, + /*resultTypes=*/intrinsicResultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/constraintStr, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::TailCallKind::None, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); } /// Lowers `nvgpu.mma.sp.sync` to inline assembly. @@ -631,7 +630,7 @@ struct NVGPUMmaSparseSyncLowering return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = - b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); + LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata); FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, @@ -682,7 +681,7 @@ struct NVGPUAsyncCopyLowering // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); - scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); + scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr); int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; @@ -697,13 +696,13 @@ struct NVGPUAsyncCopyLowering // The rest of the DstElements in the destination (shared memory) are // filled with zeros. Value c3I32 = - b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); - Value bitwidth = b.create<LLVM::ConstantOp>( - b.getI32Type(), + LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3)); + Value bitwidth = LLVM::ConstantOp::create( + b, b.getI32Type(), b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); - Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); - srcBytes = b.create<LLVM::LShrOp>( - b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); + Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes); + srcBytes = LLVM::LShrOp::create( + b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32); } // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than // 16 dst bytes. @@ -712,14 +711,15 @@ struct NVGPUAsyncCopyLowering ? NVVM::LoadCacheModifierKind::CG : NVVM::LoadCacheModifierKind::CA; - b.create<NVVM::CpAsyncOp>( - dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), + NVVM::CpAsyncOp::create( + b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), srcBytes); // Drop the result token. - Value zero = b.create<LLVM::ConstantOp>( - IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); + Value zero = + LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -733,11 +733,11 @@ struct NVGPUAsyncCreateGroupLowering LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); + NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc()); // Drop the result token. - Value zero = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -753,7 +753,7 @@ struct NVGPUAsyncWaitLowering ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); - rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); + NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } @@ -771,8 +771,8 @@ struct NVGPUMBarrierCreateLowering SymbolTable symbolTable(moduleOp); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&moduleOp.front()); - auto global = rewriter.create<memref::GlobalOp>( - funcOp->getLoc(), "__mbarrier", + auto global = memref::GlobalOp::create( + rewriter, funcOp->getLoc(), "__mbarrier", /*sym_visibility=*/rewriter.getStringAttr("private"), /*type=*/barrierType, /*initial_value=*/ElementsAttr(), @@ -974,7 +974,7 @@ struct NVGPUMBarrierTryWaitParityLowering adaptor.getMbarId(), rewriter); Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = - b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); + LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); if (isMbarrierShared(op.getBarriers().getType())) { rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( @@ -1063,16 +1063,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering auto ti64 = b.getIntegerType(64); auto makeConst = [&](uint64_t index) -> Value { - return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); + return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index)); }; auto shiftLeft = [&](Value value, unsigned shift) -> Value { - return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); + return LLVM::ShlOp::create(b, ti64, value, makeConst(shift)); }; auto shiftRight = [&](Value value, unsigned shift) -> Value { - return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); + return LLVM::LShrOp::create(b, ti64, value, makeConst(shift)); }; auto insertBit = [&](Value desc, Value val, int startBit) { - return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); + return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit)); }; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); @@ -1086,7 +1086,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering Value baseAddr = getStridedElementPtr( rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()), adaptor.getTensor(), {}); - Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); + Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr); // Just use 14 bits for base address Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); @@ -1104,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1118,8 +1118,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering }; static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { - return b.create<LLVM::ConstantOp>(b.getIntegerType(64), - b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, b.getIntegerType(64), + b.getI32IntegerAttr(index)); } /// Returns a Value that holds data type enum that is expected by CUDA driver. @@ -1182,12 +1182,12 @@ struct NVGPUTmaCreateDescriptorOpLowering auto promotedOperands = getTypeConverter()->promoteOperands( b.getLoc(), op->getOperands(), adaptor.getOperands(), b); - Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, - makeI64Const(b, 5)); + Value boxArrayPtr = LLVM::AllocaOp::create( + b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5)); for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { - Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, - boxArrayPtr, makeI64Const(b, index)); - b.create<LLVM::StoreOp>(value, gep); + Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType, + boxArrayPtr, makeI64Const(b, index)); + LLVM::StoreOp::create(b, value, gep); } nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); @@ -1280,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1337,7 +1337,7 @@ struct NVGPUWarpgroupMmaOpLowering /// Basic function to generate Add Value makeAdd(Value lhs, Value rhs) { - return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. @@ -1365,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1390,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1399,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1430,29 +1429,30 @@ struct NVGPUWarpgroupMmaOpLowering auto overflow = NVVM::MMAIntOverflowAttr::get( op->getContext(), NVVM::MMAIntOverflow::wrapped); - return b.create<NVVM::WgmmaMmaAsyncOp>( - matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, - itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, + return NVVM::WgmmaMmaAsyncOp::create( + b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape, + itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); } /// Generates multiple wgmma instructions to complete the given GEMM shape Value generateWgmmaGroup() { Value wgmmaResult = - b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType()); + LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType()); // Perform GEMM SmallVector<Value> wgmmaResults; for (int i = 0; i < iterationM; ++i) { - Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); + Value matrixC = + LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i); for (int j = 0; j < iterationN; ++j) for (int k = 0; k < iterationK; ++k) matrixC = generateWgmma(i, j, k, matrixC); wgmmaResults.push_back(matrixC); } for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { - wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), - wgmmaResult, matrix, idx); + wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(), + wgmmaResult, matrix, idx); } return wgmmaResult; } @@ -1465,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( @@ -1486,10 +1486,10 @@ struct NVGPUWarpgroupMmaOpLowering /// (WgmmaGroupSyncAlignedOp) for group synchronization /// (WgmmaWaitGroupSyncOp) after the instructions. Value generateWarpgroupMma() { - b.create<NVVM::WgmmaFenceAlignedOp>(); + NVVM::WgmmaFenceAlignedOp::create(b); Value wgmmaResult = generateWgmmaGroup(); - b.create<NVVM::WgmmaGroupSyncAlignedOp>(); - b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); + NVVM::WgmmaGroupSyncAlignedOp::create(b); + NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup()); return wgmmaResult; } }; @@ -1557,7 +1557,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { - return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index)); }; Value c1 = makeConst(1); Value c2 = makeConst(2); @@ -1567,29 +1567,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering Value warpSize = makeConst(kWarpSize); auto makeMul = [&](Value lhs, Value rhs) -> Value { - return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); + return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs); }; auto makeAdd = [&](Value lhs, Value rhs) -> Value { - return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, TypedValue<::mlir::MemRefType> memref) { Type it = b.getIndexType(); - Value idx = b.create<arith::IndexCastOp>(it, x); - Value idy0 = b.create<arith::IndexCastOp>(it, y); - Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); - Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); - Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); - b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); - b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); + Value idx = arith::IndexCastOp::create(b, it, x); + Value idy0 = arith::IndexCastOp::create(b, it, y); + Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1)); + Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i); + Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1); + memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0}); + memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1}); }; - Value tidx = b.create<NVVM::ThreadIdXOp>(i32); - Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); - Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); - Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); - Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); + Value tidx = NVVM::ThreadIdXOp::create(b, i32); + Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize); + Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize); + Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4); + Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4); Value tj = makeMul(lane4modId, c2); Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); @@ -1626,7 +1626,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { auto structType = cast<LLVM::LLVMStructType>(matrixD); - Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); + Value innerStructValue = + LLVM::ExtractValueOp::create(b, matriDValue, idx); storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); offset += structType.getBody().size(); } @@ -1648,23 +1649,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) .getBody() .front(); - Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); - Value packStruct = b.create<LLVM::PoisonOp>(packStructType); + Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType)); + Value packStruct = LLVM::PoisonOp::create(b, packStructType); SmallVector<Value> innerStructs; // Unpack the structs and set all values to zero for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { auto structType = cast<LLVM::LLVMStructType>(s); - Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); + Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx); for (unsigned i = 0; i < structType.getBody().size(); ++i) { - structValue = b.create<LLVM::InsertValueOp>( - structType, structValue, zero, ArrayRef<int64_t>({i})); + structValue = LLVM::InsertValueOp::create(b, structType, structValue, + zero, ArrayRef<int64_t>({i})); } innerStructs.push_back(structValue); } // Pack the inner structs into a single struct for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { - packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), - packStruct, matrix, idx); + packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(), + packStruct, matrix, idx); } rewriter.replaceOp(op, packStruct); return success(); @@ -1681,7 +1682,7 @@ struct NVGPUTmaFenceOpLowering ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto i32Ty = b.getI32Type(); Value tensormapSize = - b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128)); + LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128)); auto memscope = NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); @@ -1716,13 +1717,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { VectorType inTy = op.getIn().getType(); // apply rcp.approx.ftz.f on each element in vector. auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { - Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy); + Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy); int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); for (int i = 0; i < numElems; i++) { - Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); - Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); - Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); - ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); + Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i)); + Value elem = LLVM::ExtractElementOp::create(b, inVec, idx); + Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem); + ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx); } return ret1DVec; }; |