diff options
Diffstat (limited to 'flang/lib/Lower')
-rw-r--r-- | flang/lib/Lower/Bridge.cpp | 2 | ||||
-rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 34 | ||||
-rw-r--r-- | flang/lib/Lower/SymbolMap.cpp | 10 |
3 files changed, 36 insertions, 10 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 149e51b..780d56f 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3182,7 +3182,7 @@ private: mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); localSymbols.pushScope(); mlir::Value exitCond = genOpenACCConstruct( - *this, bridge.getSemanticsContext(), getEval(), acc); + *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols); const Fortran::parser::OpenACCLoopConstruct *accLoop = std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 95d0ada..f9b9b850 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::AccClauseList &accClauseList) { + const Fortran::parser::AccClauseList &accClauseList, + Fortran::lower::SymMap &localSymbols) { mlir::Value ifCond; llvm::SmallVector<mlir::Value> dataOperands; bool addIfPresentAttr = false; @@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *useDevice = std::get_if<Fortran::parser::AccClause::UseDevice>( &clause.u)) { + // When CUDA Fotran is enabled, extra symbols are used in the host_data + // region. Look for them and bind their values with the symbols in the + // outer scope. + if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) { + const Fortran::parser::AccObjectList &objectList{useDevice->v}; + for (const auto &accObject : objectList.v) { + Fortran::semantics::Symbol &symbol = + getSymbolFromAccObject(accObject); + const Fortran::semantics::Symbol *baseSym = + localSymbols.lookupSymbolByName(symbol.name().ToString()); + localSymbols.copySymbolBinding(*baseSym, symbol); + } + } genDataOperandOperations<mlir::acc::UseDeviceOp>( useDevice->v, converter, semanticsContext, stmtCtx, dataOperands, mlir::acc::DataClause::acc_use_device, @@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, hostDataOp.setIfPresentAttr(builder.getUnitAttr()); } -static void -genACC(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCBlockConstruct &blockConstruct, + Fortran::lower::SymMap &localSymbols) { const auto &beginBlockDirective = std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t); const auto &blockDirective = @@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter, accClauseList); } else if (blockDirective.v == llvm::acc::ACCD_host_data) { genACCHostDataOp(converter, currentLocation, eval, semanticsContext, - stmtCtx, accClauseList); + stmtCtx, accClauseList, localSymbols); } } @@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCConstruct &accConstruct) { + const Fortran::parser::OpenACCConstruct &accConstruct, + Fortran::lower::SymMap &localSymbols) { mlir::Value exitCond; Fortran::common::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { - genACC(converter, semanticsContext, eval, blockConstruct); + genACC(converter, semanticsContext, eval, blockConstruct, + localSymbols); }, [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp index 080f21e..78529e0 100644 --- a/flang/lib/Lower/SymbolMap.cpp +++ b/flang/lib/Lower/SymbolMap.cpp @@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) { return SymbolBox::None{}; } +const Fortran::semantics::Symbol * +Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) { + for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend(); + jmap != jend; ++jmap) + for (auto const &[sym, symBox] : *jmap) + if (sym->name().ToString() == symName) + return sym; + return nullptr; +} + Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol( Fortran::semantics::SymbolRef symRef) { auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); |