diff options
author | Matthias Springer <me@m-sp.org> | 2024-06-23 09:50:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-23 09:50:01 +0200 |
commit | 346c4a88afedcef3da40f68c83f0a5b3e0ac61ea (patch) | |
tree | e57dbaaa79cdd8c44c3c0c4ac2b222bca02583e9 | |
parent | e7622ab4721141d9e6af6041fa7f9bbc1029e9aa (diff) | |
download | llvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.zip llvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.tar.gz llvm-346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.tar.bz2 |
[mlir][NVVM] Disallow results on kernel functions (#96399)
Functions that have the `nvvm.kernel` attribute should have 0 results.
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 13 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/nvvmir.mlir | 7 |
2 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 94197e4..3d6a911 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -214,7 +214,8 @@ void MmaOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); // Print the types of the operands and result. - p << " : " << "("; + p << " : " + << "("; llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), frags[1].regs[0].getType(), frags[2].regs[0].getType()}, @@ -955,7 +956,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { ss << "},"; // Need to map read/write registers correctly. regCnt = (regCnt * 2); - ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; + ss << " $" << (regCnt) << "," + << " $" << (regCnt + 1) << "," + << " p"; if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } @@ -1053,10 +1056,14 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, StringAttr attrName = attr.getName(); // Kernel function attribute should be attached to functions. if (attrName == NVVMDialect::getKernelFuncAttrName()) { - if (!isa<LLVM::LLVMFuncOp>(op)) { + auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!funcOp) { return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; } + if (!funcOp.getResultTypes().empty()) { + return op->emitError() << "kernel function cannot have results"; + } } // If maxntid and reqntid exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMaxntidAttrName() || diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index a8ae4d9..26ba80c 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -574,3 +574,10 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} { llvm.return } + +// ----- + +// expected-error @below{{kernel function cannot have results}} +llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel} { + llvm.return %i : i32 +} |