diff options
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp | 19 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 30 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2067 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp | 45 |
7 files changed, 2145 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index cc66fac..a73f0c1 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRFunctionInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa3..160b6ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index b8331e0..9f87e50 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) { MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context, ArrayRef<ModRefInfo> memInfoArgs) { if (memInfoArgs.empty()) - return MemoryEffectsAttr::get(context, ModRefInfo::ModRef, - ModRefInfo::ModRef, ModRefInfo::ModRef); - if (memInfoArgs.size() == 3) + return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef, + /*argMem=*/ModRefInfo::ModRef, + /*inaccessibleMem=*/ModRefInfo::ModRef, + /*errnoMem=*/ModRefInfo::ModRef, + /*targetMem0=*/ModRefInfo::ModRef, + /*targetMem1=*/ModRefInfo::ModRef); + if (memInfoArgs.size() == 6) return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1], - memInfoArgs[2]); + memInfoArgs[2], memInfoArgs[3], + memInfoArgs[4], memInfoArgs[5]); return {}; } @@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() { return false; if (this->getOther() != ModRefInfo::ModRef) return false; + if (this->getErrnoMem() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem0() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem1() != ModRefInfo::ModRef) + return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 2731069..5b81948 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -640,8 +640,6 @@ SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { // Code for LLVM::GEPOp. //===----------------------------------------------------------------------===// -constexpr int32_t GEPOp::kDynamicIndex; - GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(), getDynamicIndices()); @@ -4226,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() { } //===----------------------------------------------------------------------===// +// UDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability UDivOp::getSpeculatability() { + // X / 0 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroU())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// +// SDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability SDivOp::getSpeculatability() { + // This function conservatively assumes that all signed division by -1 are + // not speculatable. + // X / 0 => UB + // INT_MIN / -1 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroS()) && + matchPattern(divisor, m_IntRangeWithoutNegOneS())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ce93d18..5dc4fa2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, static constexpr llvm::StringRef kSpirvPrefix = "spirv."; static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; +static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier"; bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { // See llvm/lib/IR/Type.cpp for reference. @@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { properties |= (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + if (getExtTypeName() == kAMDGCNNamedBarrier) + properties |= LLVMTargetExtType::CanBeGlobal; + return (properties & prop) == prop; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4db..5ce56e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/NVPTXAddrSpace.h" @@ -48,6 +49,47 @@ using namespace NVVM; static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; //===----------------------------------------------------------------------===// +// Helper/Utility methods +//===----------------------------------------------------------------------===// + +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInGenericSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + +static bool isPtrInSharedClusterSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster); +} + +static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder, + llvm::Value *ptr, + NVVMMemorySpace targetAS) { + unsigned AS = static_cast<unsigned>(targetAS); + return builder.CreateAddrSpaceCast( + ptr, llvm::PointerType::get(builder.getContext(), AS)); +} + +// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM +static llvm::nvvm::CTAGroupKind +getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) { + switch (ctaGroup) { + case NVVM::CTAGroupKind::CTA_1: + return llvm::nvvm::CTAGroupKind::CG_1; + case NVVM::CTAGroupKind::CTA_2: + return llvm::nvvm::CTAGroupKind::CG_2; + } + llvm_unreachable("unsupported cta_group value"); +} + +//===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { return success(); } +LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { + bool isSharedCTA = isPtrInSharedCTASpace(getDstMem()); + if (isSharedCTA && getMulticastMask()) + return emitError("Multicast is not supported with shared::cta mode."); + + return success(); +} + +static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, + NVVM::MemScopeKind scope, + Value retVal = nullptr) { + if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER) + return op->emitError("mbarrier scope must be either CTA or Cluster"); + + bool isSharedCluster = isPtrInSharedClusterSpace(addr); + bool hasRetValue = static_cast<bool>(retVal); + if (isSharedCluster && hasRetValue) + return op->emitError( + "mbarrier in shared_cluster space cannot return any value"); + + return success(); +} + +LogicalResult MBarrierArriveOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveExpectTxOp::verify() { + // The inline-ptx version of this Op does not support all features. + // With predicate, this Op lowers to inline-ptx. So, verify and + // error-out if there are unsupported features. + if (getPredicate()) { + if (getScope() != NVVM::MemScopeKind::CTA) + return emitError("mbarrier scope must be CTA when using predicate"); + + if (isPtrInSharedClusterSpace(getAddr())) + return emitError("mbarrier in shared_cluster space is not supported when " + "using predicate"); + + if (getRes()) + return emitError("return-value is not supported when using predicate"); + + if (getRelaxed() == true) + return emitError("mbarrier with relaxed semantics is not supported when " + "using predicate"); + } + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierCompleteTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTestWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTryWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -365,6 +484,108 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +LogicalResult PermuteOp::verify() { + using Mode = NVVM::PermuteMode; + bool hasHi = static_cast<bool>(getHi()); + + switch (getMode()) { + case Mode::DEFAULT: + case Mode::F4E: + case Mode::B4E: + if (!hasHi) + return emitError("mode '") + << stringifyPermuteMode(getMode()) << "' requires 'hi' operand."; + break; + case Mode::RC8: + case Mode::ECL: + case Mode::ECR: + case Mode::RC16: + if (hasHi) + return emitError("mode '") << stringifyPermuteMode(getMode()) + << "' does not accept 'hi' operand."; + break; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Stochastic Rounding Conversion Ops +//===----------------------------------------------------------------------===// + +static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, + FPRoundingMode rnd, + bool hasRandomBits, + Operation *op) { + static constexpr FPRoundingMode validRndModes[] = { + FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; + + if (!llvm::is_contained(validRndModes, rnd)) { + return op->emitOpError( + "Only RN, RZ, and RS rounding modes are supported for " + "conversions from f32x2 to ") + << dstType << "."; + } + + if (rnd == FPRoundingMode::RS) { + if (!hasRandomBits) { + return op->emitOpError("random_bits is required for RS rounding mode."); + } + } else { + if (hasRandomBits) { + return op->emitOpError( + "random_bits not supported for RN and RZ rounding modes."); + } + } + + return success(); +} + +LogicalResult ConvertF32x2ToF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + +LogicalResult ConvertF32x2ToBF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + +LogicalResult ConvertF32x4ToF8x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f32x4 to f8x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF6x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x4 to f6x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF4x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from " + "f32x4 to f4x4."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -866,16 +1087,517 @@ LogicalResult MmaOp::verify() { return success(); } -LogicalResult ShflOp::verify() { - if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) +MMATypes MmaSpOp::accumPtxType() { + std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); + assert(val.has_value() && "accumulator PTX type should always be inferrable"); + return val.value(); +} + +MMATypes MmaSpOp::resultPtxType() { + std::optional<mlir::NVVM::MMATypes> val = + MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); + assert(val.has_value() && "result PTX type should always be inferrable"); + return val.value(); +} + +mlir::NVVM::IDArgPair +MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MmaSpOp>(op); + + // Get operands + llvm::SmallVector<llvm::Value *> args; + for (mlir::Value v : thisOp.getOperands()) + args.push_back(mt.lookupValue(v)); + + // Get intrinsic ID using the existing getIntrinsicID method + auto intId = MmaSpOp::getIntrinsicID( + thisOp.getShape().getM(), thisOp.getShape().getN(), + thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(), + thisOp.getOrderedMetadata(), thisOp.getKind(), + *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(), + thisOp.accumPtxType(), thisOp.resultPtxType()); + + return {intId, args}; +} + +void MmaSpOp::print(OpAsmPrinter &p) { + SmallVector<Type, 4> regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector<Value, 4> regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array<OperandFragment, 5> frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", ""), OperandFragment("sparseMetadata", ""), + OperandFragment("selector", "")}; + SmallVector<StringRef, 4> ignoreAttrNames{ + mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()}; + + // Handle variadic operands A, B, C + for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == varOperandSpec.first) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + // Handle sparse metadata and selector (single operands) + frags[3].regs.push_back(getSparseMetadata()); + frags[4].regs.push_back(getSparsitySelector()); + + auto printMmaSpOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "]"; + }; + + for (const auto &frag : frags) + printMmaSpOperand(frag); + + p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames); + p << " : "; + p << "("; + for (int i = 0; i < 3; ++i) { + p << regTypes[i]; + if (i < 2) + p << ", "; + } + p << ") -> " << getResult().getType(); +} + +void MmaSpOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape, + std::optional<MMAIntOverflow> intOverflow, + std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands(sparseMetadata); + result.addOperands(sparsitySelector); + + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (intOverflow.has_value()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + + result.addTypes(resultType); + result.addAttribute( + MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), + static_cast<int32_t>(operandB.size()), + static_cast<int32_t>(operandC.size()), 1, + 1})); // sparseMetadata and sparsitySelector +} + +ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + std::optional<MMATypes> elemtype; + SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; + SmallVector<Type> regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaSpOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); return success(); - auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); - auto elementType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); + }; + + // Parse the operand segments. + if (parseMmaSpOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaSpOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaSpOperand("C", frags[2]).failed()) + return failure(); + if (parseMmaSpOperand("sparseMetadata", frags[3]).failed()) + return failure(); + if (parseMmaSpOperand("selector", frags[4]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector<Type, 3> operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (const auto &iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + MmaOp::inferOperandMMAType(frag.regTypes[0], + /*isAccumulator*/ iter.index() >= 2); + } + + Type resultType; + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + frags[5].elemtype = + MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true); + + // Resolve sparse metadata and selector (assume i32 type) + Type i32Type = builder.getIntegerType(32); + if (parser + .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + if (parser + .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + + std::array<StringRef, 2> names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.has_value() && !attr.has_value()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.has_value()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast<int32_t>(frags[0].regs.size()), + static_cast<int32_t>(frags[1].regs.size()), + static_cast<int32_t>(frags[2].regs.size()), + 1, // sparseMetadata + 1 // sparsitySelector + })); + return success(); +} + +LogicalResult MmaSpOp::verify() { + MLIRContext *context = getContext(); + auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32Ty = Float32Type::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), + getShapeAttr().getK()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; + using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector<Type> expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (*getMultiplicandAPtxType()) { + case MMATypes::tf32: + kFactor = 4; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k8 and m16n8k16 for tf32 + allowedShapes.push_back({16, 8, 8}); + allowedShapes.push_back({16, 8, 16}); + break; + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k16 and m16n8k32 for bf16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::f16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k16 and m16n8k32 for f16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4 + allowedShapes.push_back({16, 8, 64}); + allowedShapes.push_back({16, 8, 128}); + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8 + allowedShapes.push_back({16, 8, 32}); + allowedShapes.push_back({16, 8, 64}); + break; + case MMATypes::e4m3: + case MMATypes::e5m2: + case MMATypes::e3m2: + case MMATypes::e2m3: + case MMATypes::e2m1: + kFactor = 32; + multiplicandFragType = i32Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k64 for FP8 types + allowedShapes.push_back({16, 8, 64}); + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(getMultiplicandAPtxType().value())); + } + + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 && + *getMultiplicandAPtxType() <= MMATypes::e2m1) { + // FP8 types + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + // For sparse MMA, A operand is compressed (2:4 sparsity means half the + // elements) + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2; + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + + if (resultPtxType() != accumPtxType()) + return emitOpError("ctype does not match dtype"); + } + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (*getMultiplicandAPtxType() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (*getMultiplicandAPtxType() == MMATypes::f64) { + Type f64Ty = Float64Type::get(context); + expectedA.emplace_back(1, f64Ty); + expectedB.emplace_back(1, f64Ty); + expectedC.emplace_back(2, f64Ty); + expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( + context, SmallVector<Type>(2, f64Ty))); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 16}); + } + } + + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); + + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::is_contained(allowedShapes, mmaShape)) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); + } + + // Verify the operand types for segments of A, B, and C operands. + std::array<StringRef, 3> operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = llvm::is_contained(iter.value(), operandTySeg); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorMessage); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorMessage); + } + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*getMultiplicandAPtxType()) || + isInt8PtxType(*getMultiplicandAPtxType())) { + if (!getIntOverflowBehavior()) + return emitOpError("op requires " + + getIntOverflowBehaviorAttrName().strref() + + " attribute"); + } + + // Validate sparse metadata type (should be i32) + if (!getSparseMetadata().getType().isInteger(32)) { + return emitOpError() << "sparse metadata must be i32 type"; + } + + // Validate sparsity selector type (should be i32) + if (!getSparsitySelector().getType().isInteger(32)) { + return emitOpError() << "sparsity selector must be i32 type"; + } + + return success(); +} + +LogicalResult ShflOp::verify() { + auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); + + auto verifyTypeError = [&](Twine desc, Type expectedType, + Type actualType) -> LogicalResult { + return emitOpError("expected " + desc + " to be of type ") + << expectedType << " but got " << actualType << " instead"; + }; + + if (returnStructType) { + if (!getReturnValueAndIsValid()) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when the return type is a struct type"); + + if (returnStructType.getBody().size() != 2) + return emitOpError("expected return type to be a two-element struct"); + + llvm::ArrayRef<Type> returnStruct = returnStructType.getBody(); + auto resultType = returnStruct[0]; + if (resultType != getVal().getType()) + return verifyTypeError("first element in the returned struct", + getVal().getType(), resultType); + + auto predicateType = returnStruct[1]; + if (!predicateType.isInteger(1)) + return verifyTypeError("second element in the returned struct", + mlir::IntegerType::get(getContext(), 1), + predicateType); + } else { + if (getReturnValueAndIsValid()) + return emitOpError("expected return type to be a two-element struct"); + + if (getType() != getVal().getType()) + return verifyTypeError("return type", getVal().getType(), getType()); + } return success(); } @@ -896,6 +1618,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, } else if (type == NVVM::MMATypes::f32) { elementType = builder.getF32Type(); numberElements = 8; + } else if (type == NVVM::MMATypes::f64) { + elementType = builder.getF64Type(); + if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) + numberElements = 1; + else + numberElements = 2; } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; @@ -954,6 +1682,14 @@ LogicalResult NVVM::WMMALoadOp::verify() { return emitOpError() << "invalid attribute combination"; std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( getEltype(), getFrag(), getM(), getN(), getK(), getContext()); + // Special case for f64 fragments + Type f64Ty = Float64Type::get(getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + if (getType() != f64Ty) + return emitOpError("expected destination type to be f64"); + return success(); + } + // Everything else is a struct Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); if (getType() != dstType) @@ -1362,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues( return true; // Has manual mapping } +LogicalResult NVVM::FenceSyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + return success(); +} + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1384,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); - return success(); } @@ -1396,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); + return success(); +} + +LogicalResult NVVM::FenceProxySyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + if (getFromProxy() != NVVM::ProxyKind::GENERIC) + return emitOpError("only generic is support for from_proxy attribute"); + + if (getToProxy() != NVVM::ProxyKind::async) + return emitOpError("only async is supported for to_proxy attribute"); return success(); } @@ -1412,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); + + if (getBarrierId() && (getReductionOp() || getReductionPredicate())) + return emitOpError("reduction are only available when id is 0"); + + if ((getReductionOp() && !getReductionPredicate()) || + (!getReductionOp() && getReductionPredicate())) + return emitOpError("reduction predicate and reduction operation must be " + "specified together"); + return success(); } @@ -1563,6 +2326,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() { return success(); } +LogicalResult NVVM::ReduxOp::verify() { + mlir::Type reduxType = getType(); + + if (!reduxType.isF32()) { + if (getAbs()) + return emitOpError("abs attribute is supported only for f32 type"); + if (getNan()) + return emitOpError("nan attribute is supported only for f32 type"); + } + + NVVM::ReduxKind kind = getKind(); + switch (kind) { + case NVVM::ReduxKind::ADD: + case NVVM::ReduxKind::AND: + case NVVM::ReduxKind::OR: + case NVVM::ReduxKind::XOR: + case NVVM::ReduxKind::MAX: + case NVVM::ReduxKind::MIN: + case NVVM::ReduxKind::UMAX: + case NVVM::ReduxKind::UMIN: + if (!reduxType.isInteger(32)) + return emitOpError("'") + << stringifyEnum(kind) << "' redux kind unsupported with " + << reduxType << " type. Only supported type is 'i32'."; + break; + case NVVM::ReduxKind::FMIN: + case NVVM::ReduxKind::FMAX: + if (!reduxType.isF32()) + return emitOpError("'") + << stringifyEnum(kind) << "' redux kind unsupported with " + << reduxType << " type. Only supported type is 'f32'."; + break; + } + + return success(); +} + /// Packs the given `field` into the `result`. /// The `result` is 64-bits and each `field` can be 32-bits or narrower. static llvm::Value * @@ -1608,9 +2408,439 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, } //===----------------------------------------------------------------------===// +// getPtx methods +//===----------------------------------------------------------------------===// + +std::string NVVM::MBarrierInitOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +std::string NVVM::MBarrierArriveExpectTxOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared + ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;") + : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); +} + +std::string NVVM::MBarrierTryWaitParityOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + llvm::StringRef space = isShared ? ".shared" : ""; + + return llvm::formatv("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}", + space); +} + +//===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::BarrierOp>(op); + llvm::Value *barrierId = thisOp.getBarrierId() + ? mt.lookupValue(thisOp.getBarrierId()) + : builder.getInt32(0); + llvm::Intrinsic::ID id; + llvm::SmallVector<llvm::Value *> args; + if (thisOp.getNumberOfThreads()) { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count; + args.push_back(barrierId); + args.push_back(mt.lookupValue(thisOp.getNumberOfThreads())); + } else if (thisOp.getReductionOp()) { + switch (*thisOp.getReductionOp()) { + case NVVM::BarrierReduction::AND: + id = llvm::Intrinsic::nvvm_barrier0_and; + break; + case NVVM::BarrierReduction::OR: + id = llvm::Intrinsic::nvvm_barrier0_or; + break; + case NVVM::BarrierReduction::POPC: + id = llvm::Intrinsic::nvvm_barrier0_popc; + break; + } + args.push_back(mt.lookupValue(thisOp.getReductionPredicate())); + } else { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all; + args.push_back(barrierId); + } + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInitOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInvalOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_inval_shared + : llvm::Intrinsic::nvvm_mbarrier_inval; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + +mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +bool MBarrierArriveExpectTxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + // Add all the operands but not the attrs to the asmValues list. + // The attrs here are used to generate the right variants for + // intrinsics-lowering. So, we ignore them while generating inline-PTX. + for (auto val : getOperands()) + asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); + + return false; +} + +mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: isPhaseParity + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, input}}; +} + +mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + bool hasTicks = static_cast<bool>(thisOp.getTicks()); + // bit-0: isPhaseParity + // bit-1: Scope + // bit-2: hasTicks + size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) | + (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the mbarrier pointer + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mbar); + args.push_back(mt.lookupValue(thisOp.getStateOrPhase())); + if (hasTicks) + args.push_back(mt.lookupValue(thisOp.getTicks())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + + llvm::Intrinsic::ID id; + if (thisOp.getNoinc()) { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc; + } else { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; + } + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix @@ -1680,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); - // Multicast mask, if available. + // Multicast mask for shared::cluster only, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast<bool>(multicastMask); - llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); - args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem()); + if (!isSharedCTA) { + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) + : i16Unused); + } // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); @@ -1693,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. - args.push_back(builder.getInt1(hasMulticastMask)); + if (!isSharedCTA) + args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + isSharedCTA + ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta + : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } @@ -2412,6 +3649,155 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() +NVVM::IDArgPair +ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rn, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rz, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rs, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); + } +} + +NVVM::IDArgPair +ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rn, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rz, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rs, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); + } +} + +llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite; + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite; + }) + .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast<NVVM::Tcgen05CpOp>(op); bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; @@ -2451,6 +3837,9 @@ LogicalResult Tcgen05LdOp::verify() { if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); + if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) + result = emitError("offset argument is only supported for shape 16x32bx2"); + auto resTy = getRes().getType(); unsigned resLen = isa<VectorType>(resTy) ? llvm::cast<VectorType>(resTy).getNumElements() @@ -2694,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( return {intrinsicID, args}; } +mlir::NVVM::IDArgPair +PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::PermuteOp>(op); + NVVM::PermuteMode mode = thisOp.getMode(); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e, + llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8, + llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr, + llvm::Intrinsic::nvvm_prmt_rc16}; + + unsigned modeIndex = static_cast<unsigned>(mode); + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getLo())); + + // Only first 3 modes (Default, f4e, b4e) need the hi operand. + if (modeIndex < 3) + args.push_back(mt.lookupValue(thisOp.getHi())); + + args.push_back(mt.lookupValue(thisOp.getSelector())); + + return {IDs[modeIndex], args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair +Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + const bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + const unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, + NVVM::CTAGroupKind ctaGroup, bool hasAShift, + NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) { + + if (disableOutputLane) { + mlir::VectorType disableOutputLaneType = + cast<mlir::VectorType>(disableOutputLane.getType()); + if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 && + disableOutputLaneType.getNumElements() != 4) || + (ctaGroup == NVVM::CTAGroupKind::CTA_2 && + disableOutputLaneType.getNumElements() != 8)) + return emitError(loc) << "Disable Output Lane of length " + << disableOutputLaneType.getNumElements() + << " is incompatible with CtaGroupAttr"; + } + + if (hasAShift && !isATensor) + return emitError( + loc, "A-shift can be applied only when matrix A is in tensor memory"); + + if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL || + collectorOp == Tcgen05MMACollectorOp::USE)) + return emitError( + loc, "Cannot use collector buffer operation fill or use with ashift"); + + return success(); +} + +LogicalResult Tcgen05MMAOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale + : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.block_scale attributes"); + }(); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, + NVVM::Tcgen05MMABlockScaleKind kind, + NVVM::Tcgen05MMABlockScale blockScale, + Location loc) { + + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT && + kind == Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, "mxf4nvf4 requires block scale attribute"); + + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 && + kind != Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, + llvm::formatv("{} kind does not support block16 attribute", + stringifyEnum(kind))); + + return success(); +} + +LogicalResult Tcgen05MMABlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes"); + }(); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseBlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -2897,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { "Minimum NVVM target SM version is sm_20"); } - gpuModuleOp->walk([&](Operation *op) { - if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { - const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); - if (!requirement.isCompatibleWith(targetSMVersion)) { - op->emitOpError() << "is not supported on " << getChip(); - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + if (gpuModuleOp + ->walk([&](Operation *op) { + if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { + const NVVMCheckSMVersion requirement = + reqOp.getRequiredMinSMVersion(); + if (!requirement.isCompatibleWith(targetSMVersion)) { + op->emitOpError() << "is not supported on " << getChip(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 67573c4..12dd225 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); } +/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations +/// from different files than their containing function. static void setLexicalBlockFileAttr(Operation *op) { - if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) { + Location opLoc = op->getLoc(); + + if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) { auto callerLoc = callSiteLoc.getCaller(); auto calleeLoc = callSiteLoc.getCallee(); LLVM::DIScopeAttr scopeAttr; @@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) { op->setLoc( CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc)); } + + return; + } + + auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); + if (!funcOp) + return; + + FileLineColLoc opFileLoc = extractFileLoc(opLoc); + if (!opFileLoc) + return; + + FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc()); + if (!funcFileLoc) + return; + + StringRef opFile = opFileLoc.getFilename().getValue(); + StringRef funcFile = funcFileLoc.getFilename().getValue(); + + // Handle cross-file operations: add DILexicalBlockFileAttr when the + // operation's source file differs from its containing function. + if (opFile != funcFile) { + auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc()); + if (!funcOpLoc) + return; + auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata()); + if (!scopeAttr) + return; + + auto *context = op->getContext(); + LLVM::DIFileAttr opFileAttr = + LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile), + llvm::sys::path::parent_path(opFile)); + + LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr = + LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0); + + Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr); + op->setLoc(newLoc); } } |
