diff options
Diffstat (limited to 'flang/lib')
30 files changed, 767 insertions, 300 deletions
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index b927fa3..bd06acc 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1153,6 +1153,18 @@ bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) { return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0; } +bool IsCUDADeviceSymbol(const Symbol &sym) { + if (const auto *details = + sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) { + return details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned; + } else if (const auto *details = + sym.GetUltimate().detailsIf<semantics::AssocEntityDetails>()) { + return GetNbOfCUDADeviceSymbols(details->expr()) > 0; + } + return false; +} + // HasVectorSubscript() struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper, bool, diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 604b137..cd53dc9 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -950,7 +950,8 @@ static void genIoLoop(Fortran::lower::AbstractConverter &converter, makeNextConditionalOn(builder, loc, checkResult, ok, inLoop); const auto &itemList = std::get<0>(ioImpliedDo.t); const auto &control = std::get<1>(ioImpliedDo.t); - const auto &loopSym = *control.name.thing.thing.symbol; + const auto &loopSym = + *Fortran::parser::UnwrapRef<Fortran::parser::Name>(control.name).symbol; mlir::Value loopVar = fir::getBase(converter.genExprAddr( Fortran::evaluate::AsGenericExpr(loopSym).value(), stmtCtx)); auto genControlValue = [&](const Fortran::parser::ScalarIntExpr &expr) { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 55eda7e..85398be 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1343,8 +1343,10 @@ bool ClauseProcessor::processMap( const parser::CharBlock &source) { using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &[mapType, typeMods, refMod, mappers, iterator, objects] = - clause.t; + const auto &[mapType, typeMods, attachMod, refMod, mappers, iterator, + objects] = clause.t; + if (attachMod) + TODO(currentLocation, "ATTACH modifier is not implemented yet"); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; std::string mapperIdName = "__implicit_mapper"; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index fac37a3..ba34212 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -219,7 +219,6 @@ MAKE_EMPTY_CLASS(AcqRel, AcqRel); MAKE_EMPTY_CLASS(Acquire, Acquire); MAKE_EMPTY_CLASS(Capture, Capture); MAKE_EMPTY_CLASS(Compare, Compare); -MAKE_EMPTY_CLASS(DynamicAllocators, DynamicAllocators); MAKE_EMPTY_CLASS(Full, Full); MAKE_EMPTY_CLASS(Inbranch, Inbranch); MAKE_EMPTY_CLASS(Mergeable, Mergeable); @@ -235,13 +234,9 @@ MAKE_EMPTY_CLASS(OmpxBare, OmpxBare); MAKE_EMPTY_CLASS(Read, Read); MAKE_EMPTY_CLASS(Relaxed, Relaxed); MAKE_EMPTY_CLASS(Release, Release); -MAKE_EMPTY_CLASS(ReverseOffload, ReverseOffload); MAKE_EMPTY_CLASS(SeqCst, SeqCst); -MAKE_EMPTY_CLASS(SelfMaps, SelfMaps); MAKE_EMPTY_CLASS(Simd, Simd); MAKE_EMPTY_CLASS(Threads, Threads); -MAKE_EMPTY_CLASS(UnifiedAddress, UnifiedAddress); -MAKE_EMPTY_CLASS(UnifiedSharedMemory, UnifiedSharedMemory); MAKE_EMPTY_CLASS(Unknown, Unknown); MAKE_EMPTY_CLASS(Untied, Untied); MAKE_EMPTY_CLASS(Weak, Weak); @@ -775,7 +770,18 @@ Doacross make(const parser::OmpClause::Doacross &inp, return makeDoacross(inp.v.v, semaCtx); } -// DynamicAllocators: empty +DynamicAllocators make(const parser::OmpClause::DynamicAllocators &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> td::optional<arser::OmpDynamicAllocatorsClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpDynamicAllocatorsClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return DynamicAllocators{/*Required=*/std::move(maybeRequired)}; +} + DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp, semantics::SemanticsContext &semaCtx) { @@ -1069,6 +1075,15 @@ Map make(const parser::OmpClause::Map &inp, ); CLAUSET_ENUM_CONVERT( // + convertAttachMod, parser::OmpAttachModifier::Value, Map::AttachModifier, + // clang-format off + MS(Always, Always) + MS(Auto, Auto) + MS(Never, Never) + // clang-format on + ); + + CLAUSET_ENUM_CONVERT( // convertRefMod, parser::OmpRefModifier::Value, Map::RefModifier, // clang-format off MS(Ref_Ptee, RefPtee) @@ -1115,6 +1130,13 @@ Map make(const parser::OmpClause::Map &inp, if (!modSet.empty()) maybeTypeMods = Map::MapTypeModifiers(modSet.begin(), modSet.end()); + auto attachMod = [&]() -> std::optional<Map::AttachModifier> { + if (auto *t = + semantics::OmpGetUniqueModifier<parser::OmpAttachModifier>(mods)) + return convertAttachMod(t->v); + return std::nullopt; + }(); + auto refMod = [&]() -> std::optional<Map::RefModifier> { if (auto *t = semantics::OmpGetUniqueModifier<parser::OmpRefModifier>(mods)) return convertRefMod(t->v); @@ -1135,6 +1157,7 @@ Map make(const parser::OmpClause::Map &inp, return Map{{/*MapType=*/std::move(type), /*MapTypeModifiers=*/std::move(maybeTypeMods), + /*AttachModifier=*/std::move(attachMod), /*RefModifier=*/std::move(refMod), /*Mapper=*/std::move(mappers), /*Iterator=*/std::move(iterator), /*LocatorList=*/makeObjects(t2, semaCtx)}}; @@ -1321,7 +1344,18 @@ Reduction make(const parser::OmpClause::Reduction &inp, // Relaxed: empty // Release: empty -// ReverseOffload: empty + +ReverseOffload make(const parser::OmpClause::ReverseOffload &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpReverseOffloadClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpReverseOffloadClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return ReverseOffload{/*Required=*/std::move(maybeRequired)}; +} Safelen make(const parser::OmpClause::Safelen &inp, semantics::SemanticsContext &semaCtx) { @@ -1374,6 +1408,18 @@ Schedule make(const parser::OmpClause::Schedule &inp, // SeqCst: empty +SelfMaps make(const parser::OmpClause::SelfMaps &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpSelfMapsClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpSelfMapsClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return SelfMaps{/*Required=*/std::move(maybeRequired)}; +} + Severity make(const parser::OmpClause::Severity &inp, semantics::SemanticsContext &semaCtx) { // inp -> empty @@ -1463,8 +1509,29 @@ To make(const parser::OmpClause::To &inp, /*LocatorList=*/makeObjects(t3, semaCtx)}}; } -// UnifiedAddress: empty -// UnifiedSharedMemory: empty +UnifiedAddress make(const parser::OmpClause::UnifiedAddress &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpUnifiedAddressClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpUnifiedAddressClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return UnifiedAddress{/*Required=*/std::move(maybeRequired)}; +} + +UnifiedSharedMemory make(const parser::OmpClause::UnifiedSharedMemory &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpUnifiedSharedMemoryClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpUnifiedSharedMemoryClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return UnifiedSharedMemory{/*Required=*/std::move(maybeRequired)}; +} Uniform make(const parser::OmpClause::Uniform &inp, semantics::SemanticsContext &semaCtx) { diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 444f274..f86ee01 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions( void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, const semantics::Symbol *symbol) { using MlirRequires = mlir::omp::ClauseRequires; - using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag; if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) { - semantics::WithOmpDeclarative::RequiresFlags semaFlags; + semantics::WithOmpDeclarative::RequiresClauses reqs; if (symbol) { common::visit( [&](const auto &details) { if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative, std::decay_t<decltype(details)>>) { if (details.has_ompRequires()) - semaFlags = *details.ompRequires(); + reqs = *details.ompRequires(); } }, symbol->details()); @@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, // Use pre-populated omp.requires module attribute if it was set, so that // the "-fopenmp-force-usm" compiler option is honored. MlirRequires mlirFlags = offloadMod.getRequires(); - if (semaFlags.test(SemaRequires::ReverseOffload)) + if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators)) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload)) mlirFlags = mlirFlags | MlirRequires::reverse_offload; - if (semaFlags.test(SemaRequires::UnifiedAddress)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_address)) mlirFlags = mlirFlags | MlirRequires::unified_address; - if (semaFlags.test(SemaRequires::UnifiedSharedMemory)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory)) mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; - if (semaFlags.test(SemaRequires::DynamicAllocators)) - mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; offloadMod.setRequires(mlirFlags); } diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt index fc23e64..790b9fd 100644 --- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Support) +add_subdirectory(Transforms) diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp new file mode 100644 index 0000000..4840a99 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -0,0 +1,191 @@ +//===- ACCRecipeBufferization.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Bufferize OpenACC recipes that yield fir.box<T> to operate on +// fir.ref<fir.box<T>> and update uses accordingly. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace fir::acc { +#define GEN_PASS_DEF_ACCRECIPEBUFFERIZATION +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace fir::acc + +namespace { + +class BufferizeInterface { +public: + static std::optional<mlir::Type> mustBufferize(mlir::Type recipeType) { + if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(recipeType)) + return fir::ReferenceType::get(boxTy); + return std::nullopt; + } + + static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + return builder.create<fir::LoadOp>(loc, value); + } + + static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); + builder.create<fir::StoreOp>(loc, value, alloca); + return alloca; + } +}; + +static void bufferizeRegionArgsAndYields(mlir::Region ®ion, + mlir::Location loc, mlir::Type oldType, + mlir::Type newType) { + if (region.empty()) + return; + + mlir::OpBuilder builder(®ion); + for (mlir::BlockArgument arg : region.getArguments()) { + if (arg.getType() == oldType) { + arg.setType(newType); + if (!arg.use_empty()) { + mlir::Operation *loadOp = BufferizeInterface::load(builder, loc, arg); + arg.replaceAllUsesExcept(loadOp->getResult(0), loadOp); + } + } + } + if (auto yield = + llvm::dyn_cast<mlir::acc::YieldOp>(region.back().getTerminator())) { + llvm::SmallVector<mlir::Value> newOperands; + newOperands.reserve(yield.getNumOperands()); + bool changed = false; + for (mlir::Value oldYieldArg : yield.getOperands()) { + if (oldYieldArg.getType() == oldType) { + builder.setInsertionPoint(yield); + mlir::Value alloca = + BufferizeInterface::placeInMemory(builder, loc, oldYieldArg); + newOperands.push_back(alloca); + changed = true; + } else { + newOperands.push_back(oldYieldArg); + } + } + if (changed) + yield->setOperands(newOperands); + } +} + +static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands, + llvm::StringRef recipeSymName, + mlir::Operation *computeOp) { + if (!recipes) + return; + for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) { + if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() != + recipeSymName) + continue; + + mlir::Operation *dataOp = oldRes.getDefiningOp(); + assert(dataOp && "dataOp must be paired with computeOp"); + mlir::Location loc = dataOp->getLoc(); + mlir::OpBuilder builder(dataOp); + llvm::TypeSwitch<mlir::Operation *, void>(dataOp) + .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp, + mlir::acc::ReductionOp>([&](auto privateOp) { + builder.setInsertionPointAfterValue(privateOp.getVar()); + mlir::Value alloca = BufferizeInterface::placeInMemory( + builder, loc, privateOp.getVar()); + privateOp.getVarMutable().assign(alloca); + privateOp.getAccVar().setType(alloca.getType()); + }); + + llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(), + oldRes.getUsers().end()); + for (mlir::Operation *useOp : users) { + if (useOp == computeOp) + continue; + builder.setInsertionPoint(useOp); + mlir::Operation *load = BufferizeInterface::load(builder, loc, oldRes); + useOp->replaceUsesOfWith(oldRes, load->getResult(0)); + } + } +} + +class ACCRecipeBufferization + : public fir::acc::impl::ACCRecipeBufferizationBase< + ACCRecipeBufferization> { +public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + + llvm::SmallVector<llvm::StringRef> recipeNames; + module.walk([&](mlir::Operation *recipe) { + llvm::TypeSwitch<mlir::Operation *, void>(recipe) + .Case<mlir::acc::PrivateRecipeOp, mlir::acc::FirstprivateRecipeOp, + mlir::acc::ReductionRecipeOp>([&](auto recipe) { + mlir::Type oldType = recipe.getType(); + auto bufferizedType = + BufferizeInterface::mustBufferize(recipe.getType()); + if (!bufferizedType) + return; + recipe.setTypeAttr(mlir::TypeAttr::get(*bufferizedType)); + mlir::Location loc = recipe.getLoc(); + using RecipeOp = decltype(recipe); + bufferizeRegionArgsAndYields(recipe.getInitRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::FirstprivateRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCopyRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::ReductionRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCombinerRegion(), loc, + oldType, *bufferizedType); + bufferizeRegionArgsAndYields(recipe.getDestroyRegion(), loc, + oldType, *bufferizedType); + recipeNames.push_back(recipe.getSymName()); + }); + }); + if (recipeNames.empty()) + return; + + module.walk([&](mlir::Operation *op) { + llvm::TypeSwitch<mlir::Operation *, void>(op) + .Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>( + [&](auto computeOp) { + for (llvm::StringRef recipeName : recipeNames) { + if (computeOp.getPrivatizationRecipes()) + updateRecipeUse(computeOp.getPrivatizationRecipesAttr(), + computeOp.getPrivateOperands(), recipeName, + op); + if (computeOp.getFirstprivatizationRecipes()) + updateRecipeUse( + computeOp.getFirstprivatizationRecipesAttr(), + computeOp.getFirstprivateOperands(), recipeName, op); + if (computeOp.getReductionRecipes()) + updateRecipeUse(computeOp.getReductionRecipesAttr(), + computeOp.getReductionOperands(), + recipeName, op); + } + }); + }); + } +}; + +} // namespace + +std::unique_ptr<mlir::Pass> fir::acc::createACCRecipeBufferizationPass() { + return std::make_unique<ACCRecipeBufferization>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt new file mode 100644 index 0000000..2427da0 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_flang_library(FIROpenACCTransforms + ACCRecipeBufferization.cpp + + DEPENDS + FIROpenACCPassesIncGen + + LINK_LIBS + MLIRIR + MLIRPass + FIRDialect + MLIROpenACCDialect +) diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 9507021..d677e14 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -548,6 +548,14 @@ TYPE_PARSER(construct<OmpAllocatorSimpleModifier>(scalarIntExpr)) TYPE_PARSER(construct<OmpAlwaysModifier>( // "ALWAYS" >> pure(OmpAlwaysModifier::Value::Always))) +TYPE_PARSER(construct<OmpAttachModifier::Value>( + "ALWAYS" >> pure(OmpAttachModifier::Value::Always) || + "AUTO" >> pure(OmpAttachModifier::Value::Auto) || + "NEVER" >> pure(OmpAttachModifier::Value::Never))) + +TYPE_PARSER(construct<OmpAttachModifier>( // + "ATTACH" >> parenthesized(Parser<OmpAttachModifier::Value>{}))) + TYPE_PARSER(construct<OmpAutomapModifier>( "AUTOMAP" >> pure(OmpAutomapModifier::Value::Automap))) @@ -744,6 +752,7 @@ TYPE_PARSER(sourced( TYPE_PARSER(sourced(construct<OmpMapClause::Modifier>( sourced(construct<OmpMapClause::Modifier>(Parser<OmpAlwaysModifier>{}) || + construct<OmpMapClause::Modifier>(Parser<OmpAttachModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpCloseModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpDeleteModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpPresentModifier>{}) || @@ -1085,7 +1094,7 @@ TYPE_PARSER(construct<OmpBindClause>( "TEAMS" >> pure(OmpBindClause::Binding::Teams) || "THREAD" >> pure(OmpBindClause::Binding::Thread))) -TYPE_PARSER(construct<OmpAlignClause>(scalarIntExpr)) +TYPE_PARSER(construct<OmpAlignClause>(scalarIntConstantExpr)) TYPE_PARSER(construct<OmpAtClause>( "EXECUTION" >> pure(OmpAtClause::ActionTime::Execution) || @@ -1158,7 +1167,8 @@ TYPE_PARSER( // "DOACROSS" >> construct<OmpClause>(parenthesized(Parser<OmpDoacrossClause>{})) || "DYNAMIC_ALLOCATORS" >> - construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) || + construct<OmpClause>(construct<OmpClause::DynamicAllocators>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "DYN_GROUPPRIVATE" >> construct<OmpClause>(construct<OmpClause::DynGroupprivate>( parenthesized(Parser<OmpDynGroupprivateClause>{}))) || @@ -1270,12 +1280,15 @@ TYPE_PARSER( // "REPLAYABLE" >> construct<OmpClause>(construct<OmpClause::Replayable>( maybe(parenthesized(Parser<OmpReplayableClause>{})))) || "REVERSE_OFFLOAD" >> - construct<OmpClause>(construct<OmpClause::ReverseOffload>()) || + construct<OmpClause>(construct<OmpClause::ReverseOffload>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "SAFELEN" >> construct<OmpClause>(construct<OmpClause::Safelen>( parenthesized(scalarIntConstantExpr))) || "SCHEDULE" >> construct<OmpClause>(construct<OmpClause::Schedule>( parenthesized(Parser<OmpScheduleClause>{}))) || "SEQ_CST" >> construct<OmpClause>(construct<OmpClause::SeqCst>()) || + "SELF_MAPS" >> construct<OmpClause>(construct<OmpClause::SelfMaps>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "SEVERITY" >> construct<OmpClause>(construct<OmpClause::Severity>( parenthesized(Parser<OmpSeverityClause>{}))) || "SHARED" >> construct<OmpClause>(construct<OmpClause::Shared>( @@ -1303,9 +1316,11 @@ TYPE_PARSER( // construct<OmpClause>(construct<OmpClause::UseDeviceAddr>( parenthesized(Parser<OmpObjectList>{}))) || "UNIFIED_ADDRESS" >> - construct<OmpClause>(construct<OmpClause::UnifiedAddress>()) || + construct<OmpClause>(construct<OmpClause::UnifiedAddress>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "UNIFIED_SHARED_MEMORY" >> - construct<OmpClause>(construct<OmpClause::UnifiedSharedMemory>()) || + construct<OmpClause>(construct<OmpClause::UnifiedSharedMemory>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "UNIFORM" >> construct<OmpClause>(construct<OmpClause::Uniform>( parenthesized(nonemptyList(name)))) || "UNTIED" >> construct<OmpClause>(construct<OmpClause::Untied>()) || diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp index cb30939..8cbaa39 100644 --- a/flang/lib/Parser/parse-tree.cpp +++ b/flang/lib/Parser/parse-tree.cpp @@ -185,7 +185,7 @@ StructureConstructor ArrayElement::ConvertToStructureConstructor( std::list<ComponentSpec> components; for (auto &subscript : subscripts) { components.emplace_back(std::optional<Keyword>{}, - ComponentDataSource{std::move(*Unwrap<Expr>(subscript))}); + ComponentDataSource{std::move(UnwrapRef<Expr>(subscript))}); } DerivedTypeSpec spec{std::move(name), std::list<TypeParamSpec>{}}; spec.derivedTypeSpec = &derived; diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 0511f5b..b172e429 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2384,6 +2384,11 @@ public: Walk(x.v); Put(")"); } + void Unparse(const OmpAttachModifier &x) { + Word("ATTACH("); + Walk(x.v); + Put(")"); + } void Unparse(const OmpOrderClause &x) { using Modifier = OmpOrderClause::Modifier; Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ":"); @@ -2820,6 +2825,7 @@ public: WALK_NESTED_ENUM(OmpMapType, Value) // OMP map-type WALK_NESTED_ENUM(OmpMapTypeModifier, Value) // OMP map-type-modifier WALK_NESTED_ENUM(OmpAlwaysModifier, Value) + WALK_NESTED_ENUM(OmpAttachModifier, Value) WALK_NESTED_ENUM(OmpCloseModifier, Value) WALK_NESTED_ENUM(OmpDeleteModifier, Value) WALK_NESTED_ENUM(OmpPresentModifier, Value) diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp index f4aa496..1824a7d 100644 --- a/flang/lib/Semantics/assignment.cpp +++ b/flang/lib/Semantics/assignment.cpp @@ -194,7 +194,8 @@ void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) { template <typename A> void AssignmentContext::PushWhereContext(const A &x) { const auto &expr{std::get<parser::LogicalExpr>(x.t)}; - CheckShape(expr.thing.value().source, GetExpr(context_, expr)); + CheckShape( + parser::UnwrapRef<parser::Expr>(expr).source, GetExpr(context_, expr)); ++whereDepth_; } diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp index 823aa4e..e019bbd 100644 --- a/flang/lib/Semantics/check-allocate.cpp +++ b/flang/lib/Semantics/check-allocate.cpp @@ -151,7 +151,9 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions( [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (info.gotMsg) { // C943 context.Say( "ERRMSG may not be duplicated in a ALLOCATE statement"_err_en_US); @@ -439,7 +441,7 @@ static bool HaveCompatibleLengths( evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{ evaluate::ToInt64(type2.characterTypeSpec().length().GetExplicit())}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -452,7 +454,7 @@ static bool HaveCompatibleLengths( auto v1{ evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{type2.knownLength()}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -598,7 +600,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { std::optional<evaluate::ConstantSubscript> lbound; if (const auto &lb{std::get<0>(shapeSpec.t)}) { lbound.reset(); - const auto &lbExpr{lb->thing.thing.value()}; + const auto &lbExpr{parser::UnwrapRef<parser::Expr>(lb)}; if (const auto *expr{GetExpr(context, lbExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; @@ -609,7 +611,8 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { lbound = 1; } if (lbound) { - const auto &ubExpr{std::get<1>(shapeSpec.t).thing.thing.value()}; + const auto &ubExpr{ + parser::UnwrapRef<parser::Expr>(std::get<1>(shapeSpec.t))}; if (const auto *expr{GetExpr(context, ubExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp index 5ce143c..7593154 100644 --- a/flang/lib/Semantics/check-case.cpp +++ b/flang/lib/Semantics/check-case.cpp @@ -72,7 +72,7 @@ private: } std::optional<Value> GetValue(const parser::CaseValue &caseValue) { - const parser::Expr &expr{caseValue.thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)}; auto *x{expr.typedExpr.get()}; if (x && x->v) { // C1147 auto type{x->v->GetType()}; diff --git a/flang/lib/Semantics/check-coarray.cpp b/flang/lib/Semantics/check-coarray.cpp index 0e444f1..9113369 100644 --- a/flang/lib/Semantics/check-coarray.cpp +++ b/flang/lib/Semantics/check-coarray.cpp @@ -112,7 +112,7 @@ static void CheckTeamType( static void CheckTeamStat( SemanticsContext &context, const parser::ImageSelectorSpec::Stat &stat) { - const parser::Variable &var{stat.v.thing.thing.value()}; + const auto &var{parser::UnwrapRef<parser::Variable>(stat)}; if (parser::GetCoindexedNamedObject(var)) { context.Say(parser::FindSourceLocation(var), // C931 "Image selector STAT variable must not be a coindexed " @@ -147,7 +147,8 @@ static void CheckSyncStat(SemanticsContext &context, }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var).GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1172 "The errmsg-variable in a sync-stat-list may not be repeated"_err_en_US); @@ -260,7 +261,9 @@ static void CheckEventWaitSpecList(SemanticsContext &context, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1178 "A errmsg-variable in a event-wait-spec-list may not be repeated"_err_en_US); diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp index 5459290..3bcf711 100644 --- a/flang/lib/Semantics/check-data.cpp +++ b/flang/lib/Semantics/check-data.cpp @@ -25,9 +25,10 @@ namespace Fortran::semantics { // Ensures that references to an implied DO loop control variable are // represented as such in the "body" of the implied DO loop. void DataChecker::Enter(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } @@ -36,7 +37,8 @@ void DataChecker::Enter(const parser::DataImpliedDo &x) { } void DataChecker::Leave(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; exprAnalyzer_.RemoveImpliedDo(name.source); } @@ -211,7 +213,7 @@ void DataChecker::Leave(const parser::DataIDoObject &object) { std::get_if<parser::Scalar<common::Indirection<parser::Designator>>>( &object.u)}) { if (MaybeExpr expr{exprAnalyzer_.Analyze(*designator)}) { - auto source{designator->thing.value().source}; + auto source{parser::UnwrapRef<parser::Designator>(*designator).source}; DataVarChecker checker{exprAnalyzer_.context(), source}; if (checker(*expr)) { if (checker.HasComponentWithoutSubscripts()) { // C880 diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp index c45b585..c1ebc5f 100644 --- a/flang/lib/Semantics/check-deallocate.cpp +++ b/flang/lib/Semantics/check-deallocate.cpp @@ -114,7 +114,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) { }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context_, - GetExpr(context_, var), var.v.thing.thing.GetSource(), + GetExpr(context_, var), + parser::UnwrapRef<parser::Variable>(var).GetSource(), "ERRMSG="); if (gotMsg) { context_.Say( diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp index a2f3685..8a47340 100644 --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -535,7 +535,8 @@ private: if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) { if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) { // No warnings or errors for type INTEGER - const parser::CharBlock &loc{scalarExpression.thing.value().source}; + parser::CharBlock loc{ + parser::UnwrapRef<parser::Expr>(scalarExpression).source}; CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real)); } } @@ -552,7 +553,7 @@ private: CheckDoExpression(*bounds.step); if (IsZero(*bounds.step)) { context_.Warn(common::UsageWarning::ZeroDoStep, - bounds.step->thing.value().source, + parser::UnwrapRef<parser::Expr>(bounds.step).source, "DO step expression should not be zero"_warn_en_US); } } @@ -615,7 +616,7 @@ private: // C1121 - procedures in mask must be pure void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const { UnorderedSymbolSet references{ - GatherSymbolsFromExpression(mask.thing.thing.value())}; + GatherSymbolsFromExpression(parser::UnwrapRef<parser::Expr>(mask))}; for (const Symbol &ref : OrderBySourcePosition(references)) { if (IsProcedure(ref) && !IsPureProcedure(ref)) { context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source, @@ -639,32 +640,33 @@ private: } void HasNoReferences(const UnorderedSymbolSet &indexNames, - const parser::ScalarIntExpr &expr) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - indexNames, + const parser::ScalarIntExpr &scalarIntExpr) const { + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s limit expression may not reference index variable '%s'"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in mask expressions void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(mask)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s mask expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - mask.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in limit or step // expressions - void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr, + void CheckExprDoesNotReferenceLocal( + const parser::ScalarIntExpr &scalarIntExpr, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to @@ -772,7 +774,7 @@ private: HasNoReferences(indexNames, std::get<2>(control.t)); if (const auto &intExpr{ std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) { - const parser::Expr &expr{intExpr->thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(intExpr)}; CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s step expression may not reference index variable '%s'"_err_en_US, expr.source); @@ -840,7 +842,7 @@ private: } void CheckForImpureCall(const parser::ScalarIntExpr &x, std::optional<IndexVarKind> nesting) const { - const auto &parsedExpr{x.thing.thing.value()}; + const auto &parsedExpr{parser::UnwrapRef<parser::Expr>(x)}; auto oldLocation{context_.location()}; context_.set_location(parsedExpr.source); if (const auto &typedExpr{parsedExpr.typedExpr}) { @@ -1124,7 +1126,8 @@ void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) { const auto *newunit{ std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)}; if (newunit) { - context_.CheckIndexVarRedefine(newunit->v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(newunit)); } } @@ -1166,14 +1169,14 @@ void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) { const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)}; if (intVar) { const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)}; - context_.CheckIndexVarRedefine(scalar.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(scalar)); } } void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) { const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)}; if (size) { - context_.CheckIndexVarRedefine(size->v.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(size)); } } @@ -1190,16 +1193,19 @@ static void CheckIoImpliedDoIndex( void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::InputImpliedDo &inputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::StatVariable &statVariable) { - context_.CheckIndexVarRedefine(statVariable.v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(statVariable)); } } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/check-io.cpp b/flang/lib/Semantics/check-io.cpp index a1ff4b9..19059ad 100644 --- a/flang/lib/Semantics/check-io.cpp +++ b/flang/lib/Semantics/check-io.cpp @@ -424,8 +424,8 @@ void IoChecker::Enter(const parser::InquireSpec::CharVar &spec) { specKind = IoSpecKind::Dispose; break; } - const parser::Variable &var{ - std::get<parser::ScalarDefaultCharVariable>(spec.t).thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>( + std::get<parser::ScalarDefaultCharVariable>(spec.t))}; std::string what{parser::ToUpperCaseLetters(common::EnumToString(specKind))}; CheckForDefinableVariable(var, what); WarnOnDeferredLengthCharacterScalar( @@ -627,7 +627,7 @@ void IoChecker::Enter(const parser::IoUnit &spec) { } void IoChecker::Enter(const parser::MsgVariable &msgVar) { - const parser::Variable &var{msgVar.v.thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>(msgVar)}; if (stmt_ == IoStmtKind::None) { // allocate, deallocate, image control CheckForDefinableVariable(var, "ERRMSG"); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index d65a89e..ea6fe43 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1517,19 +1517,42 @@ void OmpStructureChecker::Leave(const parser::OpenMPDepobjConstruct &x) { void OmpStructureChecker::Enter(const parser::OpenMPRequiresConstruct &x) { const auto &dirName{x.v.DirName()}; PushContextAndClauseSets(dirName.source, dirName.v); + unsigned version{context_.langOptions().OpenMPVersion}; - if (visitedAtomicSource_.empty()) { - return; - } for (const parser::OmpClause &clause : x.v.Clauses().v) { llvm::omp::Clause id{clause.Id()}; if (id == llvm::omp::Clause::OMPC_atomic_default_mem_order) { - parser::MessageFormattedText txt( - "REQUIRES directive with '%s' clause found lexically after atomic operation without a memory order clause"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(id))); - parser::Message message(clause.source, txt); - message.Attach(visitedAtomicSource_, "Previous atomic construct"_en_US); - context_.Say(std::move(message)); + if (!visitedAtomicSource_.empty()) { + parser::MessageFormattedText txt( + "REQUIRES directive with '%s' clause found lexically after atomic operation without a memory order clause"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(id))); + parser::Message message(clause.source, txt); + message.Attach(visitedAtomicSource_, "Previous atomic construct"_en_US); + context_.Say(std::move(message)); + } + } else { + bool hasArgument{common::visit( + [&](auto &&s) { + using TypeS = llvm::remove_cvref_t<decltype(s)>; + if constexpr ( // + std::is_same_v<TypeS, parser::OmpClause::DynamicAllocators> || + std::is_same_v<TypeS, parser::OmpClause::ReverseOffload> || + std::is_same_v<TypeS, parser::OmpClause::SelfMaps> || + std::is_same_v<TypeS, parser::OmpClause::UnifiedAddress> || + std::is_same_v<TypeS, parser::OmpClause::UnifiedSharedMemory>) { + return s.v.has_value(); + } else { + return false; + } + }, + clause.u)}; + if (version < 60 && hasArgument) { + context_.Say(clause.source, + "An argument to %s is an %s feature, %s"_warn_en_US, + parser::ToUpperCaseLetters( + llvm::omp::getOpenMPClauseName(clause.Id())), + ThisVersion(60), TryVersion(60)); + } } } } @@ -1540,9 +1563,8 @@ void OmpStructureChecker::Leave(const parser::OpenMPRequiresConstruct &) { void OmpStructureChecker::CheckAlignValue(const parser::OmpClause &clause) { if (auto *align{std::get_if<parser::OmpClause::Align>(&clause.u)}) { - if (const auto &v{GetIntValue(align->v)}; !v || *v <= 0) { - context_.Say(clause.source, - "The alignment value should be a constant positive integer"_err_en_US); + if (const auto &v{GetIntValue(align->v)}; v && *v <= 0) { + context_.Say(clause.source, "The alignment should be positive"_err_en_US); } } } @@ -2336,7 +2358,7 @@ private: } if (auto &repl{std::get<parser::OmpClause::Replayable>(clause.u).v}) { // Scalar<Logical<Constant<indirection<Expr>>>> - const parser::Expr &parserExpr{repl->v.thing.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(repl)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { return GetLogicalValue(*expr).value_or(true); } @@ -2350,7 +2372,7 @@ private: bool isTransparent{true}; if (auto &transp{std::get<parser::OmpClause::Transparent>(clause.u).v}) { // Scalar<Integer<indirection<Expr>>> - const parser::Expr &parserExpr{transp->v.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(transp)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the argument is omp_not_impex (defined as 0), then // the task is not transparent, otherwise it is. @@ -2389,8 +2411,8 @@ private: } } // Scalar<Logical<indirection<Expr>>> - auto &parserExpr{ - std::get<parser::ScalarLogicalExpr>(ifc.v.t).thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>( + std::get<parser::ScalarLogicalExpr>(ifc.v.t))}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the value is known to be false, an undeferred task will be // generated. @@ -3017,8 +3039,8 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) { &objs, std::string clause) { for (const auto &obj : objs.v) { - if (const parser::Name * - objName{parser::Unwrap<parser::Name>(obj)}) { + if (const parser::Name *objName{ + parser::Unwrap<parser::Name>(obj)}) { if (&objName->symbol->GetUltimate() == eventHandleSym) { context_.Say(GetContext().clauseSource, "A variable: `%s` that appears in a DETACH clause cannot appear on %s clause on the same construct"_err_en_US, @@ -3637,7 +3659,8 @@ void OmpStructureChecker::CheckReductionModifier( if (modifier.v == ReductionModifier::Value::Task) { // "Task" is only allowed on worksharing or "parallel" directive. static llvm::omp::Directive worksharing[]{ - llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_scope, + llvm::omp::Directive::OMPD_do, // + llvm::omp::Directive::OMPD_scope, // llvm::omp::Directive::OMPD_sections, // There are more worksharing directives, but they do not apply: // "for" is C++ only, @@ -4081,9 +4104,15 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { if (auto *iter{OmpGetUniqueModifier<parser::OmpIterator>(modifiers)}) { CheckIteratorModifier(*iter); } + + using Directive = llvm::omp::Directive; + Directive dir{GetContext().directive}; + llvm::ArrayRef<Directive> leafs{llvm::omp::getLeafConstructsOrSelf(dir)}; + parser::OmpMapType::Value mapType{parser::OmpMapType::Value::Storage}; + if (auto *type{OmpGetUniqueModifier<parser::OmpMapType>(modifiers)}) { - using Directive = llvm::omp::Directive; using Value = parser::OmpMapType::Value; + mapType = type->v; static auto isValidForVersion{ [](parser::OmpMapType::Value t, unsigned version) { @@ -4120,10 +4149,6 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { return result; }()}; - llvm::omp::Directive dir{GetContext().directive}; - llvm::ArrayRef<llvm::omp::Directive> leafs{ - llvm::omp::getLeafConstructsOrSelf(dir)}; - if (llvm::is_contained(leafs, Directive::OMPD_target) || llvm::is_contained(leafs, Directive::OMPD_target_data)) { if (version >= 60) { @@ -4141,6 +4166,43 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { } } + if (auto *attach{ + OmpGetUniqueModifier<parser::OmpAttachModifier>(modifiers)}) { + bool mapEnteringConstructOrMapper{ + llvm::is_contained(leafs, Directive::OMPD_target) || + llvm::is_contained(leafs, Directive::OMPD_target_data) || + llvm::is_contained(leafs, Directive::OMPD_target_enter_data) || + llvm::is_contained(leafs, Directive::OMPD_declare_mapper)}; + + if (!mapEnteringConstructOrMapper || !IsMapEnteringType(mapType)) { + const auto &desc{OmpGetDescriptor<parser::OmpAttachModifier>()}; + context_.Say(OmpGetModifierSource(modifiers, attach), + "The '%s' modifier can only appear on a map-entering construct or on a DECLARE_MAPPER directive"_err_en_US, + desc.name.str()); + } + + auto hasBasePointer{[&](const SomeExpr &item) { + evaluate::SymbolVector symbols{evaluate::GetSymbolVector(item)}; + return llvm::any_of( + symbols, [](SymbolRef s) { return IsPointer(s.get()); }); + }}; + + evaluate::ExpressionAnalyzer ea{context_}; + const auto &objects{std::get<parser::OmpObjectList>(x.v.t)}; + for (auto &object : objects.v) { + if (const parser::Designator *d{GetDesignatorFromObj(object)}) { + if (auto &&expr{ea.Analyze(*d)}) { + if (hasBasePointer(*expr)) { + continue; + } + } + } + auto source{GetObjectSource(object)}; + context_.Say(source ? *source : GetContext().clauseSource, + "A list-item that appears in a map clause with the ATTACH modifier must have a base-pointer"_err_en_US); + } + } + auto &&typeMods{ OmpGetRepeatableModifier<parser::OmpMapTypeModifier>(modifiers)}; struct Less { diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp index 1e46dab..bbf3b28 100644 --- a/flang/lib/Semantics/data-to-inits.cpp +++ b/flang/lib/Semantics/data-to-inits.cpp @@ -179,13 +179,14 @@ bool DataInitializationCompiler<DSV>::Scan( template <typename DSV> bool DataInitializationCompiler<DSV>::Scan(const parser::DataImpliedDo &ido) { const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; - const auto *lowerExpr{ - GetExpr(exprAnalyzer_.context(), bounds.lower.thing.thing)}; - const auto *upperExpr{ - GetExpr(exprAnalyzer_.context(), bounds.upper.thing.thing)}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; + const auto *lowerExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.lower))}; + const auto *upperExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.upper))}; const auto *stepExpr{bounds.step - ? GetExpr(exprAnalyzer_.context(), bounds.step->thing.thing) + ? GetExpr(exprAnalyzer_.context(), + parser::UnwrapRef<parser::Expr>(bounds.step)) : nullptr}; if (lowerExpr && upperExpr) { // Fold the bounds expressions (again) in case any of them depend @@ -240,7 +241,9 @@ bool DataInitializationCompiler<DSV>::Scan( return common::visit( common::visitors{ [&](const parser::Scalar<common::Indirection<parser::Designator>> - &var) { return Scan(var.thing.value()); }, + &var) { + return Scan(parser::UnwrapRef<parser::Designator>(var)); + }, [&](const common::Indirection<parser::DataImpliedDo> &ido) { return Scan(ido.value()); }, diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index 2feec98..4aeb9a4 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -176,8 +176,8 @@ public: // Find and return a user-defined operator or report an error. // The provided message is used if there is no such operator. - MaybeExpr TryDefinedOp( - const char *, parser::MessageFixedText, bool isUserOp = false); + MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText, + bool isUserOp = false, bool checkForNullPointer = true); template <typename E> MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) { return TryDefinedOp( @@ -211,7 +211,8 @@ private: void SayNoMatch( const std::string &, bool isAssignment = false, bool isAmbiguous = false); std::string TypeAsFortran(std::size_t); - bool AnyUntypedOrMissingOperand() const; + bool AnyUntypedOperand() const; + bool AnyMissingOperand() const; ExpressionAnalyzer &context_; ActualArguments actuals_; @@ -1954,9 +1955,10 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { const auto &control{std::get<parser::AcImpliedDoControl>(impliedDo.t)}; const auto &bounds{std::get<parser::AcImpliedDoControl::Bounds>(control.t)}; exprAnalyzer_.Analyze(bounds.name); - parser::CharBlock name{bounds.name.thing.thing.source}; + const auto &parsedName{parser::UnwrapRef<parser::Name>(bounds.name)}; + parser::CharBlock name{parsedName.source}; int kind{ImpliedDoIntType::kind}; - if (const Symbol * symbol{bounds.name.thing.thing.symbol}) { + if (const Symbol *symbol{parsedName.symbol}) { if (auto dynamicType{DynamicType::From(symbol)}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); @@ -1981,7 +1983,7 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { auto cUpper{ToInt64(upper)}; auto cStride{ToInt64(stride)}; if (!(messageDisplayedSet_ & 0x10) && cStride && *cStride == 0) { - exprAnalyzer_.SayAt(bounds.step.value().thing.thing.value().source, + exprAnalyzer_.SayAt(parser::UnwrapRef<parser::Expr>(bounds.step).source, "The stride of an implied DO loop must not be zero"_err_en_US); messageDisplayedSet_ |= 0x10; } @@ -2526,7 +2528,7 @@ static const Symbol *GetBindingResolution( auto ExpressionAnalyzer::AnalyzeProcedureComponentRef( const parser::ProcComponentRef &pcr, ActualArguments &&arguments, bool isSubroutine) -> std::optional<CalleeAndArguments> { - const parser::StructureComponent &sc{pcr.v.thing}; + const auto &sc{parser::UnwrapRef<parser::StructureComponent>(pcr)}; if (MaybeExpr base{Analyze(sc.base)}) { if (const Symbol *sym{sc.component.symbol}) { if (context_.HasError(sym)) { @@ -3695,11 +3697,12 @@ std::optional<characteristics::Procedure> ExpressionAnalyzer::CheckCall( MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::Parentheses &x) { if (MaybeExpr operand{Analyze(x.v.value())}) { - if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { + if (IsNullPointerOrAllocatable(&*operand)) { + Say("NULL() may not be parenthesized"_err_en_US); + } else if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { if (const semantics::Symbol *result{FindFunctionResult(*symbol)}) { if (semantics::IsProcedurePointer(*result)) { - Say("A function reference that returns a procedure " - "pointer may not be parenthesized"_err_en_US); // C1003 + Say("A function reference that returns a procedure pointer may not be parenthesized"_err_en_US); // C1003 } } } @@ -3788,7 +3791,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) { ArgumentAnalyzer analyzer{*this, name.source}; analyzer.Analyze(std::get<1>(x.t)); return analyzer.TryDefinedOp(name.source.ToString().c_str(), - "No operator %s defined for %s"_err_en_US, true); + "No operator %s defined for %s"_err_en_US, /*isUserOp=*/true); } // Binary (dyadic) operations @@ -3997,7 +4000,9 @@ static bool CheckFuncRefToArrayElement(semantics::SemanticsContext &context, auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; const auto *name{std::get_if<parser::Name>(&proc.u)}; if (!name) { - name = &std::get<parser::ProcComponentRef>(proc.u).v.thing.component; + name = &parser::UnwrapRef<parser::StructureComponent>( + std::get<parser::ProcComponentRef>(proc.u)) + .component; } if (!name->symbol) { return false; @@ -4047,14 +4052,16 @@ static void FixMisparsedFunctionReference( } } auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; - if (Symbol *origSymbol{ - common::visit(common::visitors{ - [&](parser::Name &name) { return name.symbol; }, - [&](parser::ProcComponentRef &pcr) { - return pcr.v.thing.component.symbol; - }, - }, - proc.u)}) { + if (Symbol * + origSymbol{common::visit( + common::visitors{ + [&](parser::Name &name) { return name.symbol; }, + [&](parser::ProcComponentRef &pcr) { + return parser::UnwrapRef<parser::StructureComponent>(pcr) + .component.symbol; + }, + }, + proc.u)}) { Symbol &symbol{origSymbol->GetUltimate()}; if (symbol.has<semantics::ObjectEntityDetails>() || symbol.has<semantics::AssocEntityDetails>()) { @@ -4176,15 +4183,23 @@ MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions( } while (!queue.empty()); // Analyze the collected subexpressions in bottom-up order. // On an error, bail out and leave partial results in place. - MaybeExpr result; - for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { - const parser::Expr &expr{**riter}; - result = ExprOrVariable(expr, expr.source); - if (!result) { - return result; + if (finish.size() == 1) { + const parser::Expr &expr{DEREF(finish.front())}; + return ExprOrVariable(expr, expr.source); + } else { + // NULL() operand catching is deferred to operation analysis so + // that they can be accepted by defined operators. + auto restorer{AllowNullPointer()}; + MaybeExpr result; + for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { + const parser::Expr &expr{**riter}; + result = ExprOrVariable(expr, expr.source); + if (!result) { + return result; + } } + return result; // last value was from analysis of "top" } - return result; // last value was from analysis of "top" } MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) { @@ -4681,7 +4696,7 @@ bool ArgumentAnalyzer::AnyCUDADeviceData() const { // attribute. bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( const char *opr) const { - if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) { + if (AnyCUDADeviceData() && !AnyUntypedOperand() && !AnyMissingOperand()) { std::string oprNameString{"operator("s + opr + ')'}; parser::CharBlock oprName{oprNameString}; parser::Messages buffer; @@ -4709,9 +4724,9 @@ bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( return false; } -MaybeExpr ArgumentAnalyzer::TryDefinedOp( - const char *opr, parser::MessageFixedText error, bool isUserOp) { - if (AnyUntypedOrMissingOperand()) { +MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr, + parser::MessageFixedText error, bool isUserOp, bool checkForNullPointer) { + if (AnyMissingOperand()) { context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); return std::nullopt; } @@ -4790,7 +4805,9 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( context_.Say( "Operands of %s are not conformable; have rank %d and rank %d"_err_en_US, ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank()); - } else if (CheckForNullPointer() && CheckForAssumedRank()) { + } else if (!CheckForAssumedRank()) { + } else if (checkForNullPointer && !CheckForNullPointer()) { + } else { // use the supplied error context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); } return result; @@ -4808,15 +4825,16 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( for (std::size_t i{0}; i < oprs.size(); ++i) { parser::Messages buffer; auto restorer{context_.GetContextualMessages().SetMessages(buffer)}; - if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) { + if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error, /*isUserOp=*/false, + /*checkForNullPointer=*/false)}) { result = std::move(thisResult); hit.push_back(oprs[i]); hitBuffer = std::move(buffer); } } } - if (hit.empty()) { // for the error - result = TryDefinedOp(oprs[0], error); + if (hit.empty()) { // run TryDefinedOp() again just to emit errors + CHECK(!TryDefinedOp(oprs[0], error).has_value()); } else if (hit.size() > 1) { context_.Say( "Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US, @@ -5232,10 +5250,19 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) { } } -bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const { +bool ArgumentAnalyzer::AnyUntypedOperand() const { + for (const auto &actual : actuals_) { + if (actual && !actual->GetType() && + !IsBareNullPointer(actual->UnwrapExpr())) { + return true; + } + } + return false; +} + +bool ArgumentAnalyzer::AnyMissingOperand() const { for (const auto &actual : actuals_) { - if (!actual || - (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) { + if (!actual) { return true; } } @@ -5268,9 +5295,9 @@ void ExprChecker::Post(const parser::DataStmtObject &obj) { bool ExprChecker::Pre(const parser::DataImpliedDo &ido) { parser::Walk(std::get<parser::DataImpliedDo::Bounds>(ido.t), *this); const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 8074c94..556259d 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -17,6 +17,7 @@ #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "llvm/Frontend/OpenMP/OMP.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" @@ -24,6 +25,7 @@ #include <fstream> #include <set> #include <string_view> +#include <type_traits> #include <variant> #include <vector> @@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) { } } +static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) { + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + using OmpMemoryOrderType = common::OmpMemoryOrderType; + + const auto [reqs, order]{common::visit( + [&](auto &&details) + -> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + return {details.ompRequires(), details.ompAtomicDefaultMemOrder()}; + } else { + return {nullptr, nullptr}; + } + }, + symbol.details())}; + + if (order) { + llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order}; + os << "!$omp requires " + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n"; + } + if (reqs) { + os << "!$omp requires"; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) { + os << ' ' + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + } + }); + os << "\n"; + } +} + // Put out the visible symbols from scope. void ModFileWriter::PutSymbols( const Scope &scope, UnorderedSymbolSet *hermeticModules) { @@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols( for (const Symbol &symbol : uses) { PutUse(symbol); } + PutOpenMPRequirements(decls_, DEREF(scope.symbol())); for (const auto &set : scope.equivalenceSets()) { if (!set.empty() && !set.front().symbol.test(Symbol::Flag::CompilerCreated)) { diff --git a/flang/lib/Semantics/openmp-modifiers.cpp b/flang/lib/Semantics/openmp-modifiers.cpp index af4000c..717fb03 100644 --- a/flang/lib/Semantics/openmp-modifiers.cpp +++ b/flang/lib/Semantics/openmp-modifiers.cpp @@ -157,6 +157,22 @@ const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAlwaysModifier>() { } template <> +const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAttachModifier>() { + static const OmpModifierDescriptor desc{ + /*name=*/"attach-modifier", + /*props=*/ + { + {61, {OmpProperty::Unique}}, + }, + /*clauses=*/ + { + {61, {Clause::OMPC_map}}, + }, + }; + return desc; +} + +template <> const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAutomapModifier>() { static const OmpModifierDescriptor desc{ /*name=*/"automap-modifier", diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index a8ec4d6..292e73b 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -13,6 +13,7 @@ #include "flang/Semantics/openmp-utils.h" #include "flang/Common/Fortran-consts.h" +#include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" @@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) { return *iter; } +const Scope &GetProgramUnit(const Scope &scope) { + const Scope *unit{nullptr}; + for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockData: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + return *iter; + case Scope::Kind::Subprogram: + // Ignore subprograms that are nested. + unit = iter; + break; + default: + break; + } + } + assert(unit && "Scope not in a program unit"); + return *unit; +} + SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) { if (x == nullptr) { return SourcedActionStmt{}; @@ -202,7 +223,7 @@ std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) { // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() // `- std::optional<Expr> ^->v - return typedExpr.get()->v; + return DEREF(typedExpr.get()).v; } std::optional<evaluate::DynamicType> GetDynamicType( diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 18fc638..7067ed3 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -435,6 +435,22 @@ public: return true; } + bool Pre(const parser::UseStmt &x) { + if (x.moduleName.symbol) { + Scope &thisScope{context_.FindScope(x.moduleName.source)}; + common::visit( + [&](auto &&details) { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + AddOmpRequiresToScope(thisScope, details.ompRequires(), + details.ompAtomicDefaultMemOrder()); + } + }, + x.moduleName.symbol->details()); + } + return true; + } + bool Pre(const parser::OmpMetadirectiveDirective &x) { PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective); return true; @@ -538,38 +554,55 @@ public: void Post(const parser::OpenMPFlushConstruct &) { PopContext(); } bool Pre(const parser::OpenMPRequiresConstruct &x) { - using Flags = WithOmpDeclarative::RequiresFlags; - using Requires = WithOmpDeclarative::RequiresFlag; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; PushContext(x.source, llvm::omp::Directive::OMPD_requires); + auto getArgument{[&](auto &&maybeClause) { + if (maybeClause) { + // Scalar<Logical<Constant<common::Indirection<Expr>>>> + auto &parserExpr{maybeClause->v.thing.thing.thing.value()}; + evaluate::ExpressionAnalyzer ea{context_}; + if (auto &&maybeExpr{ea.Analyze(parserExpr)}) { + if (auto v{omp::GetLogicalValue(*maybeExpr)}) { + return *v; + } + } + } + // If the argument is missing, it is assumed to be true. + return true; + }}; + // Gather information from the clauses. - Flags flags; - std::optional<common::OmpMemoryOrderType> memOrder; + RequiresClauses reqs; + const common::OmpMemoryOrderType *memOrder{nullptr}; for (const parser::OmpClause &clause : x.v.Clauses().v) { - flags |= common::visit( + using OmpClause = parser::OmpClause; + reqs |= common::visit( common::visitors{ - [&memOrder]( - const parser::OmpClause::AtomicDefaultMemOrder &atomic) { - memOrder = atomic.v.v; - return Flags{}; - }, - [](const parser::OmpClause::ReverseOffload &) { - return Flags{Requires::ReverseOffload}; + [&](const OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = &atomic.v.v; + return RequiresClauses{}; }, - [](const parser::OmpClause::UnifiedAddress &) { - return Flags{Requires::UnifiedAddress}; - }, - [](const parser::OmpClause::UnifiedSharedMemory &) { - return Flags{Requires::UnifiedSharedMemory}; - }, - [](const parser::OmpClause::DynamicAllocators &) { - return Flags{Requires::DynamicAllocators}; + [&](auto &&s) { + using TypeS = llvm::remove_cvref_t<decltype(s)>; + if constexpr ( // + std::is_same_v<TypeS, OmpClause::DynamicAllocators> || + std::is_same_v<TypeS, OmpClause::ReverseOffload> || + std::is_same_v<TypeS, OmpClause::SelfMaps> || + std::is_same_v<TypeS, OmpClause::UnifiedAddress> || + std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) { + if (getArgument(s.v)) { + return RequiresClauses{clause.Id()}; + } + } + return RequiresClauses{}; }, - [](const auto &) { return Flags{}; }}, + }, clause.u); } + // Merge clauses into parents' symbols details. - AddOmpRequiresToScope(currScope(), flags, memOrder); + AddOmpRequiresToScope(currScope(), &reqs, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -1001,8 +1034,9 @@ private: std::int64_t ordCollapseLevel{0}; - void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags, - std::optional<common::OmpMemoryOrderType>); + void AddOmpRequiresToScope(Scope &, + const WithOmpDeclarative::RequiresClauses *, + const common::OmpMemoryOrderType *); void IssueNonConformanceWarning(llvm::omp::Directive D, parser::CharBlock source, unsigned EmitFromVersion); @@ -3309,86 +3343,6 @@ void ResolveOmpParts( } } -void ResolveOmpTopLevelParts( - SemanticsContext &context, const parser::Program &program) { - if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { - return; - } - - // Gather REQUIRES clauses from all non-module top-level program unit symbols, - // combine them together ensuring compatibility and apply them to all these - // program units. Modules are skipped because their REQUIRES clauses should be - // propagated via USE statements instead. - WithOmpDeclarative::RequiresFlags combinedFlags; - std::optional<common::OmpMemoryOrderType> combinedMemOrder; - - // Function to go through non-module top level program units and extract - // REQUIRES information to be processed by a function-like argument. - auto processProgramUnits{[&](auto processFn) { - for (const parser::ProgramUnit &unit : program.v) { - if (!std::holds_alternative<common::Indirection<parser::Module>>( - unit.u) && - !std::holds_alternative<common::Indirection<parser::Submodule>>( - unit.u) && - !std::holds_alternative< - common::Indirection<parser::CompilerDirective>>(unit.u)) { - Symbol *symbol{common::visit( - [&context](auto &x) { - Scope *scope = GetScope(context, x.value()); - return scope ? scope->symbol() : nullptr; - }, - unit.u)}; - // FIXME There is no symbol defined for MainProgram units in certain - // circumstances, so REQUIRES information has no place to be stored in - // these cases. - if (!symbol) { - continue; - } - common::visit( - [&](auto &details) { - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - processFn(*symbol, details); - } - }, - symbol->details()); - } - } - }}; - - // Combine global REQUIRES information from all program units except modules - // and submodules. - processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) { - if (const WithOmpDeclarative::RequiresFlags * - flags{details.ompRequires()}) { - combinedFlags |= *flags; - } - if (const common::OmpMemoryOrderType * - memOrder{details.ompAtomicDefaultMemOrder()}) { - if (combinedMemOrder && *combinedMemOrder != *memOrder) { - context.Say(symbol.scope()->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - combinedMemOrder = *memOrder; - } - }); - - // Update all program units except modules and submodules with the combined - // global REQUIRES information. - processProgramUnits([&](Symbol &, WithOmpDeclarative &details) { - if (combinedFlags.any()) { - details.set_ompRequires(combinedFlags); - } - if (combinedMemOrder) { - details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); - } - }); -} - static bool IsSymbolThreadprivate(const Symbol &symbol) { if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) { return details->symbol().test(Symbol::Flag::OmpThreadprivate); @@ -3547,42 +3501,39 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source, } void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, - WithOmpDeclarative::RequiresFlags flags, - std::optional<common::OmpMemoryOrderType> memOrder) { - Scope *scopeIter = &scope; - do { - if (Symbol * symbol{scopeIter->symbol()}) { - common::visit( - [&](auto &details) { - // Store clauses information into the symbol for the parent and - // enclosing modules, programs, functions and subroutines. - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - if (flags.any()) { - if (const WithOmpDeclarative::RequiresFlags * - otherFlags{details.ompRequires()}) { - flags |= *otherFlags; - } - details.set_ompRequires(flags); + const WithOmpDeclarative::RequiresClauses *reqs, + const common::OmpMemoryOrderType *memOrder) { + const Scope &programUnit{omp::GetProgramUnit(scope)}; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}}; + + if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) { + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v<decltype(&details), + WithOmpDeclarative *>) { + if (combinedReqs.any()) { + if (const RequiresClauses *otherReqs{details.ompRequires()}) { + combinedReqs |= *otherReqs; } - if (memOrder) { - if (details.has_ompAtomicDefaultMemOrder() && - *details.ompAtomicDefaultMemOrder() != *memOrder) { - context_.Say(scopeIter->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - details.set_ompAtomicDefaultMemOrder(*memOrder); + details.set_ompRequires(combinedReqs); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(programUnit.sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); } + details.set_ompAtomicDefaultMemOrder(*memOrder); } - }, - symbol->details()); - } - scopeIter = &scopeIter->parent(); - } while (!scopeIter->IsGlobal()); + } + }, + symbol->details()); + } } void OmpAttributeVisitor::IssueNonConformanceWarning(llvm::omp::Directive D, diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h index 5a890c2..36d3ce9 100644 --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -23,7 +23,5 @@ class SemanticsContext; void ResolveAccParts( SemanticsContext &, const parser::ProgramUnit &, Scope *topScope); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); -void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); - } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp index 742bb74..ac67799 100644 --- a/flang/lib/Semantics/resolve-names-utils.cpp +++ b/flang/lib/Semantics/resolve-names-utils.cpp @@ -492,12 +492,14 @@ bool EquivalenceSets::CheckDesignator(const parser::Designator &designator) { const auto &range{std::get<parser::SubstringRange>(x.t)}; bool ok{CheckDataRef(designator.source, dataRef)}; if (const auto &lb{std::get<0>(range.t)}) { - ok &= CheckSubstringBound(lb->thing.thing.value(), true); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(lb), true); } else { currObject_.substringStart = 1; } if (const auto &ub{std::get<1>(range.t)}) { - ok &= CheckSubstringBound(ub->thing.thing.value(), false); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(ub), false); } return ok; }, @@ -528,7 +530,8 @@ bool EquivalenceSets::CheckDataRef( return false; }, [&](const parser::IntExpr &y) { - return CheckArrayBound(y.thing.value()); + return CheckArrayBound( + parser::UnwrapRef<parser::Expr>(y)); }, }, subscript.u); diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 86121880..699de41 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1140,7 +1140,7 @@ protected: std::optional<SourceName> BeginCheckOnIndexUseInOwnBounds( const parser::DoVariable &name) { std::optional<SourceName> result{checkIndexUseInOwnBounds_}; - checkIndexUseInOwnBounds_ = name.thing.thing.source; + checkIndexUseInOwnBounds_ = parser::UnwrapRef<parser::Name>(name).source; return result; } void EndCheckOnIndexUseInOwnBounds(const std::optional<SourceName> &restore) { @@ -2130,7 +2130,7 @@ public: void Post(const parser::SubstringInquiry &); template <typename A, typename B> void Post(const parser::LoopBounds<A, B> &x) { - ResolveName(*parser::Unwrap<parser::Name>(x.name)); + ResolveName(parser::UnwrapRef<parser::Name>(x.name)); } void Post(const parser::ProcComponentRef &); bool Pre(const parser::FunctionReference &); @@ -2560,7 +2560,7 @@ KindExpr DeclTypeSpecVisitor::GetKindParamExpr( CHECK(!state_.originalKindParameter); // Save a pointer to the KIND= expression in the parse tree // in case we need to reanalyze it during PDT instantiation. - state_.originalKindParameter = &expr->thing.thing.thing.value(); + state_.originalKindParameter = parser::Unwrap<parser::Expr>(expr); } } // Inhibit some errors now that will be caught later during instantiations. @@ -5649,6 +5649,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { if (details->init() || symbol.test(Symbol::Flag::InDataStmt)) { Say(name, "Named constant '%s' already has a value"_err_en_US); } + parser::CharBlock at{parser::UnwrapRef<parser::Expr>(expr).source}; if (inOldStyleParameterStmt_) { // non-standard extension PARAMETER statement (no parentheses) Walk(expr); @@ -5657,7 +5658,6 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { SayWithDecl(name, symbol, "Alternative style PARAMETER '%s' must not already have an explicit type"_err_en_US); } else if (folded) { - auto at{expr.thing.value().source}; if (evaluate::IsActuallyConstant(*folded)) { if (const auto *type{currScope().GetType(*folded)}) { if (type->IsPolymorphic()) { @@ -5682,8 +5682,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { // standard-conforming PARAMETER statement (with parentheses) ApplyImplicitRules(symbol); Walk(expr); - if (auto converted{EvaluateNonPointerInitializer( - symbol, expr, expr.thing.value().source)}) { + if (auto converted{EvaluateNonPointerInitializer(symbol, expr, at)}) { details->set_init(std::move(*converted)); } } @@ -6149,7 +6148,7 @@ bool DeclarationVisitor::Pre(const parser::KindParam &x) { if (const auto *kind{std::get_if< parser::Scalar<parser::Integer<parser::Constant<parser::Name>>>>( &x.u)}) { - const parser::Name &name{kind->thing.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(kind)}; if (!FindSymbol(name)) { Say(name, "Parameter '%s' not found"_err_en_US); } @@ -7460,7 +7459,7 @@ void DeclarationVisitor::DeclareLocalEntity( Symbol *DeclarationVisitor::DeclareStatementEntity( const parser::DoVariable &doVar, const std::optional<parser::IntegerTypeSpec> &type) { - const parser::Name &name{doVar.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(doVar)}; const DeclTypeSpec *declTypeSpec{nullptr}; if (auto *prev{FindSymbol(name)}) { if (prev->owner() == currScope()) { @@ -7893,13 +7892,14 @@ bool ConstructVisitor::Pre(const parser::DataIDoObject &x) { common::visit( common::visitors{ [&](const parser::Scalar<Indirection<parser::Designator>> &y) { - Walk(y.thing.value()); - const parser::Name &first{parser::GetFirstName(y.thing.value())}; + const auto &designator{parser::UnwrapRef<parser::Designator>(y)}; + Walk(designator); + const parser::Name &first{parser::GetFirstName(designator)}; if (first.symbol) { first.symbol->set(Symbol::Flag::InDataStmt); } }, - [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y.value()); }, + [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y); }, }, x.u); return false; @@ -8582,8 +8582,7 @@ public: void Post(const parser::WriteStmt &) { inAsyncIO_ = false; } void Post(const parser::IoControlSpec::Size &size) { if (const auto *designator{ - std::get_if<common::Indirection<parser::Designator>>( - &size.v.thing.thing.u)}) { + parser::Unwrap<common::Indirection<parser::Designator>>(size)}) { NoteAsyncIODesignator(designator->value()); } } @@ -9175,16 +9174,17 @@ bool DeclarationVisitor::CheckNonPointerInitialization( } void DeclarationVisitor::NonPointerInitialization( - const parser::Name &name, const parser::ConstantExpr &expr) { + const parser::Name &name, const parser::ConstantExpr &constExpr) { if (CheckNonPointerInitialization( name, /*inLegacyDataInitialization=*/false)) { Symbol &ultimate{name.symbol->GetUltimate()}; auto &details{ultimate.get<ObjectEntityDetails>()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(constExpr)}; if (ultimate.owner().IsParameterizedDerivedType()) { // Save the expression for per-instantiation analysis. - details.set_unanalyzedPDTComponentInit(&expr.thing.value()); + details.set_unanalyzedPDTComponentInit(&expr); } else if (MaybeExpr folded{EvaluateNonPointerInitializer( - ultimate, expr, expr.thing.value().source)}) { + ultimate, constExpr, expr.source)}) { details.set_init(std::move(*folded)); ultimate.set(Symbol::Flag::InDataStmt, false); } @@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) { CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); - // Top-level resolution to propagate information across program units after - // each of them has been resolved separately. - ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp index 69169469..0ec44b7 100644 --- a/flang/lib/Semantics/symbol.cpp +++ b/flang/lib/Semantics/symbol.cpp @@ -70,6 +70,32 @@ static void DumpList(llvm::raw_ostream &os, const char *label, const T &list) { } } +llvm::raw_ostream &operator<<( + llvm::raw_ostream &os, const WithOmpDeclarative &x) { + if (x.has_ompRequires() || x.has_ompAtomicDefaultMemOrder()) { + os << " OmpRequirements:("; + if (const common::OmpMemoryOrderType *admo{x.ompAtomicDefaultMemOrder()}) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*admo)) << ')'; + if (x.has_ompRequires()) { + os << ','; + } + } + if (const WithOmpDeclarative::RequiresClauses *reqs{x.ompRequires()}) { + size_t num{0}, size{reqs->count()}; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + if (++num < size) { + os << ','; + } + }); + } + os << ')'; + } + return os; +} + void SubprogramDetails::set_moduleInterface(Symbol &symbol) { CHECK(!moduleInterface_); moduleInterface_ = &symbol; @@ -150,6 +176,7 @@ llvm::raw_ostream &operator<<( os << x; } } + os << static_cast<const WithOmpDeclarative &>(x); return os; } @@ -580,7 +607,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { common::visit( // common::visitors{ [&](const UnknownDetails &) {}, - [&](const MainProgramDetails &) {}, + [&](const MainProgramDetails &x) { + os << static_cast<const WithOmpDeclarative &>(x); + }, [&](const ModuleDetails &x) { if (x.isSubmodule()) { os << " ("; @@ -599,6 +628,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { if (x.isDefaultPrivate()) { os << " isDefaultPrivate"; } + os << static_cast<const WithOmpDeclarative &>(x); }, [&](const SubprogramNameDetails &x) { os << ' ' << EnumToString(x.kind()); |