aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower')
-rw-r--r--flang/lib/Lower/Bridge.cpp2
-rw-r--r--flang/lib/Lower/OpenACC.cpp34
-rw-r--r--flang/lib/Lower/SymbolMap.cpp10
3 files changed, 36 insertions, 10 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b..780d56f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ private:
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
- *this, bridge.getSemanticsContext(), getEval(), acc);
+ *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0ada..f9b9b850 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is enabled, extra symbols are used in the host_data
+ // region. Look for them and bind their values with the symbols in the
+ // outer scope.
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21e..78529e0 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+ for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+ jmap != jend; ++jmap)
+ for (auto const &[sym, symBox] : *jmap)
+ if (sym->name().ToString() == symName)
+ return sym;
+ return nullptr;
+}
+
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();