aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:31:57 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:33:27 +0900
commitdf025ebf872052c0761d44a3ef9b65e9675af8a8 (patch)
tree9b4e94583e2536546d6606270bcdf846c95e1ba2 /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
parent4428c9d0b1344179f85a72e183a44796976521e3 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
downloadllvm-users/chapuni/cov/single/loop.zip
llvm-users/chapuni/cov/single/loop.tar.gz
llvm-users/chapuni/cov/single/loop.tar.bz2
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/loopusers/chapuni/cov/single/loop
Conflicts: clang/lib/CodeGen/CoverageMappingGen.cpp
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp122
1 files changed, 60 insertions, 62 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4..544fc57 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
using namespace mlir;
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+ Location loc, OpBuilder &b,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+ StringRef prefix) {
+ // Get a unique global name.
+ unsigned stringNumber = 0;
+ SmallString<16> stringConstName;
+ do {
+ stringConstName.clear();
+ (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+ } while (moduleOp.lookupSymbol(stringConstName));
+ return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment, unsigned addrSpace) {
+ llvm::SmallString<20> nullTermStr(str);
+ nullTermStr.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+ StringAttr attr = b.getStringAttr(nullTermStr);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
- const char formatStringPrefix[] = "printfFormat_";
- // Get a unique global name.
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
- return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
- OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
- llvm::SmallString<20> formatString(str);
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- StringAttr attr = b.getStringAttr(formatString);
-
- // Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
- if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
- globalOp.getValueAttr() == attr &&
- globalOp.getAlignment().value_or(0) == alignment &&
- globalOp.getAddrSpace() == addrSpace)
- return globalOp;
-
- // Not found: create new global.
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
- SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMFunctionType type) {
- LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
- LLVM::Linkage::External);
- }
- return ret;
-}
-
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
- addressSpace);
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+ /*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);