diff options
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 95d0ada..f9b9b850 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::AccClauseList &accClauseList) { + const Fortran::parser::AccClauseList &accClauseList, + Fortran::lower::SymMap &localSymbols) { mlir::Value ifCond; llvm::SmallVector<mlir::Value> dataOperands; bool addIfPresentAttr = false; @@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *useDevice = std::get_if<Fortran::parser::AccClause::UseDevice>( &clause.u)) { + // When CUDA Fotran is enabled, extra symbols are used in the host_data + // region. Look for them and bind their values with the symbols in the + // outer scope. + if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) { + const Fortran::parser::AccObjectList &objectList{useDevice->v}; + for (const auto &accObject : objectList.v) { + Fortran::semantics::Symbol &symbol = + getSymbolFromAccObject(accObject); + const Fortran::semantics::Symbol *baseSym = + localSymbols.lookupSymbolByName(symbol.name().ToString()); + localSymbols.copySymbolBinding(*baseSym, symbol); + } + } genDataOperandOperations<mlir::acc::UseDeviceOp>( useDevice->v, converter, semanticsContext, stmtCtx, dataOperands, mlir::acc::DataClause::acc_use_device, @@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, hostDataOp.setIfPresentAttr(builder.getUnitAttr()); } -static void -genACC(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCBlockConstruct &blockConstruct, + Fortran::lower::SymMap &localSymbols) { const auto &beginBlockDirective = std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t); const auto &blockDirective = @@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter, accClauseList); } else if (blockDirective.v == llvm::acc::ACCD_host_data) { genACCHostDataOp(converter, currentLocation, eval, semanticsContext, - stmtCtx, accClauseList); + stmtCtx, accClauseList, localSymbols); } } @@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCConstruct &accConstruct) { + const Fortran::parser::OpenACCConstruct &accConstruct, + Fortran::lower::SymMap &localSymbols) { mlir::Value exitCond; Fortran::common::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { - genACC(converter, semanticsContext, eval, blockConstruct); + genACC(converter, semanticsContext, eval, blockConstruct, + localSymbols); }, [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { |