aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp64
1 files changed, 33 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a73afbc..2285d26 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -20,20 +20,20 @@
using namespace mlir;
-LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
- Location loc, OpBuilder &b,
- StringRef name,
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type) {
- LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
- ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
- }
- return ret;
+ auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
+ if (existing)
+ return existing;
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
+ return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
-static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
@@ -41,15 +41,16 @@ static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
+ } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
return stringConstName;
}
-LLVM::GlobalOp
-mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
- gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef namePrefix, StringRef str,
- uint64_t alignment, unsigned addrSpace) {
+LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ Operation *moduleOp, Type llvmI8,
+ StringRef namePrefix,
+ StringRef str,
+ uint64_t alignment,
+ unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
@@ -57,7 +58,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
StringAttr attr = b.getStringAttr(nullTermStr);
// Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
@@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return LLVM::GlobalOp::create(b, loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
@@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
- // This ensures that global constants and declarations are placed within
- // the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+
+ Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto ocklBegin =
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
@@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
mlir::Type ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
- // This ensures that global constants and declarations are placed within
- // the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto printfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
@@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
- // This ensures that global constants and declarations are placed within
- // the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
// Create a valid global location removing any metadata attached to the
// location as debug info metadata inside of a function cannot be used outside