aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp57
4 files changed, 60 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3eae67f..2731069 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
return structType.getBody()[memberIndex];
return nullptr;
})
- .Default(Type(nullptr));
+ .Default(nullptr);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index cee943d..7d9058c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
.Case<IntegerType, FloatType>([](auto type) {
return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
- .Default([](Type) { return false; });
+ .Default(false);
if (!canConvertType)
return false;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ac35eea..ce93d18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
// clang-format on
.Case<PtrLikeTypeInterface>(
[](Type type) { return isCompatiblePtrType(type); })
- .Default([](Type) { return false; });
+ .Default(false);
if (!result)
compatibleTypes.erase(type);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4db..a5ffb9e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
} else if (type == NVVM::MMATypes::f32) {
elementType = builder.getF32Type();
numberElements = 8;
+ } else if (type == NVVM::MMATypes::f64) {
+ elementType = builder.getF64Type();
+ if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+ numberElements = 1;
+ else
+ numberElements = 2;
} else if (type == NVVM::MMATypes::tf32) {
elementType = builder.getI32Type();
numberElements = 4;
@@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+ // Special case for f64 fragments
+ Type f64Ty = Float64Type::get(getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ if (getType() != f64Ty)
+ return emitOpError("expected destination type to be f64");
+ return success();
+ }
+ // Everything else is a struct
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
@@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
}
//===----------------------------------------------------------------------===//
+// getPtx methods
+//===----------------------------------------------------------------------===//
+
+std::string NVVM::MBarrierInitOp::getPtx() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
+ return (addressSpace == NVVMMemorySpace::Shared)
+ ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
+}
+
+//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInitOp>(op);
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
+ .getAddressSpace();
+ llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ ? llvm::Intrinsic::nvvm_mbarrier_init_shared
+ : llvm::Intrinsic::nvvm_mbarrier_init;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
+ .getAddressSpace();
+ llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
+ : llvm::Intrinsic::nvvm_mbarrier_inval;
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix