//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the types and operation details for the NVVM IR dialect in // MLIR, and the LLVM IR dialect. It also registers the dialect. // // The NVVM dialect only contains GPU specific additions on top of the general // LLVM dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/NVPTXAddrSpace.h" #include "llvm/Support/raw_ostream.h" #include #include #include using namespace mlir; using namespace NVVM; #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; //===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// // This verifier is shared among the following Ops: // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store) // CpAsyncBulkTensorReduceOp (TMA Store-Reduce) static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc) { if (tensorDims < 1 || tensorDims > 5) return emitError(loc, "expects coordinates between 1 to 5 dimension"); // For Im2Col mode, there are two constraints: if (isIm2Col) { // 1. Tensor must always be at least 3-d. if (tensorDims < 3) return emitError( loc, "to use im2col mode, the tensor has to be at least 3-dimensional"); // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number. if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2))) return emitError( loc, "im2col offsets must be 2 less than number of coordinates"); } return success(); } LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { TMAStoreMode mode = getMode(); // We lower through inline-ptx when getPredicate() is true. // a) Only TILE mode is supported // b) Cache-hint is not supported if (getPredicate()) { if (mode != TMAStoreMode::TILE) return emitError("Inline-ptx lowering supported only for Tile mode."); if (getL2CacheHint()) return emitError("Inline-ptx lowering unsupported with L2 cache-hint."); } size_t dims = getCoordinates().size(); switch (mode) { case TMAStoreMode::TILE: return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); case TMAStoreMode::IM2COL: return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); case TMAStoreMode::TILE_SCATTER4: if (dims != 5) return emitError("Scatter4 mode expects 5 coordinates"); } return success(); } LogicalResult CpAsyncOp::verify() { if (getModifier() != LoadCacheModifierKind::CG && getModifier() != LoadCacheModifierKind::CA) return emitError("Only CG and CA cache modifiers are supported."); if (getSize() != 4 && getSize() != 8 && getSize() != 16) return emitError("expected byte size to be either 4, 8 or 16."); if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) return emitError("CG cache modifier is only support for 16 bytes copy."); return success(); } // This verify params can be shared across TMA Load and Prefetch Ops. static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc) { if (tensorDims < 1 || tensorDims > 5) return emitError(loc, "expects coordinates between 1 to 5 dimension"); auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col, size_t expectedIm2colOff) -> LogicalResult { if (isIm2col && (tensorDims < 3)) return emitError(loc) << "to use " << stringifyEnum(mode) << " mode, the tensor has to be at least 3-dimensional"; if (numIm2colOff != expectedIm2colOff) return emitError(loc) << " im2col offsets expected " << expectedIm2colOff << " (provided " << numIm2colOff << ")"; return success(); }; switch (mode) { case TMALoadMode::TILE: return checkTMALoadParams(mode, false, 0); case TMALoadMode::IM2COL: return checkTMALoadParams(mode, true, tensorDims - 2); case TMALoadMode::IM2COL_W: case TMALoadMode::IM2COL_W_128: return checkTMALoadParams(mode, true, 2); case TMALoadMode::TILE_GATHER4: return (tensorDims == 5) ? checkTMALoadParams(mode, false, 0) : emitError(loc, "Gather4 mode expects 5 coordinates"); } return success(); } LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(), getMode(), getLoc()); } LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { TMALoadMode mode = getMode(); bool isCTAOnly = getIsCTAOnly(); if (getPredicate()) { // Inline-asm based lowering if (isCTAOnly) return emitError("Predicate is supported only for shared::cluster mode."); if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL) return emitError( "Predicate is supported only for Tile and Im2col modes."); } else { // Intrinsics-based lowering NVVMMemorySpace expectedAS = isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster; unsigned AS = llvm::cast(getDstMem().getType()) .getAddressSpace(); if (AS != expectedAS) return emitError() << (isCTAOnly ? "Shared::cta destination requires address-space 3." : "Shared::cluster destination requires address-space 7."); // Checks specific to shared::cta mode if (isCTAOnly) { if (getMulticastMask()) return emitError("Multicast is not supported with shared::cta mode."); if (getGroup()) return emitError("CTAGroup is not supported with shared::cta mode."); } } return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(), getMode(), getLoc()); } LogicalResult CpAsyncBulkTensorReduceOp::verify() { TMAStoreMode mode = getMode(); size_t dims = getCoordinates().size(); switch (mode) { case TMAStoreMode::TILE: return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); case TMAStoreMode::IM2COL: return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); case TMAStoreMode::TILE_SCATTER4: return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp"); } return success(); } LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { case RndMode::RNA: if (getRelu()) return emitError("Relu not supported with rna rounding mode."); break; case RndMode::RN: case RndMode::RZ: break; default: return emitError( "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op."); } return success(); } LogicalResult ConvertF32x2ToF6x2Op::verify() { mlir::MLIRContext *ctx = getContext(); if (!llvm::isa(getDstTy())) { return emitOpError("Only ") << mlir::Float6E2M3FNType::get(ctx) << " and " << mlir::Float6E3M2FNType::get(ctx) << " types are supported for conversions from f32x2 to f6x2."; } return success(); } LogicalResult ConvertF32x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; using SatMode = NVVM::SaturationMode; bool isRoundingModeRN = getRnd() == RndMode::RN; bool isRoundingModeRZ = getRnd() == RndMode::RZ; bool isRoundingModeRP = getRnd() == RndMode::RP; bool isSatFinite = getSat() == SatMode::SATFINITE; bool hasRelu = getRelu(); mlir::MLIRContext *ctx = getContext(); return llvm::TypeSwitch(getDstTy()) .Case( [&](mlir::Type) -> LogicalResult { if (!isRoundingModeRN) { return emitOpError("Only RN rounding mode is supported for " "conversions from f32x2 to ") << mlir::Float8E4M3FNType::get(ctx) << " and " << mlir::Float8E5M2Type::get(ctx) << " types"; } if (!isSatFinite) { return emitOpError("Only SATFINITE saturation mode is supported " "for conversions " "from f32x2 to ") << mlir::Float8E4M3FNType::get(ctx) << " and " << mlir::Float8E5M2Type::get(ctx) << " types"; } return success(); }) .Case([&](mlir::Type) -> LogicalResult { if (!(isRoundingModeRZ || isRoundingModeRP)) { return emitOpError("Only RZ and RP rounding modes are supported for " "conversions from f32x2 to ") << mlir::Float8E8M0FNUType::get(ctx) << " type"; } if (hasRelu) { return emitOpError("relu not supported for conversions to ") << mlir::Float8E8M0FNUType::get(ctx) << " type"; } return success(); }) .Default([&](mlir::Type) { return emitOpError("Only ") << mlir::Float8E4M3FNType::get(ctx) << ", " << mlir::Float8E5M2Type::get(ctx) << ", and " << mlir::Float8E8M0FNUType::get(ctx) << " types are " "supported for conversions from f32x2 to f8x2"; }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { mlir::MLIRContext *ctx = getContext(); if (!llvm::isa(getDstTy())) { return emitOpError("Only ") << mlir::Float8E4M3FNType::get(ctx) << " and " << mlir::Float8E5M2Type::get(ctx) << " types are supported for conversions from f16x2 to f8x2."; } return success(); } LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; if (!llvm::isa(getDstTy())) return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) << " type is supported for conversions from " "bf16x2 to f8x2."; auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) return emitOpError("Only RZ and RP rounding modes are supported for " "conversions from bf16x2 to f8x2."); return success(); } LogicalResult ConvertF32x2ToF4x2Op::verify() { mlir::MLIRContext *ctx = getContext(); if (!llvm::isa(getDstTy())) return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx) << " type is supported for conversions from f32x2 to f4x2."; return success(); } LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); return success(); } LogicalResult PMEventOp::verify() { auto eventId = getEventId(); auto maskedEventId = getMaskedEventId(); if (!maskedEventId && !eventId) { return emitOpError() << "either `id` or `mask` must be set"; } if (maskedEventId && eventId) { return emitOpError() << "`id` and `mask` cannot be set at the same time"; } if (eventId) { if (eventId < 0 || eventId > 15) { return emitOpError() << "`id` must be between 0 and 15"; } } return llvm::success(); } // Given the element type of an operand and whether or not it is an accumulator, // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the // operand's element type. std::optional MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { auto half2Type = VectorType::get(2, Float16Type::get(operandElType.getContext())); if (operandElType.isF64()) return NVVM::MMATypes::f64; if (operandElType.isF16() || operandElType == half2Type) return NVVM::MMATypes::f16; if (operandElType.isF32() && isAccumulator) return NVVM::MMATypes::f32; if (operandElType.isF32() && !isAccumulator) return NVVM::MMATypes::tf32; if (llvm::isa(operandElType)) { if (isAccumulator) return NVVM::MMATypes::s32; return std::nullopt; } if (auto structType = llvm::dyn_cast(operandElType)) { if (structType.getBody().empty()) return std::nullopt; return inferOperandMMAType(structType.getBody()[0], isAccumulator); } return std::nullopt; } static bool isInt4PtxType(MMATypes type) { return (type == MMATypes::u4 || type == MMATypes::s4); } static bool isInt8PtxType(MMATypes type) { return (type == MMATypes::u8 || type == MMATypes::s8); } static bool isIntegerPtxType(MMATypes type) { return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || type == MMATypes::s32; } MMATypes MmaOp::accumPtxType() { std::optional val = inferOperandMMAType( getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); assert(val.has_value() && "accumulator PTX type should always be inferrable"); return val.value(); } MMATypes MmaOp::resultPtxType() { std::optional val = inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); assert(val.has_value() && "result PTX type should always be inferrable"); return val.value(); } void MmaOp::print(OpAsmPrinter &p) { SmallVector regTypes; struct OperandFragment { StringRef operandName; StringRef ptxTypeAttr; SmallVector regs; explicit OperandFragment(StringRef name, StringRef ptxTypeName) : operandName(name), ptxTypeAttr(ptxTypeName) {} }; std::array frags{ OperandFragment("A", getMultiplicandAPtxTypeAttrName()), OperandFragment("B", getMultiplicandBPtxTypeAttrName()), OperandFragment("C", "")}; SmallVector ignoreAttrNames{ mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { auto &frag = frags[fragIdx]; auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); for (auto operandIdx = varOperandSpec.first; operandIdx < varOperandSpec.first + varOperandSpec.second; operandIdx++) { frag.regs.push_back(this->getOperand(operandIdx)); if (operandIdx == 0) { regTypes.push_back(this->getOperand(operandIdx).getType()); } } std::optional inferredType = inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2); if (inferredType) ignoreAttrNames.push_back(frag.ptxTypeAttr); } auto printMmaOperand = [&](const OperandFragment &frag) -> void { p << " " << frag.operandName; p << "["; p.printOperands(frag.regs); p << "] "; }; for (const auto &frag : frags) { printMmaOperand(frag); } p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); // Print the types of the operands and result. p << " : " << "("; llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), frags[1].regs[0].getType(), frags[2].regs[0].getType()}, p); p << ")"; p.printArrowTypeList(TypeRange{this->getRes().getType()}); } void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, ValueRange operandA, ValueRange operandB, ValueRange operandC, ArrayRef shape, std::optional b1Op, std::optional intOverflow, std::optional> multiplicandPtxTypes, std::optional> multiplicandLayouts) { assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); MLIRContext *ctx = builder.getContext(); result.addAttribute( "shape", builder.getAttr(shape[0], shape[1], shape[2])); result.addOperands(operandA); result.addOperands(operandB); result.addOperands(operandC); if (multiplicandPtxTypes) { result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); } else { if (auto res = inferOperandMMAType(operandA[0].getType(), false)) result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); if (auto res = inferOperandMMAType(operandB[0].getType(), false)) result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); } if (multiplicandLayouts) { result.addAttribute("layoutA", MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); result.addAttribute("layoutB", MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); } else { result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); } if (intOverflow.has_value()) result.addAttribute("intOverflowBehavior", MMAIntOverflowAttr::get(ctx, *intOverflow)); if (b1Op.has_value()) result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); result.addTypes(resultType); result.addAttribute( MmaOp::getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr({static_cast(operandA.size()), static_cast(operandB.size()), static_cast(operandC.size())})); } // := // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) // `->` type($res) ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { struct OperandFragment { std::optional elemtype; SmallVector regs; SmallVector regTypes; }; Builder &builder = parser.getBuilder(); std::array frags; NamedAttrList namedAttributes; // A helper to parse the operand segments. auto parseMmaOperand = [&](StringRef operandName, OperandFragment &frag) -> LogicalResult { if (parser.parseKeyword(operandName).failed()) return failure(); if (parser .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) .failed()) return failure(); return success(); }; // Parse the operand segments. if (parseMmaOperand("A", frags[0]).failed()) return failure(); if (parseMmaOperand("B", frags[1]).failed()) return failure(); if (parseMmaOperand("C", frags[2]).failed()) return failure(); if (parser.parseOptionalAttrDict(namedAttributes).failed()) return failure(); // Parse the type specification and resolve operands. SmallVector operandTypes; if (failed(parser.parseColon())) return failure(); if (failed(parser.parseLParen())) return failure(); if (failed(parser.parseTypeList(operandTypes))) return failure(); if (failed(parser.parseRParen())) if (operandTypes.size() != 3) return parser.emitError( parser.getNameLoc(), "expected one type for each operand segment but got " + Twine(operandTypes.size()) + " types"); for (const auto &iter : llvm::enumerate(operandTypes)) { auto &frag = frags[iter.index()]; frag.regTypes.resize(frag.regs.size(), iter.value()); if (failed(parser.resolveOperands(frag.regs, frag.regTypes, parser.getNameLoc(), result.operands))) return failure(); frag.elemtype = inferOperandMMAType(frag.regTypes[0], /*isAccumulator*/ iter.index() < 2); } Type resultType; if (parser.parseArrow() || parser.parseType(resultType)) return failure(); frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true); std::array names{"multiplicandAPtxType", "multiplicandBPtxType"}; for (unsigned idx = 0; idx < names.size(); idx++) { const auto &frag = frags[idx]; std::optional attr = namedAttributes.getNamed(names[idx]); if (!frag.elemtype.has_value() && !attr.has_value()) { return parser.emitError( parser.getNameLoc(), "attribute " + names[idx] + " is not provided explicitly and cannot be inferred"); } if (!attr.has_value()) result.addAttribute( names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); } result.addTypes(resultType); if (!namedAttributes.empty()) result.addAttributes(namedAttributes); result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr({ static_cast(frags[0].regs.size()), static_cast(frags[1].regs.size()), static_cast(frags[2].regs.size()), })); return success(); } LogicalResult MmaOp::verify() { MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); auto i32Ty = IntegerType::get(context, 32); auto f16x2Ty = VectorType::get(2, f16Ty); auto f32Ty = Float32Type::get(context); auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); auto s32x4StructTy = LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); auto f16x2x2StructTy = LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); auto f32x4StructTy = LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); auto s32x2StructTy = LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); std::array mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), getShapeAttr().getK()}; // These variables define the set of allowed data types for matrices A, B, C, // and result. using AllowedShapes = SmallVector, 2>; using AllowedTypes = SmallVector, 2>; AllowedShapes allowedShapes; AllowedTypes expectedA; AllowedTypes expectedB; AllowedTypes expectedC; SmallVector expectedResult; // When M = 16, we just need to calculate the number of 8xk tiles, where // k is a factor that depends on the data type. if (mmaShape[0] == 16) { int64_t kFactor; Type multiplicandFragType; switch (*getMultiplicandAPtxType()) { case MMATypes::tf32: kFactor = 4; multiplicandFragType = i32Ty; expectedResult.push_back(LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty})); break; case MMATypes::bf16: kFactor = 8; multiplicandFragType = i32Ty; expectedResult.push_back(LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty})); break; case MMATypes::f16: kFactor = 8; multiplicandFragType = f16x2Ty; expectedResult.push_back(f16x2x2StructTy); expectedResult.push_back(f32x4StructTy); break; case MMATypes::s4: case MMATypes::u4: kFactor = 32; break; case MMATypes::b1: kFactor = 128; break; case MMATypes::s8: case MMATypes::u8: kFactor = 16; break; default: return emitError("invalid shape or multiplicand type: " + stringifyEnum(getMultiplicandAPtxType().value())); } if (isIntegerPtxType(getMultiplicandAPtxType().value())) { expectedResult.push_back(s32x4StructTy); expectedC.emplace_back(4, i32Ty); multiplicandFragType = i32Ty; } else { expectedC.emplace_back(2, f16x2Ty); expectedC.emplace_back(4, f32Ty); } int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); expectedA.emplace_back(unitA, multiplicandFragType); expectedB.emplace_back(unitB, multiplicandFragType); allowedShapes.push_back({16, 8, kFactor}); allowedShapes.push_back({16, 8, kFactor * 2}); if (resultPtxType() != accumPtxType()) return emitOpError("ctype does not match dtype"); } // In the M=8 case, there is only 1 possible case per data type. if (mmaShape[0] == 8) { if (*getMultiplicandAPtxType() == MMATypes::f16) { expectedA.emplace_back(2, f16x2Ty); expectedB.emplace_back(2, f16x2Ty); expectedResult.push_back(f16x2x4StructTy); expectedResult.push_back(f32x8StructTy); expectedC.emplace_back(4, f16x2Ty); expectedC.emplace_back(8, f32Ty); allowedShapes.push_back({8, 8, 4}); } if (*getMultiplicandAPtxType() == MMATypes::f64) { Type f64Ty = Float64Type::get(context); expectedA.emplace_back(1, f64Ty); expectedB.emplace_back(1, f64Ty); expectedC.emplace_back(2, f64Ty); expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( context, SmallVector(2, f64Ty))); allowedShapes.push_back({8, 8, 4}); } if (isIntegerPtxType(getMultiplicandAPtxType().value())) { expectedA.push_back({i32Ty}); expectedB.push_back({i32Ty}); expectedC.push_back({i32Ty, i32Ty}); expectedResult.push_back(s32x2StructTy); if (isInt4PtxType(getMultiplicandAPtxType().value())) allowedShapes.push_back({8, 8, 32}); if (isInt8PtxType(getMultiplicandAPtxType().value())) allowedShapes.push_back({8, 8, 16}); if (getMultiplicandAPtxType().value() == MMATypes::b1) allowedShapes.push_back({8, 8, 128}); } } std::string errorMessage; llvm::raw_string_ostream errorStream(errorMessage); // Check that we matched an existing shape/dtype combination. if (expectedA.empty() || expectedB.empty() || expectedC.empty() || !llvm::is_contained(allowedShapes, mmaShape)) { errorStream << "unimplemented variant for MMA shape <"; llvm::interleaveComma(mmaShape, errorStream); errorStream << ">"; return emitOpError(errorMessage); } // Verify the operand types for segments of A, B, and C operands. std::array operandNames{"A", "B", "C"}; for (const auto &iter : llvm::enumerate( SmallVector{expectedA, expectedB, expectedC})) { auto spec = this->getODSOperandIndexAndLength(iter.index()); SmallVector operandTySeg(operand_type_begin() + spec.first, operand_type_begin() + spec.first + spec.second); bool match = llvm::is_contained(iter.value(), operandTySeg); if (!match) { errorStream << "Could not match types for the " << operandNames[iter.index()] << " operands; expected one of "; for (const auto &x : iter.value()) { errorStream << x.size() << "x" << x[0] << " "; } errorStream << "but got "; llvm::interleaveComma(operandTySeg, errorStream); return emitOpError(errorMessage); } } // Check the result type if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { return expectedResultType == getResult().getType(); })) { errorStream << "Could not match allowed types for the result; expected one of "; llvm::interleaveComma(expectedResult, errorStream); errorStream << " but got " << getResult().getType(); return emitOpError(errorMessage); } // Ensure that binary MMA variants have a b1 MMA operation defined. if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { return emitOpError("op requires " + getB1OpAttrName().strref() + " attribute"); } // Ensure int4/int8 MMA variants specify the accum overflow behavior // attribute. if (isInt4PtxType(*getMultiplicandAPtxType()) || isInt8PtxType(*getMultiplicandAPtxType())) { if (!getIntOverflowBehavior()) return emitOpError("op requires " + getIntOverflowBehaviorAttrName().strref() + " attribute"); } // Validate layout combinations. According to the operation description, most // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 // can use other layout combinations. bool isM8N8K4_F16 = (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && getMultiplicandAPtxType() == MMATypes::f16); if (!isM8N8K4_F16) { // For all other shapes/types, layoutA must be row and layoutB must be col if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { return emitOpError("requires layoutA = #nvvm.mma_layout and " "layoutB = #nvvm.mma_layout for shape <") << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] << "> with element types " << stringifyEnum(*getMultiplicandAPtxType()) << " and " << stringifyEnum(*getMultiplicandBPtxType()) << ". Only m8n8k4 with f16 supports other layouts."; } } return success(); } LogicalResult ShflOp::verify() { if (!(*this)->getAttrOfType("return_value_and_is_valid")) return success(); auto type = llvm::dyn_cast(getType()); auto elementType = (type && type.getBody().size() == 2) ? llvm::dyn_cast(type.getBody()[1]) : nullptr; if (!elementType || elementType.getWidth() != 1) return emitError("expected return type to be a two-element struct with " "i1 as the second element"); return success(); } std::pair NVVM::inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, int nRow, int nCol, MLIRContext *context) { unsigned numberElements = 0; Type elementType; OpBuilder builder(context); Type f16x2 = VectorType::get(2, builder.getF16Type()); if (type == NVVM::MMATypes::f16) { elementType = f16x2; if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) numberElements = 8; else numberElements = 4; } else if (type == NVVM::MMATypes::f32) { elementType = builder.getF32Type(); numberElements = 8; } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { elementType = builder.getI32Type(); int parallelSize = 0; if (frag == NVVM::MMAFrag::a) parallelSize = nRow; if (frag == NVVM::MMAFrag::b) parallelSize = nCol; // m == 16 && n == 16 && k == 16 if (parallelSize == 16) numberElements = 2; // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 else if (parallelSize == 8) numberElements = 1; else if (parallelSize == 32) numberElements = 4; } else if (type == NVVM::MMATypes::s32) { elementType = builder.getI32Type(); numberElements = 8; } assert(numberElements != 0 && elementType != nullptr); return std::make_pair(elementType, numberElements); } static std::pair inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context) { int nRow, nCol; if (frag == NVVM::MMAFrag::a) { nRow = m; nCol = k; } else if (frag == NVVM::MMAFrag::b) { nRow = k; nCol = n; } else { nRow = m; nCol = n; } assert(nRow && nCol); return inferMMAType(type, frag, nRow, nCol, context); } LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global && addressSpace != NVVMMemorySpace::Shared) return emitOpError("expected source pointer in memory " "space 0, 1, 3"); if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), getEltype(), getFrag()) == 0) return emitOpError() << "invalid attribute combination"; std::pair typeInfo = inferMMATypeFromMNK( getEltype(), getFrag(), getM(), getN(), getK(), getContext()); Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector(typeInfo.second, typeInfo.first)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") << typeInfo.second << " elements of type " << typeInfo.first; return success(); } LogicalResult NVVM::WMMAStoreOp::verify() { unsigned addressSpace = llvm::cast(getPtr().getType()).getAddressSpace(); if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global && addressSpace != NVVMMemorySpace::Shared) return emitOpError("expected operands to be a source pointer in memory " "space 0, 1, 3"); if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), getEltype()) == 0) return emitOpError() << "invalid attribute combination"; std::pair typeInfo = inferMMATypeFromMNK( getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); if (getArgs().size() != typeInfo.second) return emitOpError() << "expected " << typeInfo.second << " data operands"; if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { return operands.getType() != typeInfo.first; })) return emitOpError() << "expected data operands of type " << typeInfo.first; return success(); } LogicalResult NVVM::WMMAMmaOp::verify() { if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), getLayoutB(), getEltypeA(), getEltypeB()) == 0) return emitOpError() << "invalid attribute combination"; std::pair typeInfoA = inferMMATypeFromMNK( getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); std::pair typeInfoB = inferMMATypeFromMNK( getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); std::pair typeInfoC = inferMMATypeFromMNK( getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); SmallVector arguments; arguments.append(typeInfoA.second, typeInfoA.first); arguments.append(typeInfoB.second, typeInfoB.first); arguments.append(typeInfoC.second, typeInfoC.first); unsigned numArgs = arguments.size(); if (getArgs().size() != numArgs) return emitOpError() << "expected " << numArgs << " arguments"; for (unsigned i = 0; i < numArgs; i++) { if (getArgs()[i].getType() != arguments[i]) return emitOpError() << "expected argument " << i << " to be of type " << arguments[i]; } Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") << typeInfoC.second << " elements of type " << typeInfoC.first; return success(); } LogicalResult NVVM::LdMatrixOp::verify() { uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); if (m == 8 && n == 8) { if (num != 1 && num != 2 && num != 4) { return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " "matrix"); } if (getEltType() != LdStMatrixEltType::B16) { return emitOpError("expected element type to be b16 for 8x8 matrix"); } } else if (m == 8 && n == 16) { if (num != 1 && num != 2 && num != 4) { return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " "matrix"); } if (getLayout() != MMALayout::row) { return emitOpError("expected layout to be row for 8x16 matrix"); } if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { return emitOpError("expected element type to be b8x16.b4x16_p64 or " "b8x16.b6x16_p32 for 8x16 matrix"); } } else if (m == 16 && n == 16) { if (num != 1 && num != 2) { return emitOpError("expected num attribute to be 1 or 2 for 16x16 " "matrix"); } if (getLayout() != MMALayout::col) { return emitOpError("expected layout to be col for 16x16 matrix"); } if (getEltType() != LdStMatrixEltType::B8 && getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " "b8x16.b6x16_p32 for 16x16 matrix"); } } else { return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); } Type i32 = IntegerType::get(getContext(), 32); uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") << numElements << " elements of type i32"; } return success(); } LogicalResult NVVM::StMatrixOp::verify() { int numMatrix = getSources().size(); if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) return emitOpError("expected num attribute to be 1, 2 or 4"); int m = getShape().getM(), n = getShape().getN(); if (m == 8 && n == 8) { if (getEltType() != NVVM::LdStMatrixEltType::B16) { return emitOpError("expected element type to be B16 for 8x8 matrix"); } } else if (m == 16 && n == 8) { if (getEltType() != NVVM::LdStMatrixEltType::B8) { return emitOpError("expected element type to be B8 for 16x8 matrix"); } if (getLayout() != NVVM::MMALayout::col) { return emitOpError("expected layout to be col for 16x8 matrix"); } } else { return emitOpError("expected shape to be 8x8 or 16x8"); } return success(); } static FailureOr getAllowedSizeK(NVVM::WGMMATypes typeA) { if (typeA == NVVM::WGMMATypes::tf32) return 8; if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) return 16; if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) return 32; if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) return 32; if (typeA == NVVM::WGMMATypes::b1) return 256; return failure(); } static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB) { switch (typeA) { case NVVM::WGMMATypes::f16: if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && typeB == NVVM::WGMMATypes::f16) return success(); break; case NVVM::WGMMATypes::tf32: if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) return success(); break; case NVVM::WGMMATypes::u8: case NVVM::WGMMATypes::s8: if (typeD == NVVM::WGMMATypes::s32 && (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) return success(); break; case NVVM::WGMMATypes::b1: if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) return success(); break; case NVVM::WGMMATypes::bf16: if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && typeB == NVVM::WGMMATypes::bf16) return success(); break; case NVVM::WGMMATypes::e4m3: case NVVM::WGMMATypes::e5m2: if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) return success(); break; case WGMMATypes::f32: case WGMMATypes::s32: llvm_unreachable("unsupported input types"); break; } return failure(); } static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256}; SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256}; switch (typeA) { case WGMMATypes::f16: case WGMMATypes::tf32: case WGMMATypes::bf16: case WGMMATypes::e4m3: case WGMMATypes::e5m2: if (llvm::is_contained(allowedN, sizeN)) return success(); break; case WGMMATypes::u8: case WGMMATypes::s8: case WGMMATypes::b1: if (llvm::is_contained(allowedNshort, sizeN)) return success(); break; case WGMMATypes::f32: case WGMMATypes::s32: llvm_unreachable("unsupported input types"); break; } return failure(); } LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { Value outValue = getResults(); auto stype = dyn_cast(outValue.getType()); if (!stype) return emitOpError() << "expected results to be struct"; int outputSize = stype.getBody().size(); WGMMATypes typeD = getTypeD(); WGMMATypes typeA = getTypeA(); WGMMATypes typeB = getTypeB(); for (Type t : stype.getBody()) { if (t != stype.getBody().front()) return emitOpError() << "all elements in struct must be same type but there is " << t; } if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && typeD != WGMMATypes::s32) { return emitOpError() << "does not support the given output type " << NVVM::stringifyWGMMATypes(typeD); } if (typeD == WGMMATypes::s32 && (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; } if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { return emitOpError() << NVVM::stringifyWGMMATypes(typeD) << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " << NVVM::stringifyWGMMATypes(typeB) << ", it is not supported."; } // Check M if (getShape().getM() != 64) return emitOpError() << "shape 'm' must be 64"; // Check K FailureOr allowedK = getAllowedSizeK(typeA); if (failed(allowedK) || allowedK.value() != getShape().getK()) return emitOpError() << "shape 'k' must be " << allowedK.value() << " for input type " << NVVM::stringifyWGMMATypes(typeA); // Check N if (failed(isAllowedSizeN(getShape().getN(), typeA))) { return emitOpError() << "has input type " << NVVM::stringifyWGMMATypes(typeA) << " n is set to " << getShape().getN() << ", it is not supported."; } // Check transpose (only available for f16/bf16) // Matrices A should be stored in row-major and B in column-major. // Only f16/bf16 matrices can be stored in either column-major or row-major // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code. if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && (getLayoutA() == mlir::NVVM::MMALayout::col || getLayoutB() == mlir::NVVM::MMALayout::row)) { return emitOpError() << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) << " and layout_b = " << stringifyMMALayout(getLayoutB()) << " for input types " << stringifyWGMMATypes(typeA) << " and " << stringifyWGMMATypes(typeB) << " requires transpose. However, this is only supported for: " << stringifyMMATypes(MMATypes::f16) << " and " << stringifyMMATypes(MMATypes::bf16); } // Check result registers int expectedOutput = 0; if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) expectedOutput = getShape().getN() / 2; if (typeD == WGMMATypes::f16) expectedOutput = getShape().getN() / 4; if (outputSize != expectedOutput) { return emitOpError() << "results " << expectedOutput << ", however output struct has " << outputSize << " elements"; } // Check satfinite (only available for s32 accumulator) if (typeD != WGMMATypes::s32 && getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == NVVM::MMAIntOverflow::satfinite) { return emitOpError() << " `satfinite` can be only used with s32 accumulator, however " "the current accumulator is " << NVVM::stringifyWGMMATypes(typeD); } return success(); } std::string NVVM::WgmmaMmaAsyncOp::getPtx() { int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); int expectedOutputRegisters = 0; if (getTypeD() == WGMMATypes::f16) expectedOutputRegisters = getShape().getN() / 4; else expectedOutputRegisters = getShape().getN() / 2; std::string ptx; llvm::raw_string_ostream ss(ptx); ss << "{\n" ".reg .pred p;\n" "setp.ne.b32 p, $" << ((expectedOutputRegisters * 2) + 2) << ", 0;\n" "wgmma.mma_async.sync.aligned.m" << m << "n" << n << "k" << k << "." << outputTypeName << "." << stringifyWGMMATypes(getTypeA()) << "." << stringifyWGMMATypes(getTypeB()); if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == NVVM::MMAIntOverflow::satfinite) ss << ".satfinite"; ss << " {"; int regCnt = 0; for (; regCnt < expectedOutputRegisters; ++regCnt) { ss << "$" << regCnt; if (regCnt != expectedOutputRegisters - 1) ss << ", "; } ss << "},"; // Need to map read/write registers correctly. regCnt = (regCnt * 2); ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } // Don't add transpose parameters unless needed. if (isF16) { ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); } ss << ";\n" << "}\n"; return ptx; } bool NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> &asmValues) { bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; if (getResults()) asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); if (getInouts()) asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); asmValues.push_back({makeConstantI32(rewriter, static_cast(getScaleD())), mlir::NVVM::PTXRegisterMod::Read}); if (getTypeD() != WGMMATypes::s32) { asmValues.push_back( {makeConstantI32(rewriter, getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), mlir::NVVM::PTXRegisterMod::Read}); asmValues.push_back( {makeConstantI32(rewriter, getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), mlir::NVVM::PTXRegisterMod::Read}); } if (isF16) { asmValues.push_back( {makeConstantI32(rewriter, static_cast(getLayoutA())), mlir::NVVM::PTXRegisterMod::Read}); asmValues.push_back( {makeConstantI32(rewriter, 1 - static_cast(getLayoutB())), mlir::NVVM::PTXRegisterMod::Read}); } return true; // Has manual mapping } LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; if (getKind() == NVVM::ProxyKind::GENERIC) return emitOpError() << "generic proxy not a supported proxy kind"; if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { return emitOpError() << "async_shared fence requires space attribute"; } if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { return emitOpError() << "only async_shared fence can have space attribute"; } return success(); } LogicalResult NVVM::FenceProxyAcquireOp::verify() { if (getFromProxy() != NVVM::ProxyKind::GENERIC) return emitOpError("uni-directional proxies only support generic for " "from_proxy attribute"); if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); return success(); } LogicalResult NVVM::FenceProxyReleaseOp::verify() { if (getFromProxy() != NVVM::ProxyKind::GENERIC) return emitOpError("uni-directional proxies only support generic for " "from_proxy attribute"); if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); return success(); } LogicalResult NVVM::SetMaxRegisterOp::verify() { if (getRegCount() % 8) return emitOpError("new register size must be multiple of 8"); if (getRegCount() < 24 || getRegCount() > 256) return emitOpError("new register size must be in between 24 to 256"); return success(); } LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); return success(); } LogicalResult NVVM::Tcgen05CpOp::verify() { auto mc = getMulticast(); using SH = Tcgen05CpShape; using MC = Tcgen05CpMulticast; switch (getShape()) { case SH::SHAPE_128x256b: case SH::SHAPE_128x128b: case SH::SHAPE_4x256b: if (mc != MC::NONE) return emitError("Invalid multicast type for tcgen05.cp Op"); break; case SH::SHAPE_64x128b: if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13) return emitError("Shape 64x128b requires multicast warpx2_01_23 or " "warpx2_02_13 for tcgen05.cp Op"); break; case SH::SHAPE_32x128b: if (mc != MC::WARPX4) return emitError( "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op"); break; } return success(); } LogicalResult NVVM::MatchSyncOp::verify() { if (getKind() == NVVM::MatchSyncKind::all) { auto type = llvm::dyn_cast(getType()); if (!type || type.getBody().size() != 2 || !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) { return emitOpError("match.sync 'all' returns a two element struct with " "first element as i32 and second element as i1"); } } else { if (!getType().isInteger(32)) { return emitOpError("match.sync 'any' returns an i32"); } } return success(); } LogicalResult NVVM::VoteSyncOp::verify() { if (getKind() == NVVM::VoteSyncKind::ballot) { if (!getType().isInteger(32)) { return emitOpError("vote.sync 'ballot' returns an i32"); } } else { if (!getType().isInteger(1)) { return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1"); } } return success(); } LogicalResult NVVM::PrefetchOp::verify() { using MemSpace = NVVM::NVVMMemorySpace; using CacheLevel = NVVM::PrefetchCacheLevel; unsigned addressSpace = llvm::cast(getAddr().getType()).getAddressSpace(); std::optional evictPriority = getEvictPriority(); std::optional cacheLevel = getCacheLevel(); if (getTensormap() && cacheLevel) return emitOpError("cannot specify both tensormap and cache level"); if (getTensormap()) { if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Constant) { return emitOpError( "prefetch tensormap requires a generic or constant pointer"); } if (evictPriority) { return emitOpError( "prefetch tensormap does not support eviction priority"); } if (getInParamSpace() && addressSpace != MemSpace::Generic) { return emitOpError( "in_param_space can only be specified for a generic pointer"); } } else if (cacheLevel) { if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global && addressSpace != MemSpace::Local) { return emitOpError("prefetch to cache level requires a generic, global, " "or local pointer"); } if (getUniform()) { if (*cacheLevel != CacheLevel::L1) { return emitOpError( "unsupported cache level, the only supported uniform " "cache level is L1"); } if (addressSpace != MemSpace::Generic) { return emitOpError( "prefetch to uniform cache requires a generic pointer"); } } if (evictPriority) { if (*cacheLevel != CacheLevel::L2) return emitOpError( "cache eviction priority supported only for cache level L2"); if (addressSpace != MemSpace::Global) return emitOpError("cache eviction priority requires a global pointer"); if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && *evictPriority != NVVM::CacheEvictionPriority::EvictLast) return emitOpError( "unsupported cache eviction priority, only evict_last and " "evict_normal are supported"); } if (getPredicate()) return emitOpError("predicate supported only on prefetch tensormap"); } else { return emitOpError( "requires specification of either cache level or tensormap"); } return success(); } LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() { switch (getQueryType()) { case NVVM::ClusterLaunchControlQueryType::IS_CANCELED: if (!getType().isInteger(1)) return emitOpError("is_canceled query type returns an i1"); break; case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X: case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y: case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z: if (!getType().isInteger(32)) { return emitOpError("get_first_cta_id_x, get_first_cta_id_y, " "get_first_cta_id_z query types return an i32"); } break; } return success(); } /// Packs the given `field` into the `result`. /// The `result` is 64-bits and each `field` can be 32-bits or narrower. static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, // the `result` (unset bits are zero) llvm::Value *field, // `field` to pack into `result` unsigned sizeInBits, // Size of `field` in bits unsigned start) { // Starting bit within `result` field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty()); unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu); if (mask != 0xffffffffu) field = builder.CreateAnd(field, builder.getInt32(mask)); field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty()); field = builder.CreateShl(field, start); return builder.CreateOr(result, field); } void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::Value *smemDesc = builder.getInt64(0); smemDesc = packValInto64Bits(builder, smemDesc, mt.lookupValue(thisOp.getStartAddr()), 14, 0); smemDesc = packValInto64Bits( builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16); smemDesc = packValInto64Bits( builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32); smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46); smemDesc = packValInto64Bits(builder, smemDesc, mt.lookupValue(thisOp.getBaseOffset()), 3, 49); smemDesc = packValInto64Bits( builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52); smemDesc = packValInto64Bits(builder, smemDesc, mt.lookupValue(thisOp.getSwizzleMode()), 3, 61); mt.mapValue(thisOp.getRes()) = smemDesc; } //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \ has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, ) llvm::Intrinsic::ID CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector &args) { llvm::Intrinsic::ID id; auto cpAsyncOp = cast(op); bool hasCpSize = static_cast(cpAsyncOp.getCpSize()); switch (cpAsyncOp.getSize()) { case 4: id = GET_CP_ASYNC_ID(ca, 4, hasCpSize); break; case 8: id = GET_CP_ASYNC_ID(ca, 8, hasCpSize); break; case 16: id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG) ? GET_CP_ASYNC_ID(cg, 16, hasCpSize) : GET_CP_ASYNC_ID(ca, 16, hasCpSize); break; default: llvm_unreachable("Invalid copy size in CpAsyncOp."); } // Fill the Intrinsic Args args.push_back(mt.lookupValue(cpAsyncOp.getDst())); args.push_back(mt.lookupValue(cpAsyncOp.getSrc())); if (hasCpSize) args.push_back(mt.lookupValue(cpAsyncOp.getCpSize())); return id; } mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2; // Fill the Intrinsic Args args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); args.push_back(builder.getInt1(hasCacheHint)); return {id, std::move(args)}; } mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; // Fill the Intrinsic Args: dst, mbar, src, size. args.push_back(mt.lookupValue(thisOp.getDstMem())); args.push_back(mt.lookupValue(thisOp.getMbar())); args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); // Multicast mask, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast(multicastMask); llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global; // Fill the Intrinsic Args args.push_back(mt.lookupValue(thisOp.getDstMem())); args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); args.push_back(builder.getInt1(hasCacheHint)); // Choose the bytemask variant if (mlir::Value byteMask = thisOp.getByteMask()) { args.push_back(mt.lookupValue(byteMask)); id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask; } return {id, std::move(args)}; } bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> &asmValues) { // Add all the operands but not the attrs to the asmValues list. // The attrs here are used to generate the right variants for // intrinsics-lowering. So, we ignore them while generating inline-PTX. for (auto val : getOperands()) asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); return false; } mlir::NVVM::IDArgPair CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); const bool isCTAOnly = thisOp.getIsCTAOnly(); llvm::SmallVector args; // Fill the Intrinsic Args args.push_back(mt.lookupValue(thisOp.getDstMem())); args.push_back(mt.lookupValue(thisOp.getMbar())); args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); // Coordinates and im2col-offsets for (mlir::Value v : thisOp.getCoordinates()) args.push_back(mt.lookupValue(v)); for (mlir::Value v : thisOp.getIm2colOffsets()) args.push_back(mt.lookupValue(v)); // MulticastMask, if available mlir::Value mcMask = thisOp.getMulticastMask(); const bool hasMC = static_cast(mcMask); llvm::Value *i16Zero = llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0); // CacheHint, if available mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Zero = llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); // Flag argument CTAGroup // CTA_1/2 is mapped to values 1 and 2 for the intrinsics. // Hence, the +1 to getGroup(). const int32_t val = thisOp.getGroup() ? (static_cast(*thisOp.getGroup()) + 1) : 0; llvm::Value *cg = llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val); if (!isCTAOnly) { // For shared::cluster, all the arguments that we build are applicable. args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero); args.push_back(builder.getInt1(hasMC)); args.push_back(builder.getInt1(hasCacheHint)); args.push_back(cg); } else { // For shared::cta, only cache-hint is applicable. args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero); args.push_back(builder.getInt1(hasCacheHint)); } constexpr size_t numDims = 5; // 1D to 5D constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4 using rowTy = std::array; using TableTy = std::array; static constexpr TableTy IDTable{ {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}}; static constexpr TableTy IDTableCTA{ {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d}, {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}}; static_assert( (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) && (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1), "TMALoadModes must match number of rows in IDTable and IDTableCTA"); size_t mode = static_cast(thisOp.getMode()); size_t dim = thisOp.getCoordinates().size(); auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim]; assert(id != notIntrinsic && "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp."); return {id, std::move(args)}; } mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; // Fill the Intrinsic Args args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); for (auto v : thisOp.getCoordinates()) args.push_back(mt.lookupValue(v)); for (auto v : thisOp.getIm2colOffsets()) args.push_back(mt.lookupValue(v)); mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); args.push_back(builder.getInt1(hasCacheHint)); const unsigned NI = llvm::Intrinsic::not_intrinsic; static constexpr llvm::Intrinsic::ID IDTable[][6] = { {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d}, {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d}, {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d}, {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d}, {NI, NI, NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}}; static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1, "TMALoadModes must match number of rows in IDTable"); size_t mode = static_cast(thisOp.getMode()); size_t dim = thisOp.getCoordinates().size(); llvm::Intrinsic::ID id = IDTable[mode][dim]; if (id == llvm::Intrinsic::not_intrinsic) llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); return {id, std::move(args)}; } mlir::NVVM::IDArgPair CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; // Fill the Intrinsic Args args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); for (auto v : thisOp.getCoordinates()) args.push_back(mt.lookupValue(v)); mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); args.push_back(builder.getInt1(hasCacheHint)); const unsigned NI = llvm::Intrinsic::not_intrinsic; static constexpr llvm::Intrinsic::ID IDTable[][6] = { {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d}, {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d}, {NI, NI, NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}}; static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1, "TMAStoreModes must match number of rows in IDTable"); size_t mode = static_cast(thisOp.getMode()); size_t dim = thisOp.getCoordinates().size(); llvm::Intrinsic::ID id = IDTable[mode][dim]; if (id == llvm::Intrinsic::not_intrinsic) llvm_unreachable( "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp."); return {id, std::move(args)}; } NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::LLVMContext &ctx = mt.getLLVMContext(); llvm::SmallVector args; // Arguments to the intrinsic: // shared_mem_ptr, tmaDesc, tensorDims // cache_hint(if applicable) and flag(boolean) args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); for (Value v : thisOp.getCoordinates()) args.push_back(mt.lookupValue(v)); mlir::Value cacheHint = thisOp.getL2CacheHint(); const bool hasCacheHint = static_cast(cacheHint); llvm::Value *i64ZeroValue = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue); args.push_back(builder.getInt1(hasCacheHint)); const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic; constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR constexpr unsigned numLayouts = 2; // TILE, IM2COL constexpr unsigned maxDim = 5; // 1D to 5D using row = std::array; using layoutTable = std::array; using fullTable = std::array; static constexpr fullTable IDTable{ {// RedTy::ADD {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}}, // RedTy::MIN {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}}, // RedTy::MAX {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}}, // RedTy::INC {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}}, // RedTy::DEC {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}}, // RedTy::AND {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}}, // RedTy::OR {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}}, // RedTy::XOR {{{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}}, {{notIntrinsic, notIntrinsic, notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d, llvm::Intrinsic:: nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}}; static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1, "TMAReduxKinds must match number of rows in IDTable"); size_t redKind = static_cast(thisOp.getRedKind()); size_t mode = static_cast(thisOp.getMode()); size_t dim = thisOp.getCoordinates().size(); assert(redKind < IDTable.size() && "Invalid redKind for CpAsyncBulkTensorReduceOp"); assert(mode < IDTable[redKind].size() && "Invalid mode for CpAsyncBulkTensorReduceOp"); assert(dim < IDTable[redKind][mode].size() && "Invalid dim for CpAsyncBulkTensorReduceOp"); llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim]; assert(intrinsicID != notIntrinsic && "Invalid intrinsic for CpAsyncBulkTensorReduceOp."); return {intrinsicID, std::move(args)}; } #define _none #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \ hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \ : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf #define GET_CVT_F2TF32_ID(rnd, relu, sf) \ hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \ : CVT_F2TF32_ID_IMPL(rnd, relu, ) llvm::Intrinsic::ID ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { using RndMode = NVVM::FPRoundingMode; bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); switch (rnd) { case RndMode::RN: return GET_CVT_F2TF32_ID(rn, _relu, _satfinite); case RndMode::RZ: return GET_CVT_F2TF32_ID(rz, _relu, _satfinite); case RndMode::RNA: return GET_CVT_F2TF32_ID(rna, _none, _satfinite); default: llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op"); } } NVVM::IDArgPair ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { llvm::SmallVector args; args.push_back(mt.lookupValue(op.getA())); args.push_back(mt.lookupValue(op.getB())); bool hasRelu = op.getRelu(); llvm::Intrinsic::ID intId = hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; return {intId, std::move(args)}; } #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, bool hasRelu) { return llvm::TypeSwitch(dstTy) .Case([&](mlir::Float6E2M3FNType) { return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); }) .Case([&](mlir::Float6E3M2FNType) { return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); }) .Default([](mlir::Type) { llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); return llvm::Intrinsic::not_intrinsic; }); } #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \ : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn llvm::Intrinsic::ID ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); return llvm::TypeSwitch(dstTy) .Case([&](mlir::Float8E4M3FNType) { return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); }) .Case([&](mlir::Float8E5M2Type) { return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); }) .Case([&](mlir::Float8E8M0FNUType) { if (hasRoundingModeRZ) return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); else if (hasRoundingModeRP) return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); }) .Default([](mlir::Type) { llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); return llvm::Intrinsic::not_intrinsic; }); } #define GET_F16x2_TO_F8X2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, bool hasRelu) { return llvm::TypeSwitch(dstTy) .Case([&](mlir::Float8E4M3FNType) { return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); }) .Case([&](mlir::Float8E5M2Type) { return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); }) .Default([](mlir::Type) { llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op"); return llvm::Intrinsic::not_intrinsic; }); } #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \ : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd llvm::Intrinsic::ID ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); switch (rnd) { case NVVM::FPRoundingMode::RZ: return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite); case NVVM::FPRoundingMode::RP: return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite); default: llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op"); } } llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector &args) { auto curOp = cast(op); unsigned as = llvm::cast(curOp.getAddr().getType()) .getAddressSpace(); bool isShared = as == NVVMMemorySpace::Shared; bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id; if (isShared) { id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1; } else { id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1; } // Fill the Intrinsic Args args.push_back(mt.lookupValue(curOp.getAddr())); args.push_back(mt.lookupValue(curOp.getNCols())); return id; } llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector &args) { auto curOp = cast(op); auto id = (curOp.getGroup() == CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2; // Fill the Intrinsic Args args.push_back(mt.lookupValue(curOp.getTaddr())); args.push_back(mt.lookupValue(curOp.getNCols())); return id; } #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \ is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \ : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \ has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \ : TCGEN05_COMMIT_IMPL(cta_group, is_shared, ) llvm::Intrinsic::ID Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector &args) { auto curOp = cast(op); unsigned as = llvm::cast(curOp.getAddr().getType()) .getAddressSpace(); bool isShared = as == NVVMMemorySpace::Shared; bool hasMulticast = static_cast(curOp.getMulticastMask()); bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast) : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast); // Fill the Intrinsic Args args.push_back(mt.lookupValue(curOp.getAddr())); if (hasMulticast) args.push_back(mt.lookupValue(curOp.getMulticastMask())); return id; } #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \ llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \ is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \ : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1) #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \ [&]() -> auto { \ if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \ return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \ if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \ return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \ return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast(op); bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; auto srcFmt = curOp.getSrcFormat(); auto mc = curOp.getMulticast(); switch (curOp.getShape()) { case Tcgen05CpShape::SHAPE_128x256b: return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA); case Tcgen05CpShape::SHAPE_128x128b: return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA); case Tcgen05CpShape::SHAPE_4x256b: return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA); case Tcgen05CpShape::SHAPE_32x128b: return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA); case Tcgen05CpShape::SHAPE_64x128b: return (mc == Tcgen05CpMulticast::WARPX2_01_23) ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA) : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA); } llvm_unreachable("Invalid shape in tcgen05 cp Op"); } // Returns the valid vector length for a given shape and vector length, the // function models the table mentioned in the tcgen05.{ld, st} Op description static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen) { if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B) return vecLen >= 2; if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B) return vecLen >= 4; return true; } LogicalResult Tcgen05LdOp::verify() { LogicalResult result = success(); if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); auto resTy = getRes().getType(); unsigned resLen = isa(resTy) ? llvm::cast(resTy).getNumElements() : 1; if (!isValidVectorLength(getShape(), resLen)) result = emitError(llvm::formatv("invalid result type length {0} for shape " "{1} in tcgen05.ld Op", resLen, stringifyEnum(getShape()))); return result; } LogicalResult Tcgen05StOp::verify() { LogicalResult result = success(); if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); auto valTy = getVal().getType(); unsigned valLen = isa(valTy) ? llvm::cast(valTy).getNumElements() : 1; if (!isValidVectorLength(getShape(), valLen)) result = emitError(llvm::formatv("invalid input length {0} for shape " "{1} in tcgen05.st Op", valLen, stringifyEnum(getShape()))); return result; } /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might /// have ConstantRangeAttr. static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { if (auto rangeAttr = op->getAttrOfType("range")) { setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(), rangeAttr.getLower(), rangeAttr.getUpper()}); } } /// Verify the range attribute satisfies LLVM ConstantRange constructor /// requirements for NVVM SpecialRangeableRegisterOp. static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional rangeAttr) { if (!rangeAttr) return success(); const llvm::APInt &lower = rangeAttr->getLower(); const llvm::APInt &upper = rangeAttr->getUpper(); // Check LLVM ConstantRange constructor condition if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { unsigned bitWidth = lower.getBitWidth(); llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); return op->emitOpError( "invalid range attribute: Lower == Upper, but they aren't min (") << llvm::toString(minVal, 10, false) << ") or max (" << llvm::toString(maxVal, 10, false) << ") value! This is an invalid constant range."; } return success(); } static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, llvm::Type::getInt32Ty(builder.getContext())); } NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); llvm::SmallVector args; args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder)); args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder)); args.push_back(mt.lookupValue(curOp.getC())); bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED; bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED; unsigned type = (isASigned << 1) | isBSigned; const llvm::Intrinsic::ID ids[] = { llvm::Intrinsic::nvvm_idp4a_u_u, llvm::Intrinsic::nvvm_idp4a_u_s, llvm::Intrinsic::nvvm_idp4a_s_u, llvm::Intrinsic::nvvm_idp4a_s_s, }; return {ids[type], args}; } NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); llvm::SmallVector args; args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder)); args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder)); args.push_back(builder.getInt1(curOp.getBHi())); args.push_back(mt.lookupValue(curOp.getC())); bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED; bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED; unsigned type = (isASigned << 1) | isBSigned; const llvm::Intrinsic::ID ids[] = { llvm::Intrinsic::nvvm_idp2a_u_u, llvm::Intrinsic::nvvm_idp2a_u_s, llvm::Intrinsic::nvvm_idp2a_s_u, llvm::Intrinsic::nvvm_idp2a_s_s, }; return {ids[type], args}; } static llvm::Value *getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder) { return builder.CreateAddrSpaceCast( addr, llvm::PointerType::get(builder.getContext(), llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM)); } NVVM::IDArgPair PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { using MemSpace = NVVM::NVVMMemorySpace; using CacheLevel = NVVM::PrefetchCacheLevel; std::optional cacheLevel = op.getCacheLevel(); std::optional evictPriority = op.getEvictPriority(); unsigned addressSpace = llvm::cast(op.getAddr().getType()) .getAddressSpace(); llvm::SmallVector args; llvm::Value *addr = mt.lookupValue(op.getAddr()); args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder) : addr); if (op.getTensormap()) return {llvm::Intrinsic::nvvm_prefetch_tensormap, args}; assert(cacheLevel && "expected cache level for non-tensormap prefetch"); if (op.getUniform() && *cacheLevel == CacheLevel::L1) return {llvm::Intrinsic::nvvm_prefetchu_L1, args}; if (evictPriority && *cacheLevel == CacheLevel::L2) { switch (*evictPriority) { case NVVM::CacheEvictionPriority::EvictLast: return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args}; case NVVM::CacheEvictionPriority::EvictNormal: return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args}; default: llvm_unreachable("Invalid cache eviction priority"); } } switch (static_cast(addressSpace)) { case MemSpace::Generic: return *cacheLevel == CacheLevel::L1 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args}) : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args}); case MemSpace::Global: return *cacheLevel == CacheLevel::L1 ? NVVM::IDArgPair( {llvm::Intrinsic::nvvm_prefetch_global_L1, args}) : NVVM::IDArgPair( {llvm::Intrinsic::nvvm_prefetch_global_L2, args}); case MemSpace::Local: return *cacheLevel == CacheLevel::L1 ? NVVM::IDArgPair( {llvm::Intrinsic::nvvm_prefetch_local_L1, args}) : NVVM::IDArgPair( {llvm::Intrinsic::nvvm_prefetch_local_L2, args}); default: llvm_unreachable("Invalid pointer address space"); } } bool NVVM::InlinePtxOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> &asmValues) { for (auto arg : getReadWriteArgs()) asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite}); for (auto arg : getResults()) asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write}); for (auto arg : getReadOnlyArgs()) asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read}); if (getPredicate()) asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read}); return false; // No manual mapping needed } NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); llvm::SmallVector args; args.push_back(mt.lookupValue(curOp.getSmemAddress())); args.push_back(mt.lookupValue(curOp.getMbarrier())); llvm::Intrinsic::ID intrinsicID = curOp.getMulticast() ? llvm::Intrinsic:: nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared; return {intrinsicID, args}; } NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); llvm::SmallVector args; args.push_back(mt.lookupValue(curOp.getTryCancelResponse())); llvm::Intrinsic::ID intrinsicID; switch (curOp.getQueryType()) { case NVVM::ClusterLaunchControlQueryType::IS_CANCELED: intrinsicID = llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled; break; case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X: intrinsicID = llvm::Intrinsic:: nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x; break; case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y: intrinsicID = llvm::Intrinsic:: nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y; break; case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z: intrinsicID = llvm::Intrinsic:: nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z; break; } return {intrinsicID, args}; } //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// // TODO: This should be the llvm.nvvm dialect once this is supported. void NVVMDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" >(); // Support unknown operations because not all NVVM operations are // registered. allowUnknownOperations(); declarePromisedInterface(); declarePromisedInterface(); } LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { StringAttr attrName = attr.getName(); // Kernel function attribute should be attached to functions. if (attrName == NVVMDialect::getKernelFuncAttrName()) { if (!isa(op)) { return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; } } // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3 // dim if (attrName == NVVMDialect::getMaxntidAttrName() || attrName == NVVMDialect::getReqntidAttrName() || attrName == NVVMDialect::getClusterDimAttrName()) { auto values = llvm::dyn_cast(attr.getValue()); if (!values || values.empty() || values.size() > 3) { return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; } } // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer // attribute if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName() || attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { if (!llvm::dyn_cast(attr.getValue())) { return op->emitError() << "'" << attrName << "' attribute must be integer constant"; } } // blocksareclusters must be used along with reqntid and cluster_dim if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) { if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) || !op->hasAttr(NVVMDialect::getClusterDimAttrName())) { return op->emitError() << "'" << attrName << "' attribute must be used along with " << "'" << NVVMDialect::getReqntidAttrName() << "' and " << "'" << NVVMDialect::getClusterDimAttrName() << "'"; } } return success(); } LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, NamedAttribute argAttr) { auto funcOp = dyn_cast(op); if (!funcOp) return success(); bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); StringAttr attrName = argAttr.getName(); if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { if (!isKernel) { return op->emitError() << "'" << attrName << "' attribute must be present only on kernel arguments"; } if (!isa(argAttr.getValue())) return op->emitError() << "'" << attrName << "' must be a unit attribute"; if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { return op->emitError() << "'" << attrName << "' attribute requires the argument to also have attribute '" << LLVM::LLVMDialect::getByValAttrName() << "'"; } } return success(); } //===----------------------------------------------------------------------===// // NVVM Address Space Attr //===----------------------------------------------------------------------===// unsigned NVVMMemorySpaceAttr::getAddressSpace() const { return static_cast(getValue()); } bool NVVMMemorySpaceAttr::isValidLoad( Type type, ptr::AtomicOrdering ordering, std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); } bool NVVMMemorySpaceAttr::isValidStore( Type type, ptr::AtomicOrdering ordering, std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); } bool NVVMMemorySpaceAttr::isValidAtomicOp( ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { // TODO: update this method once `ptr.atomic_rmw` is implemented. assert(false && "unimplemented, see TODO in the source."); return false; } bool NVVMMemorySpaceAttr::isValidAtomicXchg( Type type, ptr::AtomicOrdering successOrdering, ptr::AtomicOrdering failureOrdering, std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { // TODO: update this method once `ptr.atomic_cmpxchg` is implemented. assert(false && "unimplemented, see TODO in the source."); return false; } bool NVVMMemorySpaceAttr::isValidAddrSpaceCast( Type tgt, Type src, function_ref emitError) const { // TODO: update this method once the `ptr.addrspace_cast` op is added to the // dialect. assert(false && "unimplemented, see TODO in the source."); return false; } bool NVVMMemorySpaceAttr::isValidPtrIntCast( Type intLikeTy, Type ptrLikeTy, function_ref emitError) const { // TODO: update this method once the int-cast ops are added to the `ptr` // dialect. assert(false && "unimplemented, see TODO in the source."); return false; } //===----------------------------------------------------------------------===// // NVVM target attribute. //===----------------------------------------------------------------------===// LogicalResult NVVMTargetAttr::verify(function_ref emitError, int optLevel, StringRef triple, StringRef chip, StringRef features, DictionaryAttr flags, ArrayAttr files, bool verifyTarget) { if (optLevel < 0 || optLevel > 3) { emitError() << "The optimization level must be a number between 0 and 3."; return failure(); } if (triple.empty()) { emitError() << "The target triple cannot be empty."; return failure(); } if (chip.empty()) { emitError() << "The target chip cannot be empty."; return failure(); } if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { return mlir::isa_and_nonnull(attr); })) { emitError() << "All the elements in the `link` array must be strings."; return failure(); } return success(); } LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { if (!getVerifyTarget()) return success(); auto gpuModuleOp = llvm::dyn_cast(gpuModule); if (!gpuModuleOp) { return emitError(gpuModule->getLoc(), "NVVM target attribute must be attached to a GPU module"); } const NVVMCheckSMVersion targetSMVersion = NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip()); if (!targetSMVersion.isMinimumSMVersion()) { return emitError(gpuModule->getLoc(), "Minimum NVVM target SM version is sm_20"); } gpuModuleOp->walk([&](Operation *op) { if (auto reqOp = llvm::dyn_cast(op)) { const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); if (!requirement.isCompatibleWith(targetSMVersion)) { op->emitOpError() << "is not supported on " << getChip(); return WalkResult::interrupt(); } } return WalkResult::advance(); }); return success(); } #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"