aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h21
2 files changed, 21 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 2285d26..eb662a1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
/*isVarArg=*/true);
LLVM::LLVMFuncOp printfDecl =
- getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+ getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
+ printfDecl.setCConv(callingConvention);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateStringConstant(
@@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ call.setCConv(callingConvention);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 66d3bb4..ec74787 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
namespace mlir {
@@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
/// This pass will add a declaration of printf() to the GPUModule if needed
/// and separate out the format strings into global constants. For some
/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
-/// will lower printf calls to appropriate device-side code
+/// will lower printf calls to appropriate device-side code.
+/// However not all backends use the same calling convention and function
+/// naming.
+/// For example, the LLVM SPIRV backend requires calling convention
+/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be
+/// mangled as "_Z6printfPU3AS2Kcz".
+/// Default callingConvention is LLVM::cconv::CConv::C and
+/// funcName is "printf" but they can be customized as needed.
struct GPUPrintfOpToLLVMCallLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
- GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
- int addressSpace = 0)
+ GPUPrintfOpToLLVMCallLowering(
+ const LLVMTypeConverter &converter, int addressSpace = 0,
+ LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
+ StringRef funcName = "printf")
: ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
- addressSpace(addressSpace) {}
+ addressSpace(addressSpace), callingConvention(callingConvention),
+ funcName(funcName) {}
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering
private:
int addressSpace;
+ LLVM::cconv::CConv callingConvention;
+ StringRef funcName;
};
/// Lowering of gpu.printf to a vprintf standard library.