aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR')
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp11
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp19
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp30
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp2067
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp45
7 files changed, 2145 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index cc66fac..a73f0c1 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa3..160b6ae 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index b8331e0..9f87e50 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) {
MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context,
ArrayRef<ModRefInfo> memInfoArgs) {
if (memInfoArgs.empty())
- return MemoryEffectsAttr::get(context, ModRefInfo::ModRef,
- ModRefInfo::ModRef, ModRefInfo::ModRef);
- if (memInfoArgs.size() == 3)
+ return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef,
+ /*argMem=*/ModRefInfo::ModRef,
+ /*inaccessibleMem=*/ModRefInfo::ModRef,
+ /*errnoMem=*/ModRefInfo::ModRef,
+ /*targetMem0=*/ModRefInfo::ModRef,
+ /*targetMem1=*/ModRefInfo::ModRef);
+ if (memInfoArgs.size() == 6)
return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1],
- memInfoArgs[2]);
+ memInfoArgs[2], memInfoArgs[3],
+ memInfoArgs[4], memInfoArgs[5]);
return {};
}
@@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() {
return false;
if (this->getOther() != ModRefInfo::ModRef)
return false;
+ if (this->getErrnoMem() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem0() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem1() != ModRefInfo::ModRef)
+ return false;
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2731069..5b81948 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -640,8 +640,6 @@ SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
// Code for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
-constexpr int32_t GEPOp::kDynamicIndex;
-
GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
getDynamicIndices());
@@ -4226,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() {
}
//===----------------------------------------------------------------------===//
+// UDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability UDivOp::getSpeculatability() {
+ // X / 0 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
+// SDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability SDivOp::getSpeculatability() {
+ // This function conservatively assumes that all signed division by -1 are
+ // not speculatable.
+ // X / 0 => UB
+ // INT_MIN / -1 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
+ matchPattern(divisor, m_IntRangeWithoutNegOneS()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ce93d18..5dc4fa2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
+static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
// See llvm/lib/IR/Type.cpp for reference.
@@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
properties |=
(LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
+ if (getExtTypeName() == kAMDGCNNamedBarrier)
+ properties |= LLVMTargetExtType::CanBeGlobal;
+
return (properties & prop) == prop;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4db..5ce56e6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -31,6 +31,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
@@ -48,6 +49,47 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInGenericSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
+static bool isPtrInSharedClusterSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
+}
+
+static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
+ llvm::Value *ptr,
+ NVVMMemorySpace targetAS) {
+ unsigned AS = static_cast<unsigned>(targetAS);
+ return builder.CreateAddrSpaceCast(
+ ptr, llvm::PointerType::get(builder.getContext(), AS));
+}
+
+// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
+static llvm::nvvm::CTAGroupKind
+getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
+ switch (ctaGroup) {
+ case NVVM::CTAGroupKind::CTA_1:
+ return llvm::nvvm::CTAGroupKind::CG_1;
+ case NVVM::CTAGroupKind::CTA_2:
+ return llvm::nvvm::CTAGroupKind::CG_2;
+ }
+ llvm_unreachable("unsupported cta_group value");
+}
+
+//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+ bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+ if (isSharedCTA && getMulticastMask())
+ return emitError("Multicast is not supported with shared::cta mode.");
+
+ return success();
+}
+
+static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
+ NVVM::MemScopeKind scope,
+ Value retVal = nullptr) {
+ if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
+ return op->emitError("mbarrier scope must be either CTA or Cluster");
+
+ bool isSharedCluster = isPtrInSharedClusterSpace(addr);
+ bool hasRetValue = static_cast<bool>(retVal);
+ if (isSharedCluster && hasRetValue)
+ return op->emitError(
+ "mbarrier in shared_cluster space cannot return any value");
+
+ return success();
+}
+
+LogicalResult MBarrierArriveOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveExpectTxOp::verify() {
+ // The inline-ptx version of this Op does not support all features.
+ // With predicate, this Op lowers to inline-ptx. So, verify and
+ // error-out if there are unsupported features.
+ if (getPredicate()) {
+ if (getScope() != NVVM::MemScopeKind::CTA)
+ return emitError("mbarrier scope must be CTA when using predicate");
+
+ if (isPtrInSharedClusterSpace(getAddr()))
+ return emitError("mbarrier in shared_cluster space is not supported when "
+ "using predicate");
+
+ if (getRes())
+ return emitError("return-value is not supported when using predicate");
+
+ if (getRelaxed() == true)
+ return emitError("mbarrier with relaxed semantics is not supported when "
+ "using predicate");
+ }
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierCompleteTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTestWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTryWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
@@ -365,6 +484,108 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+LogicalResult PermuteOp::verify() {
+ using Mode = NVVM::PermuteMode;
+ bool hasHi = static_cast<bool>(getHi());
+
+ switch (getMode()) {
+ case Mode::DEFAULT:
+ case Mode::F4E:
+ case Mode::B4E:
+ if (!hasHi)
+ return emitError("mode '")
+ << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
+ break;
+ case Mode::RC8:
+ case Mode::ECL:
+ case Mode::ECR:
+ case Mode::RC16:
+ if (hasHi)
+ return emitError("mode '") << stringifyPermuteMode(getMode())
+ << "' does not accept 'hi' operand.";
+ break;
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Stochastic Rounding Conversion Ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
+ FPRoundingMode rnd,
+ bool hasRandomBits,
+ Operation *op) {
+ static constexpr FPRoundingMode validRndModes[] = {
+ FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
+
+ if (!llvm::is_contained(validRndModes, rnd)) {
+ return op->emitOpError(
+ "Only RN, RZ, and RS rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << dstType << ".";
+ }
+
+ if (rnd == FPRoundingMode::RS) {
+ if (!hasRandomBits) {
+ return op->emitOpError("random_bits is required for RS rounding mode.");
+ }
+ } else {
+ if (hasRandomBits) {
+ return op->emitOpError(
+ "random_bits not supported for RN and RZ rounding modes.");
+ }
+ }
+
+ return success();
+}
+
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
+LogicalResult ConvertF32x4ToF8x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f32x4 to f8x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF6x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x4 to f6x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF4x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from "
+ "f32x4 to f4x4.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -866,16 +1087,517 @@ LogicalResult MmaOp::verify() {
return success();
}
-LogicalResult ShflOp::verify() {
- if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
+MMATypes MmaSpOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+ // Get operands
+ llvm::SmallVector<llvm::Value *> args;
+ for (mlir::Value v : thisOp.getOperands())
+ args.push_back(mt.lookupValue(v));
+
+ // Get intrinsic ID using the existing getIntrinsicID method
+ auto intId = MmaSpOp::getIntrinsicID(
+ thisOp.getShape().getM(), thisOp.getShape().getN(),
+ thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+ thisOp.getOrderedMetadata(), thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(), thisOp.resultPtxType());
+
+ return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 5> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
+ OperandFragment("selector", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+ // Handle variadic operands A, B, C
+ for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == varOperandSpec.first) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ // Handle sparse metadata and selector (single operands)
+ frags[3].regs.push_back(getSparseMetadata());
+ frags[4].regs.push_back(getSparsitySelector());
+
+ auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "]";
+ };
+
+ for (const auto &frag : frags)
+ printMmaSpOperand(frag);
+
+ p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+ p << " : ";
+ p << "(";
+ for (int i = 0; i < 3; ++i) {
+ p << regTypes[i];
+ if (i < 2)
+ p << ", ";
+ }
+ p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(sparseMetadata);
+ result.addOperands(sparsitySelector);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()), 1,
+ 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaSpOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
return success();
- auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
- auto elementType = (type && type.getBody().size() == 2)
- ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
- : nullptr;
- if (!elementType || elementType.getWidth() != 1)
- return emitError("expected return type to be a two-element struct with "
- "i1 as the second element");
+ };
+
+ // Parse the operand segments.
+ if (parseMmaSpOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaSpOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaSpOperand("C", frags[2]).failed())
+ return failure();
+ if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
+ return failure();
+ if (parseMmaSpOperand("selector", frags[4]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ return failure();
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ MmaOp::inferOperandMMAType(frag.regTypes[0],
+ /*isAccumulator*/ iter.index() >= 2);
+ }
+
+ Type resultType;
+ if (parser.parseArrow() || parser.parseType(resultType))
+ return failure();
+ frags[5].elemtype =
+ MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+
+ // Resolve sparse metadata and selector (assume i32 type)
+ Type i32Type = builder.getIntegerType(32);
+ if (parser
+ .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+ if (parser
+ .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.has_value() && !attr.has_value()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.has_value())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1 // sparsitySelector
+ }));
+ return success();
+}
+
+LogicalResult MmaSpOp::verify() {
+ MLIRContext *context = getContext();
+ auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f32Ty = Float32Type::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+ getShapeAttr().getK()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (*getMultiplicandAPtxType()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
+ allowedShapes.push_back({16, 8, 8});
+ allowedShapes.push_back({16, 8, 16});
+ break;
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::f16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k16 and m16n8k32 for f16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
+ allowedShapes.push_back({16, 8, 64});
+ allowedShapes.push_back({16, 8, 128});
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
+ allowedShapes.push_back({16, 8, 32});
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ case MMATypes::e4m3:
+ case MMATypes::e5m2:
+ case MMATypes::e3m2:
+ case MMATypes::e2m3:
+ case MMATypes::e2m1:
+ kFactor = 32;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k64 for FP8 types
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(getMultiplicandAPtxType().value()));
+ }
+
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
+ *getMultiplicandAPtxType() <= MMATypes::e2m1) {
+ // FP8 types
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+ // elements)
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+
+ if (resultPtxType() != accumPtxType())
+ return emitOpError("ctype does not match dtype");
+ }
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (*getMultiplicandAPtxType() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (*getMultiplicandAPtxType() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 16});
+ }
+ }
+
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
+
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::is_contained(allowedShapes, mmaShape)) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
+ }
+
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorMessage);
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorMessage);
+ }
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+ isInt8PtxType(*getMultiplicandAPtxType())) {
+ if (!getIntOverflowBehavior())
+ return emitOpError("op requires " +
+ getIntOverflowBehaviorAttrName().strref() +
+ " attribute");
+ }
+
+ // Validate sparse metadata type (should be i32)
+ if (!getSparseMetadata().getType().isInteger(32)) {
+ return emitOpError() << "sparse metadata must be i32 type";
+ }
+
+ // Validate sparsity selector type (should be i32)
+ if (!getSparsitySelector().getType().isInteger(32)) {
+ return emitOpError() << "sparsity selector must be i32 type";
+ }
+
+ return success();
+}
+
+LogicalResult ShflOp::verify() {
+ auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+
+ auto verifyTypeError = [&](Twine desc, Type expectedType,
+ Type actualType) -> LogicalResult {
+ return emitOpError("expected " + desc + " to be of type ")
+ << expectedType << " but got " << actualType << " instead";
+ };
+
+ if (returnStructType) {
+ if (!getReturnValueAndIsValid())
+ return emitOpError("\"return_value_and_is_valid\" attribute must be "
+ "specified when the return type is a struct type");
+
+ if (returnStructType.getBody().size() != 2)
+ return emitOpError("expected return type to be a two-element struct");
+
+ llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
+ auto resultType = returnStruct[0];
+ if (resultType != getVal().getType())
+ return verifyTypeError("first element in the returned struct",
+ getVal().getType(), resultType);
+
+ auto predicateType = returnStruct[1];
+ if (!predicateType.isInteger(1))
+ return verifyTypeError("second element in the returned struct",
+ mlir::IntegerType::get(getContext(), 1),
+ predicateType);
+ } else {
+ if (getReturnValueAndIsValid())
+ return emitOpError("expected return type to be a two-element struct");
+
+ if (getType() != getVal().getType())
+ return verifyTypeError("return type", getVal().getType(), getType());
+ }
return success();
}
@@ -896,6 +1618,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
} else if (type == NVVM::MMATypes::f32) {
elementType = builder.getF32Type();
numberElements = 8;
+ } else if (type == NVVM::MMATypes::f64) {
+ elementType = builder.getF64Type();
+ if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+ numberElements = 1;
+ else
+ numberElements = 2;
} else if (type == NVVM::MMATypes::tf32) {
elementType = builder.getI32Type();
numberElements = 4;
@@ -954,6 +1682,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+ // Special case for f64 fragments
+ Type f64Ty = Float64Type::get(getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ if (getType() != f64Ty)
+ return emitOpError("expected destination type to be f64");
+ return success();
+ }
+ // Everything else is a struct
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
@@ -1362,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
return true; // Has manual mapping
}
+LogicalResult NVVM::FenceSyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+ return success();
+}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1384,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
-
return success();
}
@@ -1396,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
+ return success();
+}
+
+LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+ if (getFromProxy() != NVVM::ProxyKind::GENERIC)
+ return emitOpError("only generic is support for from_proxy attribute");
+
+ if (getToProxy() != NVVM::ProxyKind::async)
+ return emitOpError("only async is supported for to_proxy attribute");
return success();
}
@@ -1412,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");
+
+ if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
+ return emitOpError("reduction are only available when id is 0");
+
+ if ((getReductionOp() && !getReductionPredicate()) ||
+ (!getReductionOp() && getReductionPredicate()))
+ return emitOpError("reduction predicate and reduction operation must be "
+ "specified together");
+
return success();
}
@@ -1563,6 +2326,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
return success();
}
+LogicalResult NVVM::ReduxOp::verify() {
+ mlir::Type reduxType = getType();
+
+ if (!reduxType.isF32()) {
+ if (getAbs())
+ return emitOpError("abs attribute is supported only for f32 type");
+ if (getNan())
+ return emitOpError("nan attribute is supported only for f32 type");
+ }
+
+ NVVM::ReduxKind kind = getKind();
+ switch (kind) {
+ case NVVM::ReduxKind::ADD:
+ case NVVM::ReduxKind::AND:
+ case NVVM::ReduxKind::OR:
+ case NVVM::ReduxKind::XOR:
+ case NVVM::ReduxKind::MAX:
+ case NVVM::ReduxKind::MIN:
+ case NVVM::ReduxKind::UMAX:
+ case NVVM::ReduxKind::UMIN:
+ if (!reduxType.isInteger(32))
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'i32'.";
+ break;
+ case NVVM::ReduxKind::FMIN:
+ case NVVM::ReduxKind::FMAX:
+ if (!reduxType.isF32())
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'f32'.";
+ break;
+ }
+
+ return success();
+}
+
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
@@ -1608,9 +2408,439 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
}
//===----------------------------------------------------------------------===//
+// getPtx methods
+//===----------------------------------------------------------------------===//
+
+std::string NVVM::MBarrierInitOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
+}
+
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
+}
+
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ llvm::StringRef space = isShared ? ".shared" : "";
+
+ return llvm::formatv("{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}",
+ space);
+}
+
+//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::BarrierOp>(op);
+ llvm::Value *barrierId = thisOp.getBarrierId()
+ ? mt.lookupValue(thisOp.getBarrierId())
+ : builder.getInt32(0);
+ llvm::Intrinsic::ID id;
+ llvm::SmallVector<llvm::Value *> args;
+ if (thisOp.getNumberOfThreads()) {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
+ args.push_back(barrierId);
+ args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
+ } else if (thisOp.getReductionOp()) {
+ switch (*thisOp.getReductionOp()) {
+ case NVVM::BarrierReduction::AND:
+ id = llvm::Intrinsic::nvvm_barrier0_and;
+ break;
+ case NVVM::BarrierReduction::OR:
+ id = llvm::Intrinsic::nvvm_barrier0_or;
+ break;
+ case NVVM::BarrierReduction::POPC:
+ id = llvm::Intrinsic::nvvm_barrier0_popc;
+ break;
+ }
+ args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
+ } else {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
+ args.push_back(barrierId);
+ }
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInitOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
+ : llvm::Intrinsic::nvvm_mbarrier_init;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared
+ ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
+ : llvm::Intrinsic::nvvm_mbarrier_inval;
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
+mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+bool MBarrierArriveExpectTxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ // Add all the operands but not the attrs to the asmValues list.
+ // The attrs here are used to generate the right variants for
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
+ for (auto val : getOperands())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+ return false;
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, input}};
+}
+
+mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ bool hasTicks = static_cast<bool>(thisOp.getTicks());
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ // bit-2: hasTicks
+ size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
+ (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the mbarrier pointer
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mbar);
+ args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
+ if (hasTicks)
+ args.push_back(mt.lookupValue(thisOp.getTicks()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+
+ llvm::Intrinsic::ID id;
+ if (thisOp.getNoinc()) {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
+ } else {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
+ }
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
@@ -1680,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
- // Multicast mask, if available.
+ // Multicast mask for shared::cluster only, if available.
mlir::Value multicastMask = thisOp.getMulticastMask();
const bool hasMulticastMask = static_cast<bool>(multicastMask);
- llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
- args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+ const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+ if (!isSharedCTA) {
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+ : i16Unused);
+ }
// Cache hint, if available.
mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1693,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
// Flag arguments for multicast and cachehint.
- args.push_back(builder.getInt1(hasMulticastMask));
+ if (!isSharedCTA)
+ args.push_back(builder.getInt1(hasMulticastMask));
args.push_back(builder.getInt1(hasCacheHint));
llvm::Intrinsic::ID id =
- llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+ isSharedCTA
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+ : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
return {id, std::move(args)};
}
@@ -2412,6 +3649,155 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
+}
+
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
@@ -2451,6 +3837,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");
+ if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
+ result = emitError("offset argument is only supported for shape 16x32bx2");
+
auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
@@ -2694,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+mlir::NVVM::IDArgPair
+PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::PermuteOp>(op);
+ NVVM::PermuteMode mode = thisOp.getMode();
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
+ llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
+ llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
+ llvm::Intrinsic::nvvm_prmt_rc16};
+
+ unsigned modeIndex = static_cast<unsigned>(mode);
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getLo()));
+
+ // Only first 3 modes (Default, f4e, b4e) need the hi operand.
+ if (modeIndex < 3)
+ args.push_back(mt.lookupValue(thisOp.getHi()));
+
+ args.push_back(mt.lookupValue(thisOp.getSelector()));
+
+ return {IDs[modeIndex], args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ const bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ const unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
+ NVVM::CTAGroupKind ctaGroup, bool hasAShift,
+ NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
+
+ if (disableOutputLane) {
+ mlir::VectorType disableOutputLaneType =
+ cast<mlir::VectorType>(disableOutputLane.getType());
+ if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
+ disableOutputLaneType.getNumElements() != 4) ||
+ (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
+ disableOutputLaneType.getNumElements() != 8))
+ return emitError(loc) << "Disable Output Lane of length "
+ << disableOutputLaneType.getNumElements()
+ << " is incompatible with CtaGroupAttr";
+ }
+
+ if (hasAShift && !isATensor)
+ return emitError(
+ loc, "A-shift can be applied only when matrix A is in tensor memory");
+
+ if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
+ collectorOp == Tcgen05MMACollectorOp::USE))
+ return emitError(
+ loc, "Cannot use collector buffer operation fill or use with ashift");
+
+ return success();
+}
+
+LogicalResult Tcgen05MMAOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
+ : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
+ NVVM::Tcgen05MMABlockScaleKind kind,
+ NVVM::Tcgen05MMABlockScale blockScale,
+ Location loc) {
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
+ kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc, "mxf4nvf4 requires block scale attribute");
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
+ kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc,
+ llvm::formatv("{} kind does not support block16 attribute",
+ stringifyEnum(kind)));
+
+ return success();
+}
+
+LogicalResult Tcgen05MMABlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -2897,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
"Minimum NVVM target SM version is sm_20");
}
- gpuModuleOp->walk([&](Operation *op) {
- if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
- const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
- if (!requirement.isCompatibleWith(targetSMVersion)) {
- op->emitOpError() << "is not supported on " << getChip();
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- });
+ if (gpuModuleOp
+ ->walk([&](Operation *op) {
+ if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
+ const NVVMCheckSMVersion requirement =
+ reqOp.getRequiredMinSMVersion();
+ if (!requirement.isCompatibleWith(targetSMVersion)) {
+ op->emitOpError() << "is not supported on " << getChip();
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
index 67573c4..12dd225 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
@@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
}
+/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations
+/// from different files than their containing function.
static void setLexicalBlockFileAttr(Operation *op) {
- if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) {
+ Location opLoc = op->getLoc();
+
+ if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
auto callerLoc = callSiteLoc.getCaller();
auto calleeLoc = callSiteLoc.getCallee();
LLVM::DIScopeAttr scopeAttr;
@@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) {
op->setLoc(
CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc));
}
+
+ return;
+ }
+
+ auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ if (!funcOp)
+ return;
+
+ FileLineColLoc opFileLoc = extractFileLoc(opLoc);
+ if (!opFileLoc)
+ return;
+
+ FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc());
+ if (!funcFileLoc)
+ return;
+
+ StringRef opFile = opFileLoc.getFilename().getValue();
+ StringRef funcFile = funcFileLoc.getFilename().getValue();
+
+ // Handle cross-file operations: add DILexicalBlockFileAttr when the
+ // operation's source file differs from its containing function.
+ if (opFile != funcFile) {
+ auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc());
+ if (!funcOpLoc)
+ return;
+ auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
+ if (!scopeAttr)
+ return;
+
+ auto *context = op->getContext();
+ LLVM::DIFileAttr opFileAttr =
+ LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile),
+ llvm::sys::path::parent_path(opFile));
+
+ LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr =
+ LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0);
+
+ Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr);
+ op->setLoc(newLoc);
}
}