aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2023-01-05 21:20:45 +0000
committerThomas Raoux <thomasraoux@google.com>2023-01-06 17:29:30 +0000
commit7efdc117b1518bb11a74cf315b17d4cbb751de6c (patch)
tree41c4bf3742c16a3b7c2d9330c7e73f3ed1ccb255 /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
parent9b5f62685ab447ba9d3ea8ac2616e0c76a44d21b (diff)
downloadllvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.zip
llvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.tar.gz
llvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.tar.bz2
[mlir][nvvm] Add lowering of gpu.printf to nvvm
When converting to nvvm lowering gpu.printf to vprintf allows us to support printing when running on cuda. Differential Revision: https://reviews.llvm.org/D141049
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp104
1 files changed, 89 insertions, 15 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 48effe2..668b443 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -172,7 +172,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
-static const char formatStringPrefix[] = "printfFormat_";
+static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
+ const char formatStringPrefix[] = "printfFormat_";
+ // Get a unique global name.
+ unsigned stringNumber = 0;
+ SmallString<16> stringConstName;
+ do {
+ stringConstName.clear();
+ (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
+ } while (moduleOp.lookupSymbol(stringConstName));
+ return stringConstName;
+}
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
@@ -225,13 +235,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
- // Create a global constant for the format string
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
+ // Get a unique global name for the format.
+ SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
@@ -320,13 +325,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
- // Create a global constant for the format string
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
+ // Get a unique global name for the format.
+ SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
@@ -359,6 +359,80 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
return success();
}
+LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
+ gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = gpuPrintfOp->getLoc();
+
+ mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
+ mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
+
+ // Note: this is the GPUModule op, not the ModuleOp that surrounds it
+ // This ensures that global constants and declarations are placed within
+ // the device code, not the host code
+ auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+
+ auto vprintfType =
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
+ LLVM::LLVMFuncOp vprintfDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
+
+ // Get a unique global name for the format.
+ SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
+
+ llvm::SmallString<20> formatString(adaptor.getFormat());
+ formatString.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+ LLVM::GlobalOp global;
+ {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ global = rewriter.create<LLVM::GlobalOp>(
+ loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
+ rewriter.getStringAttr(formatString), /*allignment=*/0);
+ }
+
+ // Get a pointer to the format string's first element
+ Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ Value stringStart = rewriter.create<LLVM::GEPOp>(
+ loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ SmallVector<Type> types;
+ SmallVector<Value> args;
+ // Promote and pack the arguments into a stack allocation.
+ for (Value arg : adaptor.getArgs()) {
+ Type type = arg.getType();
+ Value promotedArg = arg;
+ assert(type.isIntOrFloat());
+ if (type.isa<FloatType>()) {
+ type = rewriter.getF64Type();
+ promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
+ }
+ types.push_back(type);
+ args.push_back(promotedArg);
+ }
+ Type structType =
+ LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
+ Type structPtrType = LLVM::LLVMPointerType::get(structType);
+ Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(1));
+ Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
+ /*alignment=*/0);
+ for (auto [index, arg] : llvm::enumerate(args)) {
+ Value ptr = rewriter.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
+ ArrayRef<LLVM::GEPArg>{0, index});
+ rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
+ }
+ tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
+ std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
+
+ rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
+ rewriter.eraseOp(gpuPrintfOp);
+ return success();
+}
+
/// Unrolls op if it's operating on vectors.
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,