diff options
author | Thomas Raoux <thomasraoux@google.com> | 2023-01-05 21:20:45 +0000 |
---|---|---|
committer | Thomas Raoux <thomasraoux@google.com> | 2023-01-06 17:29:30 +0000 |
commit | 7efdc117b1518bb11a74cf315b17d4cbb751de6c (patch) | |
tree | 41c4bf3742c16a3b7c2d9330c7e73f3ed1ccb255 /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | |
parent | 9b5f62685ab447ba9d3ea8ac2616e0c76a44d21b (diff) | |
download | llvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.zip llvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.tar.gz llvm-7efdc117b1518bb11a74cf315b17d4cbb751de6c.tar.bz2 |
[mlir][nvvm] Add lowering of gpu.printf to nvvm
When converting to nvvm lowering gpu.printf to vprintf allows us to
support printing when running on cuda.
Differential Revision: https://reviews.llvm.org/D141049
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 104 |
1 files changed, 89 insertions, 15 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 48effe2..668b443 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -172,7 +172,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, return success(); } -static const char formatStringPrefix[] = "printfFormat_"; +static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { + const char formatStringPrefix[] = "printfFormat_"; + // Get a unique global name. + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + return stringConstName; +} template <typename T> static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, @@ -225,13 +235,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); - // Create a global constant for the format string - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); llvm::SmallString<20> formatString(adaptor.getFormat()); formatString.push_back('\0'); // Null terminate for C @@ -320,13 +325,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( LLVM::LLVMFuncOp printfDecl = getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); - // Create a global constant for the format string - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); llvm::SmallString<20> formatString(adaptor.getFormat()); formatString.push_back('\0'); // Null terminate for C @@ -359,6 +359,80 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( return success(); } +LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( + gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = gpuPrintfOp->getLoc(); + + mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); + mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); + + // Note: this is the GPUModule op, not the ModuleOp that surrounds it + // This ensures that global constants and declarations are placed within + // the device code, not the host code + auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); + + auto vprintfType = + LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr}); + LLVM::LLVMFuncOp vprintfDecl = + getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); + + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); + + llvm::SmallString<20> formatString(adaptor.getFormat()); + formatString.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create<LLVM::GlobalOp>( + loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString), /*allignment=*/0); + } + + // Get a pointer to the format string's first element + Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); + Value stringStart = rewriter.create<LLVM::GEPOp>( + loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + SmallVector<Type> types; + SmallVector<Value> args; + // Promote and pack the arguments into a stack allocation. + for (Value arg : adaptor.getArgs()) { + Type type = arg.getType(); + Value promotedArg = arg; + assert(type.isIntOrFloat()); + if (type.isa<FloatType>()) { + type = rewriter.getF64Type(); + promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); + } + types.push_back(type); + args.push_back(promotedArg); + } + Type structType = + LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); + Type structPtrType = LLVM::LLVMPointerType::get(structType); + Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); + Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one, + /*alignment=*/0); + for (auto [index, arg] : llvm::enumerate(args)) { + Value ptr = rewriter.create<LLVM::GEPOp>( + loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc, + ArrayRef<LLVM::GEPArg>{0, index}); + rewriter.create<LLVM::StoreOp>(loc, arg, ptr); + } + tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc); + std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; + + rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); + rewriter.eraseOp(gpuPrintfOp); + return success(); +} + /// Unrolls op if it's operating on vectors. LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, |