diff options
author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:15:55 +0900 |
---|---|---|
committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:15:55 +0900 |
commit | bdcf47e4bcb92889665825654bb80a8bbe30379e (patch) | |
tree | 4de1d6b4ddc69f4f32daabb11ad5c71ab0cf895e /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | |
parent | e7fd5cd25334048980ea207a9eff72698724721a (diff) | |
parent | fea7da1b00cc97d742faede2df96c7d327950f49 (diff) | |
download | llvm-users/chapuni/cov/single/base.zip llvm-users/chapuni/cov/single/base.tar.gz llvm-users/chapuni/cov/single/base.tar.bz2 |
Merge branch 'users/chapuni/cov/single/nextcount' into users/chapuni/cov/single/baseusers/chapuni/cov/single/base
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 122 |
1 files changed, 60 insertions, 62 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index b3c3fd4..544fc57 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -19,6 +19,59 @@ using namespace mlir; +LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, + Location loc, OpBuilder &b, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); + } + return ret; +} + +static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, + StringRef prefix) { + // Get a unique global name. + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (prefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + return stringConstName; +} + +LLVM::GlobalOp +mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, + gpu::GPUModuleOp moduleOp, Type llvmI8, + StringRef namePrefix, StringRef str, + uint64_t alignment, unsigned addrSpace) { + llvm::SmallString<20> nullTermStr(str); + nullTermStr.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); + StringAttr attr = b.getStringAttr(nullTermStr); + + // Try to find existing global. + for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) + if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && + globalOp.getValueAttr() == attr && + globalOp.getAlignment().value_or(0) == alignment && + globalOp.getAddrSpace() == addrSpace) + return globalOp; + + // Not found: create new global. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleOp.getBody()); + SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); + return b.create<LLVM::GlobalOp>(loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); +} + LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, return success(); } -static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { - const char formatStringPrefix[] = "printfFormat_"; - // Get a unique global name. - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); - return stringConstName; -} - -/// Create an global that contains the given format string. If a global with -/// the same format string exists already in the module, return that global. -static LLVM::GlobalOp getOrCreateFormatStringConstant( - OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, - StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) { - llvm::SmallString<20> formatString(str); - formatString.push_back('\0'); // Null terminate for C - auto globalType = - LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); - StringAttr attr = b.getStringAttr(formatString); - - // Try to find existing global. - for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) - if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && - globalOp.getValueAttr() == attr && - globalOp.getAlignment().value_or(0) == alignment && - globalOp.getAddrSpace() == addrSpace) - return globalOp; - - // Not found: create new global. - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); - SmallString<16> name = getUniqueFormatGlobalName(moduleOp); - return b.create<LLVM::GlobalOp>(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); -} - -template <typename T> -static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, - LLVM::LLVMFunctionType type) { - LLVM::LLVMFuncOp ret; - if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, - LLVM::Linkage::External); - } - return ret; -} - LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() Value globalPtr = rewriter.create<LLVM::AddressOfOp>( @@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0, - addressSpace); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), + /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element Value globalPtr = rewriter.create<LLVM::AddressOfOp>( @@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); // Create the global op or find an existing one. - LLVM::GlobalOp global = getOrCreateFormatStringConstant( - rewriter, loc, moduleOp, llvmI8, adaptor.getFormat()); + LLVM::GlobalOp global = getOrCreateStringConstant( + rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); |