diff options
Diffstat (limited to 'flang/lib')
54 files changed, 3462 insertions, 406 deletions
diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp index f204eef..1de5e6b 100644 --- a/flang/lib/Evaluate/intrinsics.cpp +++ b/flang/lib/Evaluate/intrinsics.cpp @@ -111,6 +111,7 @@ ENUM_CLASS(KindCode, none, defaultIntegerKind, atomicIntKind, // atomic_int_kind from iso_fortran_env atomicIntOrLogicalKind, // atomic_int_kind or atomic_logical_kind sameAtom, // same type and kind as atom + extensibleOrUnlimitedType, // extensible or unlimited polymorphic type ) struct TypePattern { @@ -160,7 +161,8 @@ static constexpr TypePattern AnyChar{CharType, KindCode::any}; static constexpr TypePattern AnyLogical{LogicalType, KindCode::any}; static constexpr TypePattern AnyRelatable{RelatableType, KindCode::any}; static constexpr TypePattern AnyIntrinsic{IntrinsicType, KindCode::any}; -static constexpr TypePattern ExtensibleDerived{DerivedType, KindCode::any}; +static constexpr TypePattern ExtensibleDerived{ + DerivedType, KindCode::extensibleOrUnlimitedType}; static constexpr TypePattern AnyData{AnyType, KindCode::any}; // Type is irrelevant, but not BOZ (for PRESENT(), OPTIONAL(), &c.) @@ -2103,9 +2105,13 @@ std::optional<SpecificCall> IntrinsicInterface::Match( } return std::nullopt; } else if (!d.typePattern.categorySet.test(type->category())) { + const char *expected{ + d.typePattern.kindCode == KindCode::extensibleOrUnlimitedType + ? ", expected extensible or unlimited polymorphic type" + : ""}; messages.Say(arg->sourceLocation(), - "Actual argument for '%s=' has bad type '%s'"_err_en_US, d.keyword, - type->AsFortran()); + "Actual argument for '%s=' has bad type '%s'%s"_err_en_US, d.keyword, + type->AsFortran(), expected); return std::nullopt; // argument has invalid type category } bool argOk{false}; @@ -2244,6 +2250,17 @@ std::optional<SpecificCall> IntrinsicInterface::Match( return std::nullopt; } break; + case KindCode::extensibleOrUnlimitedType: + argOk = type->IsUnlimitedPolymorphic() || + (type->category() == TypeCategory::Derived && + IsExtensibleType(GetDerivedTypeSpec(type))); + if (!argOk) { + messages.Say(arg->sourceLocation(), + "Actual argument for '%s=' has type '%s', but was expected to be an extensible or unlimited polymorphic type"_err_en_US, + d.keyword, type->AsFortran()); + return std::nullopt; + } + break; default: CRASH_NO_CASE; } diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index b927fa3..bd06acc 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1153,6 +1153,18 @@ bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) { return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0; } +bool IsCUDADeviceSymbol(const Symbol &sym) { + if (const auto *details = + sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) { + return details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned; + } else if (const auto *details = + sym.GetUltimate().detailsIf<semantics::AssocEntityDetails>()) { + return GetNbOfCUDADeviceSymbols(details->expr()) > 0; + } + return false; +} + // HasVectorSubscript() struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper, bool, diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp index 53239cb..e7a6c4d 100644 --- a/flang/lib/Lower/Allocatable.cpp +++ b/flang/lib/Lower/Allocatable.cpp @@ -629,6 +629,10 @@ private: unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol()); fir::ExtendedValue exv = isSource ? sourceExv : moldExv; + if (const Fortran::semantics::Symbol *sym{GetLastSymbol(sourceExpr)}) + if (Fortran::semantics::IsCUDADevice(*sym)) + TODO(loc, "CUDA Fortran: allocate with device source"); + // Generate a sequence of runtime calls. errorManager.genStatCheck(builder, loc); genAllocateObjectInit(box, allocatorIdx); @@ -767,6 +771,15 @@ private: const fir::MutableBoxValue &box, ErrorManager &errorManager, const Fortran::semantics::Symbol &sym) { + + if (const Fortran::semantics::DeclTypeSpec *declTypeSpec = sym.GetType()) + if (const Fortran::semantics::DerivedTypeSpec *derivedTypeSpec = + declTypeSpec->AsDerived()) + if (derivedTypeSpec->HasDefaultInitialization( + /*ignoreAllocatable=*/true, /*ignorePointer=*/true)) + TODO(loc, + "CUDA Fortran: allocate on device with default initialization"); + Fortran::lower::StatementContext stmtCtx; cuf::DataAttributeAttr cudaAttr = Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(), diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 68adf34..0595ca0 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4987,11 +4987,8 @@ private: // host = device if (!lhsIsDevice && rhsIsDevice) { - if (Fortran::lower::isTransferWithConversion(rhs)) { + if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) { mlir::OpBuilder::InsertionGuard insertionGuard(builder); - auto elementalOp = - mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()); - assert(elementalOp && "expect elemental op"); auto designateOp = *elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin(); builder.setInsertionPoint(elementalOp); @@ -6079,7 +6076,7 @@ private: if (resTy != wrappedSymTy) { // check size of the pointed to type so we can't overflow by writing // double precision to a single precision allocation, etc - LLVM_ATTRIBUTE_UNUSED auto getBitWidth = [this](mlir::Type ty) { + [[maybe_unused]] auto getBitWidth = [this](mlir::Type ty) { // 15.6.2.6.3: differering result types should be integer, real, // complex or logical if (auto cmplx = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp index bb4bdee..9501b0e 100644 --- a/flang/lib/Lower/CUDA.cpp +++ b/flang/lib/Lower/CUDA.cpp @@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( return cuf::getDataAttribute(mlirContext, cudaAttr); } -bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) { +hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) { + auto isConversionElementalOp = [](hlfir::ElementalOp elOp) { + return llvm::hasSingleElement( + elOp.getBody()->getOps<hlfir::DesignateOp>()) && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == + 1; + }; + if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) { + if (!declOp.getMemref().getDefiningOp()) + return {}; + if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>( + declOp.getMemref().getDefiningOp())) + if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>( + associateOp.getSource().getDefiningOp())) + if (isConversionElementalOp(elOp)) + return elOp; + } if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp())) - if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) && - llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 && - llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1) - return true; - return false; + if (isConversionElementalOp(elOp)) + return elOp; + return {}; } diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index d7f94e1..a46d219 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -5603,7 +5603,7 @@ private: return newIters; }; if (useTripsForSlice) { - LLVM_ATTRIBUTE_UNUSED auto vectorSubscriptShape = + [[maybe_unused]] auto vectorSubscriptShape = getShape(arrayOperands.back()); auto undef = fir::UndefOp::create(builder, loc, idxTy); trips.push_back(undef); diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 98dc78f..cd53dc9 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -524,12 +524,18 @@ getNamelistGroup(Fortran::lower::AbstractConverter &converter, descAddr = builder.createConvert(loc, builder.getRefType(symType), varAddr); } else { + fir::BaseBoxType boxType; const auto expr = Fortran::evaluate::AsGenericExpr(s); fir::ExtendedValue exv = converter.genExprAddr(*expr, stmtCtx); mlir::Type type = fir::getBase(exv).getType(); + bool isClassType = mlir::isa<fir::ClassType>(type); if (mlir::Type baseTy = fir::dyn_cast_ptrOrBoxEleTy(type)) type = baseTy; - fir::BoxType boxType = fir::BoxType::get(fir::PointerType::get(type)); + + if (isClassType) + boxType = fir::ClassType::get(fir::PointerType::get(type)); + else + boxType = fir::BoxType::get(fir::PointerType::get(type)); descAddr = builder.createTemporary(loc, boxType); fir::MutableBoxValue box = fir::MutableBoxValue(descAddr, {}, {}); fir::factory::associateMutableBox(builder, loc, box, exv, @@ -944,7 +950,8 @@ static void genIoLoop(Fortran::lower::AbstractConverter &converter, makeNextConditionalOn(builder, loc, checkResult, ok, inLoop); const auto &itemList = std::get<0>(ioImpliedDo.t); const auto &control = std::get<1>(ioImpliedDo.t); - const auto &loopSym = *control.name.thing.thing.symbol; + const auto &loopSym = + *Fortran::parser::UnwrapRef<Fortran::parser::Name>(control.name).symbol; mlir::Value loopVar = fir::getBase(converter.genExprAddr( Fortran::evaluate::AsGenericExpr(loopSym).value(), stmtCtx)); auto genControlValue = [&](const Fortran::parser::ScalarIntExpr &expr) { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 62e5c0c..cfb1891 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -978,15 +978,40 @@ static RecipeOp genRecipeOp( auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty); assert(mappableTy && "Expected that all variable types are considered mappable"); + bool needsDestroy = false; auto retVal = mappableTy.generatePrivateInit( builder, loc, mlir::cast<mlir::TypedValue<mlir::acc::MappableType>>( initBlock->getArgument(0)), initName, initBlock->getArguments().take_back(initBlock->getArguments().size() - 1), - initValue); + initValue, needsDestroy); mlir::acc::YieldOp::create(builder, loc, retVal ? retVal : initBlock->getArgument(0)); + // Create destroy region and generate destruction if requested. + if (needsDestroy) { + llvm::SmallVector<mlir::Type> destroyArgsTy; + llvm::SmallVector<mlir::Location> destroyArgsLoc; + // original and privatized/reduction value + destroyArgsTy.push_back(ty); + destroyArgsTy.push_back(ty); + destroyArgsLoc.push_back(loc); + destroyArgsLoc.push_back(loc); + // Append bounds arguments (if any) in the same order as init region + if (argsTy.size() > 1) { + destroyArgsTy.append(argsTy.begin() + 1, argsTy.end()); + destroyArgsLoc.insert(destroyArgsLoc.end(), argsTy.size() - 1, loc); + } + + builder.createBlock(&recipe.getDestroyRegion(), + recipe.getDestroyRegion().end(), destroyArgsTy, + destroyArgsLoc); + builder.setInsertionPointToEnd(&recipe.getDestroyRegion().back()); + // Call interface on the privatized/reduction value (2nd argument). + (void)mappableTy.generatePrivateDestroy( + builder, loc, recipe.getDestroyRegion().front().getArgument(1)); + mlir::acc::TerminatorOp::create(builder, loc); + } return recipe; } diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ff82a36..3ab8a58 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -20,6 +20,7 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Parser/parse-tree.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/type.h" #include "flang/Support/Fortran.h" @@ -183,12 +184,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) { // scope. // For safety, traverse all enclosing scopes and check if their symbol // contains REQUIRES. - for (const auto *sc{&scope}; sc->kind() != semantics::Scope::Kind::Global; - sc = &sc->parent()) { - const semantics::Symbol *sym = sc->symbol(); - if (!sym) - continue; - + const semantics::Scope &unitScope = semantics::omp::GetProgramUnit(scope); + if (auto *symbol = unitScope.symbol()) { const common::OmpMemoryOrderType *admo = common::visit( [](auto &&s) { using WithOmpDeclarative = semantics::WithOmpDeclarative; @@ -198,7 +195,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) { } return static_cast<const common::OmpMemoryOrderType *>(nullptr); }, - sym->details()); + symbol->details()); + if (admo) return getMemoryOrderKind(*admo); } @@ -214,19 +212,83 @@ getDefaultAtomicMemOrder(semantics::SemanticsContext &semaCtx) { return std::nullopt; } -static std::optional<mlir::omp::ClauseMemoryOrderKind> +static std::pair<std::optional<mlir::omp::ClauseMemoryOrderKind>, bool> getAtomicMemoryOrder(semantics::SemanticsContext &semaCtx, const omp::List<omp::Clause> &clauses, const semantics::Scope &scope) { for (const omp::Clause &clause : clauses) { if (auto maybeKind = getMemoryOrderKind(clause.id)) - return *maybeKind; + return std::make_pair(*maybeKind, /*canOverride=*/false); } if (auto maybeKind = getMemoryOrderFromRequires(scope)) - return *maybeKind; + return std::make_pair(*maybeKind, /*canOverride=*/true); - return getDefaultAtomicMemOrder(semaCtx); + return std::make_pair(getDefaultAtomicMemOrder(semaCtx), + /*canOverride=*/false); +} + +static std::optional<mlir::omp::ClauseMemoryOrderKind> +makeValidForAction(std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder, + int action0, int action1, unsigned version) { + // When the atomic default memory order specified on a REQUIRES directive is + // disallowed on a given ATOMIC operation, and it's not ACQ_REL, the order + // reverts to RELAXED. ACQ_REL decays to either ACQUIRE or RELEASE, depending + // on the operation. + + if (!memOrder) { + return memOrder; + } + + using Analysis = parser::OpenMPAtomicConstruct::Analysis; + // Figure out the main action (i.e. disregard a potential capture operation) + int action = action0; + if (action1 != Analysis::None) + action = action0 == Analysis::Read ? action1 : action0; + + // Avaliable orderings: acquire, acq_rel, relaxed, release, seq_cst + + if (action == Analysis::Read) { + // "acq_rel" decays to "acquire" + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Acquire; + } else if (action == Analysis::Write) { + // "acq_rel" decays to "release" + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Release; + } + + if (version > 50) { + if (action == Analysis::Read) { + // "release" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + if (action == Analysis::Write) { + // "acquire" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + } else { + if (action == Analysis::Read) { + // "release" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } else { + if (action & Analysis::Write) { // include "update" + // "acquire" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + if (action == Analysis::Update) { + // "acq_rel" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + } + } + } + + return memOrder; } static mlir::omp::ClauseMemoryOrderKindAttr @@ -449,16 +511,19 @@ void Fortran::lower::omp::lowerAtomic( mlir::Value atomAddr = fir::getBase(converter.genExprAddr(atom, stmtCtx, &loc)); mlir::IntegerAttr hint = getAtomicHint(converter, clauses); - std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder = - getAtomicMemoryOrder(semaCtx, clauses, - semaCtx.FindScope(construct.source)); + auto [memOrder, canOverride] = getAtomicMemoryOrder( + semaCtx, clauses, semaCtx.FindScope(construct.source)); + + unsigned version = semaCtx.langOptions().OpenMPVersion; + int action0 = analysis.op0.what & analysis.Action; + int action1 = analysis.op1.what & analysis.Action; + if (canOverride) + memOrder = makeValidForAction(memOrder, action0, action1, version); if (auto *cond = get(analysis.cond)) { (void)cond; TODO(loc, "OpenMP ATOMIC COMPARE"); } else { - int action0 = analysis.op0.what & analysis.Action; - int action1 = analysis.op1.what & analysis.Action; mlir::Operation *captureOp = nullptr; fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint(); fir::FirOpBuilder::InsertPoint atomicAt, postAt; diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 55eda7e..85398be 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1343,8 +1343,10 @@ bool ClauseProcessor::processMap( const parser::CharBlock &source) { using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &[mapType, typeMods, refMod, mappers, iterator, objects] = - clause.t; + const auto &[mapType, typeMods, attachMod, refMod, mappers, iterator, + objects] = clause.t; + if (attachMod) + TODO(currentLocation, "ATTACH modifier is not implemented yet"); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; std::string mapperIdName = "__implicit_mapper"; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index fac37a3..ba34212 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -219,7 +219,6 @@ MAKE_EMPTY_CLASS(AcqRel, AcqRel); MAKE_EMPTY_CLASS(Acquire, Acquire); MAKE_EMPTY_CLASS(Capture, Capture); MAKE_EMPTY_CLASS(Compare, Compare); -MAKE_EMPTY_CLASS(DynamicAllocators, DynamicAllocators); MAKE_EMPTY_CLASS(Full, Full); MAKE_EMPTY_CLASS(Inbranch, Inbranch); MAKE_EMPTY_CLASS(Mergeable, Mergeable); @@ -235,13 +234,9 @@ MAKE_EMPTY_CLASS(OmpxBare, OmpxBare); MAKE_EMPTY_CLASS(Read, Read); MAKE_EMPTY_CLASS(Relaxed, Relaxed); MAKE_EMPTY_CLASS(Release, Release); -MAKE_EMPTY_CLASS(ReverseOffload, ReverseOffload); MAKE_EMPTY_CLASS(SeqCst, SeqCst); -MAKE_EMPTY_CLASS(SelfMaps, SelfMaps); MAKE_EMPTY_CLASS(Simd, Simd); MAKE_EMPTY_CLASS(Threads, Threads); -MAKE_EMPTY_CLASS(UnifiedAddress, UnifiedAddress); -MAKE_EMPTY_CLASS(UnifiedSharedMemory, UnifiedSharedMemory); MAKE_EMPTY_CLASS(Unknown, Unknown); MAKE_EMPTY_CLASS(Untied, Untied); MAKE_EMPTY_CLASS(Weak, Weak); @@ -775,7 +770,18 @@ Doacross make(const parser::OmpClause::Doacross &inp, return makeDoacross(inp.v.v, semaCtx); } -// DynamicAllocators: empty +DynamicAllocators make(const parser::OmpClause::DynamicAllocators &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> td::optional<arser::OmpDynamicAllocatorsClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpDynamicAllocatorsClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return DynamicAllocators{/*Required=*/std::move(maybeRequired)}; +} + DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp, semantics::SemanticsContext &semaCtx) { @@ -1069,6 +1075,15 @@ Map make(const parser::OmpClause::Map &inp, ); CLAUSET_ENUM_CONVERT( // + convertAttachMod, parser::OmpAttachModifier::Value, Map::AttachModifier, + // clang-format off + MS(Always, Always) + MS(Auto, Auto) + MS(Never, Never) + // clang-format on + ); + + CLAUSET_ENUM_CONVERT( // convertRefMod, parser::OmpRefModifier::Value, Map::RefModifier, // clang-format off MS(Ref_Ptee, RefPtee) @@ -1115,6 +1130,13 @@ Map make(const parser::OmpClause::Map &inp, if (!modSet.empty()) maybeTypeMods = Map::MapTypeModifiers(modSet.begin(), modSet.end()); + auto attachMod = [&]() -> std::optional<Map::AttachModifier> { + if (auto *t = + semantics::OmpGetUniqueModifier<parser::OmpAttachModifier>(mods)) + return convertAttachMod(t->v); + return std::nullopt; + }(); + auto refMod = [&]() -> std::optional<Map::RefModifier> { if (auto *t = semantics::OmpGetUniqueModifier<parser::OmpRefModifier>(mods)) return convertRefMod(t->v); @@ -1135,6 +1157,7 @@ Map make(const parser::OmpClause::Map &inp, return Map{{/*MapType=*/std::move(type), /*MapTypeModifiers=*/std::move(maybeTypeMods), + /*AttachModifier=*/std::move(attachMod), /*RefModifier=*/std::move(refMod), /*Mapper=*/std::move(mappers), /*Iterator=*/std::move(iterator), /*LocatorList=*/makeObjects(t2, semaCtx)}}; @@ -1321,7 +1344,18 @@ Reduction make(const parser::OmpClause::Reduction &inp, // Relaxed: empty // Release: empty -// ReverseOffload: empty + +ReverseOffload make(const parser::OmpClause::ReverseOffload &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpReverseOffloadClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpReverseOffloadClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return ReverseOffload{/*Required=*/std::move(maybeRequired)}; +} Safelen make(const parser::OmpClause::Safelen &inp, semantics::SemanticsContext &semaCtx) { @@ -1374,6 +1408,18 @@ Schedule make(const parser::OmpClause::Schedule &inp, // SeqCst: empty +SelfMaps make(const parser::OmpClause::SelfMaps &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpSelfMapsClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpSelfMapsClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return SelfMaps{/*Required=*/std::move(maybeRequired)}; +} + Severity make(const parser::OmpClause::Severity &inp, semantics::SemanticsContext &semaCtx) { // inp -> empty @@ -1463,8 +1509,29 @@ To make(const parser::OmpClause::To &inp, /*LocatorList=*/makeObjects(t3, semaCtx)}}; } -// UnifiedAddress: empty -// UnifiedSharedMemory: empty +UnifiedAddress make(const parser::OmpClause::UnifiedAddress &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpUnifiedAddressClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpUnifiedAddressClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return UnifiedAddress{/*Required=*/std::move(maybeRequired)}; +} + +UnifiedSharedMemory make(const parser::OmpClause::UnifiedSharedMemory &inp, + semantics::SemanticsContext &semaCtx) { + // inp.v -> std::optional<parser::OmpUnifiedSharedMemoryClause> + auto &&maybeRequired = maybeApply( + [&](const parser::OmpUnifiedSharedMemoryClause &c) { + return makeExpr(c.v, semaCtx); + }, + inp.v); + + return UnifiedSharedMemory{/*Required=*/std::move(maybeRequired)}; +} Uniform make(const parser::OmpClause::Uniform &inp, semantics::SemanticsContext &semaCtx) { diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index bd94651..f86ee01 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -3383,7 +3383,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter, } } - switch (llvm::omp::Directive dir = item->id) { + llvm::omp::Directive dir = item->id; + switch (dir) { case llvm::omp::Directive::OMPD_barrier: newOp = genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item); break; @@ -4207,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions( void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, const semantics::Symbol *symbol) { using MlirRequires = mlir::omp::ClauseRequires; - using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag; if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) { - semantics::WithOmpDeclarative::RequiresFlags semaFlags; + semantics::WithOmpDeclarative::RequiresClauses reqs; if (symbol) { common::visit( [&](const auto &details) { if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative, std::decay_t<decltype(details)>>) { if (details.has_ompRequires()) - semaFlags = *details.ompRequires(); + reqs = *details.ompRequires(); } }, symbol->details()); @@ -4227,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, // Use pre-populated omp.requires module attribute if it was set, so that // the "-fopenmp-force-usm" compiler option is honored. MlirRequires mlirFlags = offloadMod.getRequires(); - if (semaFlags.test(SemaRequires::ReverseOffload)) + if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators)) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload)) mlirFlags = mlirFlags | MlirRequires::reverse_offload; - if (semaFlags.test(SemaRequires::UnifiedAddress)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_address)) mlirFlags = mlirFlags | MlirRequires::unified_address; - if (semaFlags.test(SemaRequires::UnifiedSharedMemory)) + if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory)) mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; - if (semaFlags.test(SemaRequires::DynamicAllocators)) - mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; offloadMod.setRequires(mlirFlags); } diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp index a096099..155bc0f 100644 --- a/flang/lib/Optimizer/Builder/Character.cpp +++ b/flang/lib/Optimizer/Builder/Character.cpp @@ -92,7 +92,7 @@ getCompileTimeLength(const fir::CharBoxValue &box) { /// Detect the precondition that the value `str` does not reside in memory. Such /// values will have a type `!fir.array<...x!fir.char<N>>` or `!fir.char<N>`. -LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) { +[[maybe_unused]] static bool needToMaterialize(mlir::Value str) { return mlir::isa<fir::SequenceType>(str.getType()) || fir::isa_char(str.getType()); } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 2c21868..0195178 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -346,6 +346,14 @@ static constexpr IntrinsicHandler handlers[]{ &I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>, {{{"mask", asValue}, {"pred", asValue}}}, /*isElemental=*/false}, + {"barrier_arrive", + &I::genBarrierArrive, + {{{"barrier", asAddr}}}, + /*isElemental=*/false}, + {"barrier_arrive_cnt", + &I::genBarrierArriveCnt, + {{{"barrier", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, {"barrier_init", &I::genBarrierInit, {{{"barrier", asAddr}, {"count", asValue}}}, @@ -494,6 +502,10 @@ static constexpr IntrinsicHandler handlers[]{ &I::genExtendsTypeOf, {{{"a", asBox}, {"mold", asBox}}}, /*isElemental=*/false}, + {"fence_proxy_async", + &I::genFenceProxyAsync, + {}, + /*isElemental=*/false}, {"findloc", &I::genFindloc, {{{"array", asBox}, @@ -1004,6 +1016,25 @@ static constexpr IntrinsicHandler handlers[]{ {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false}, {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false}, {"time", &I::genTime, {}, /*isElemental=*/false}, + {"tma_bulk_commit_group", + &I::genTMABulkCommitGroup, + {{}}, + /*isElemental=*/false}, + {"tma_bulk_g2s", + &I::genTMABulkG2S, + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nbytes", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_s2g", + &I::genTMABulkS2G, + {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_wait_group", + &I::genTMABulkWaitGroup, + {{}}, + /*isElemental=*/false}, {"trailz", &I::genTrailz}, {"transfer", &I::genTransfer, @@ -2138,7 +2169,8 @@ IntrinsicLibrary::genElementalCall<IntrinsicLibrary::ExtendedGenerator>( for (const fir::ExtendedValue &arg : args) { auto *box = arg.getBoxOf<fir::BoxValue>(); if (!arg.getUnboxed() && !arg.getCharBox() && - !(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType()))) + !(box && (fir::isScalarBoxedRecordType(fir::getBase(*box).getType()) || + fir::isClassStarType(fir::getBase(*box).getType())))) fir::emitFatalError(loc, "nonscalar intrinsic argument"); } if (outline) @@ -3180,20 +3212,61 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType, return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox); } -// BARRIER_INIT (CUDA) -void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 2); - auto llvmPtr = fir::ConvertOp::create( +static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value barrier, + mlir::NVVM::NVVMMemorySpace space) { + mlir::Value llvmPtr = fir::ConvertOp::create( builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), - fir::getBase(args[0])); - auto addrCast = mlir::LLVM::AddrSpaceCastOp::create( + barrier); + mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create( builder, loc, - mlir::LLVM::LLVMPointerType::get( - builder.getContext(), - static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)), + mlir::LLVM::LLVMPointerType::get(builder.getContext(), + static_cast<unsigned>(space)), llvmPtr); - mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, addrCast, + return addrCast; +} + +// BARRIER_ARRIVE (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierArrive(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 1); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); + return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType, + barrier) + .getResult(); +} + +// BARRIER_ARRIBVE_CNT (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value token = fir::AllocaOp::create(builder, loc, resultType); + // TODO: the MBarrierArriveExpectTxOp is not taking the state argument and + // currently just the sink symbol `_`. + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier, args[1], + {}); + return fir::LoadOp::create(builder, loc, token); +} + +// BARRIER_INIT (CUDA) +void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); + mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier, fir::getBase(args[1]), {}); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); } // BESSEL_JN @@ -4312,6 +4385,17 @@ IntrinsicLibrary::genExtendsTypeOf(mlir::Type resultType, fir::getBase(args[1]))); } +// FENCE_PROXY_ASYNC (CUDA) +void IntrinsicLibrary::genFenceProxyAsync( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); +} + // FINDLOC fir::ExtendedValue IntrinsicLibrary::genFindloc(mlir::Type resultType, @@ -9127,6 +9211,46 @@ mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType, fir::runtime::genTime(builder, loc)); } +// TMA_BULK_COMMIT_GROUP (CUDA) +void IntrinsicLibrary::genTMABulkCommitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc); +} + +// TMA_BULK_G2S (CUDA) +void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value dst = + convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]), + mlir::NVVM::NVVMMemorySpace::SharedCluster); + mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create( + builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {}); +} + +// TMA_BULK_S2G (CUDA) +void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]), + mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create( + builder, loc, dst, src, fir::getBase(args[2]), {}, {}); +} + +// TMA_BULK_WAIT_GROUP (CUDA) +void IntrinsicLibrary::genTMABulkWaitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto group = builder.getIntegerAttr(builder.getI32Type(), 0); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {}); +} + // TRIM fir::ExtendedValue IntrinsicLibrary::genTrim(mlir::Type resultType, diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 4a05cd9..70bb43a2 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -176,6 +176,19 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> { llvm::LogicalResult matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + + if (auto gpuMod = addr->getParentOfType<mlir::gpu::GPUModuleOp>()) { + auto global = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol()); + replaceWithAddrOfOrASCast( + rewriter, addr->getLoc(), + global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter), + getProgramAddressSpace(rewriter), + global ? global.getSymName() + : addr.getSymbol().getRootReference().getValue(), + convertType(addr.getType()), addr); + return mlir::success(); + } + auto global = addr->getParentOfType<mlir::ModuleOp>() .lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol()); replaceWithAddrOfOrASCast( @@ -3229,6 +3242,11 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { g.setAddrSpace( static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)); + if (global.getDataAttr() && + *global.getDataAttr() == cuf::DataAttribute::Constant) + g.setAddrSpace( + static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant)); + rewriter.eraseOp(global); return mlir::success(); } diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index 4a9579c..48e1622 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -336,6 +336,17 @@ bool isBoxedRecordType(mlir::Type ty) { return false; } +// CLASS(*) +bool isClassStarType(mlir::Type ty) { + if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) { + if (mlir::isa<mlir::NoneType>(clTy.getEleTy())) + return true; + mlir::Type innerType = clTy.unwrapInnerType(); + return innerType && mlir::isa<mlir::NoneType>(innerType); + } + return false; +} + bool isScalarBoxedRecordType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; @@ -398,12 +409,8 @@ bool isPolymorphicType(mlir::Type ty) { bool isUnlimitedPolymorphicType(mlir::Type ty) { // CLASS(*) - if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) { - if (mlir::isa<mlir::NoneType>(clTy.getEleTy())) - return true; - mlir::Type innerType = clTy.unwrapInnerType(); - return innerType && mlir::isa<mlir::NoneType>(innerType); - } + if (isClassStarType(ty)) + return true; // TYPE(*) return isAssumedType(ty); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp index a48b7ba..63a5803 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp @@ -21,24 +21,27 @@ //===----------------------------------------------------------------------===// /// Log RAW or WAW conflict. -static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, - mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB); +[[maybe_unused]] static void logConflict(llvm::raw_ostream &os, + mlir::Value writtenOrReadVarA, + mlir::Value writtenVarB); /// Log when an expression evaluation must be saved. -static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os, - unsigned runid, - mlir::Region &yieldRegion, - bool anyWrite); +[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, + unsigned runid, + mlir::Region &yieldRegion, + bool anyWrite); /// Log when an assignment is scheduled. -static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation( - llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign); +[[maybe_unused]] static void +logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid, + hlfir::RegionAssignOp assign); /// Log when starting to schedule an order assignment tree. -static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( - llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root); +[[maybe_unused]] static void +logStartScheduling(llvm::raw_ostream &os, + hlfir::OrderedAssignmentTreeOpInterface root); /// Log op if effect value is not known. -static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue( - llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op); +[[maybe_unused]] static void +logIfUnkownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op); //===----------------------------------------------------------------------===// // Scheduling Implementation @@ -701,23 +704,24 @@ static llvm::raw_ostream &printRegionPath(llvm::raw_ostream &os, return printRegionId(os, yieldRegion); } -static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os, - unsigned runid, - mlir::Region &yieldRegion, - bool anyWrite) { +[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, + unsigned runid, + mlir::Region &yieldRegion, + bool anyWrite) { os << "run " << runid << " save " << (anyWrite ? "(w)" : " ") << ": "; printRegionPath(os, yieldRegion) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation( - llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign) { +[[maybe_unused]] static void +logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid, + hlfir::RegionAssignOp assign) { os << "run " << runid << " evaluate: "; printNodePath(os, assign.getOperation()) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, - mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB) { +[[maybe_unused]] static void logConflict(llvm::raw_ostream &os, + mlir::Value writtenOrReadVarA, + mlir::Value writtenVarB) { auto printIfValue = [&](mlir::Value var) -> llvm::raw_ostream & { if (!var) return os << "<unknown>"; @@ -728,8 +732,9 @@ static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, printIfValue(writtenVarB) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( - llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root) { +[[maybe_unused]] static void +logStartScheduling(llvm::raw_ostream &os, + hlfir::OrderedAssignmentTreeOpInterface root) { os << "------------ scheduling "; printNodePath(os, root.getOperation()); if (auto funcOp = root->getParentOfType<mlir::func::FuncOp>()) @@ -737,9 +742,10 @@ static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( os << "------------\n"; } -static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue( - llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op) { +[[maybe_unused]] static void +logIfUnkownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op) { if (effect.getValue() != nullptr) return; os << "unknown effected value ("; diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt index fc23e64..790b9fd 100644 --- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Support) +add_subdirectory(Transforms) diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index 89aa010..ed9e41c 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -21,6 +21,7 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/BuiltinOps.h" @@ -352,6 +353,14 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) { // calculation op. mlir::Value baseRef = llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op) + .Case<fir::DeclareOp>([&](auto op) { + // If this declare binds a view with an underlying storage operand, + // treat that storage as the base reference. Otherwise, fall back + // to the declared memref. + if (auto storage = op.getStorage()) + return storage; + return mlir::Value(varPtr); + }) .Case<hlfir::DesignateOp>([&](auto op) { // Get the base object. return op.getMemref(); @@ -548,14 +557,27 @@ template <typename Ty> mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal) const { + mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const { + needsDestroy = false; mlir::Value retVal; mlir::Type unwrappedTy = fir::unwrapRefType(type); mlir::ModuleOp mod = builder.getInsertionBlock() ->getParent() ->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder firBuilder(builder, mod); + if (auto recType = llvm::dyn_cast<fir::RecordType>( + fir::getFortranElementType(unwrappedTy))) { + // Need to make deep copies of allocatable components. + if (fir::isRecordWithAllocatableMember(recType)) + TODO(loc, + "OpenACC: privatizing derived type with allocatable components"); + // Need to decide if user assignment/final routine should be called. + if (fir::isRecordWithFinalRoutine(recType, mod).value_or(false)) + TODO(loc, "OpenACC: privatizing derived type with user assignment or " + "final routine "); + } + + fir::FirOpBuilder firBuilder(builder, mod); auto getDeclareOpForType = [&](mlir::Type ty) -> hlfir::DeclareOp { auto alloca = fir::AllocaOp::create(firBuilder, loc, ty); return hlfir::DeclareOp::create(firBuilder, loc, alloca, varName); @@ -615,9 +637,11 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( mlir::Value firClass = fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar); fir::StoreOp::create(builder, loc, firClass, retVal); + needsDestroy = true; } else if (mlir::isa<fir::SequenceType>(innerTy)) { hlfir::Entity source = hlfir::Entity{var}; - auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source); + auto [temp, cleanupFlag] = + hlfir::createTempFromMold(loc, firBuilder, source); if (fir::isa_ref_type(type)) { // When the temp is created - it is not a reference - thus we can // end up with a type inconsistency. Therefore ensure storage is created @@ -636,6 +660,9 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( } else { retVal = temp; } + // If heap was allocated, a destroy is required later. + if (cleanupFlag) + needsDestroy = true; } else { TODO(loc, "Unsupported boxed type for OpenACC private-like recipe"); } @@ -667,23 +694,302 @@ template mlir::Value OpenACCMappableModel<fir::BaseBoxType>::generatePrivateInit( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal) const; + mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const; template mlir::Value OpenACCMappableModel<fir::ReferenceType>::generatePrivateInit( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal) const; + mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const; template mlir::Value OpenACCMappableModel<fir::HeapType>::generatePrivateInit( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal) const; + mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const; template mlir::Value OpenACCMappableModel<fir::PointerType>::generatePrivateInit( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal) const; + mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const; + +template <typename Ty> +bool OpenACCMappableModel<Ty>::generatePrivateDestroy( + mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value privatized) const { + mlir::Type unwrappedTy = fir::unwrapRefType(type); + // For boxed scalars allocated with AllocMem during init, free the heap. + if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) { + mlir::Value boxVal = privatized; + if (fir::isa_ref_type(boxVal.getType())) + boxVal = fir::LoadOp::create(builder, loc, boxVal); + mlir::Value addr = fir::BoxAddrOp::create(builder, loc, boxVal); + // FreeMem only accepts fir.heap and this may not be represented in the box + // type if the privatized entity is not an allocatable. + mlir::Type heapType = + fir::HeapType::get(fir::unwrapRefType(addr.getType())); + if (heapType != addr.getType()) + addr = fir::ConvertOp::create(builder, loc, heapType, addr); + fir::FreeMemOp::create(builder, loc, addr); + return true; + } + + // Nothing to do for other categories by default, they are stack allocated. + return true; +} + +template bool OpenACCMappableModel<fir::BaseBoxType>::generatePrivateDestroy( + mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value privatized) const; +template bool OpenACCMappableModel<fir::ReferenceType>::generatePrivateDestroy( + mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value privatized) const; +template bool OpenACCMappableModel<fir::HeapType>::generatePrivateDestroy( + mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value privatized) const; +template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy( + mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value privatized) const; + +template <typename Ty> +mlir::Value OpenACCPointerLikeModel<Ty>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types are descriptors that contain both metadata and a pointer to data. + // The `genAllocate` API is designed for simple allocations and cannot + // properly handle the dual nature of boxes. Using `generatePrivateInit` + // instead can allocate both the descriptor and its referenced data. For use + // cases that require an empty descriptor storage, potentially this could be + // implemented here. + if (fir::isa_box_type(pointeeTy)) + return {}; + + // Unlimited polymorphic (class(*)) cannot be handled - size unknown + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return {}; + + // Return null for dynamic size types because the size of the + // allocation cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return {}; + + // Use heap allocation for fir.heap, stack allocation for others (fir.ref, + // fir.ptr, fir.llvm_ptr). For fir.ptr, which is supposed to represent a + // Fortran pointer type, it feels a bit odd to "allocate" since it is meant + // to point to an existing entity - but one can imagine where a pointee is + // privatized - thus it makes sense to issue an allocate. + mlir::Value allocation; + if (std::is_same_v<Ty, fir::HeapType>) { + needsFree = true; + allocation = fir::AllocMemOp::create(builder, loc, pointeeTy); + } else { + needsFree = false; + allocation = fir::AllocaOp::create(builder, loc, pointeeTy); + } + + // Convert to the requested pointer type if needed. + // This means converting from a fir.ref to either a fir.llvm_ptr or a fir.ptr. + // fir.heap is already correct type in this case. + if (allocation.getType() != pointer) { + assert(!(std::is_same_v<Ty, fir::HeapType>) && + "fir.heap is already correct type because of allocmem"); + return fir::ConvertOp::create(builder, loc, pointer, allocation); + } + + return allocation; +} + +template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +static mlir::Value stripCasts(mlir::Value value, bool stripDeclare = true) { + mlir::Value currentValue = value; + + while (currentValue) { + auto *definingOp = currentValue.getDefiningOp(); + if (!definingOp) + break; + + if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) { + currentValue = convertOp.getValue(); + continue; + } + + if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) { + currentValue = viewLike.getViewSource(); + continue; + } + + if (stripDeclare) { + if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + + if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + } + break; + } + + return currentValue; +} + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and data. The `genFree` API + // handles simple deallocations and cannot properly manage both parts. + // Using `generatePrivateDestroy` instead can free both the descriptor and + // its referenced data. + if (fir::isa_box_type(pointeeTy)) + return false; + + // If pointer type is HeapType, assume it's a heap allocation + if (std::is_same_v<Ty, fir::HeapType>) { + fir::FreeMemOp::create(builder, loc, varToFree); + return true; + } + + // Use allocRes if provided to determine the allocation type + mlir::Value valueToInspect = allocRes ? allocRes : varToFree; + + // Strip casts and declare operations to find the original allocation + mlir::Value strippedValue = stripCasts(valueToInspect); + mlir::Operation *originalAlloc = strippedValue.getDefiningOp(); + + // If we found an AllocMemOp (heap allocation), free it + if (mlir::isa_and_nonnull<fir::AllocMemOp>(originalAlloc)) { + mlir::Value toFree = varToFree; + if (!mlir::isa<fir::HeapType>(valueToInspect.getType())) + toFree = fir::ConvertOp::create( + builder, loc, + fir::HeapType::get(varToFree.getType().getElementType()), toFree); + fir::FreeMemOp::create(builder, loc, toFree); + return true; + } + + // If we found an AllocaOp (stack allocation), no deallocation needed + if (mlir::isa_and_nonnull<fir::AllocaOp>(originalAlloc)) + return true; + + // Unable to determine allocation type + return false; +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::PointerType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::HeapType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const { + + // Check that source and destination types match + if (source.getType() != destination.getType()) + return false; + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and referenced data. The genCopy API + // handles simple copies and cannot properly manage both parts. + if (fir::isa_box_type(pointeeTy)) + return false; + + // Unlimited polymorphic (class(*)) cannot be handled because source and + // destination types are not known. + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return false; + + // Return false for dynamic size types because the copy logic + // cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return false; + + if (fir::isa_trivial(pointeeTy)) { + auto loadVal = fir::LoadOp::create(builder, loc, source); + fir::StoreOp::create(builder, loc, loadVal, destination); + } else { + hlfir::AssignOp::create(builder, loc, source, destination); + } + return true; +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::PointerType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::HeapType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp new file mode 100644 index 0000000..4840a99 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -0,0 +1,191 @@ +//===- ACCRecipeBufferization.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Bufferize OpenACC recipes that yield fir.box<T> to operate on +// fir.ref<fir.box<T>> and update uses accordingly. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace fir::acc { +#define GEN_PASS_DEF_ACCRECIPEBUFFERIZATION +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace fir::acc + +namespace { + +class BufferizeInterface { +public: + static std::optional<mlir::Type> mustBufferize(mlir::Type recipeType) { + if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(recipeType)) + return fir::ReferenceType::get(boxTy); + return std::nullopt; + } + + static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + return builder.create<fir::LoadOp>(loc, value); + } + + static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); + builder.create<fir::StoreOp>(loc, value, alloca); + return alloca; + } +}; + +static void bufferizeRegionArgsAndYields(mlir::Region ®ion, + mlir::Location loc, mlir::Type oldType, + mlir::Type newType) { + if (region.empty()) + return; + + mlir::OpBuilder builder(®ion); + for (mlir::BlockArgument arg : region.getArguments()) { + if (arg.getType() == oldType) { + arg.setType(newType); + if (!arg.use_empty()) { + mlir::Operation *loadOp = BufferizeInterface::load(builder, loc, arg); + arg.replaceAllUsesExcept(loadOp->getResult(0), loadOp); + } + } + } + if (auto yield = + llvm::dyn_cast<mlir::acc::YieldOp>(region.back().getTerminator())) { + llvm::SmallVector<mlir::Value> newOperands; + newOperands.reserve(yield.getNumOperands()); + bool changed = false; + for (mlir::Value oldYieldArg : yield.getOperands()) { + if (oldYieldArg.getType() == oldType) { + builder.setInsertionPoint(yield); + mlir::Value alloca = + BufferizeInterface::placeInMemory(builder, loc, oldYieldArg); + newOperands.push_back(alloca); + changed = true; + } else { + newOperands.push_back(oldYieldArg); + } + } + if (changed) + yield->setOperands(newOperands); + } +} + +static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands, + llvm::StringRef recipeSymName, + mlir::Operation *computeOp) { + if (!recipes) + return; + for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) { + if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() != + recipeSymName) + continue; + + mlir::Operation *dataOp = oldRes.getDefiningOp(); + assert(dataOp && "dataOp must be paired with computeOp"); + mlir::Location loc = dataOp->getLoc(); + mlir::OpBuilder builder(dataOp); + llvm::TypeSwitch<mlir::Operation *, void>(dataOp) + .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp, + mlir::acc::ReductionOp>([&](auto privateOp) { + builder.setInsertionPointAfterValue(privateOp.getVar()); + mlir::Value alloca = BufferizeInterface::placeInMemory( + builder, loc, privateOp.getVar()); + privateOp.getVarMutable().assign(alloca); + privateOp.getAccVar().setType(alloca.getType()); + }); + + llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(), + oldRes.getUsers().end()); + for (mlir::Operation *useOp : users) { + if (useOp == computeOp) + continue; + builder.setInsertionPoint(useOp); + mlir::Operation *load = BufferizeInterface::load(builder, loc, oldRes); + useOp->replaceUsesOfWith(oldRes, load->getResult(0)); + } + } +} + +class ACCRecipeBufferization + : public fir::acc::impl::ACCRecipeBufferizationBase< + ACCRecipeBufferization> { +public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + + llvm::SmallVector<llvm::StringRef> recipeNames; + module.walk([&](mlir::Operation *recipe) { + llvm::TypeSwitch<mlir::Operation *, void>(recipe) + .Case<mlir::acc::PrivateRecipeOp, mlir::acc::FirstprivateRecipeOp, + mlir::acc::ReductionRecipeOp>([&](auto recipe) { + mlir::Type oldType = recipe.getType(); + auto bufferizedType = + BufferizeInterface::mustBufferize(recipe.getType()); + if (!bufferizedType) + return; + recipe.setTypeAttr(mlir::TypeAttr::get(*bufferizedType)); + mlir::Location loc = recipe.getLoc(); + using RecipeOp = decltype(recipe); + bufferizeRegionArgsAndYields(recipe.getInitRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::FirstprivateRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCopyRegion(), loc, oldType, + *bufferizedType); + if constexpr (std::is_same_v<RecipeOp, + mlir::acc::ReductionRecipeOp>) + bufferizeRegionArgsAndYields(recipe.getCombinerRegion(), loc, + oldType, *bufferizedType); + bufferizeRegionArgsAndYields(recipe.getDestroyRegion(), loc, + oldType, *bufferizedType); + recipeNames.push_back(recipe.getSymName()); + }); + }); + if (recipeNames.empty()) + return; + + module.walk([&](mlir::Operation *op) { + llvm::TypeSwitch<mlir::Operation *, void>(op) + .Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>( + [&](auto computeOp) { + for (llvm::StringRef recipeName : recipeNames) { + if (computeOp.getPrivatizationRecipes()) + updateRecipeUse(computeOp.getPrivatizationRecipesAttr(), + computeOp.getPrivateOperands(), recipeName, + op); + if (computeOp.getFirstprivatizationRecipes()) + updateRecipeUse( + computeOp.getFirstprivatizationRecipesAttr(), + computeOp.getFirstprivateOperands(), recipeName, op); + if (computeOp.getReductionRecipes()) + updateRecipeUse(computeOp.getReductionRecipesAttr(), + computeOp.getReductionOperands(), + recipeName, op); + } + }); + }); + } +}; + +} // namespace + +std::unique_ptr<mlir::Pass> fir::acc::createACCRecipeBufferizationPass() { + return std::make_unique<ACCRecipeBufferization>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt new file mode 100644 index 0000000..2427da0 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_flang_library(FIROpenACCTransforms + ACCRecipeBufferization.cpp + + DEPENDS + FIROpenACCPassesIncGen + + LINK_LIBS + MLIRIR + MLIRPass + FIRDialect + MLIROpenACCDialect +) diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index b85ee7e..23a7dc8 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp SimdOnly.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000..9278e17 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,1852 @@ +//===- LowerWorkdistribute.cpp +//-------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +// Fortran array statements are lowered to fir as fir.do_loop unordered. +// lower-workdistribute pass works mainly on identifying fir.do_loop unordered +// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and +// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}. +// It hoists all the other ops outside target region. +// Relaces heap allocation on target with omp.target_allocmem and +// deallocation with omp.target_freemem from host. Also replaces +// runtime function "Assign" with omp_target_memcpy. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include <mlir/Dialect/Arith/IR/Arith.h> +#include <mlir/Dialect/LLVMIR/LLVMTypes.h> +#include <mlir/Dialect/Utils/IndexingUtils.h> +#include <mlir/IR/BlockSupport.h> +#include <mlir/IR/BuiltinOps.h> +#include <mlir/IR/Diagnostics.h> +#include <mlir/IR/IRMapping.h> +#include <mlir/IR/PatternMatch.h> +#include <mlir/Interfaces/SideEffectInterfaces.h> +#include <mlir/Support/LLVM.h> +#include <optional> +#include <variant> + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +/// This string is used to identify the Fortran-specific runtime FortranAAssign. +static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign"; + +/// The isRuntimeCall function is a utility designed to determine +/// if a given operation is a call to a Fortran-specific runtime function. +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast<fir::CallOp>(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. +/// Parallelize here refers to dividing into units of work. +static bool shouldParallelize(Operation *op) { + // True if the op is a runtime call to Assign + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + return true; + } + } + // We cannot parallelize ops with side effects. + // Parallelizable operations should not produce + // values that other operations depend on + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast<fir::DoLoopOp>(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + // We cannot parallelize anything else. + return false; +} + +/// The getPerfectlyNested function is a generic utility for finding +/// a single, "perfectly nested" operation within a parent operation. +template <typename T> +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast<T>(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// verifyTargetTeamsWorkdistribute method verifies that +/// omp.target { teams { workdistribute { ... } } } is well formed +/// and fails for function calls that don't have lowering implemented yet. +static LogicalResult +verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + + bool foundWorkdistribute = false; + for (auto &op : teams.getOps()) { + if (isa<omp::WorkdistributeOp>(op)) { + if (foundWorkdistribute) { + emitError(loc, "teams has multiple workdistribute ops.\n"); + return failure(); + } + foundWorkdistribute = true; + continue; + } + // Identify any omp dialect ops present before/after workdistribute. + if (op.getDialect() && isa<omp::OpenMPDialect>(op.getDialect()) && + !isa<omp::TerminatorOp>(op)) { + emitError(loc, "teams has omp ops other than workdistribute. Lowering " + "not implemented yet.\n"); + return failure(); + } + } + + omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp()); + // return if not omp.target + if (!targetOp) + return success(); + + for (auto &op : workdistribute.getOps()) { + if (auto callOp = dyn_cast<fir::CallOp>(op)) { + if (isRuntimeCall(&op)) { + auto funcName = (*callOp.getCallee()).getRootReference().getValue(); + // _FortranAAssign is handled. Other runtime calls are not supported + // in omp.workdistribute yet. + if (funcName == FortranAssignStr) + continue; + else { + emitError(loc, "Runtime call " + funcName + + " lowering not supported for workdistribute yet."); + return failure(); + } + } + } + } + return success(); +} + +/// fissionWorkdistribute method finds the parallelizable ops +/// within teams {workdistribute} region and moves them to their +/// own teams{workdistribute} region. +/// +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() +static FailureOr<bool> +fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector<Operation *> teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return failure(); + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector<Operation *> hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc); + rewriter.create<omp::TerminatorOp>(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create<omp::TerminatorOp>(loc); + changed = true; + } + } + return changed; +} + +/// Generate omp.parallel operation with an empty region. +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + return; +} + +/// Generate omp.distribute operation with an empty region. +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +/// Generate loop nest clause operands from fir.do_loop operation. +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +/// Generate omp.wsloop operation with an empty region and +/// clone the body of fir.do_loop operation inside the loop nest region. +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); + terminatorOp->erase(); + } +} + +/// workdistributeDoLower method finds the fir.do_loop unoredered +/// nested in teams {workdistribute{fir.do_loop unoredered}} and +/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +/// +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } +static bool +workdistributeDoLower(omp::WorkdistributeOp workdistribute, + SetVector<omp::TargetOp> &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + + // Record the target ops to process later + if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) { + auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + // Generate the nested parallel, distribute, wsloop and loop_nest ops. + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +static bool isEnclosedTypeRefToBoxArray(Type type) { + // Check if it's a reference type + if (auto refType = dyn_cast<fir::ReferenceType>(type)) { + // Get the referenced type (should be fir.box) + auto referencedType = refType.getEleTy(); + // Check if referenced type is a box + if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) { + // Get the boxed type and check if it's an array + auto boxedType = boxType.getEleTy(); + // Check if boxed type is a sequence (array) + return isa<fir::SequenceType>(boxedType); + } + } + return false; +} + +/// Check if the enclosed type in fir.box is scalar (not array) +static bool isEnclosedTypeBoxScalar(Type type) { + // Check if it's a box type + if (auto boxType = dyn_cast<fir::BoxType>(type)) { + // Get the boxed type + auto boxedType = boxType.getEleTy(); + // Check if boxed type is NOT a sequence (array) + return !isa<fir::SequenceType>(boxedType); + } + return false; +} + +/// Check if the FortranAAssign call has src as scalar and dest as array +static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { + if (callOp.getNumOperands() < 2) + return false; + auto srcArg = callOp.getOperand(1); + auto destArg = callOp.getOperand(0); + // Both operands should be fir.convert ops + auto srcConvert = srcArg.getDefiningOp<fir::ConvertOp>(); + auto destConvert = destArg.getDefiningOp<fir::ConvertOp>(); + if (!srcConvert || !destConvert) { + emitError(callOp->getLoc(), + "Unimplemented: FortranAssign to OpenMP lowering\n"); + return false; + } + // Get the original types before conversion + auto srcOrigType = srcConvert.getValue().getType(); + auto destOrigType = destConvert.getValue().getType(); + + // Check if src is scalar and dest is array + bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType); + bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType); + return srcIsScalar && destIsArray; +} + +/// Convert a flat index to multi-dimensional indices for an array box +/// Example: 2D array with shape (2,4) +/// Col 1 Col 2 Col 3 Col 4 +/// Row 1: (1,1) (1,2) (1,3) (1,4) +/// Row 2: (2,1) (2,2) (2,3) (2,4) +/// +/// extents: (2,4) +/// +/// flatIdx: 0 1 2 3 4 5 6 7 +/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder, + Location loc, Value flatIdx, + Value arrayBox) { + // Get array type and rank + auto boxType = cast<fir::BoxType>(arrayBox.getType()); + auto seqType = cast<fir::SequenceType>(boxType.getEleTy()); + int rank = seqType.getDimension(); + + // Get all extents + SmallVector<Value> extents; + // Get extents for each dimension + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + extents.push_back(boxDims.getResult(1)); + } + + // Convert flat index to multi-dimensional indices + SmallVector<Value> indices(rank); + Value temp = flatIdx; + auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + + // Work backwards through dimensions (row-major order) + for (int i = rank - 1; i >= 0; --i) { + Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]); + // Convert to one-based index + indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1); + if (i > 0) { + temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]); + } + } + + return indices; +} + +/// Calculate the total number of elements in the array box +/// (totalElems = extent(1) * extent(2) * ... * extent(n)) +static Value CalculateTotalElements(OpBuilder &builder, Location loc, + Value arrayBox) { + auto boxType = cast<fir::BoxType>(arrayBox.getType()); + auto seqType = cast<fir::SequenceType>(boxType.getEleTy()); + int rank = seqType.getDimension(); + + Value totalElems = nullptr; + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + Value extent = boxDims.getResult(1); + if (i == 0) { + totalElems = extent; + } else { + totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent); + } + } + return totalElems; +} + +/// Replace the FortranAAssign runtime call with an unordered do loop +static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, + omp::TeamsOp teamsOp, + omp::WorkdistributeOp workdistribute, + fir::CallOp callOp) { + auto destConvert = callOp.getOperand(0).getDefiningOp<fir::ConvertOp>(); + auto srcConvert = callOp.getOperand(1).getDefiningOp<fir::ConvertOp>(); + + Value destBox = destConvert.getValue(); + Value srcBox = srcConvert.getValue(); + + // get defining alloca op of destBox and srcBox + auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>(); + + if (!destAlloca) { + emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n"); + return; + } + + // get the store op that stores to the alloca + for (auto user : destAlloca->getUsers()) { + if (auto storeOp = dyn_cast<fir::StoreOp>(user)) { + destBox = storeOp.getValue(); + break; + } + } + + builder.setInsertionPoint(teamsOp); + // Load destination array box (if it's a reference) + Value arrayBox = destBox; + if (isa<fir::ReferenceType>(destBox.getType())) + arrayBox = builder.create<fir::LoadOp>(loc, destBox); + + auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox); + Value scalar = builder.create<fir::LoadOp>(loc, scalarValue); + + // Calculate total number of elements (flattened) + auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0); + auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + Value totalElems = CalculateTotalElements(builder, loc, arrayBox); + + auto *workdistributeBlock = &workdistribute.getRegion().front(); + builder.setInsertionPointToStart(workdistributeBlock); + // Create single unordered loop for flattened array + auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true); + Block *loopBlock = &doLoop.getRegion().front(); + builder.setInsertionPointToStart(doLoop.getBody()); + + auto flatIdx = loopBlock->getArgument(0); + SmallVector<Value> indices = + convertFlatToMultiDim(builder, loc, flatIdx, arrayBox); + // Use fir.array_coor for linear addressing + auto elemPtr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, + nullptr, nullptr, ValueRange{indices}, ValueRange{}); + + builder.create<fir::StoreOp>(loc, scalar, elemPtr); +} + +/// workdistributeRuntimeCallLower method finds the runtime calls +/// nested in teams {workdistribute{}} and +/// lowers FortranAAssign to unordered do loop if src is scalar and dest is +/// array. Other runtime calls are not handled currently. +static FailureOr<bool> +workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, + SetVector<omp::TargetOp> &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + bool changed = false; + // Get the target op parent of teams + omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp()); + SmallVector<Operation *> opsToErase; + for (auto &op : workdistribute.getOps()) { + if (isRuntimeCall(&op)) { + rewriter.setInsertionPoint(&op); + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { + // Record the target ops to process later + targetOpsToProcess.insert(targetOp); + replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, + runtimeCall); + opsToErase.push_back(&op); + changed = true; + } + } + } + } + // Erase the runtime calls that have been replaced. + for (auto *op : opsToErase) { + op->erase(); + } + return changed; +} + +/// teamsWorkdistributeToSingleOp method hoists all the ops inside +/// teams {workdistribute{}} before teams op. +/// +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// +/// If only the terminator remains in teams after hoisting, we erase teams op. +static bool +teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, + SetVector<omp::TargetOp> &targetOpsToProcess) { + auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + // Record the target ops to process later + for (auto &op : workdistributeBlock.getOperations()) { + if (shouldParallelize(&op)) { + auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + } + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + emitError(loc, "Target region has no data maps\n"); + return failure(); + } + // Collect all the mapinfo ops + SmallVector<omp::MapInfoOp> mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector<Value> innerMapInfos; + SmallVector<Value> outerMapInfos; + // Create new mapinfo ops for the inner target region + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + // For bycopy, we keep the same map type and capture type + // For byref, we change the map type to none and keep the capture type + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + emitError(targetOp->getLoc(), "Unhandled case"); + return failure(); + } + auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + // Create the target data op + auto targetDataOp = rewriter.create<omp::TargetDataOp>( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create<mlir::omp::TerminatorOp>(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + // Create the inner target op + auto newTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + rewriter.replaceOp(targetOp, targetDataOp); + return newTargetOp; +} + +/// getNestedOpToIsolate function is designed to identify a specific teams +/// parallel op within the body of an omp::TargetOp that should be "isolated." +/// This returns a tuple of op, if its first op in targetBlock, or if the op is +/// last op in the traget block. +static std::optional<std::tuple<Operation *, bool, bool>> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa<omp::TeamsOp>(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +/// Temporary structure to hold the two mapinfo ops +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +/// isPtr checks if the type is a pointer or reference type. +static bool isPtr(Type ty) { + return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty); +} + +/// getPtrTypeForOmp returns an LLVM pointer type for the given type. +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::ReferenceType::get(ty); +} + +/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping +static TempOmpVar allocateTempOmpVar(Location loc, Type ty, + RewriterBase &rewriter) { + MLIRContext &ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + // Get the appropriate type for allocation + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } else { + allocType = ty; + alloc = rewriter.create<fir::AllocaOp>(loc, allocType); + } + // Lambda to create mapinfo ops + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create<omp::MapInfoOp>( + loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), + mappingFlags), + rewriter.getAttr<omp::VariableCaptureKindAttr>( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector<Value>{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); + }; + // Create mapinfo ops. + uint64_t mapFrom = + static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = + static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +} + +// usedOutsideSplit checks if a value is used outside the split operation. +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast<omp::TargetOp>(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +} + +/// isRecomputableAfterFission checks if an operation can be recomputed +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + // If the op has side effects, it cannot be recomputed. + // We consider fir.declare as having no side effects. + return isa<fir::DeclareOp>(op) || isMemoryEffectFree(op); +} + +/// collectNonRecomputableDeps collects dependencies that cannot be recomputed +static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, + SetVector<Operation *> &nonRecomputable, + SetVector<Operation *> &toCache, + SetVector<Operation *> &toRecompute) { + Operation *op = v.getDefiningOp(); + // If v is a block argument, it must be from the targetOp. + if (!op) { + assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp); + return; + } + // If the op is in the nonRecomputable set, add it to toCache and return. + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + // Add the op to toRecompute. + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, + toRecompute); +} + +/// createBlockArgsAndMap creates block arguments and maps them +static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, + omp::TargetOp &targetOp, Block *targetBlock, + Block *newTargetBlock, + SmallVector<Value> &hostEvalVars, + SmallVector<Value> &mapOperands, + SmallVector<Value> &allocs, + IRMapping &irMapping) { + // FIRST: Map `host_eval_vars` to block arguments + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < hostEvalVars.size(); ++i) { + Value originalValue; + BlockArgument newArg; + if (i < originalHostEvalVarsSize) { + originalValue = targetBlock->getArgument(i); // Host_eval args come first + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } else { + originalValue = hostEvalVars[i]; + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // SECOND: Map `map_operands` to block arguments + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < mapOperands.size(); ++i) { + Value originalValue; + BlockArgument newArg; + // Map the new arguments from the original block. + if (i < originalMapVarsSize) { + originalValue = targetBlock->getArgument(originalHostEvalVarsSize + + i); // Offset by host_eval count + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + // Map the new arguments from the `allocs`. + else { + originalValue = allocs[i - originalMapVarsSize]; + newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // THIRD: Map `private_vars` to block arguments (if any) + unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); + for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { + auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + irMapping.map(originalArg, newArg); + } + return; +} + +/// reloadCacheAndRecompute reloads cached values and recomputes operations +static void reloadCacheAndRecompute( + Location loc, RewriterBase &rewriter, Operation *splitBefore, + omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector<Value> &hostEvalVars, SmallVector<Value> &mapOperands, + SmallVector<Value> &allocs, SetVector<Operation *> &toRecompute, + IRMapping &irMapping) { + // Handle the load operations for the allocs. + rewriter.setInsertionPointToStart(newTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + unsigned hostEvalVarsSize = hostEvalVars.size(); + // Create load operations for each allocated variable. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value original = allocs[i]; + // Get the new block argument for this specific allocated value. + Value newArg = + newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); + Value restored; + // If the original value is a pointer or reference, load and convert if + // necessary. + if (isPtr(original.getType())) { + restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg); + if (!isa<LLVM::LLVMPointerType>(original.getType())) + restored = + rewriter.create<fir::ConvertOp>(loc, original.getType(), restored); + } else { + restored = rewriter.create<fir::LoadOp>(loc, newArg); + } + irMapping.map(original, restored); + } + // Clone the operations if they are in the toRecompute set. + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, irMapping); + } +} + +/// Given a teamsOp, navigate down the nested structure to find the +/// innermost LoopNestOp. The expected nesting is: +/// teams -> parallel -> distribute -> wsloop -> loop_nest +static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { + if (teamsOp.getRegion().empty()) + return nullptr; + // Ensure the teams region has a single block. + if (teamsOp.getRegion().getBlocks().size() != 1) + return nullptr; + // Find parallel op inside teams + mlir::omp::ParallelOp parallelOp = nullptr; + // Look for the parallel op in the teams region + for (auto &op : teamsOp.getRegion().front()) { + if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) { + parallelOp = parallel; + break; + } + } + if (!parallelOp) + return nullptr; + + // Find distribute op inside parallel + mlir::omp::DistributeOp distributeOp = nullptr; + for (auto &op : parallelOp.getRegion().front()) { + if (auto distribute = dyn_cast<mlir::omp::DistributeOp>(op)) { + distributeOp = distribute; + break; + } + } + if (!distributeOp) + return nullptr; + + // Find wsloop op inside distribute + mlir::omp::WsloopOp wsloopOp = nullptr; + for (auto &op : distributeOp.getRegion().front()) { + if (auto wsloop = dyn_cast<mlir::omp::WsloopOp>(op)) { + wsloopOp = wsloop; + break; + } + } + if (!wsloopOp) + return nullptr; + + // Find loop_nest op inside wsloop + for (auto &op : wsloopOp.getRegion().front()) { + if (auto loopNest = dyn_cast<mlir::omp::LoopNestOp>(op)) { + return loopNest; + } + } + + return nullptr; +} + +/// Generate LLVM constant operations for i32 and i64 types. +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); +} + +/// Given a box descriptor, extract the base address of the data it describes. +/// If the box descriptor is a reference, load it first. +/// The base address is returned as an i8* pointer. +static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetBaseAddress"); + auto i8Type = builder.getI8Type(); + auto unknownArrayType = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type); + auto i8BoxType = fir::BoxType::get(unknownArrayType); + auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box); + auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox); + return rawAddr; +} + +/// Given a box descriptor, extract the total number of elements in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The total number of elements is returned as an i64 value. +static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetTotalElements"); + auto i64Type = builder.getI64Type(); + return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, extract the size of each element in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The element size is returned as an i64 value. +static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, + Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + auto i64Type = builder.getI64Type(); + return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, compute the total size in bytes of the data it +/// describes. This is done by multiplying the total number of elements by the +/// size of each element. If the box descriptor is a reference, load it first. +/// The total size in bytes is returned as an i64 value. +static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + Value eleSize = genDescriptorGetEleSize(builder, loc, box); + Value totalElements = genDescriptorGetTotalElements(builder, loc, box); + return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); +} + +/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +/// retrieve the device pointer corresponding to a given host pointer and device +/// number. If no mapping exists, the original host pointer is returned. +/// Signature: +/// void *omp_get_mapped_ptr(void *host_ptr, int device_num); +static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value hostPtr, + mlir::Value deviceNum, + mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto i32Type = builder.getI32Type(); + auto funcName = "omp_get_mapped_ptr"; + auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName); + + if (!funcOp) { + auto funcType = + mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType}); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector<mlir::Value> args; + args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr)); + args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum)); + auto callOp = fir::CallOp::create(builder, loc, funcOp, args); + auto mappedPtr = callOp.getResult(0); + auto isNull = builder.genIsNullAddr(loc, mappedPtr); + auto convertedHostPtr = + fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr); + auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr, + mappedPtr); + return result; +} + +/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +/// perform memory copy between host and device or between devices. +/// Signature: +/// int omp_target_memcpy(void *dst, const void *src, size_t length, +/// size_t dst_offset, size_t src_offset, +/// int dst_device, int src_device); +static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value dst, + mlir::Value src, mlir::Value length, + mlir::Value dstOffset, mlir::Value srcOffset, + mlir::Value device, mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto funcName = "omp_target_memcpy"; + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit + auto i32Type = builder.getI32Type(); + auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName); + + if (!funcOp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + llvm::SmallVector<mlir::Type> argTypes = { + voidPtrType, voidPtrType, sizeTType, sizeTType, + sizeTType, i32Type, i32Type}; + auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type}); + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset, + srcOffset, device, device}; + fir::CallOp::create(builder, loc, funcOp, args); + return; +} + +/// Generate code to replace a Fortran array assignment call with OpenMP +/// runtime calls to perform the equivalent operation on the device. +/// This involves extracting the source and destination pointers from the +/// Fortran array descriptors, retrieving their mapped device pointers (if any), +/// and invoking `omp_target_memcpy` to copy the data on the device. +static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::CallOp callOp, + mlir::Value device, + mlir::ModuleOp module) { + assert(callOp.getNumResults() == 0 && + "Expected _FortranAAssign to have no results"); + assert(callOp.getNumOperands() >= 2 && + "Expected _FortranAAssign to have at least two operands"); + + // Extract the source and destination pointers from the call operands. + mlir::Value dest = callOp.getOperand(0); + mlir::Value src = callOp.getOperand(1); + + // Get the base addresses of the source and destination arrays. + mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src); + mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); + + // Get the total size in bytes of the data to be copied. + mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + + // Retrieve the mapped device pointers for source and destination. + // If no mapping exists, the original host pointer is used. + Value destPtr = + genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); + Value srcPtr = + genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); + Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); + + // Generate the call to omp_target_memcpy to perform the data copy on the + // device. + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero, + device, module); +} + +/// Struct to hold the host eval vars corresponding to loop bounds and steps +struct HostEvalVars { + SmallVector<Value> lbs; + SmallVector<Value> ubs; + SmallVector<Value> steps; +}; + +/// moveToHost method clones all the ops from target region outside of it. +/// It hoists runtime function "_FortranAAssign" and replaces it with omp +/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and +/// fir.freemem with omp.target_freemem +static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + + // Get the parent target_data op + auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp()); + if (!targetDataOp) { + emitError(targetOp->getLoc(), + "Expected target op to be inside target_data op"); + return failure(); + } + // create mapping for host_eval_vars + unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { + Value hostEvalVar = targetOp.getHostEvalVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; + mapping.map(arg, hostEvalVar); + } + // create mapping for map_vars + for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { + Value mapInfo = targetOp.getMapVars()[i]; + BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i]; + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast<omp::MapInfoOp>(op); + // map the block argument to the host-side variable pointer + mapping.map(arg, mapInfoOp.getVarPtr()); + } + // create mapping for private_vars + unsigned mapSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { + Value privateVar = targetOp.getPrivateVars()[i]; + // The mapping should link the device-side variable to the host-side one. + BlockArgument arg = + targetBlock->getArguments()[hostEvalVarCount + mapSize + i]; + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + mapping.map(arg, privateVar); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector<Operation *> opsToReplace; + Value device = targetOp.getDevice(); + + // If device is not specified, default to device 0. + if (!device) { + device = genI32Constant(targetOp.getLoc(), rewriter, 0); + } + // Clone all operations. + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + Operation *clonedOp = rewriter.clone(*op, mapping); + // Map the results of the original op to the cloned op. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) { + auto originalDeclareOp = cast<fir::DeclareOp>(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast<fir::ReferenceType>(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast<fir::ReferenceType>(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); + } + } + } + // Collect the ops to be replaced. + if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp)) + opsToReplace.push_back(clonedOp); + // Check for runtime calls to be replaced. + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + opsToReplace.push_back(clonedOp); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } + } + // Replace fir.allocmem with omp.target_allocmem. + for (Operation *op : opsToReplace) { + if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) { + rewriter.setInsertionPoint(allocOp); + auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>( + allocOp.getLoc(), rewriter.getI64Type(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto firConvertOp = rewriter.create<fir::ConvertOp>( + allocOp.getLoc(), allocOp.getResult().getType(), + ompAllocmemOp.getResult()); + rewriter.replaceOp(allocOp, firConvertOp.getResult()); + } + // Replace fir.freemem with omp.target_freemem. + else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) { + rewriter.setInsertionPoint(freeOp); + auto firConvertOp = rewriter.create<fir::ConvertOp>( + freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); + rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device, + firConvertOp.getResult()); + rewriter.eraseOp(freeOp); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) { + Type clonedInType = clonedDeclareOp.getMemref().getType(); + fir::ReferenceType clonedRefType = + dyn_cast<fir::ReferenceType>(clonedInType); + Type clonedEleTy = clonedRefType.getElementType(); + rewriter.setInsertionPoint(op); + Value loadedValue = rewriter.create<fir::LoadOp>( + clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); + } + // Replace runtime calls with omp versions. + else if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + rewriter.setInsertionPoint(op); + fir::FirOpBuilder builder{rewriter, op}; + + mlir::Location loc = runtimeCall.getLoc(); + genFortranAssignOmpReplacement(builder, loc, runtimeCall, device, + module); + rewriter.eraseOp(op); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } else { + emitError(op->getLoc(), "Unhandled op hoisting."); + return failure(); + } + } + + // Update the host_eval_vars to use the mapped values. + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]); + hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]); + hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return success(); +} + +/// Result of isolateOp method +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +/// computeAllocsCacheRecomputable method computes the allocs needed to cache +/// the values that are used outside the split point. It also computes the ops +/// that need to be cached and the ops that can be recomputed after the split. +static void computeAllocsCacheRecomputable( + omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector<Value> &preMapOperands, SmallVector<Value> &postMapOperands, + SmallVector<Value> &allocs, SmallVector<Value> &requiredVals, + SetVector<Operation *> &nonRecomputable, SetVector<Operation *> &toCache, + SetVector<Operation *> &toRecompute) { + auto *targetBlock = &targetOp.getRegion().front(); + // Find all values that are used outside the split point. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + // Check if any of the results are used outside the split point. + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) { + requiredVals.push_back(res); + } + } + // If the op is not recomputable, add it to the nonRecomputable set. + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) { + nonRecomputable.insert(&*it); + } + } + // For each required value, collect its dependencies. + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + // For each op in toCache, create an alloc and update the pre and post map + // operands. + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } +} + +/// genPreTargetOp method generates the preTargetOp that contains all the ops +/// before the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the store operations for the allocs. +static omp::TargetOp +genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, + SmallVector<Value> &allocs, Operation *splitBeforeOp, + RewriterBase &rewriter, struct HostEvalVars &hostEvalVars, + bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of preTargetOp + omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preHostEvalVars, preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalHostEvalVarsSize = preHostEvalVars.size(); + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + // Create the store operation. + if (isPtr(originalResult.getType())) { + if (!isa<LLVM::LLVMPointerType>(toStore.getType())) + toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore); + rewriter.create<LLVM::StoreOp>(loc, toStore, newArg); + } else { + rewriter.create<fir::StoreOp>(loc, toStore, newArg); + } + } + rewriter.create<omp::TerminatorOp>(loc); + + // Update hostEvalVars with the mapped values for the loop bounds if we have + // a loopNestOp and we are not generating code for the target device. + omp::LoopNestOp loopNestOp = + getLoopNestFromTeams(cast<omp::TeamsOp>(splitBeforeOp)); + if (loopNestOp && !isTargetDevice) { + for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) { + Value lb = loopNestOp.getLoopLowerBounds()[i]; + Value ub = loopNestOp.getLoopUpperBounds()[i]; + Value step = loopNestOp.getLoopSteps()[i]; + + hostEvalVars.lbs.push_back(preMapping.lookup(lb)); + hostEvalVars.ubs.push_back(preMapping.lookup(ub)); + hostEvalVars.steps.push_back(preMapping.lookup(step)); + } + } + + return preTargetOp; +} + +/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +/// ops between the split point. It also creates the block arguments and maps +/// the values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp +genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, + Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector<Value> &allocs, + SetVector<Operation *> &toRecompute, + struct HostEvalVars &hostEvalVars, bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of isolatedTargetOp + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + isolatedHostEvalVars.append(hostEvalVars.lbs.begin(), + hostEvalVars.lbs.end()); + isolatedHostEvalVars.append(hostEvalVars.ubs.begin(), + hostEvalVars.ubs.end()); + isolatedHostEvalVars.append(hostEvalVars.steps.begin(), + hostEvalVars.steps.end()); + } + // Create the isolated target op + omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + IRMapping isolatedMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, isolatedMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, toRecompute, + isolatedMapping); + + // Clone the original operations. + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create<omp::TerminatorOp>(loc); + + // update the loop bounds in the isolatedTargetOp if we have host_eval vars + // and we are not generating code for the target device. + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + omp::TeamsOp teamsOp; + for (auto &op : *isolatedTargetBlock) { + if (isa<omp::TeamsOp>(&op)) + teamsOp = cast<omp::TeamsOp>(&op); + } + assert(teamsOp && "No teamsOp found in isolated target region"); + // Get the loopNestOp inside the teamsOp + auto loopNestOp = getLoopNestFromTeams(teamsOp); + // Get the BlockArgs related to host_eval vars and update loop_nest bounds + // to them + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + unsigned index = originalHostEvalVarsSize; + // Replace loop bounds with the block arguments passed down via host_eval + SmallVector<Value> lbs, ubs, steps; + + // Collect new lb/ub/step values from target block args + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) + lbs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) + ubs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) + steps.push_back(isolatedTargetBlock->getArgument(index++)); + + // Reset the loop bounds + loopNestOp.getLoopLowerBoundsMutable().assign(lbs); + loopNestOp.getLoopUpperBoundsMutable().assign(ubs); + loopNestOp.getLoopStepsMutable().assign(steps); + } + + return isolatedTargetOp; +} + +/// genPostTargetOp method generates the postTargetOp that contains all the ops +/// after the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, + Operation *splitBeforeOp, + SmallVector<Value> &postMapOperands, + RewriterBase &rewriter, + SmallVector<Value> &allocs, + SetVector<Operation *> &toRecompute) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()}; + // Create the post target op + omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + // Create the block for postTargetOp + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock, + postHostEvalVars, postMapOperands, allocs, postMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postHostEvalVars, postMapOperands, + allocs, toRecompute, postMapping); + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + // Clone the original operations after the split point. + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + return postTargetOp; +} + +/// isolateOp method rewrites a omp.target_data { omp.target } in to +/// omp.target_data { +/// // preTargetOp region contains ops before splitBeforeOp. +/// omp.target {} +/// // isolatedTargetOp region contains splitBeforeOp, +/// omp.target {} +/// // postTargetOp region contains ops after splitBeforeOp. +/// omp.target {} +/// } +/// It also handles the mapping of variables and the caching/recomputing +/// of values as needed. +static FailureOr<SplitResult> isolateOp(Operation *splitBeforeOp, + bool splitAfter, RewriterBase &rewriter, + mlir::ModuleOp module, + bool isTargetDevice) { + auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp()); + assert(targetOp); + rewriter.setInsertionPoint(targetOp); + + // Prepare the map operands for preTargetOp and postTargetOp + auto preMapOperands = SmallVector<Value>(targetOp.getMapVars()); + auto postMapOperands = SmallVector<Value>(targetOp.getMapVars()); + + // Vectors to hold analysis results + SmallVector<Value> requiredVals; + SetVector<Operation *> toCache; + SetVector<Operation *> toRecompute; + SetVector<Operation *> nonRecomputable; + SmallVector<Value> allocs; + struct HostEvalVars hostEvalVars; + + // Analyze the ops in target region to determine which ops need to be + // cached and which ops need to be recomputed + computeAllocsCacheRecomputable( + targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands, + allocs, requiredVals, nonRecomputable, toCache, toRecompute); + + rewriter.setInsertionPoint(targetOp); + + // Generate the preTargetOp that contains all the ops before splitBeforeOp. + auto preTargetOp = + genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter, + hostEvalVars, isTargetDevice); + + // Move the ops of preTarget to host. + auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars); + if (failed(res)) + return failure(); + rewriter.setInsertionPoint(targetOp); + + // Generate the isolatedTargetOp + omp::TargetOp isolatedTargetOp = + genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter, + allocs, toRecompute, hostEvalVars, isTargetDevice); + + omp::TargetOp postTargetOp = nullptr; + // Generate the postTargetOp that contains all the ops after splitBeforeOp. + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands, + rewriter, allocs, toRecompute); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +/// Recursively fission target ops until no more nested ops can be isolated. +static LogicalResult fissionTarget(omp::TargetOp targetOp, + RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + struct HostEvalVars hostEvalVars; + return moveToHost(targetOp, rewriter, module, hostEvalVars); + } + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + // Recursively isolate the target op. + if (splitBefore && splitAfter) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice); + } + // Isolate only before the op. + if (splitBefore) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + } else { + emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget"); + return failure(); + } + return success(); +} + +/// Pass to lower omp.workdistribute ops. +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + SetVector<omp::TargetOp> targetOpsToProcess; + auto verify = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + if (failed(verifyTargetTeamsWorkdistribute(workdistribute))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (verify.wasInterrupted()) + return signalPassFailure(); + + auto fission = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (fission.wasInterrupted()) + return signalPassFailure(); + + auto rtCallLower = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = workdistributeRuntimeCallLower(workdistribute, + targetOpsToProcess); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (rtCallLower.wasInterrupted()) + return signalPassFailure(); + + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); + }); + + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); + }); + if (changed) { + bool isTargetDevice = + llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp) + .getIsTargetDevice(); + IRRewriter rewriter(&context); + for (auto targetOp : targetOpsToProcess) { + auto res = splitTargetData(targetOp, rewriter); + if (failed(res)) + return signalPassFailure(); + if (*res) { + if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) + return signalPassFailure(); + } + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 260e525..2bbd803 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -40,6 +40,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -128,6 +129,17 @@ class MapInfoFinalizationPass } } + /// Return true if the module has an OpenMP requires clause that includes + /// unified_shared_memory. + static bool moduleRequiresUSM(mlir::ModuleOp module) { + assert(module && "invalid module"); + if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>( + "omp.requires")) + return mlir::omp::bitEnumContainsAll( + req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory); + return false; + } + /// Create the member map for coordRef and append it (and its index /// path) to the provided new* vectors, if it is not already present. void appendMemberMapIfNew( @@ -425,8 +437,12 @@ class MapInfoFinalizationPass mapFlags flags = mapFlags::OMP_MAP_TO | (mapFlags(mapTypeFlag) & - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE | - mapFlags::OMP_MAP_ALWAYS)); + (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); + // For unified_shared_memory, we additionally add `CLOSE` on the descriptor + // to ensure device-local placement where required by tests relying on USM + + // close semantics. + if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) + flags |= mapFlags::OMP_MAP_CLOSE; return llvm::to_underlying(flags); } @@ -518,6 +534,75 @@ class MapInfoFinalizationPass return newMapInfoOp; } + // Expand mappings of type(C_PTR) to map their `__address` field explicitly + // as a single pointer-sized member (USM-gated at callsite). This helps in + // USM scenarios to ensure the pointer-sized mapping is used. + mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op, + fir::FirOpBuilder &builder) { + if (!op.getMembers().empty()) + return op; + + mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType()); + if (!mlir::isa<fir::RecordType>(varTy)) + return op; + auto recTy = mlir::cast<fir::RecordType>(varTy); + // If not a builtin C_PTR record, skip. + if (!recTy.getName().ends_with("__builtin_c_ptr")) + return op; + + // Find the index of the c_ptr address component named "__address". + int32_t fieldIdx = recTy.getFieldIndex("__address"); + if (fieldIdx < 0) + return op; + + mlir::Location loc = op.getVarPtr().getLoc(); + mlir::Type memTy = recTy.getType(fieldIdx); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); + mlir::Value coord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(memTy), op.getVarPtr(), + llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); + + // Child for the `__address` member. + llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; + mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); + // Force CLOSE in USM paths so the pointer gets device-local placement + // when required by tests relying on USM + close semantics. + uint64_t mapTypeVal = + op.getMapType() | + llvm::to_underlying( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); + mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( + builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); + + mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( + builder, loc, coord.getType(), coord, + mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr, + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( + mlir::omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/mlir::Value{}, + /*members=*/llvm::SmallVector<mlir::Value>{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/op.getBounds(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); + + // Rebuild the parent as a container with the `__address` member. + mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create( + builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(), + op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(), + /*varPtrPtr=*/mlir::Value{}, + /*members=*/llvm::SmallVector<mlir::Value>{memberMap}, + /*member_index=*/newMembersAttr, + /*bounds=*/llvm::SmallVector<mlir::Value>{}, + /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); + op.replaceAllUsesWith(newParent.getResult()); + op->erase(); + return newParent; + } + mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Operation *target) { @@ -1169,6 +1254,17 @@ class MapInfoFinalizationPass genBoxcharMemberMap(op, builder); }); + // Expand type(C_PTR) only when unified_shared_memory is required, + // to ensure device-visible pointer size/behavior in USM scenarios + // without changing default expectations elsewhere. + func->walk([&](mlir::omp::MapInfoOp op) { + // Only expand C_PTR members when unified_shared_memory is required. + if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>())) + return; + builder.setInsertionPoint(op); + genCptrMemberMap(op, builder); + }); + func->walk([&](mlir::omp::MapInfoOp op) { // TODO: Currently only supports a single user for the MapInfoOp. This // is fine for the moment, as the Fortran frontend will generate a diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index a83b066..1ecb6d3 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -301,8 +301,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, addNestedPassToAllTopLevelOperations<PassConstructor>( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP != EnableOpenMP::None) + if (enableOpenMP != EnableOpenMP::None) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } if (enableOpenMP == EnableOpenMP::Simd) pm.addPass(flangomp::createSimdOnlyPass()); } diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index c71642c..92390e4a 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -51,6 +51,16 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault( return std::nullopt; } +std::optional<bool> +fir::isRecordWithFinalRoutine(fir::RecordType recordType, mlir::ModuleOp module, + const mlir::SymbolTable *symbolTable) { + fir::TypeInfoOp typeInfo = + fir::lookupTypeInfoOp(recordType, module, symbolTable); + if (!typeInfo) + return std::nullopt; + return !typeInfo.getNoFinal(); +} + mlir::LLVM::ConstantOp fir::genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index 061a7d2..bdc3418 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -474,7 +474,7 @@ public: mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; loop.dump();); - LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = + [[maybe_unused]] auto loopAnalysis = functionAnalysis.getChildLoopAnalysis(loop); if (!loopAnalysis.canPromoteToAffine()) return rewriter.notifyMatchFailure(loop, "cannot promote to affine"); diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 609a1fc..759e3a65d 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -558,6 +558,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, if (srcTy.isInteger(1)) { // i1 is not a supported type in the descriptor and it is actually coming // from a LOGICAL constant. Use the destination type to avoid mismatch. + assert(dstEleTy && "expect dst element type to be set"); srcTy = dstEleTy; src = createConvertOp(rewriter, loc, srcTy, src); addr = builder.createTemporary(loc, srcTy); @@ -652,7 +653,8 @@ struct CUFDataTransferOpConversion // Initialization of an array from a scalar value should be implemented // via a kernel launch. Use the flang runtime via the Assign function // until we have more infrastructure. - mlir::Value src = emboxSrc(rewriter, op, symtab); + mlir::Type dstEleTy = fir::unwrapInnerType(fir::unwrapRefType(dstTy)); + mlir::Value src = emboxSrc(rewriter, op, symtab, dstEleTy); mlir::Value dst = emboxDst(rewriter, op, symtab); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>( @@ -739,6 +741,9 @@ struct CUFDataTransferOpConversion fir::StoreOp::create(builder, loc, val, box); return box; } + if (mlir::isa<fir::BaseBoxType>(val.getType())) + if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp())) + return loadOp.getMemref(); return val; }; diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 80b3f68..8601499 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -561,7 +561,7 @@ static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, return stack; fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); - LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = + [[maybe_unused]] fir::ReferenceType firRefTy = mlir::cast<fir::ReferenceType>(stackTy); assert(firHeapTy.getElementType() == firRefTy.getElementType() && "Allocations must have the same type"); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 9507021..d677e14 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -548,6 +548,14 @@ TYPE_PARSER(construct<OmpAllocatorSimpleModifier>(scalarIntExpr)) TYPE_PARSER(construct<OmpAlwaysModifier>( // "ALWAYS" >> pure(OmpAlwaysModifier::Value::Always))) +TYPE_PARSER(construct<OmpAttachModifier::Value>( + "ALWAYS" >> pure(OmpAttachModifier::Value::Always) || + "AUTO" >> pure(OmpAttachModifier::Value::Auto) || + "NEVER" >> pure(OmpAttachModifier::Value::Never))) + +TYPE_PARSER(construct<OmpAttachModifier>( // + "ATTACH" >> parenthesized(Parser<OmpAttachModifier::Value>{}))) + TYPE_PARSER(construct<OmpAutomapModifier>( "AUTOMAP" >> pure(OmpAutomapModifier::Value::Automap))) @@ -744,6 +752,7 @@ TYPE_PARSER(sourced( TYPE_PARSER(sourced(construct<OmpMapClause::Modifier>( sourced(construct<OmpMapClause::Modifier>(Parser<OmpAlwaysModifier>{}) || + construct<OmpMapClause::Modifier>(Parser<OmpAttachModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpCloseModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpDeleteModifier>{}) || construct<OmpMapClause::Modifier>(Parser<OmpPresentModifier>{}) || @@ -1085,7 +1094,7 @@ TYPE_PARSER(construct<OmpBindClause>( "TEAMS" >> pure(OmpBindClause::Binding::Teams) || "THREAD" >> pure(OmpBindClause::Binding::Thread))) -TYPE_PARSER(construct<OmpAlignClause>(scalarIntExpr)) +TYPE_PARSER(construct<OmpAlignClause>(scalarIntConstantExpr)) TYPE_PARSER(construct<OmpAtClause>( "EXECUTION" >> pure(OmpAtClause::ActionTime::Execution) || @@ -1158,7 +1167,8 @@ TYPE_PARSER( // "DOACROSS" >> construct<OmpClause>(parenthesized(Parser<OmpDoacrossClause>{})) || "DYNAMIC_ALLOCATORS" >> - construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) || + construct<OmpClause>(construct<OmpClause::DynamicAllocators>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "DYN_GROUPPRIVATE" >> construct<OmpClause>(construct<OmpClause::DynGroupprivate>( parenthesized(Parser<OmpDynGroupprivateClause>{}))) || @@ -1270,12 +1280,15 @@ TYPE_PARSER( // "REPLAYABLE" >> construct<OmpClause>(construct<OmpClause::Replayable>( maybe(parenthesized(Parser<OmpReplayableClause>{})))) || "REVERSE_OFFLOAD" >> - construct<OmpClause>(construct<OmpClause::ReverseOffload>()) || + construct<OmpClause>(construct<OmpClause::ReverseOffload>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "SAFELEN" >> construct<OmpClause>(construct<OmpClause::Safelen>( parenthesized(scalarIntConstantExpr))) || "SCHEDULE" >> construct<OmpClause>(construct<OmpClause::Schedule>( parenthesized(Parser<OmpScheduleClause>{}))) || "SEQ_CST" >> construct<OmpClause>(construct<OmpClause::SeqCst>()) || + "SELF_MAPS" >> construct<OmpClause>(construct<OmpClause::SelfMaps>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "SEVERITY" >> construct<OmpClause>(construct<OmpClause::Severity>( parenthesized(Parser<OmpSeverityClause>{}))) || "SHARED" >> construct<OmpClause>(construct<OmpClause::Shared>( @@ -1303,9 +1316,11 @@ TYPE_PARSER( // construct<OmpClause>(construct<OmpClause::UseDeviceAddr>( parenthesized(Parser<OmpObjectList>{}))) || "UNIFIED_ADDRESS" >> - construct<OmpClause>(construct<OmpClause::UnifiedAddress>()) || + construct<OmpClause>(construct<OmpClause::UnifiedAddress>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "UNIFIED_SHARED_MEMORY" >> - construct<OmpClause>(construct<OmpClause::UnifiedSharedMemory>()) || + construct<OmpClause>(construct<OmpClause::UnifiedSharedMemory>( + maybe(parenthesized(scalarLogicalConstantExpr)))) || "UNIFORM" >> construct<OmpClause>(construct<OmpClause::Uniform>( parenthesized(nonemptyList(name)))) || "UNTIED" >> construct<OmpClause>(construct<OmpClause::Untied>()) || diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp index cb30939..8cbaa39 100644 --- a/flang/lib/Parser/parse-tree.cpp +++ b/flang/lib/Parser/parse-tree.cpp @@ -185,7 +185,7 @@ StructureConstructor ArrayElement::ConvertToStructureConstructor( std::list<ComponentSpec> components; for (auto &subscript : subscripts) { components.emplace_back(std::optional<Keyword>{}, - ComponentDataSource{std::move(*Unwrap<Expr>(subscript))}); + ComponentDataSource{std::move(UnwrapRef<Expr>(subscript))}); } DerivedTypeSpec spec{std::move(name), std::list<TypeParamSpec>{}}; spec.derivedTypeSpec = &derived; diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 0511f5b..b172e429 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2384,6 +2384,11 @@ public: Walk(x.v); Put(")"); } + void Unparse(const OmpAttachModifier &x) { + Word("ATTACH("); + Walk(x.v); + Put(")"); + } void Unparse(const OmpOrderClause &x) { using Modifier = OmpOrderClause::Modifier; Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ":"); @@ -2820,6 +2825,7 @@ public: WALK_NESTED_ENUM(OmpMapType, Value) // OMP map-type WALK_NESTED_ENUM(OmpMapTypeModifier, Value) // OMP map-type-modifier WALK_NESTED_ENUM(OmpAlwaysModifier, Value) + WALK_NESTED_ENUM(OmpAttachModifier, Value) WALK_NESTED_ENUM(OmpCloseModifier, Value) WALK_NESTED_ENUM(OmpDeleteModifier, Value) WALK_NESTED_ENUM(OmpPresentModifier, Value) diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp index f4aa496..1824a7d 100644 --- a/flang/lib/Semantics/assignment.cpp +++ b/flang/lib/Semantics/assignment.cpp @@ -194,7 +194,8 @@ void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) { template <typename A> void AssignmentContext::PushWhereContext(const A &x) { const auto &expr{std::get<parser::LogicalExpr>(x.t)}; - CheckShape(expr.thing.value().source, GetExpr(context_, expr)); + CheckShape( + parser::UnwrapRef<parser::Expr>(expr).source, GetExpr(context_, expr)); ++whereDepth_; } diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp index 823aa4e..e019bbd 100644 --- a/flang/lib/Semantics/check-allocate.cpp +++ b/flang/lib/Semantics/check-allocate.cpp @@ -151,7 +151,9 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions( [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (info.gotMsg) { // C943 context.Say( "ERRMSG may not be duplicated in a ALLOCATE statement"_err_en_US); @@ -439,7 +441,7 @@ static bool HaveCompatibleLengths( evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{ evaluate::ToInt64(type2.characterTypeSpec().length().GetExplicit())}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -452,7 +454,7 @@ static bool HaveCompatibleLengths( auto v1{ evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{type2.knownLength()}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -598,7 +600,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { std::optional<evaluate::ConstantSubscript> lbound; if (const auto &lb{std::get<0>(shapeSpec.t)}) { lbound.reset(); - const auto &lbExpr{lb->thing.thing.value()}; + const auto &lbExpr{parser::UnwrapRef<parser::Expr>(lb)}; if (const auto *expr{GetExpr(context, lbExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; @@ -609,7 +611,8 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { lbound = 1; } if (lbound) { - const auto &ubExpr{std::get<1>(shapeSpec.t).thing.thing.value()}; + const auto &ubExpr{ + parser::UnwrapRef<parser::Expr>(std::get<1>(shapeSpec.t))}; if (const auto *expr{GetExpr(context, ubExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp index 5ce143c..7593154 100644 --- a/flang/lib/Semantics/check-case.cpp +++ b/flang/lib/Semantics/check-case.cpp @@ -72,7 +72,7 @@ private: } std::optional<Value> GetValue(const parser::CaseValue &caseValue) { - const parser::Expr &expr{caseValue.thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)}; auto *x{expr.typedExpr.get()}; if (x && x->v) { // C1147 auto type{x->v->GetType()}; diff --git a/flang/lib/Semantics/check-coarray.cpp b/flang/lib/Semantics/check-coarray.cpp index 0e444f1..9113369 100644 --- a/flang/lib/Semantics/check-coarray.cpp +++ b/flang/lib/Semantics/check-coarray.cpp @@ -112,7 +112,7 @@ static void CheckTeamType( static void CheckTeamStat( SemanticsContext &context, const parser::ImageSelectorSpec::Stat &stat) { - const parser::Variable &var{stat.v.thing.thing.value()}; + const auto &var{parser::UnwrapRef<parser::Variable>(stat)}; if (parser::GetCoindexedNamedObject(var)) { context.Say(parser::FindSourceLocation(var), // C931 "Image selector STAT variable must not be a coindexed " @@ -147,7 +147,8 @@ static void CheckSyncStat(SemanticsContext &context, }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var).GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1172 "The errmsg-variable in a sync-stat-list may not be repeated"_err_en_US); @@ -260,7 +261,9 @@ static void CheckEventWaitSpecList(SemanticsContext &context, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1178 "A errmsg-variable in a event-wait-spec-list may not be repeated"_err_en_US); diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp index 5459290..3bcf711 100644 --- a/flang/lib/Semantics/check-data.cpp +++ b/flang/lib/Semantics/check-data.cpp @@ -25,9 +25,10 @@ namespace Fortran::semantics { // Ensures that references to an implied DO loop control variable are // represented as such in the "body" of the implied DO loop. void DataChecker::Enter(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } @@ -36,7 +37,8 @@ void DataChecker::Enter(const parser::DataImpliedDo &x) { } void DataChecker::Leave(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; exprAnalyzer_.RemoveImpliedDo(name.source); } @@ -211,7 +213,7 @@ void DataChecker::Leave(const parser::DataIDoObject &object) { std::get_if<parser::Scalar<common::Indirection<parser::Designator>>>( &object.u)}) { if (MaybeExpr expr{exprAnalyzer_.Analyze(*designator)}) { - auto source{designator->thing.value().source}; + auto source{parser::UnwrapRef<parser::Designator>(*designator).source}; DataVarChecker checker{exprAnalyzer_.context(), source}; if (checker(*expr)) { if (checker.HasComponentWithoutSubscripts()) { // C880 diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp index c45b585..c1ebc5f 100644 --- a/flang/lib/Semantics/check-deallocate.cpp +++ b/flang/lib/Semantics/check-deallocate.cpp @@ -114,7 +114,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) { }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context_, - GetExpr(context_, var), var.v.thing.thing.GetSource(), + GetExpr(context_, var), + parser::UnwrapRef<parser::Variable>(var).GetSource(), "ERRMSG="); if (gotMsg) { context_.Say( diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index ea5e2c0..31e246c 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -3622,6 +3622,7 @@ void CheckHelper::CheckDioDtvArg(const Symbol &proc, const Symbol &subp, ioKind == common::DefinedIo::ReadUnformatted ? Attr::INTENT_INOUT : Attr::INTENT_IN); + CheckDioDummyIsScalar(subp, *arg); } } @@ -3687,6 +3688,7 @@ void CheckHelper::CheckDioAssumedLenCharacterArg(const Symbol &subp, "Dummy argument '%s' of a defined input/output procedure must be assumed-length CHARACTER of default kind"_err_en_US, arg->name()); } + CheckDioDummyIsScalar(subp, *arg); } } diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp index a2f3685..8a47340 100644 --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -535,7 +535,8 @@ private: if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) { if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) { // No warnings or errors for type INTEGER - const parser::CharBlock &loc{scalarExpression.thing.value().source}; + parser::CharBlock loc{ + parser::UnwrapRef<parser::Expr>(scalarExpression).source}; CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real)); } } @@ -552,7 +553,7 @@ private: CheckDoExpression(*bounds.step); if (IsZero(*bounds.step)) { context_.Warn(common::UsageWarning::ZeroDoStep, - bounds.step->thing.value().source, + parser::UnwrapRef<parser::Expr>(bounds.step).source, "DO step expression should not be zero"_warn_en_US); } } @@ -615,7 +616,7 @@ private: // C1121 - procedures in mask must be pure void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const { UnorderedSymbolSet references{ - GatherSymbolsFromExpression(mask.thing.thing.value())}; + GatherSymbolsFromExpression(parser::UnwrapRef<parser::Expr>(mask))}; for (const Symbol &ref : OrderBySourcePosition(references)) { if (IsProcedure(ref) && !IsPureProcedure(ref)) { context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source, @@ -639,32 +640,33 @@ private: } void HasNoReferences(const UnorderedSymbolSet &indexNames, - const parser::ScalarIntExpr &expr) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - indexNames, + const parser::ScalarIntExpr &scalarIntExpr) const { + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s limit expression may not reference index variable '%s'"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in mask expressions void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(mask)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s mask expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - mask.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in limit or step // expressions - void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr, + void CheckExprDoesNotReferenceLocal( + const parser::ScalarIntExpr &scalarIntExpr, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to @@ -772,7 +774,7 @@ private: HasNoReferences(indexNames, std::get<2>(control.t)); if (const auto &intExpr{ std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) { - const parser::Expr &expr{intExpr->thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(intExpr)}; CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s step expression may not reference index variable '%s'"_err_en_US, expr.source); @@ -840,7 +842,7 @@ private: } void CheckForImpureCall(const parser::ScalarIntExpr &x, std::optional<IndexVarKind> nesting) const { - const auto &parsedExpr{x.thing.thing.value()}; + const auto &parsedExpr{parser::UnwrapRef<parser::Expr>(x)}; auto oldLocation{context_.location()}; context_.set_location(parsedExpr.source); if (const auto &typedExpr{parsedExpr.typedExpr}) { @@ -1124,7 +1126,8 @@ void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) { const auto *newunit{ std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)}; if (newunit) { - context_.CheckIndexVarRedefine(newunit->v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(newunit)); } } @@ -1166,14 +1169,14 @@ void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) { const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)}; if (intVar) { const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)}; - context_.CheckIndexVarRedefine(scalar.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(scalar)); } } void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) { const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)}; if (size) { - context_.CheckIndexVarRedefine(size->v.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(size)); } } @@ -1190,16 +1193,19 @@ static void CheckIoImpliedDoIndex( void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::InputImpliedDo &inputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::StatVariable &statVariable) { - context_.CheckIndexVarRedefine(statVariable.v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(statVariable)); } } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/check-io.cpp b/flang/lib/Semantics/check-io.cpp index a1ff4b9..19059ad 100644 --- a/flang/lib/Semantics/check-io.cpp +++ b/flang/lib/Semantics/check-io.cpp @@ -424,8 +424,8 @@ void IoChecker::Enter(const parser::InquireSpec::CharVar &spec) { specKind = IoSpecKind::Dispose; break; } - const parser::Variable &var{ - std::get<parser::ScalarDefaultCharVariable>(spec.t).thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>( + std::get<parser::ScalarDefaultCharVariable>(spec.t))}; std::string what{parser::ToUpperCaseLetters(common::EnumToString(specKind))}; CheckForDefinableVariable(var, what); WarnOnDeferredLengthCharacterScalar( @@ -627,7 +627,7 @@ void IoChecker::Enter(const parser::IoUnit &spec) { } void IoChecker::Enter(const parser::MsgVariable &msgVar) { - const parser::Variable &var{msgVar.v.thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>(msgVar)}; if (stmt_ == IoStmtKind::None) { // allocate, deallocate, image control CheckForDefinableVariable(var, "ERRMSG"); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index 351af5c..515121a 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -519,8 +519,8 @@ private: /// function references with scalar data pointer result of non-character /// intrinsic type or variables that are non-polymorphic scalar pointers /// and any length type parameter must be constant. -void OmpStructureChecker::CheckAtomicType( - SymbolRef sym, parser::CharBlock source, std::string_view name) { +void OmpStructureChecker::CheckAtomicType(SymbolRef sym, + parser::CharBlock source, std::string_view name, bool checkTypeOnPointer) { const DeclTypeSpec *typeSpec{sym->GetType()}; if (!typeSpec) { return; @@ -547,6 +547,22 @@ void OmpStructureChecker::CheckAtomicType( return; } + // Apply pointer-to-non-intrinsic rule only for intrinsic-assignment paths. + if (checkTypeOnPointer) { + using Category = DeclTypeSpec::Category; + Category cat{typeSpec->category()}; + if (cat != Category::Numeric && cat != Category::Logical) { + std::string details = " has the POINTER attribute"; + if (const auto *derived{typeSpec->AsDerived()}) { + details += " and derived type '"s + derived->name().ToString() + "'"; + } + context_.Say(source, + "ATOMIC operation requires an intrinsic scalar variable; '%s'%s"_err_en_US, + sym->name(), details); + return; + } + } + // Go over all length parameters, if any, and check if they are // explicit. if (const DerivedTypeSpec *derived{typeSpec->AsDerived()}) { @@ -562,7 +578,7 @@ void OmpStructureChecker::CheckAtomicType( } void OmpStructureChecker::CheckAtomicVariable( - const SomeExpr &atom, parser::CharBlock source) { + const SomeExpr &atom, parser::CharBlock source, bool checkTypeOnPointer) { if (atom.Rank() != 0) { context_.Say(source, "Atomic variable %s should be a scalar"_err_en_US, atom.AsFortran()); @@ -572,7 +588,7 @@ void OmpStructureChecker::CheckAtomicVariable( assert(dsgs.size() == 1 && "Should have a single top-level designator"); evaluate::SymbolVector syms{evaluate::GetSymbolVector(dsgs.front())}; - CheckAtomicType(syms.back(), source, atom.AsFortran()); + CheckAtomicType(syms.back(), source, atom.AsFortran(), checkTypeOnPointer); if (IsAllocatable(syms.back()) && !IsArrayElement(atom)) { context_.Say(source, "Atomic variable %s cannot be ALLOCATABLE"_err_en_US, @@ -789,7 +805,8 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, rsrc); + CheckAtomicVariable( + atom, rsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(capture)); // This part should have been checked prior to calling this function. assert(*GetConvertInput(capture.rhs) == atom && "This cannot be a capture assignment"); @@ -808,7 +825,8 @@ void OmpStructureChecker::CheckAtomicReadAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, rsrc); + CheckAtomicVariable( + atom, rsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(read)); CheckStorageOverlap(atom, {read.lhs}, source); } } else { @@ -829,7 +847,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); } else { - CheckAtomicVariable(atom, lsrc); + CheckAtomicVariable( + atom, lsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(write)); CheckStorageOverlap(atom, {write.rhs}, source); } } @@ -854,7 +873,8 @@ OmpStructureChecker::CheckAtomicUpdateAssignment( return std::nullopt; } - CheckAtomicVariable(atom, lsrc); + CheckAtomicVariable( + atom, lsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(update)); auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs( atom, update.rhs, source, /*suppressDiagnostics=*/true)}; @@ -1017,7 +1037,8 @@ void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment( return; } - CheckAtomicVariable(atom, alsrc); + CheckAtomicVariable( + atom, alsrc, /*checkTypeOnPointer=*/!IsPointerAssignment(assign)); auto top{GetTopLevelOperationIgnoreResizing(cond)}; // Missing arguments to operations would have been diagnosed by now. diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index d65a89e..ea6fe43 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1517,19 +1517,42 @@ void OmpStructureChecker::Leave(const parser::OpenMPDepobjConstruct &x) { void OmpStructureChecker::Enter(const parser::OpenMPRequiresConstruct &x) { const auto &dirName{x.v.DirName()}; PushContextAndClauseSets(dirName.source, dirName.v); + unsigned version{context_.langOptions().OpenMPVersion}; - if (visitedAtomicSource_.empty()) { - return; - } for (const parser::OmpClause &clause : x.v.Clauses().v) { llvm::omp::Clause id{clause.Id()}; if (id == llvm::omp::Clause::OMPC_atomic_default_mem_order) { - parser::MessageFormattedText txt( - "REQUIRES directive with '%s' clause found lexically after atomic operation without a memory order clause"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(id))); - parser::Message message(clause.source, txt); - message.Attach(visitedAtomicSource_, "Previous atomic construct"_en_US); - context_.Say(std::move(message)); + if (!visitedAtomicSource_.empty()) { + parser::MessageFormattedText txt( + "REQUIRES directive with '%s' clause found lexically after atomic operation without a memory order clause"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(id))); + parser::Message message(clause.source, txt); + message.Attach(visitedAtomicSource_, "Previous atomic construct"_en_US); + context_.Say(std::move(message)); + } + } else { + bool hasArgument{common::visit( + [&](auto &&s) { + using TypeS = llvm::remove_cvref_t<decltype(s)>; + if constexpr ( // + std::is_same_v<TypeS, parser::OmpClause::DynamicAllocators> || + std::is_same_v<TypeS, parser::OmpClause::ReverseOffload> || + std::is_same_v<TypeS, parser::OmpClause::SelfMaps> || + std::is_same_v<TypeS, parser::OmpClause::UnifiedAddress> || + std::is_same_v<TypeS, parser::OmpClause::UnifiedSharedMemory>) { + return s.v.has_value(); + } else { + return false; + } + }, + clause.u)}; + if (version < 60 && hasArgument) { + context_.Say(clause.source, + "An argument to %s is an %s feature, %s"_warn_en_US, + parser::ToUpperCaseLetters( + llvm::omp::getOpenMPClauseName(clause.Id())), + ThisVersion(60), TryVersion(60)); + } } } } @@ -1540,9 +1563,8 @@ void OmpStructureChecker::Leave(const parser::OpenMPRequiresConstruct &) { void OmpStructureChecker::CheckAlignValue(const parser::OmpClause &clause) { if (auto *align{std::get_if<parser::OmpClause::Align>(&clause.u)}) { - if (const auto &v{GetIntValue(align->v)}; !v || *v <= 0) { - context_.Say(clause.source, - "The alignment value should be a constant positive integer"_err_en_US); + if (const auto &v{GetIntValue(align->v)}; v && *v <= 0) { + context_.Say(clause.source, "The alignment should be positive"_err_en_US); } } } @@ -2336,7 +2358,7 @@ private: } if (auto &repl{std::get<parser::OmpClause::Replayable>(clause.u).v}) { // Scalar<Logical<Constant<indirection<Expr>>>> - const parser::Expr &parserExpr{repl->v.thing.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(repl)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { return GetLogicalValue(*expr).value_or(true); } @@ -2350,7 +2372,7 @@ private: bool isTransparent{true}; if (auto &transp{std::get<parser::OmpClause::Transparent>(clause.u).v}) { // Scalar<Integer<indirection<Expr>>> - const parser::Expr &parserExpr{transp->v.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(transp)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the argument is omp_not_impex (defined as 0), then // the task is not transparent, otherwise it is. @@ -2389,8 +2411,8 @@ private: } } // Scalar<Logical<indirection<Expr>>> - auto &parserExpr{ - std::get<parser::ScalarLogicalExpr>(ifc.v.t).thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>( + std::get<parser::ScalarLogicalExpr>(ifc.v.t))}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the value is known to be false, an undeferred task will be // generated. @@ -3017,8 +3039,8 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) { &objs, std::string clause) { for (const auto &obj : objs.v) { - if (const parser::Name * - objName{parser::Unwrap<parser::Name>(obj)}) { + if (const parser::Name *objName{ + parser::Unwrap<parser::Name>(obj)}) { if (&objName->symbol->GetUltimate() == eventHandleSym) { context_.Say(GetContext().clauseSource, "A variable: `%s` that appears in a DETACH clause cannot appear on %s clause on the same construct"_err_en_US, @@ -3637,7 +3659,8 @@ void OmpStructureChecker::CheckReductionModifier( if (modifier.v == ReductionModifier::Value::Task) { // "Task" is only allowed on worksharing or "parallel" directive. static llvm::omp::Directive worksharing[]{ - llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_scope, + llvm::omp::Directive::OMPD_do, // + llvm::omp::Directive::OMPD_scope, // llvm::omp::Directive::OMPD_sections, // There are more worksharing directives, but they do not apply: // "for" is C++ only, @@ -4081,9 +4104,15 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { if (auto *iter{OmpGetUniqueModifier<parser::OmpIterator>(modifiers)}) { CheckIteratorModifier(*iter); } + + using Directive = llvm::omp::Directive; + Directive dir{GetContext().directive}; + llvm::ArrayRef<Directive> leafs{llvm::omp::getLeafConstructsOrSelf(dir)}; + parser::OmpMapType::Value mapType{parser::OmpMapType::Value::Storage}; + if (auto *type{OmpGetUniqueModifier<parser::OmpMapType>(modifiers)}) { - using Directive = llvm::omp::Directive; using Value = parser::OmpMapType::Value; + mapType = type->v; static auto isValidForVersion{ [](parser::OmpMapType::Value t, unsigned version) { @@ -4120,10 +4149,6 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { return result; }()}; - llvm::omp::Directive dir{GetContext().directive}; - llvm::ArrayRef<llvm::omp::Directive> leafs{ - llvm::omp::getLeafConstructsOrSelf(dir)}; - if (llvm::is_contained(leafs, Directive::OMPD_target) || llvm::is_contained(leafs, Directive::OMPD_target_data)) { if (version >= 60) { @@ -4141,6 +4166,43 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) { } } + if (auto *attach{ + OmpGetUniqueModifier<parser::OmpAttachModifier>(modifiers)}) { + bool mapEnteringConstructOrMapper{ + llvm::is_contained(leafs, Directive::OMPD_target) || + llvm::is_contained(leafs, Directive::OMPD_target_data) || + llvm::is_contained(leafs, Directive::OMPD_target_enter_data) || + llvm::is_contained(leafs, Directive::OMPD_declare_mapper)}; + + if (!mapEnteringConstructOrMapper || !IsMapEnteringType(mapType)) { + const auto &desc{OmpGetDescriptor<parser::OmpAttachModifier>()}; + context_.Say(OmpGetModifierSource(modifiers, attach), + "The '%s' modifier can only appear on a map-entering construct or on a DECLARE_MAPPER directive"_err_en_US, + desc.name.str()); + } + + auto hasBasePointer{[&](const SomeExpr &item) { + evaluate::SymbolVector symbols{evaluate::GetSymbolVector(item)}; + return llvm::any_of( + symbols, [](SymbolRef s) { return IsPointer(s.get()); }); + }}; + + evaluate::ExpressionAnalyzer ea{context_}; + const auto &objects{std::get<parser::OmpObjectList>(x.v.t)}; + for (auto &object : objects.v) { + if (const parser::Designator *d{GetDesignatorFromObj(object)}) { + if (auto &&expr{ea.Analyze(*d)}) { + if (hasBasePointer(*expr)) { + continue; + } + } + } + auto source{GetObjectSource(object)}; + context_.Say(source ? *source : GetContext().clauseSource, + "A list-item that appears in a map clause with the ATTACH modifier must have a base-pointer"_err_en_US); + } + } + auto &&typeMods{ OmpGetRepeatableModifier<parser::OmpMapTypeModifier>(modifiers)}; struct Less { diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index f507278..543642ff 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -262,10 +262,10 @@ private: void CheckStorageOverlap(const evaluate::Expr<evaluate::SomeType> &, llvm::ArrayRef<evaluate::Expr<evaluate::SomeType>>, parser::CharBlock); void ErrorShouldBeVariable(const MaybeExpr &expr, parser::CharBlock source); - void CheckAtomicType( - SymbolRef sym, parser::CharBlock source, std::string_view name); - void CheckAtomicVariable( - const evaluate::Expr<evaluate::SomeType> &, parser::CharBlock); + void CheckAtomicType(SymbolRef sym, parser::CharBlock source, + std::string_view name, bool checkTypeOnPointer = true); + void CheckAtomicVariable(const evaluate::Expr<evaluate::SomeType> &, + parser::CharBlock, bool checkTypeOnPointer = true); std::pair<const parser::ExecutionPartConstruct *, const parser::ExecutionPartConstruct *> CheckUpdateCapture(const parser::ExecutionPartConstruct *ec1, diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp index 1e46dab..bbf3b28 100644 --- a/flang/lib/Semantics/data-to-inits.cpp +++ b/flang/lib/Semantics/data-to-inits.cpp @@ -179,13 +179,14 @@ bool DataInitializationCompiler<DSV>::Scan( template <typename DSV> bool DataInitializationCompiler<DSV>::Scan(const parser::DataImpliedDo &ido) { const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; - const auto *lowerExpr{ - GetExpr(exprAnalyzer_.context(), bounds.lower.thing.thing)}; - const auto *upperExpr{ - GetExpr(exprAnalyzer_.context(), bounds.upper.thing.thing)}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; + const auto *lowerExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.lower))}; + const auto *upperExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.upper))}; const auto *stepExpr{bounds.step - ? GetExpr(exprAnalyzer_.context(), bounds.step->thing.thing) + ? GetExpr(exprAnalyzer_.context(), + parser::UnwrapRef<parser::Expr>(bounds.step)) : nullptr}; if (lowerExpr && upperExpr) { // Fold the bounds expressions (again) in case any of them depend @@ -240,7 +241,9 @@ bool DataInitializationCompiler<DSV>::Scan( return common::visit( common::visitors{ [&](const parser::Scalar<common::Indirection<parser::Designator>> - &var) { return Scan(var.thing.value()); }, + &var) { + return Scan(parser::UnwrapRef<parser::Designator>(var)); + }, [&](const common::Indirection<parser::DataImpliedDo> &ido) { return Scan(ido.value()); }, diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index 2feec98..4aeb9a4 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -176,8 +176,8 @@ public: // Find and return a user-defined operator or report an error. // The provided message is used if there is no such operator. - MaybeExpr TryDefinedOp( - const char *, parser::MessageFixedText, bool isUserOp = false); + MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText, + bool isUserOp = false, bool checkForNullPointer = true); template <typename E> MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) { return TryDefinedOp( @@ -211,7 +211,8 @@ private: void SayNoMatch( const std::string &, bool isAssignment = false, bool isAmbiguous = false); std::string TypeAsFortran(std::size_t); - bool AnyUntypedOrMissingOperand() const; + bool AnyUntypedOperand() const; + bool AnyMissingOperand() const; ExpressionAnalyzer &context_; ActualArguments actuals_; @@ -1954,9 +1955,10 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { const auto &control{std::get<parser::AcImpliedDoControl>(impliedDo.t)}; const auto &bounds{std::get<parser::AcImpliedDoControl::Bounds>(control.t)}; exprAnalyzer_.Analyze(bounds.name); - parser::CharBlock name{bounds.name.thing.thing.source}; + const auto &parsedName{parser::UnwrapRef<parser::Name>(bounds.name)}; + parser::CharBlock name{parsedName.source}; int kind{ImpliedDoIntType::kind}; - if (const Symbol * symbol{bounds.name.thing.thing.symbol}) { + if (const Symbol *symbol{parsedName.symbol}) { if (auto dynamicType{DynamicType::From(symbol)}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); @@ -1981,7 +1983,7 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { auto cUpper{ToInt64(upper)}; auto cStride{ToInt64(stride)}; if (!(messageDisplayedSet_ & 0x10) && cStride && *cStride == 0) { - exprAnalyzer_.SayAt(bounds.step.value().thing.thing.value().source, + exprAnalyzer_.SayAt(parser::UnwrapRef<parser::Expr>(bounds.step).source, "The stride of an implied DO loop must not be zero"_err_en_US); messageDisplayedSet_ |= 0x10; } @@ -2526,7 +2528,7 @@ static const Symbol *GetBindingResolution( auto ExpressionAnalyzer::AnalyzeProcedureComponentRef( const parser::ProcComponentRef &pcr, ActualArguments &&arguments, bool isSubroutine) -> std::optional<CalleeAndArguments> { - const parser::StructureComponent &sc{pcr.v.thing}; + const auto &sc{parser::UnwrapRef<parser::StructureComponent>(pcr)}; if (MaybeExpr base{Analyze(sc.base)}) { if (const Symbol *sym{sc.component.symbol}) { if (context_.HasError(sym)) { @@ -3695,11 +3697,12 @@ std::optional<characteristics::Procedure> ExpressionAnalyzer::CheckCall( MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::Parentheses &x) { if (MaybeExpr operand{Analyze(x.v.value())}) { - if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { + if (IsNullPointerOrAllocatable(&*operand)) { + Say("NULL() may not be parenthesized"_err_en_US); + } else if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { if (const semantics::Symbol *result{FindFunctionResult(*symbol)}) { if (semantics::IsProcedurePointer(*result)) { - Say("A function reference that returns a procedure " - "pointer may not be parenthesized"_err_en_US); // C1003 + Say("A function reference that returns a procedure pointer may not be parenthesized"_err_en_US); // C1003 } } } @@ -3788,7 +3791,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) { ArgumentAnalyzer analyzer{*this, name.source}; analyzer.Analyze(std::get<1>(x.t)); return analyzer.TryDefinedOp(name.source.ToString().c_str(), - "No operator %s defined for %s"_err_en_US, true); + "No operator %s defined for %s"_err_en_US, /*isUserOp=*/true); } // Binary (dyadic) operations @@ -3997,7 +4000,9 @@ static bool CheckFuncRefToArrayElement(semantics::SemanticsContext &context, auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; const auto *name{std::get_if<parser::Name>(&proc.u)}; if (!name) { - name = &std::get<parser::ProcComponentRef>(proc.u).v.thing.component; + name = &parser::UnwrapRef<parser::StructureComponent>( + std::get<parser::ProcComponentRef>(proc.u)) + .component; } if (!name->symbol) { return false; @@ -4047,14 +4052,16 @@ static void FixMisparsedFunctionReference( } } auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; - if (Symbol *origSymbol{ - common::visit(common::visitors{ - [&](parser::Name &name) { return name.symbol; }, - [&](parser::ProcComponentRef &pcr) { - return pcr.v.thing.component.symbol; - }, - }, - proc.u)}) { + if (Symbol * + origSymbol{common::visit( + common::visitors{ + [&](parser::Name &name) { return name.symbol; }, + [&](parser::ProcComponentRef &pcr) { + return parser::UnwrapRef<parser::StructureComponent>(pcr) + .component.symbol; + }, + }, + proc.u)}) { Symbol &symbol{origSymbol->GetUltimate()}; if (symbol.has<semantics::ObjectEntityDetails>() || symbol.has<semantics::AssocEntityDetails>()) { @@ -4176,15 +4183,23 @@ MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions( } while (!queue.empty()); // Analyze the collected subexpressions in bottom-up order. // On an error, bail out and leave partial results in place. - MaybeExpr result; - for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { - const parser::Expr &expr{**riter}; - result = ExprOrVariable(expr, expr.source); - if (!result) { - return result; + if (finish.size() == 1) { + const parser::Expr &expr{DEREF(finish.front())}; + return ExprOrVariable(expr, expr.source); + } else { + // NULL() operand catching is deferred to operation analysis so + // that they can be accepted by defined operators. + auto restorer{AllowNullPointer()}; + MaybeExpr result; + for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { + const parser::Expr &expr{**riter}; + result = ExprOrVariable(expr, expr.source); + if (!result) { + return result; + } } + return result; // last value was from analysis of "top" } - return result; // last value was from analysis of "top" } MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) { @@ -4681,7 +4696,7 @@ bool ArgumentAnalyzer::AnyCUDADeviceData() const { // attribute. bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( const char *opr) const { - if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) { + if (AnyCUDADeviceData() && !AnyUntypedOperand() && !AnyMissingOperand()) { std::string oprNameString{"operator("s + opr + ')'}; parser::CharBlock oprName{oprNameString}; parser::Messages buffer; @@ -4709,9 +4724,9 @@ bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( return false; } -MaybeExpr ArgumentAnalyzer::TryDefinedOp( - const char *opr, parser::MessageFixedText error, bool isUserOp) { - if (AnyUntypedOrMissingOperand()) { +MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr, + parser::MessageFixedText error, bool isUserOp, bool checkForNullPointer) { + if (AnyMissingOperand()) { context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); return std::nullopt; } @@ -4790,7 +4805,9 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( context_.Say( "Operands of %s are not conformable; have rank %d and rank %d"_err_en_US, ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank()); - } else if (CheckForNullPointer() && CheckForAssumedRank()) { + } else if (!CheckForAssumedRank()) { + } else if (checkForNullPointer && !CheckForNullPointer()) { + } else { // use the supplied error context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); } return result; @@ -4808,15 +4825,16 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( for (std::size_t i{0}; i < oprs.size(); ++i) { parser::Messages buffer; auto restorer{context_.GetContextualMessages().SetMessages(buffer)}; - if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) { + if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error, /*isUserOp=*/false, + /*checkForNullPointer=*/false)}) { result = std::move(thisResult); hit.push_back(oprs[i]); hitBuffer = std::move(buffer); } } } - if (hit.empty()) { // for the error - result = TryDefinedOp(oprs[0], error); + if (hit.empty()) { // run TryDefinedOp() again just to emit errors + CHECK(!TryDefinedOp(oprs[0], error).has_value()); } else if (hit.size() > 1) { context_.Say( "Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US, @@ -5232,10 +5250,19 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) { } } -bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const { +bool ArgumentAnalyzer::AnyUntypedOperand() const { + for (const auto &actual : actuals_) { + if (actual && !actual->GetType() && + !IsBareNullPointer(actual->UnwrapExpr())) { + return true; + } + } + return false; +} + +bool ArgumentAnalyzer::AnyMissingOperand() const { for (const auto &actual : actuals_) { - if (!actual || - (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) { + if (!actual) { return true; } } @@ -5268,9 +5295,9 @@ void ExprChecker::Post(const parser::DataStmtObject &obj) { bool ExprChecker::Pre(const parser::DataImpliedDo &ido) { parser::Walk(std::get<parser::DataImpliedDo::Bounds>(ido.t), *this); const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 8074c94..556259d 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -17,6 +17,7 @@ #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "llvm/Frontend/OpenMP/OMP.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" @@ -24,6 +25,7 @@ #include <fstream> #include <set> #include <string_view> +#include <type_traits> #include <variant> #include <vector> @@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) { } } +static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) { + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + using OmpMemoryOrderType = common::OmpMemoryOrderType; + + const auto [reqs, order]{common::visit( + [&](auto &&details) + -> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + return {details.ompRequires(), details.ompAtomicDefaultMemOrder()}; + } else { + return {nullptr, nullptr}; + } + }, + symbol.details())}; + + if (order) { + llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order}; + os << "!$omp requires " + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n"; + } + if (reqs) { + os << "!$omp requires"; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) { + os << ' ' + << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + } + }); + os << "\n"; + } +} + // Put out the visible symbols from scope. void ModFileWriter::PutSymbols( const Scope &scope, UnorderedSymbolSet *hermeticModules) { @@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols( for (const Symbol &symbol : uses) { PutUse(symbol); } + PutOpenMPRequirements(decls_, DEREF(scope.symbol())); for (const auto &set : scope.equivalenceSets()) { if (!set.empty() && !set.front().symbol.test(Symbol::Flag::CompilerCreated)) { diff --git a/flang/lib/Semantics/openmp-modifiers.cpp b/flang/lib/Semantics/openmp-modifiers.cpp index af4000c..717fb03 100644 --- a/flang/lib/Semantics/openmp-modifiers.cpp +++ b/flang/lib/Semantics/openmp-modifiers.cpp @@ -157,6 +157,22 @@ const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAlwaysModifier>() { } template <> +const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAttachModifier>() { + static const OmpModifierDescriptor desc{ + /*name=*/"attach-modifier", + /*props=*/ + { + {61, {OmpProperty::Unique}}, + }, + /*clauses=*/ + { + {61, {Clause::OMPC_map}}, + }, + }; + return desc; +} + +template <> const OmpModifierDescriptor &OmpGetDescriptor<parser::OmpAutomapModifier>() { static const OmpModifierDescriptor desc{ /*name=*/"automap-modifier", diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index a8ec4d6..292e73b 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -13,6 +13,7 @@ #include "flang/Semantics/openmp-utils.h" #include "flang/Common/Fortran-consts.h" +#include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" @@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) { return *iter; } +const Scope &GetProgramUnit(const Scope &scope) { + const Scope *unit{nullptr}; + for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockData: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + return *iter; + case Scope::Kind::Subprogram: + // Ignore subprograms that are nested. + unit = iter; + break; + default: + break; + } + } + assert(unit && "Scope not in a program unit"); + return *unit; +} + SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) { if (x == nullptr) { return SourcedActionStmt{}; @@ -202,7 +223,7 @@ std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) { // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() // `- std::optional<Expr> ^->v - return typedExpr.get()->v; + return DEREF(typedExpr.get()).v; } std::optional<evaluate::DynamicType> GetDynamicType( diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 18fc638..7067ed3 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -435,6 +435,22 @@ public: return true; } + bool Pre(const parser::UseStmt &x) { + if (x.moduleName.symbol) { + Scope &thisScope{context_.FindScope(x.moduleName.source)}; + common::visit( + [&](auto &&details) { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + AddOmpRequiresToScope(thisScope, details.ompRequires(), + details.ompAtomicDefaultMemOrder()); + } + }, + x.moduleName.symbol->details()); + } + return true; + } + bool Pre(const parser::OmpMetadirectiveDirective &x) { PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective); return true; @@ -538,38 +554,55 @@ public: void Post(const parser::OpenMPFlushConstruct &) { PopContext(); } bool Pre(const parser::OpenMPRequiresConstruct &x) { - using Flags = WithOmpDeclarative::RequiresFlags; - using Requires = WithOmpDeclarative::RequiresFlag; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; PushContext(x.source, llvm::omp::Directive::OMPD_requires); + auto getArgument{[&](auto &&maybeClause) { + if (maybeClause) { + // Scalar<Logical<Constant<common::Indirection<Expr>>>> + auto &parserExpr{maybeClause->v.thing.thing.thing.value()}; + evaluate::ExpressionAnalyzer ea{context_}; + if (auto &&maybeExpr{ea.Analyze(parserExpr)}) { + if (auto v{omp::GetLogicalValue(*maybeExpr)}) { + return *v; + } + } + } + // If the argument is missing, it is assumed to be true. + return true; + }}; + // Gather information from the clauses. - Flags flags; - std::optional<common::OmpMemoryOrderType> memOrder; + RequiresClauses reqs; + const common::OmpMemoryOrderType *memOrder{nullptr}; for (const parser::OmpClause &clause : x.v.Clauses().v) { - flags |= common::visit( + using OmpClause = parser::OmpClause; + reqs |= common::visit( common::visitors{ - [&memOrder]( - const parser::OmpClause::AtomicDefaultMemOrder &atomic) { - memOrder = atomic.v.v; - return Flags{}; - }, - [](const parser::OmpClause::ReverseOffload &) { - return Flags{Requires::ReverseOffload}; + [&](const OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = &atomic.v.v; + return RequiresClauses{}; }, - [](const parser::OmpClause::UnifiedAddress &) { - return Flags{Requires::UnifiedAddress}; - }, - [](const parser::OmpClause::UnifiedSharedMemory &) { - return Flags{Requires::UnifiedSharedMemory}; - }, - [](const parser::OmpClause::DynamicAllocators &) { - return Flags{Requires::DynamicAllocators}; + [&](auto &&s) { + using TypeS = llvm::remove_cvref_t<decltype(s)>; + if constexpr ( // + std::is_same_v<TypeS, OmpClause::DynamicAllocators> || + std::is_same_v<TypeS, OmpClause::ReverseOffload> || + std::is_same_v<TypeS, OmpClause::SelfMaps> || + std::is_same_v<TypeS, OmpClause::UnifiedAddress> || + std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) { + if (getArgument(s.v)) { + return RequiresClauses{clause.Id()}; + } + } + return RequiresClauses{}; }, - [](const auto &) { return Flags{}; }}, + }, clause.u); } + // Merge clauses into parents' symbols details. - AddOmpRequiresToScope(currScope(), flags, memOrder); + AddOmpRequiresToScope(currScope(), &reqs, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -1001,8 +1034,9 @@ private: std::int64_t ordCollapseLevel{0}; - void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags, - std::optional<common::OmpMemoryOrderType>); + void AddOmpRequiresToScope(Scope &, + const WithOmpDeclarative::RequiresClauses *, + const common::OmpMemoryOrderType *); void IssueNonConformanceWarning(llvm::omp::Directive D, parser::CharBlock source, unsigned EmitFromVersion); @@ -3309,86 +3343,6 @@ void ResolveOmpParts( } } -void ResolveOmpTopLevelParts( - SemanticsContext &context, const parser::Program &program) { - if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { - return; - } - - // Gather REQUIRES clauses from all non-module top-level program unit symbols, - // combine them together ensuring compatibility and apply them to all these - // program units. Modules are skipped because their REQUIRES clauses should be - // propagated via USE statements instead. - WithOmpDeclarative::RequiresFlags combinedFlags; - std::optional<common::OmpMemoryOrderType> combinedMemOrder; - - // Function to go through non-module top level program units and extract - // REQUIRES information to be processed by a function-like argument. - auto processProgramUnits{[&](auto processFn) { - for (const parser::ProgramUnit &unit : program.v) { - if (!std::holds_alternative<common::Indirection<parser::Module>>( - unit.u) && - !std::holds_alternative<common::Indirection<parser::Submodule>>( - unit.u) && - !std::holds_alternative< - common::Indirection<parser::CompilerDirective>>(unit.u)) { - Symbol *symbol{common::visit( - [&context](auto &x) { - Scope *scope = GetScope(context, x.value()); - return scope ? scope->symbol() : nullptr; - }, - unit.u)}; - // FIXME There is no symbol defined for MainProgram units in certain - // circumstances, so REQUIRES information has no place to be stored in - // these cases. - if (!symbol) { - continue; - } - common::visit( - [&](auto &details) { - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - processFn(*symbol, details); - } - }, - symbol->details()); - } - } - }}; - - // Combine global REQUIRES information from all program units except modules - // and submodules. - processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) { - if (const WithOmpDeclarative::RequiresFlags * - flags{details.ompRequires()}) { - combinedFlags |= *flags; - } - if (const common::OmpMemoryOrderType * - memOrder{details.ompAtomicDefaultMemOrder()}) { - if (combinedMemOrder && *combinedMemOrder != *memOrder) { - context.Say(symbol.scope()->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - combinedMemOrder = *memOrder; - } - }); - - // Update all program units except modules and submodules with the combined - // global REQUIRES information. - processProgramUnits([&](Symbol &, WithOmpDeclarative &details) { - if (combinedFlags.any()) { - details.set_ompRequires(combinedFlags); - } - if (combinedMemOrder) { - details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); - } - }); -} - static bool IsSymbolThreadprivate(const Symbol &symbol) { if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) { return details->symbol().test(Symbol::Flag::OmpThreadprivate); @@ -3547,42 +3501,39 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source, } void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, - WithOmpDeclarative::RequiresFlags flags, - std::optional<common::OmpMemoryOrderType> memOrder) { - Scope *scopeIter = &scope; - do { - if (Symbol * symbol{scopeIter->symbol()}) { - common::visit( - [&](auto &details) { - // Store clauses information into the symbol for the parent and - // enclosing modules, programs, functions and subroutines. - if constexpr (std::is_convertible_v<decltype(&details), - WithOmpDeclarative *>) { - if (flags.any()) { - if (const WithOmpDeclarative::RequiresFlags * - otherFlags{details.ompRequires()}) { - flags |= *otherFlags; - } - details.set_ompRequires(flags); + const WithOmpDeclarative::RequiresClauses *reqs, + const common::OmpMemoryOrderType *memOrder) { + const Scope &programUnit{omp::GetProgramUnit(scope)}; + using RequiresClauses = WithOmpDeclarative::RequiresClauses; + RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}}; + + if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) { + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v<decltype(&details), + WithOmpDeclarative *>) { + if (combinedReqs.any()) { + if (const RequiresClauses *otherReqs{details.ompRequires()}) { + combinedReqs |= *otherReqs; } - if (memOrder) { - if (details.has_ompAtomicDefaultMemOrder() && - *details.ompAtomicDefaultMemOrder() != *memOrder) { - context_.Say(scopeIter->sourceRange(), - "Conflicting '%s' REQUIRES clauses found in compilation " - "unit"_err_en_US, - parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( - llvm::omp::Clause::OMPC_atomic_default_mem_order) - .str())); - } - details.set_ompAtomicDefaultMemOrder(*memOrder); + details.set_ompRequires(combinedReqs); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(programUnit.sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); } + details.set_ompAtomicDefaultMemOrder(*memOrder); } - }, - symbol->details()); - } - scopeIter = &scopeIter->parent(); - } while (!scopeIter->IsGlobal()); + } + }, + symbol->details()); + } } void OmpAttributeVisitor::IssueNonConformanceWarning(llvm::omp::Directive D, diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h index 5a890c2..36d3ce9 100644 --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -23,7 +23,5 @@ class SemanticsContext; void ResolveAccParts( SemanticsContext &, const parser::ProgramUnit &, Scope *topScope); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); -void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); - } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp index 742bb74..ac67799 100644 --- a/flang/lib/Semantics/resolve-names-utils.cpp +++ b/flang/lib/Semantics/resolve-names-utils.cpp @@ -492,12 +492,14 @@ bool EquivalenceSets::CheckDesignator(const parser::Designator &designator) { const auto &range{std::get<parser::SubstringRange>(x.t)}; bool ok{CheckDataRef(designator.source, dataRef)}; if (const auto &lb{std::get<0>(range.t)}) { - ok &= CheckSubstringBound(lb->thing.thing.value(), true); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(lb), true); } else { currObject_.substringStart = 1; } if (const auto &ub{std::get<1>(range.t)}) { - ok &= CheckSubstringBound(ub->thing.thing.value(), false); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(ub), false); } return ok; }, @@ -528,7 +530,8 @@ bool EquivalenceSets::CheckDataRef( return false; }, [&](const parser::IntExpr &y) { - return CheckArrayBound(y.thing.value()); + return CheckArrayBound( + parser::UnwrapRef<parser::Expr>(y)); }, }, subscript.u); diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 86121880..699de41 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1140,7 +1140,7 @@ protected: std::optional<SourceName> BeginCheckOnIndexUseInOwnBounds( const parser::DoVariable &name) { std::optional<SourceName> result{checkIndexUseInOwnBounds_}; - checkIndexUseInOwnBounds_ = name.thing.thing.source; + checkIndexUseInOwnBounds_ = parser::UnwrapRef<parser::Name>(name).source; return result; } void EndCheckOnIndexUseInOwnBounds(const std::optional<SourceName> &restore) { @@ -2130,7 +2130,7 @@ public: void Post(const parser::SubstringInquiry &); template <typename A, typename B> void Post(const parser::LoopBounds<A, B> &x) { - ResolveName(*parser::Unwrap<parser::Name>(x.name)); + ResolveName(parser::UnwrapRef<parser::Name>(x.name)); } void Post(const parser::ProcComponentRef &); bool Pre(const parser::FunctionReference &); @@ -2560,7 +2560,7 @@ KindExpr DeclTypeSpecVisitor::GetKindParamExpr( CHECK(!state_.originalKindParameter); // Save a pointer to the KIND= expression in the parse tree // in case we need to reanalyze it during PDT instantiation. - state_.originalKindParameter = &expr->thing.thing.thing.value(); + state_.originalKindParameter = parser::Unwrap<parser::Expr>(expr); } } // Inhibit some errors now that will be caught later during instantiations. @@ -5649,6 +5649,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { if (details->init() || symbol.test(Symbol::Flag::InDataStmt)) { Say(name, "Named constant '%s' already has a value"_err_en_US); } + parser::CharBlock at{parser::UnwrapRef<parser::Expr>(expr).source}; if (inOldStyleParameterStmt_) { // non-standard extension PARAMETER statement (no parentheses) Walk(expr); @@ -5657,7 +5658,6 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { SayWithDecl(name, symbol, "Alternative style PARAMETER '%s' must not already have an explicit type"_err_en_US); } else if (folded) { - auto at{expr.thing.value().source}; if (evaluate::IsActuallyConstant(*folded)) { if (const auto *type{currScope().GetType(*folded)}) { if (type->IsPolymorphic()) { @@ -5682,8 +5682,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { // standard-conforming PARAMETER statement (with parentheses) ApplyImplicitRules(symbol); Walk(expr); - if (auto converted{EvaluateNonPointerInitializer( - symbol, expr, expr.thing.value().source)}) { + if (auto converted{EvaluateNonPointerInitializer(symbol, expr, at)}) { details->set_init(std::move(*converted)); } } @@ -6149,7 +6148,7 @@ bool DeclarationVisitor::Pre(const parser::KindParam &x) { if (const auto *kind{std::get_if< parser::Scalar<parser::Integer<parser::Constant<parser::Name>>>>( &x.u)}) { - const parser::Name &name{kind->thing.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(kind)}; if (!FindSymbol(name)) { Say(name, "Parameter '%s' not found"_err_en_US); } @@ -7460,7 +7459,7 @@ void DeclarationVisitor::DeclareLocalEntity( Symbol *DeclarationVisitor::DeclareStatementEntity( const parser::DoVariable &doVar, const std::optional<parser::IntegerTypeSpec> &type) { - const parser::Name &name{doVar.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(doVar)}; const DeclTypeSpec *declTypeSpec{nullptr}; if (auto *prev{FindSymbol(name)}) { if (prev->owner() == currScope()) { @@ -7893,13 +7892,14 @@ bool ConstructVisitor::Pre(const parser::DataIDoObject &x) { common::visit( common::visitors{ [&](const parser::Scalar<Indirection<parser::Designator>> &y) { - Walk(y.thing.value()); - const parser::Name &first{parser::GetFirstName(y.thing.value())}; + const auto &designator{parser::UnwrapRef<parser::Designator>(y)}; + Walk(designator); + const parser::Name &first{parser::GetFirstName(designator)}; if (first.symbol) { first.symbol->set(Symbol::Flag::InDataStmt); } }, - [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y.value()); }, + [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y); }, }, x.u); return false; @@ -8582,8 +8582,7 @@ public: void Post(const parser::WriteStmt &) { inAsyncIO_ = false; } void Post(const parser::IoControlSpec::Size &size) { if (const auto *designator{ - std::get_if<common::Indirection<parser::Designator>>( - &size.v.thing.thing.u)}) { + parser::Unwrap<common::Indirection<parser::Designator>>(size)}) { NoteAsyncIODesignator(designator->value()); } } @@ -9175,16 +9174,17 @@ bool DeclarationVisitor::CheckNonPointerInitialization( } void DeclarationVisitor::NonPointerInitialization( - const parser::Name &name, const parser::ConstantExpr &expr) { + const parser::Name &name, const parser::ConstantExpr &constExpr) { if (CheckNonPointerInitialization( name, /*inLegacyDataInitialization=*/false)) { Symbol &ultimate{name.symbol->GetUltimate()}; auto &details{ultimate.get<ObjectEntityDetails>()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(constExpr)}; if (ultimate.owner().IsParameterizedDerivedType()) { // Save the expression for per-instantiation analysis. - details.set_unanalyzedPDTComponentInit(&expr.thing.value()); + details.set_unanalyzedPDTComponentInit(&expr); } else if (MaybeExpr folded{EvaluateNonPointerInitializer( - ultimate, expr, expr.thing.value().source)}) { + ultimate, constExpr, expr.source)}) { details.set_init(std::move(*folded)); ultimate.set(Symbol::Flag::InDataStmt, false); } @@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) { CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); - // Top-level resolution to propagate information across program units after - // each of them has been resolved separately. - ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp index 69169469..0ec44b7 100644 --- a/flang/lib/Semantics/symbol.cpp +++ b/flang/lib/Semantics/symbol.cpp @@ -70,6 +70,32 @@ static void DumpList(llvm::raw_ostream &os, const char *label, const T &list) { } } +llvm::raw_ostream &operator<<( + llvm::raw_ostream &os, const WithOmpDeclarative &x) { + if (x.has_ompRequires() || x.has_ompAtomicDefaultMemOrder()) { + os << " OmpRequirements:("; + if (const common::OmpMemoryOrderType *admo{x.ompAtomicDefaultMemOrder()}) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order)) + << '(' << parser::ToLowerCaseLetters(EnumToString(*admo)) << ')'; + if (x.has_ompRequires()) { + os << ','; + } + } + if (const WithOmpDeclarative::RequiresClauses *reqs{x.ompRequires()}) { + size_t num{0}, size{reqs->count()}; + reqs->IterateOverMembers([&](llvm::omp::Clause f) { + os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f)); + if (++num < size) { + os << ','; + } + }); + } + os << ')'; + } + return os; +} + void SubprogramDetails::set_moduleInterface(Symbol &symbol) { CHECK(!moduleInterface_); moduleInterface_ = &symbol; @@ -150,6 +176,7 @@ llvm::raw_ostream &operator<<( os << x; } } + os << static_cast<const WithOmpDeclarative &>(x); return os; } @@ -580,7 +607,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { common::visit( // common::visitors{ [&](const UnknownDetails &) {}, - [&](const MainProgramDetails &) {}, + [&](const MainProgramDetails &x) { + os << static_cast<const WithOmpDeclarative &>(x); + }, [&](const ModuleDetails &x) { if (x.isSubmodule()) { os << " ("; @@ -599,6 +628,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { if (x.isDefaultPrivate()) { os << " isDefaultPrivate"; } + os << static_cast<const WithOmpDeclarative &>(x); }, [&](const SubprogramNameDetails &x) { os << ' ' << EnumToString(x.kind()); |