diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 64 |
1 files changed, 33 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index a73afbc..2285d26 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -20,20 +20,20 @@ using namespace mlir; -LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, - Location loc, OpBuilder &b, - StringRef name, +LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc, + OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type) { - LLVM::LLVMFuncOp ret; - if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); - ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); - } - return ret; + auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>( + SymbolTable::lookupSymbolIn(moduleOp, name)); + if (existing) + return existing; + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } -static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, +static SmallString<16> getUniqueSymbolName(Operation *moduleOp, StringRef prefix) { // Get a unique global name. unsigned stringNumber = 0; @@ -41,15 +41,16 @@ static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, do { stringConstName.clear(); (prefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName)); return stringConstName; } -LLVM::GlobalOp -mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, - gpu::GPUModuleOp moduleOp, Type llvmI8, - StringRef namePrefix, StringRef str, - uint64_t alignment, unsigned addrSpace) { +LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, + Operation *moduleOp, Type llvmI8, + StringRef namePrefix, + StringRef str, + uint64_t alignment, + unsigned addrSpace) { llvm::SmallString<20> nullTermStr(str); nullTermStr.push_back('\0'); // Null terminate for C auto globalType = @@ -57,7 +58,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, StringAttr attr = b.getStringAttr(nullTermStr); // Try to find existing global. - for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) + for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>()) if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && globalOp.getValueAttr() == attr && globalOp.getAlignment().value_or(0) == alignment && @@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, // Not found: create new global. OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); return LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, @@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); + + Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); auto ocklBegin = getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", @@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); + Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, @@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); + Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); // Create a valid global location removing any metadata attached to the // location as debug info metadata inside of a function cannot be used outside |