aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-23 09:50:01 +0200
committerGitHub <noreply@github.com>2024-06-23 09:50:01 +0200
commit346c4a88afedcef3da40f68c83f0a5b3e0ac61ea (patch)
treee57dbaaa79cdd8c44c3c0c4ac2b222bca02583e9
parente7622ab4721141d9e6af6041fa7f9bbc1029e9aa (diff)
downloadllvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.zip
llvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.tar.gz
llvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.tar.bz2
[mlir][NVVM] Disallow results on kernel functions (#96399)
Functions that have the `nvvm.kernel` attribute should have 0 results.
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp13
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir7
2 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 94197e4..3d6a911 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -214,7 +214,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : " << "(";
+ p << " : "
+ << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -955,7 +956,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+ ss << " $" << (regCnt) << ","
+ << " $" << (regCnt + 1) << ","
+ << " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -1053,10 +1056,14 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
StringAttr attrName = attr.getName();
// Kernel function attribute should be attached to functions.
if (attrName == NVVMDialect::getKernelFuncAttrName()) {
- if (!isa<LLVM::LLVMFuncOp>(op)) {
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!funcOp) {
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
<< "' attribute attached to unexpected op";
}
+ if (!funcOp.getResultTypes().empty()) {
+ return op->emitError() << "kernel function cannot have results";
+ }
}
// If maxntid and reqntid exist, it must be an array with max 3 dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a8ae4d9..26ba80c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -574,3 +574,10 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant})
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
+
+// -----
+
+// expected-error @below{{kernel function cannot have results}}
+llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel} {
+ llvm.return %i : i32
+}