diff options
Diffstat (limited to 'flang/lib/Semantics')
-rw-r--r-- | flang/lib/Semantics/check-declarations.cpp | 3 | ||||
-rw-r--r-- | flang/lib/Semantics/resolve-directives.cpp | 5 | ||||
-rw-r--r-- | flang/lib/Semantics/resolve-names.cpp | 65 |
3 files changed, 71 insertions, 2 deletions
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 1049a6d2..7b88100 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity( } } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module && symbol.owner().kind() != Scope::Kind::MainProgram && - symbol.owner().kind() != Scope::Kind::BlockConstruct) { + symbol.owner().kind() != Scope::Kind::BlockConstruct && + symbol.owner().kind() != Scope::Kind::OpenACCConstruct) { messages_.Say( "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US, parser::ToUpperCaseLetters(common::EnumToString(attr))); diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index b1eaaa8..624b890 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -328,6 +328,11 @@ public: return false; } + bool Pre(const parser::AccClause::UseDevice &x) { + ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice); + return false; + } + void Post(const parser::Name &); private: diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index d1150a9..5041a6a 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1387,6 +1387,8 @@ private: // Create scopes for OpenACC constructs class AccVisitor : public virtual DeclarationVisitor { public: + explicit AccVisitor(SemanticsContext &context) : context_{context} {} + void AddAccSourceRange(const parser::CharBlock &); static bool NeedsScope(const parser::OpenACCBlockConstruct &); @@ -1395,6 +1397,7 @@ public: void Post(const parser::OpenACCBlockConstruct &); bool Pre(const parser::OpenACCCombinedConstruct &); void Post(const parser::OpenACCCombinedConstruct &); + bool Pre(const parser::AccClause::UseDevice &x); bool Pre(const parser::AccBeginBlockDirective &x) { AddAccSourceRange(x.source); return true; @@ -1430,6 +1433,11 @@ public: void Post(const parser::AccBeginLoopDirective &x) { messageHandler().set_currStmtSource(std::nullopt); } + + void CopySymbolWithDevice(const parser::Name *name); + +private: + SemanticsContext &context_; }; bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) { @@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) { return true; } +void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { + // When CUDA Fortran is enabled together with OpenACC, new + // symbols are created for the one appearing in the use_device + // clause. These new symbols have the CUDA Fortran device + // attribute. + if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) { + name->symbol = currScope().CopySymbol(*name->symbol); + if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) { + object->set_cudaDataAttr(common::CUDADataAttr::Device); + } + } +} + +bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { + for (const auto &accObject : x.v.v) { + common::visit( + common::visitors{ + [&](const parser::Designator &designator) { + if (const auto *name{ + semantics::getDesignatorNameIfDataRef(designator)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } else { + if (const auto *dataRef{ + std::get_if<parser::DataRef>(&designator.u)}) { + using ElementIndirection = + common::Indirection<parser::ArrayElement>; + if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) { + const parser::ArrayElement &arrayElement{ind->value()}; + Walk(arrayElement.subscripts); + const parser::DataRef &base{arrayElement.base}; + if (auto *name{std::get_if<parser::Name>(&base.u)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } + } + } + } + }, + [&](const parser::Name &name) { + // TODO: common block in use_device? + }, + }, + accObject.u); + } + return false; +} + void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) { if (NeedsScope(x)) { PopScope(); @@ -2038,7 +2100,8 @@ public: ResolveNamesVisitor( SemanticsContext &context, ImplicitRulesMap &rules, Scope &top) - : BaseVisitor{context, *this, rules}, topScope_{top} { + : BaseVisitor{context, *this, rules}, AccVisitor(context), + topScope_{top} { PushScope(top); } |