diff options
Diffstat (limited to 'flang/lib')
23 files changed, 650 insertions, 180 deletions
diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp index f204eef..1de5e6b 100644 --- a/flang/lib/Evaluate/intrinsics.cpp +++ b/flang/lib/Evaluate/intrinsics.cpp @@ -111,6 +111,7 @@ ENUM_CLASS(KindCode, none, defaultIntegerKind, atomicIntKind, // atomic_int_kind from iso_fortran_env atomicIntOrLogicalKind, // atomic_int_kind or atomic_logical_kind sameAtom, // same type and kind as atom + extensibleOrUnlimitedType, // extensible or unlimited polymorphic type ) struct TypePattern { @@ -160,7 +161,8 @@ static constexpr TypePattern AnyChar{CharType, KindCode::any}; static constexpr TypePattern AnyLogical{LogicalType, KindCode::any}; static constexpr TypePattern AnyRelatable{RelatableType, KindCode::any}; static constexpr TypePattern AnyIntrinsic{IntrinsicType, KindCode::any}; -static constexpr TypePattern ExtensibleDerived{DerivedType, KindCode::any}; +static constexpr TypePattern ExtensibleDerived{ + DerivedType, KindCode::extensibleOrUnlimitedType}; static constexpr TypePattern AnyData{AnyType, KindCode::any}; // Type is irrelevant, but not BOZ (for PRESENT(), OPTIONAL(), &c.) @@ -2103,9 +2105,13 @@ std::optional<SpecificCall> IntrinsicInterface::Match( } return std::nullopt; } else if (!d.typePattern.categorySet.test(type->category())) { + const char *expected{ + d.typePattern.kindCode == KindCode::extensibleOrUnlimitedType + ? ", expected extensible or unlimited polymorphic type" + : ""}; messages.Say(arg->sourceLocation(), - "Actual argument for '%s=' has bad type '%s'"_err_en_US, d.keyword, - type->AsFortran()); + "Actual argument for '%s=' has bad type '%s'%s"_err_en_US, d.keyword, + type->AsFortran(), expected); return std::nullopt; // argument has invalid type category } bool argOk{false}; @@ -2244,6 +2250,17 @@ std::optional<SpecificCall> IntrinsicInterface::Match( return std::nullopt; } break; + case KindCode::extensibleOrUnlimitedType: + argOk = type->IsUnlimitedPolymorphic() || + (type->category() == TypeCategory::Derived && + IsExtensibleType(GetDerivedTypeSpec(type))); + if (!argOk) { + messages.Say(arg->sourceLocation(), + "Actual argument for '%s=' has type '%s', but was expected to be an extensible or unlimited polymorphic type"_err_en_US, + d.keyword, type->AsFortran()); + return std::nullopt; + } + break; default: CRASH_NO_CASE; } diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index b927fa3..bd06acc 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1153,6 +1153,18 @@ bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) { return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0; } +bool IsCUDADeviceSymbol(const Symbol &sym) { + if (const auto *details = + sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) { + return details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned; + } else if (const auto *details = + sym.GetUltimate().detailsIf<semantics::AssocEntityDetails>()) { + return GetNbOfCUDADeviceSymbols(details->expr()) > 0; + } + return false; +} + // HasVectorSubscript() struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper, bool, diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp index 53239cb..e7a6c4d 100644 --- a/flang/lib/Lower/Allocatable.cpp +++ b/flang/lib/Lower/Allocatable.cpp @@ -629,6 +629,10 @@ private: unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol()); fir::ExtendedValue exv = isSource ? sourceExv : moldExv; + if (const Fortran::semantics::Symbol *sym{GetLastSymbol(sourceExpr)}) + if (Fortran::semantics::IsCUDADevice(*sym)) + TODO(loc, "CUDA Fortran: allocate with device source"); + // Generate a sequence of runtime calls. errorManager.genStatCheck(builder, loc); genAllocateObjectInit(box, allocatorIdx); @@ -767,6 +771,15 @@ private: const fir::MutableBoxValue &box, ErrorManager &errorManager, const Fortran::semantics::Symbol &sym) { + + if (const Fortran::semantics::DeclTypeSpec *declTypeSpec = sym.GetType()) + if (const Fortran::semantics::DerivedTypeSpec *derivedTypeSpec = + declTypeSpec->AsDerived()) + if (derivedTypeSpec->HasDefaultInitialization( + /*ignoreAllocatable=*/true, /*ignorePointer=*/true)) + TODO(loc, + "CUDA Fortran: allocate on device with default initialization"); + Fortran::lower::StatementContext stmtCtx; cuf::DataAttributeAttr cudaAttr = Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(), diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 55eda7e..85398be 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1343,8 +1343,10 @@ bool ClauseProcessor::processMap( const parser::CharBlock &source) { using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &[mapType, typeMods, refMod, mappers, iterator, objects] = - clause.t; + const auto &[mapType, typeMods, attachMod, refMod, mappers, iterator, + objects] = clause.t; + if (attachMod) + TODO(currentLocation, "ATTACH modifier is not implemented yet"); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; std::string mapperIdName = "__implicit_mapper"; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index fac37a3..0842c62 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1069,6 +1069,15 @@ Map make(const parser::OmpClause::Map &inp, ); CLAUSET_ENUM_CONVERT( // + convertAttachMod, parser::OmpAttachModifier::Value, Map::AttachModifier, + // clang-format off + MS(Always, Always) + MS(Auto, Auto) + MS(Never, Never) + // clang-format on + ); + + CLAUSET_ENUM_CONVERT( // convertRefMod, parser::OmpRefModifier::Value, Map::RefModifier, // clang-format off MS(Ref_Ptee, RefPtee) @@ -1115,6 +1124,13 @@ Map make(const parser::OmpClause::Map &inp, if (!modSet.empty()) maybeTypeMods = Map::MapTypeModifiers(modSet.begin(), modSet.end()); + auto attachMod = [&]() -> std::optional<Map::AttachModifier> { + if (auto *t = + semantics::OmpGetUniqueModifier<parser::OmpAttachModifier>(mods)) + return convertAttachMod(t->v); + return std::nullopt; + }(); + auto refMod = [&]() -> std::optional<Map::RefModifier> { if (auto *t = semantics::OmpGetUniqueModifier<parser::OmpRefModifier>(mods)) return convertRefMod(t->v); @@ -1135,6 +1151,7 @@ Map make(const parser::OmpClause::Map &inp, return Map{{/*MapType=*/std::move(type), /*MapTypeModifiers=*/std::move(maybeTypeMods), + /*AttachModifier=*/std::move(attachMod), /*RefModifier=*/std::move(refMod), /*Mapper=*/std::move(mappers), /*Iterator=*/std::move(iterator), /*LocatorList=*/makeObjects(t2, semaCtx)}}; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 444f274..f86ee01 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions( void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, const semantics::Symbol *symbol) { using MlirRequires = mlir::omp::ClauseRequires; - using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag; if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) { - semantics::WithOmpDeclarative::RequiresFlags semaFlags; + semantics::WithOmpDeclarative::RequiresClauses reqs; if (symbol) { common::visit( [&](const auto &details) { if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative, std::decay_t<decltype(details)>>) { if (details.has_ompRequires()) - semaFlags = *details.ompRequires(); + reqs = *details.ompRequires(); } }, symbol->details()); @@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, // Use pre-populated omp.requires module attribute if it was set, so that // the "-fopenmp-force-usm" compiler option is honored. MlirRequires mlirFlags = offloadMod.getRequires(); - if (semaFlags.test(SemaRequires::ReverseOffload)) + if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators)) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload)) mlirFlags = mlirFlags | MlirRequires::reverse_offload; - if (semaFlags.test(SemaRequires::UnifiedAddress)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_address)) mlirFlags = mlirFlags | MlirRequires::unified_address; - if (semaFlags.test(SemaRequires::UnifiedSharedMemory)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory)) mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; - if (semaFlags.test(SemaRequires::DynamicAllocators)) - mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; offloadMod.setRequires(mlirFlags); } diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt index fc23e64..790b9fd 100644 --- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Support) +add_subdirectory(Transforms) diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp new file mode 100644 index 0000000..4840a99 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -0,0 +1,191 @@ +//===- ACCRecipeBufferization.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Bufferize OpenACC recipes that yield fir.box<T> to operate on +// fir.ref<fir.box<T>> and update uses accordingly. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace fir::acc { +#define GEN_PASS_DEF_ACCRECIPEBUFFERIZATION +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace fir::acc + +namespace { + +class BufferizeInterface { +public: + static std::optional<mlir::Type> mustBufferize(mlir::Type recipeType) { + if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(recipeType)) + return fir::ReferenceType::get(boxTy); + return std::nullopt; + } + + static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + return builder.create<fir::LoadOp>(loc, value); + } + + static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); + builder.create<fir::StoreOp>(loc, value, alloca); + return alloca; + } +}; + +static void bufferizeRegionArgsAndYields(mlir::Region ®ion, + mlir::Location loc, mlir::Type oldType, + mlir::Type newType) { + if (region.empty()) + return; + + mlir::OpBuilder builder(®ion); + for (mlir::BlockArgument arg : region.getArguments()) { + if (arg.getType() == oldType) { + arg.setType(newType); + if (!arg.use_empty()) { + mlir::Operation *loadOp = BufferizeInterface::load(builder, loc, arg); + arg.replaceAllUsesExcept(loadOp->getResult(0), loadOp); + } + } + } + if (auto yield = + llvm::dyn_cast<mlir::acc::YieldOp>(region.back().getTerminator())) { + llvm::SmallVector<mlir::Value> newOperands; + newOperands.reserve(yield.getNumOperands()); + bool changed = false; + for (mlir::Value oldYieldArg : yield.getOperands()) { + if (oldYieldArg.getType() == oldType) { + builder.setInsertionPoint(yield); + mlir::Value alloca = + BufferizeInterface::placeInMemory(builder, loc, oldYieldArg); + newOperands.push_back(alloca); + changed = true; + } else { + newOperands.push_back(oldYieldArg); + } + } + if (changed) + yield->setOperands(newOperands); + } +} + +static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands, + llvm::StringRef recipeSymName, + mlir::Operation *computeOp) { + if (!recipes) + return; + for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) { + if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() != + recipeSymName) + continue; + + mlir::Operation *dataOp = oldRes.getDefiningOp(); + assert(dataOp && "dataOp must be paired with computeOp"); + mlir::Location loc = dataOp->getLoc(); + mlir::OpBuilder builder(dataOp); + llvm::TypeSwitch<mlir::Operation *, void>(dataOp) + .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp, + mlir::acc::ReductionOp>([&](auto privateOp) { + builder.setInsertionPointAfterValue(privateOp.getVar()); + mlir::Value alloca = BufferizeInterface::placeInMemory( + builder, loc, privateOp.getVar()); + privateOp.getVarMutable().assign(alloca); + privateOp.getAccVar().setType(alloca.getType()); + }); + + llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(), + oldRes.getUsers().end()); + for (mlir::Operation *useOp : users) { + if (useOp == computeOp) + continue; + builder.setInsertionPoint(useOp); + mlir::Operation *load = BufferizeInterface::load(builder, loc, oldRes); + useOp->replaceUsesOfWith(oldRes, load->getResult(0)); + } + } +} + +class ACCRecipeBufferization + : public fir::acc::impl::ACCRecipeBufferizationBase< + ACCRecipeBufferization> { +public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + + llvm::SmallVector<llvm::StringRef> recipeNames; + module.walk([&](mlir::Operation *recipe) { + llvm::TypeSwitch<mlir::Operation *, void>(recipe) + .Case<mlir::acc::PrivateRecipeOp, mlir::acc::FirstprivateRecipeOp, + mlir::acc::ReductionRecipeOp>([&](auto recipe) { + mlir::Type oldType = recipe.getType(); + auto bufferizedType = + BufferizeInterface::mustBufferize(recipe.getType()); + if (!bufferizedType) + return; + recipe.setTypeAttr(mlir::TypeAttr::get(*bufferizedType)); + mlir::Location loc = recipe.getLoc(); + using RecipeOp = decltype(recipe); + bufferizeRegionArgsAndYields(recipe.getInitRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::FirstprivateRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCopyRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::ReductionRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCombinerRegion(), loc, + oldType, *bufferizedType); + bufferizeRegionArgsAndYields(recipe.getDestroyRegion(), loc, + oldType, *bufferizedType); + recipeNames.push_back(recipe.getSymName()); + }); + }); + if (recipeNames.empty()) + return; + + module.walk([&](mlir::Operation *op) { + llvm::TypeSwitch<mlir::Operation *, void>(op) + .Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>( + [&](auto computeOp) { + for (llvm::StringRef recipeName : recipeNames) { + if (computeOp.getPrivatizationRecipes()) + updateRecipeUse(computeOp.getPrivatizationRecipesAttr(), + computeOp.getPrivateOperands(), recipeName, + op); + if (computeOp.getFirstprivatizationRecipes()) + updateRecipeUse( + computeOp.getFirstprivatizationRecipesAttr(), + computeOp.getFirstprivateOperands(), recipeName, op); + if (computeOp.getReductionRecipes()) + updateRecipeUse(computeOp.getReductionRecipesAttr(), + computeOp.getReductionOperands(), + recipeName, op); + } + }); + }); + } +}; + +} // namespace + +std::unique_ptr<mlir::Pass> fir::acc::createACCRecipeBufferizationPass() { + return std::make_unique<ACCRecipeBufferization>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt new file mode 100644 index 0000000..2427da0 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_flang_library(FIROpenACCTransforms + ACCRecipeBufferization.cpp + + DEPENDS + FIROpenACCPassesIncGen + + LINK_LIBS + MLIRIR + MLIRPass + FIRDialect + MLIROpenACCDialect +) diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 260e525..2bbd803 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -40,6 +40,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -128,6 +129,17 @@ class MapInfoFinalizationPass } } + /// Return true if the module has an OpenMP requires clause that includes + /// unified_shared_memory. + static bool moduleRequiresUSM(mlir::ModuleOp module) { + assert(module && "invalid module"); + if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>( + "omp.requires")) + return mlir::omp::bitEnumContainsAll( + req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory); + return false; + } + /// Create the member map for coordRef and append it (and its index /// path) to the provided new* vectors, if it is not already present. void appendMemberMapIfNew( @@ -425,8 +437,12 @@ class MapInfoFinalizationPass mapFlags flags = mapFlags::OMP_MAP_TO | (mapFlags(mapTypeFlag) & - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE | - mapFlags::OMP_MAP_ALWAYS)); + (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); + // For unified_shared_memory, we additionally add `CLOSE` on the descriptor + // to ensure device-local placement where required by tests relying on USM + + // close semantics. + if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) + flags |= mapFlags::OMP_MAP_CLOSE; return llvm::to_underlying(flags); } @@ -518,6 +534,75 @@ class MapInfoFinalizationPass return newMapInfoOp; } + // Expand mappings of type(C_PTR) to map their `__address` field explicitly + // as a single pointer-sized member (USM-gated at callsite). This helps in + // USM scenarios to ensure the pointer-sized mapping is used. + mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op, + fir::FirOpBuilder &builder) { + if (!op.getMembers().empty()) + return op; + + mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType()); + if (!mlir::isa<fir::RecordType>(varTy)) + return op; + auto recTy = mlir::cast<fir::RecordType>(varTy); + // If not a builtin C_PTR record, skip. + if (!recTy.getName().ends_with("__builtin_c_ptr")) + return op; + + // Find the index of the c_ptr address component named "__address". + int32_t fieldIdx = recTy.getFieldIndex("__address"); + if (fieldIdx < 0) + return op; + + mlir::Location loc = op.getVarPtr().getLoc(); + mlir::Type memTy = recTy.getType(fieldIdx); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); + mlir::Value coord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(memTy), op.getVarPtr(), + llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); + + // Child for the `__address` member. + llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; + mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); + // Force CLOSE in USM paths so the pointer gets device-local placement + // when required by tests relying on USM + close semantics. + uint64_t mapTypeVal = + op.getMapType() | + llvm::to_underlying( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); + mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( + builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); + + mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( + builder, loc, coord.getType(), coord, + mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr, + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( + mlir::omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/mlir::Value{}, + /*members=*/llvm::SmallVector<mlir::Value>{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/op.getBounds(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); + + // Rebuild the parent as a container with the `__address` member. + mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create( + builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(), + op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(), + /*varPtrPtr=*/mlir::Value{}, + /*members=*/llvm::SmallVector<mlir::Value>{memberMap}, + /*member_index=*/newMembersAttr, + /*bounds=*/llvm::SmallVector<mlir::Value>{}, + /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); + op.replaceAllUsesWith(newParent.getResult()); + op->erase(); + return newParent; + } + mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Operation *target) { @@ -1169,6 +1254,17 @@ class MapInfoFinalizationPass genBoxcharMemberMap(op, builder); }); + // Expand type(C_PTR) only when unified_shared_memory is required, + // to ensure device-visible pointer size/behavior in USM scenarios + // without changing default expectations elsewhere. + func->walk([&](mlir::omp::MapInfoOp op) { + // Only expand C_PTR members when unified_shared_memory is required. + if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>())) + return; + builder.setInsertionPoint(op); + genCptrMemberMap(op, builder); + }); + func->walk([&](mlir::omp::MapInfoOp op) { // TODO: Currently only supports a single user for the MapInfoOp. This // is fine for the moment, as the Fortran frontend will generate a diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 9507021..b5771eb 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -548,6 +548,14 @@ TYPE_PARSER(construct<OmpAllocatorSimpleModifier>(scalarIntExpr)) TYPE_PARSER(construct<OmpAlwaysModifier>( // "ALWAYS" >> pure(OmpAlwaysModifier::Value::Always))) +TYPE_PARSER(construct<OmpAttachModifier::Value>( + "ALWAYS" >> pure(OmpAttachModifier::Value::Always) || + "AUTO" >> pure(OmpAttachModifier::Value::Auto) || + "NEVER" >> pure(OmpAttachModifier::Value::Never))) + +TYPE_PARSER(construct<OmpAttachModifier>( // + "ATTACH" >> parenthesized(Parser<OmpAttachModifier::Value>{}))) + TYPE_PARSER(construct<OmpAutomapModifier>( "AUTOMAP" >> pure(OmpAutomapModifier::Value::Automap))) @@ -744,6 +752,7 @@ TYPE_PARSER(sourced( TYPE_PARSER(sourced(construct<OmpMapClause::Modifier>( sourced(construct<OmpMapClause::Modifier>(Parser<OmpAlwaysModifier>{}) || + construct<OmpMapClause::Modifier>(Parser<OmpAttachModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpCloseModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpDeleteModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpPresentModifier>{}) || diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 0511f5b..b172e429 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2384,6 +2384,11 @@ public: Walk(x.v); Put(")"); } + void Unparse(const OmpAttachModifier &x) { + Word("ATTACH("); + Walk(x.v); + Put(")"); + } void Unparse(const OmpOrderClause &x) { using Modifier = OmpOrderClause::Modifier; Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ":"); @@ -2820,6 +2825,7 @@ public: WALK_NESTED_ENUM(OmpMapType, Value) // OMP map-type WALK_NESTED_ENUM(OmpMapTypeModifier, Value) // OMP map-type-modifier WALK_NESTED_ENUM(OmpAlwaysModifier, Value) + WALK_NESTED_ENUM(OmpAttachModifier, Value) WALK_NESTED_ENUM(OmpCloseModifier, Value) WALK_NESTED_ENUM(OmpDeleteModifier, Value) WALK_NESTED_ENUM(OmpPresentModifier, Value) diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index ea5e2c0..31e246c 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -3622,6 +3622,7 @@ void CheckHelper::CheckDioDtvArg(const Symbol &proc, const Symbol &subp, ioKind == common::DefinedIo::ReadUnformatted ? Attr::INTENT_INOUT : Attr::INTENT_IN); + CheckDioDummyIsScalar(subp, *arg); } } @@ -3687,6 +3688,7 @@ void CheckHelper::CheckDioAssumedLenCharacterArg(const Symbol &subp, "Dummy argument '%s' of a defined input/output procedure must be assumed-length CHARACTER of default kind"_err_en_US, arg->name()); } + CheckDioDummyIsScalar(subp, *arg); } } diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index 351af5c..515121a 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -519,8 +519,8 @@ private: /// function references with scalar data pointer result of non-character /// intrinsic type or variables that are non-polymorphic scalar pointers /// and any length type parameter must be constant. -void OmpStructureChecker::CheckAtomicType( - SymbolRef sym, parser::CharBlock source, std::string_view name) { +void OmpStructureChecker::CheckAtomicType(SymbolRef sym, + parser::CharBlock source, std::string_view name, bool checkTypeOnPointer) { const DeclTypeSpec *typeSpec{sym->GetType()}; if (!typeSpec) { return; @@ -547,6 +547,22 @@ void OmpStructureChecker::CheckAtomicType( return; } + // Apply pointer-to-non-intrinsic rule only for intrinsic-assignment paths. + if (checkTypeOnPointer) { + using Category = DeclTypeSpec::Category; + Category cat{typeSpec->category()}; + if (cat != Category::Numeric && cat != Category::Logical) { + std::string details = " has the POINTER attribute"; + if (const auto *derived{typeSpec->AsDerived()}) { + details += " and derived type '"s + derived->name().ToString() + "'"; + } + context_.Say(source, + "ATOMIC operation requires an intrinsic scalar variable; '%s'%s"_err_en_US, + sym->name(), details); + return; + } + } + // Go over all length parameters, if any, and check if they are // explicit. if (const DerivedTypeSpec *derived{typeSpec->AsDerived()}) { @@ -562,7 +578,7 @@ void OmpStructureChecker::CheckAtomicType( } void OmpStructureChecker::CheckAtomicVariable( - const SomeExpr &atom, parser::CharBlock source) { + const SomeExpr &atom, parser::CharBlock source, bool checkTypeOnPointer) { if (atom.Rank() != 0) { context_.Say(source, "Atomic variable %s should be a scalar"_err_en_US, atom.AsFortran()); @@ -572,7 +588,7 @@ void OmpStructureChecker::CheckAtomicVariable( assert(dsgs.size() == 1 && "Should have a single top-level designator"); evaluate::SymbolVector syms{evaluate::GetSymbolVector(dsgs.front())}; - CheckAtomicType(syms.back(), source, atom.AsFortran()); + CheckAtomicType(syms.back(), source, atom.AsFortran(), checkTypeOnPointer); if (IsAllocatable(syms.back()) && !IsArrayElement(atom)) { context_.Say(source, "Atomic variable %s cannot be ALLOCATABLE"_err_en_US, @@ -789,7 +805,8 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, rsrc); + CheckAtomicVariable( + atom, rsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(capture)); // This part should have been checked prior to calling this function. assert(*GetConvertInput(capture.rhs) == atom && "This cannot be a capture assignment"); @@ -808,7 +825,8 @@ void OmpStructureChecker::CheckAtomicReadAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, rsrc); + CheckAtomicVariable( + atom, rsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(read)); CheckStorageOverlap(atom, {read.lhs}, source); } } else { @@ -829,7 +847,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, lsrc); + CheckAtomicVariable( + atom, lsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(write)); CheckStorageOverlap(atom, {write.rhs}, source); } } @@ -854,7 +873,8 @@ OmpStructureChecker::CheckAtomicUpdateAssignment( return std::nullopt; } - CheckAtomicVariable(atom, lsrc); + CheckAtomicVariable( + atom, lsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(update)); auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs( atom, update.rhs, source, /*suppressDiagnostics=*/true)}; @@ -1017,7 +1037,8 @@ void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment( return; } - CheckAtomicVariable(atom, alsrc); + CheckAtomicVariable( + atom, alsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(assign)); auto top{GetTopLevelOperationIgnoreResizing(cond)}; // Missing arguments to operations would have been diagnosed by now. diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index d65a89e..4b5610a 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -3017,8 +3017,8 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) { &objs, std::string clause) { for (const auto &obj : objs.v) { - if (const parser::Name * - objName{parser::Unwrap<parser::Name>(obj)}) { + if (const parser::Name *objName{ + parser::Unwrap<parser::Name>(obj)}) { if (&objName->symbol->GetUltimate() == eventHandleSym) { context_.Say(GetContext().clauseSource, "A variable: `%s` that appears in a DETACH clause cannot appear on %s clause on the same construct"_err_en_US, @@ -3637,7 +3637,8 @@ void OmpStructureChecker::CheckReductionModifier( if (modifier.v == ReductionModifier::Value::Task) { // "Task" is only allowed on worksharing or "parallel" directive. static llvm::omp::Directive worksharing[]{ - llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_scope, + llvm::omp::Directive::OMPD_do, // + llvm::omp::Directive::OMPD_scope, // llvm::omp::Directive::OMPD_sections, // There are more worksharing directives, but they do not apply: // "for" is C++ only, @@ -4081,9 +4082,15 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { if (auto *iter{OmpGetUniqueModifier<parser::OmpIterator>(modifiers)}) { CheckIteratorModifier(*iter); } + + using Directive = llvm::omp::Directive; + Directive dir{GetContext().directive}; + llvm::ArrayRef<Directive> leafs{llvm::omp::getLeafConstructsOrSelf(dir)}; + parser::OmpMapType::Value mapType{parser::OmpMapType::Value::Storage}; + if (auto *type{OmpGetUniqueModifier<parser::OmpMapType>(modifiers)}) { - using Directive = llvm::omp::Directive; using Value = parser::OmpMapType::Value; + mapType = type->v; static auto isValidForVersion{ [](parser::OmpMapType::Value t, unsigned version) { @@ -4120,10 +4127,6 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { return result; }()}; - llvm::omp::Directive dir{GetContext().directive}; - llvm::ArrayRef<llvm::omp::Directive> leafs{ - llvm::omp::getLeafConstructsOrSelf(dir)}; - if (llvm::is_contained(leafs, Directive::OMPD_target) || llvm::is_contained(leafs, Directive::OMPD_target_data)) { if (version >= 60) { @@ -4141,6 +4144,43 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { } } + if (auto *attach{ + OmpGetUniqueModifier<parser::OmpAttachModifier>(modifiers)}) { + bool mapEnteringConstructOrMapper{ + llvm::is_contained(leafs, Directive::OMPD_target) || + llvm::is_contained(leafs, Directive::OMPD_target_data) || + llvm::is_contained(leafs, Directive::OMPD_target_enter_data) || + llvm::is_contained(leafs, Directive::OMPD_declare_mapper)}; + + if (!mapEnteringConstructOrMapper || !IsMapEnteringType(mapType)) { + const auto &desc{OmpGetDescriptor<parser::OmpAttachModifier>()}; + context_.Say(OmpGetModifierSource(modifiers, attach), + "The '%s' modifier can only appear on a map-entering construct or on a DECLARE_MAPPER directive"_err_en_US, + desc.name.str()); + } + + auto hasBasePointer{[&](const SomeExpr &item) { + evaluate::SymbolVector symbols{evaluate::GetSymbolVector(item)}; + return llvm::any_of( + symbols, [](SymbolRef s) { return IsPointer(s.get()); }); + }}; + + evaluate::ExpressionAnalyzer ea{context_}; + const auto &objects{std::get<parser::OmpObjectList>(x.v.t)}; + for (auto &object : objects.v) { + if (const parser::Designator *d{GetDesignatorFromObj(object)}) { + if (auto &&expr{ea.Analyze(*d)}) { + if (hasBasePointer(*expr)) { + continue; + } + } + } + auto source{GetObjectSource(object)}; + context_.Say(source ? *source : GetContext().clauseSource, + "A list-item that appears in a map clause with the ATTACH modifier must have a base-pointer"_err_en_US); + } + } + auto &&typeMods{ OmpGetRepeatableModifier<parser::OmpMapTypeModifier>(modifiers)}; struct Less { diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index f507278..543642ff 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -262,10 +262,10 @@ private: void CheckStorageOverlap(const evaluate::Expr<evaluate::SomeType> &, llvm::ArrayRef<evaluate::Expr<evaluate::SomeType>>, parser::CharBlock); void ErrorShouldBeVariable(const MaybeExpr &expr, parser::CharBlock source); - void CheckAtomicType( - SymbolRef sym, parser::CharBlock source, std::string_view name); - void CheckAtomicVariable( - const evaluate::Expr<evaluate::SomeType> &, parser::CharBlock); + void CheckAtomicType(SymbolRef sym, parser::CharBlock source, + std::string_view name, bool checkTypeOnPointer = true); + void CheckAtomicVariable(const evaluate::Expr<evaluate::SomeType> &, + parser::CharBlock, bool checkTypeOnPointer = true); std::pair<const parser::ExecutionPartConstruct *, const parser::ExecutionPartConstruct *> CheckUpdateCapture(const parser::ExecutionPartConstruct *ec1, diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 8074c94..556259d 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -17,6 +17,7 @@ #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "llvm/Frontend/OpenMP/OMP.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" @@ -24,6 +25,7 @@ #include <fstream> #include <set> #include <string_view> +#include <type_traits> #include <variant> #include <vector> @@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) { } } +static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) { + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + using OmpMemoryOrderType = common::OmpMemoryOrderType; + + const auto [reqs, order]{common::visit( + [&](auto &&details) + -> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + return {details.ompRequires(), details.ompAtomicDefaultMemOrder()}; + } else { + return {nullptr, nullptr}; + } + }, + symbol.details())}; + + if (order) { + llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order}; + os << "!$omp requires " + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n"; + } + if (reqs) { + os << "!$omp requires"; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) { + os << ' ' + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + } + }); + os << "\n"; + } +} + // Put out the visible symbols from scope. void ModFileWriter::PutSymbols( const Scope &scope, UnorderedSymbolSet *hermeticModules) { @@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols( for (const Symbol &symbol : uses) { PutUse(symbol); } + PutOpenMPRequirements(decls_, DEREF(scope.symbol())); for (const auto &set : scope.equivalenceSets()) { if (!set.empty() && !set.front().symbol.test(Symbol::Flag::CompilerCreated)) { diff --git a/flang/lib/Semantics/openmp-modifiers.cpp b/flang/lib/Semantics/openmp-modifiers.cpp index af4000c..717fb03 100644 --- a/flang/lib/Semantics/openmp-modifiers.cpp +++ b/flang/lib/Semantics/openmp-modifiers.cpp @@ -157,6 +157,22 @@ const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAlwaysModifier>() { } template <> +const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAttachModifier>() { + static const OmpModifierDescriptor desc{ + /*name=*/"attach-modifier", + /*props=*/ + { + {61, {OmpProperty::Unique}}, + }, + /*clauses=*/ + { + {61, {Clause::OMPC_map}}, + }, + }; + return desc; +} + +template <> const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAutomapModifier>() { static const OmpModifierDescriptor desc{ /*name=*/"automap-modifier", diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index a8ec4d6..292e73b 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -13,6 +13,7 @@ #include "flang/Semantics/openmp-utils.h" #include "flang/Common/Fortran-consts.h" +#include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" @@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) { return *iter; } +const Scope &GetProgramUnit(const Scope &scope) { + const Scope *unit{nullptr}; + for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockData: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + return *iter; + case Scope::Kind::Subprogram: + // Ignore subprograms that are nested. + unit = iter; + break; + default: + break; + } + } + assert(unit && "Scope not in a program unit"); + return *unit; +} + SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) { if (x == nullptr) { return SourcedActionStmt{}; @@ -202,7 +223,7 @@ std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) { // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() // `- std::optional<Expr> ^->v - return typedExpr.get()->v; + return DEREF(typedExpr.get()).v; } std::optional<evaluate::DynamicType> GetDynamicType( diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 18fc638..1228493 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -435,6 +435,22 @@ public: return true; } + bool Pre(const parser::UseStmt &x) { + if (x.moduleName.symbol) { + Scope &thisScope{context_.FindScope(x.moduleName.source)}; + common::visit( + [&](auto &&details) { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + AddOmpRequiresToScope(thisScope, details.ompRequires(), + details.ompAtomicDefaultMemOrder()); + } + }, + x.moduleName.symbol->details()); + } + return true; + } + bool Pre(const parser::OmpMetadirectiveDirective &x) { PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective); return true; @@ -538,38 +554,37 @@ public: void Post(const parser::OpenMPFlushConstruct &) { PopContext(); } bool Pre(const parser::OpenMPRequiresConstruct &x) { - using Flags = WithOmpDeclarative::RequiresFlags; - using Requires = WithOmpDeclarative::RequiresFlag; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; PushContext(x.source, llvm::omp::Directive::OMPD_requires); // Gather information from the clauses. - Flags flags; - std::optional<common::OmpMemoryOrderType> memOrder; + RequiresClauses reqs; + const common::OmpMemoryOrderType *memOrder{nullptr}; for (const parser::OmpClause &clause : x.v.Clauses().v) { - flags |= common::visit( + using OmpClause = parser::OmpClause; + reqs |= common::visit( common::visitors{ - [&memOrder]( - const parser::OmpClause::AtomicDefaultMemOrder &atomic) { - memOrder = atomic.v.v; - return Flags{}; - }, - [](const parser::OmpClause::ReverseOffload &) { - return Flags{Requires::ReverseOffload}; - }, - [](const parser::OmpClause::UnifiedAddress &) { - return Flags{Requires::UnifiedAddress}; + [&](const OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = &atomic.v.v; + return RequiresClauses{}; }, - [](const parser::OmpClause::UnifiedSharedMemory &) { - return Flags{Requires::UnifiedSharedMemory}; - }, - [](const parser::OmpClause::DynamicAllocators &) { - return Flags{Requires::DynamicAllocators}; + [&](auto &&s) { + using TypeS = llvm::remove_cvref_t<decltype(s)>; + if constexpr ( // + std::is_same_v<TypeS, OmpClause::DynamicAllocators> || + std::is_same_v<TypeS, OmpClause::ReverseOffload> || + std::is_same_v<TypeS, OmpClause::UnifiedAddress> || + std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) { + return RequiresClauses{clause.Id()}; + } else { + return RequiresClauses{}; + } }, - [](const auto &) { return Flags{}; }}, + }, clause.u); } // Merge clauses into parents' symbols details. - AddOmpRequiresToScope(currScope(), flags, memOrder); + AddOmpRequiresToScope(currScope(), &reqs, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -1001,8 +1016,9 @@ private: std::int64_t ordCollapseLevel{0}; - void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags, - std::optional<common::OmpMemoryOrderType>); + void AddOmpRequiresToScope(Scope &, + const WithOmpDeclarative::RequiresClauses *, + const common::OmpMemoryOrderType *); void IssueNonConformanceWarning(llvm::omp::Directive D, parser::CharBlock source, unsigned EmitFromVersion); @@ -3309,86 +3325,6 @@ void ResolveOmpParts( } } -void ResolveOmpTopLevelParts( - SemanticsContext &context, const parser::Program &program) { - if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { - return; - } - - // Gather REQUIRES clauses from all non-module top-level program unit symbols, - // combine them together ensuring compatibility and apply them to all these - // program units. Modules are skipped because their REQUIRES clauses should be - // propagated via USE statements instead. - WithOmpDeclarative::RequiresFlags combinedFlags; - std::optional<common::OmpMemoryOrderType> combinedMemOrder; - - // Function to go through non-module top level program units and extract - // REQUIRES information to be processed by a function-like argument. - auto processProgramUnits{[&](auto processFn) { - for (const parser::ProgramUnit &unit : program.v) { - if (!std::holds_alternative<common::Indirection<parser::Module>>( - unit.u) && - !std::holds_alternative<common::Indirection<parser::Submodule>>( - unit.u) && - !std::holds_alternative< - common::Indirection<parser::CompilerDirective>>(unit.u)) { - Symbol *symbol{common::visit( - [&context](auto &x) { - Scope *scope = GetScope(context, x.value()); - return scope ? scope->symbol() : nullptr; - }, - unit.u)}; - // FIXME There is no symbol defined for MainProgram units in certain - // circumstances, so REQUIRES information has no place to be stored in - // these cases. - if (!symbol) { - continue; - } - common::visit( - [&](auto &details) { - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - processFn(*symbol, details); - } - }, - symbol->details()); - } - } - }}; - - // Combine global REQUIRES information from all program units except modules - // and submodules. - processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) { - if (const WithOmpDeclarative::RequiresFlags * - flags{details.ompRequires()}) { - combinedFlags |= *flags; - } - if (const common::OmpMemoryOrderType * - memOrder{details.ompAtomicDefaultMemOrder()}) { - if (combinedMemOrder && *combinedMemOrder != *memOrder) { - context.Say(symbol.scope()->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - combinedMemOrder = *memOrder; - } - }); - - // Update all program units except modules and submodules with the combined - // global REQUIRES information. - processProgramUnits([&](Symbol &, WithOmpDeclarative &details) { - if (combinedFlags.any()) { - details.set_ompRequires(combinedFlags); - } - if (combinedMemOrder) { - details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); - } - }); -} - static bool IsSymbolThreadprivate(const Symbol &symbol) { if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) { return details->symbol().test(Symbol::Flag::OmpThreadprivate); @@ -3547,42 +3483,39 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source, } void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, - WithOmpDeclarative::RequiresFlags flags, - std::optional<common::OmpMemoryOrderType> memOrder) { - Scope *scopeIter = &scope; - do { - if (Symbol * symbol{scopeIter->symbol()}) { - common::visit( - [&](auto &details) { - // Store clauses information into the symbol for the parent and - // enclosing modules, programs, functions and subroutines. - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - if (flags.any()) { - if (const WithOmpDeclarative::RequiresFlags * - otherFlags{details.ompRequires()}) { - flags |= *otherFlags; - } - details.set_ompRequires(flags); + const WithOmpDeclarative::RequiresClauses *reqs, + const common::OmpMemoryOrderType *memOrder) { + const Scope &programUnit{omp::GetProgramUnit(scope)}; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}}; + + if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) { + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v<decltype(&details), + WithOmpDeclarative *>) { + if (combinedReqs.any()) { + if (const RequiresClauses *otherReqs{details.ompRequires()}) { + combinedReqs |= *otherReqs; } - if (memOrder) { - if (details.has_ompAtomicDefaultMemOrder() && - *details.ompAtomicDefaultMemOrder() != *memOrder) { - context_.Say(scopeIter->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - details.set_ompAtomicDefaultMemOrder(*memOrder); + details.set_ompRequires(combinedReqs); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(programUnit.sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); } + details.set_ompAtomicDefaultMemOrder(*memOrder); } - }, - symbol->details()); - } - scopeIter = &scopeIter->parent(); - } while (!scopeIter->IsGlobal()); + } + }, + symbol->details()); + } } void OmpAttributeVisitor::IssueNonConformanceWarning(llvm::omp::Directive D, diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h index 5a890c2..36d3ce9 100644 --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -23,7 +23,5 @@ class SemanticsContext; void ResolveAccParts( SemanticsContext &, const parser::ProgramUnit &, Scope *topScope); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); -void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); - } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 86121880..ae0ff9ca 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) { CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); - // Top-level resolution to propagate information across program units after - // each of them has been resolved separately. - ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp index 69169469..0ec44b7 100644 --- a/flang/lib/Semantics/symbol.cpp +++ b/flang/lib/Semantics/symbol.cpp @@ -70,6 +70,32 @@ static void DumpList(llvm::raw_ostream &os, const char *label, const T &list) { } } +llvm::raw_ostream &operator<<( + llvm::raw_ostream &os, const WithOmpDeclarative &x) { + if (x.has_ompRequires() || x.has_ompAtomicDefaultMemOrder()) { + os << " OmpRequirements:("; + if (const common::OmpMemoryOrderType *admo{x.ompAtomicDefaultMemOrder()}) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*admo)) << ')'; + if (x.has_ompRequires()) { + os << ','; + } + } + if (const WithOmpDeclarative::RequiresClauses *reqs{x.ompRequires()}) { + size_t num{0}, size{reqs->count()}; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + if (++num < size) { + os << ','; + } + }); + } + os << ')'; + } + return os; +} + void SubprogramDetails::set_moduleInterface(Symbol &symbol) { CHECK(!moduleInterface_); moduleInterface_ = &symbol; @@ -150,6 +176,7 @@ llvm::raw_ostream &operator<<( os << x; } } + os << static_cast<const WithOmpDeclarative &>(x); return os; } @@ -580,7 +607,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { common::visit( // common::visitors{ [&](const UnknownDetails &) {}, - [&](const MainProgramDetails &) {}, + [&](const MainProgramDetails &x) { + os << static_cast<const WithOmpDeclarative &>(x); + }, [&](const ModuleDetails &x) { if (x.isSubmodule()) { os << " ("; @@ -599,6 +628,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { if (x.isDefaultPrivate()) { os << " isDefaultPrivate"; } + os << static_cast<const WithOmpDeclarative &>(x); }, [&](const SubprogramNameDetails &x) { os << ' ' << EnumToString(x.kind()); |