aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r--flang/lib/Lower/OpenACC.cpp34
1 files changed, 25 insertions, 9 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0ada..f9b9b850 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is enabled, extra symbols are used in the host_data
+ // region. Look for them and bind their values with the symbols in the
+ // outer scope.
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {