diff options
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4db..a5ffb9e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, } else if (type == NVVM::MMATypes::f32) { elementType = builder.getF32Type(); numberElements = 8; + } else if (type == NVVM::MMATypes::f64) { + elementType = builder.getF64Type(); + if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) + numberElements = 1; + else + numberElements = 2; } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; @@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() { return emitOpError() << "invalid attribute combination"; std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( getEltype(), getFrag(), getM(), getN(), getK(), getContext()); + // Special case for f64 fragments + Type f64Ty = Float64Type::get(getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + if (getType() != f64Ty) + return emitOpError("expected destination type to be f64"); + return success(); + } + // Everything else is a struct Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); if (getType() != dstType) @@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, } //===----------------------------------------------------------------------===// +// getPtx methods +//===----------------------------------------------------------------------===// + +std::string NVVM::MBarrierInitOp::getPtx() { + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); + return (addressSpace == NVVMMemorySpace::Shared) + ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +//===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInitOp>(op); + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) + .getAddressSpace(); + llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInvalOp>(op); + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) + .getAddressSpace(); + llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + ? llvm::Intrinsic::nvvm_mbarrier_inval_shared + : llvm::Intrinsic::nvvm_mbarrier_inval; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix |
