diff options
Diffstat (limited to 'flang/lib')
22 files changed, 488 insertions, 350 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index b3e8b69..af4f420 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -718,6 +718,84 @@ static void genDataOperandOperations( } } +template <typename GlobalCtorOrDtorOp, typename EntryOp, typename DeclareOp, + typename ExitOp> +static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, + const std::string &declareGlobalName, + bool implicit, std::stringstream &asFortran) { + GlobalCtorOrDtorOp declareGlobalOp = + GlobalCtorOrDtorOp::create(modBuilder, loc, declareGlobalName); + builder.createBlock(&declareGlobalOp.getRegion(), + declareGlobalOp.getRegion().end(), {}, {}); + builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); + + fir::AddrOfOp addrOp = fir::AddrOfOp::create( + builder, loc, fir::ReferenceType::get(globalOp.getType()), + globalOp.getSymbol()); + addDeclareAttr(builder, addrOp, clause); + + llvm::SmallVector<mlir::Value> bounds; + EntryOp entryOp = createDataEntryOp<EntryOp>( + builder, loc, addrOp.getResTy(), asFortran, bounds, + /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), + /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) + DeclareOp::create(builder, loc, + mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccVar())); + else + DeclareOp::create(builder, loc, mlir::Value{}, + mlir::ValueRange(entryOp.getAccVar())); + if constexpr (std::is_same_v<GlobalCtorOrDtorOp, + mlir::acc::GlobalDestructorOp>) { + if constexpr (std::is_same_v<ExitOp, mlir::acc::DeclareLinkOp>) { + // No destructor emission for declare link in this path to avoid + // complex var/varType/varPtrPtr signatures. The ctor registers the link. + } else if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || + std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getVar(), entryOp.getVarType(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } else { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } + } + mlir::acc::TerminatorOp::create(builder, loc); + modBuilder.setInsertionPointAfter(declareGlobalOp); +} + +template <typename EntryOp, typename ExitOp> +static void +emitCtorDtorPair(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, + mlir::Location operandLocation, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, std::stringstream &asFortran, + const std::string &ctorName) { + createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp, + mlir::acc::DeclareEnterOp, ExitOp>( + modBuilder, builder, operandLocation, globalOp, clause, ctorName, + /*implicit=*/false, asFortran); + + std::stringstream dtorName; + dtorName << globalOp.getSymName().str() << "_acc_dtor"; + createDeclareGlobalOp<mlir::acc::GlobalDestructorOp, + mlir::acc::GetDevicePtrOp, mlir::acc::DeclareExitOp, + ExitOp>(modBuilder, builder, operandLocation, globalOp, + clause, dtorName.str(), + /*implicit=*/false, asFortran); +} + template <typename EntryOp, typename ExitOp> static void genDeclareDataOperandOperations( const Fortran::parser::AccObjectList &objectList, @@ -733,6 +811,37 @@ static void genDeclareDataOperandOperations( std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + // Handle COMMON/global symbols via module-level ctor/dtor path. + if (symbol.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(symbol)) { + emitCommonGlobal( + converter, builder, accObject, dataClause, + [&](mlir::OpBuilder &modBuilder, mlir::Location loc, + fir::GlobalOp globalOp, mlir::acc::DataClause clause, + std::stringstream &asFortranStr, const std::string &ctorName) { + if constexpr (std::is_same_v<EntryOp, mlir::acc::DeclareLinkOp>) { + createDeclareGlobalOp< + mlir::acc::GlobalConstructorOp, mlir::acc::DeclareLinkOp, + mlir::acc::DeclareEnterOp, mlir::acc::DeclareLinkOp>( + modBuilder, builder, loc, globalOp, clause, ctorName, + /*implicit=*/false, asFortranStr); + } else if constexpr (std::is_same_v<EntryOp, mlir::acc::CreateOp> || + std::is_same_v<EntryOp, mlir::acc::CopyinOp> || + std::is_same_v< + EntryOp, + mlir::acc::DeclareDeviceResidentOp> || + std::is_same_v<ExitOp, mlir::acc::CopyoutOp>) { + emitCtorDtorPair<EntryOp, ExitOp>(modBuilder, builder, loc, + globalOp, clause, asFortranStr, + ctorName); + } else { + // No module-level ctor/dtor for this clause (e.g., deviceptr, + // present). Handled via structured declare region only. + return; + } + }); + continue; + } Fortran::semantics::MaybeExpr designator = Fortran::common::visit( [&](auto &&s) { return ea.Analyze(s); }, accObject.u); fir::factory::AddrAndBoundsInfo info = @@ -4098,49 +4207,6 @@ static void genACC(Fortran::lower::AbstractConverter &converter, waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); } -template <typename GlobalOp, typename EntryOp, typename DeclareOp, - typename ExitOp> -static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, - fir::FirOpBuilder &builder, - mlir::Location loc, fir::GlobalOp globalOp, - mlir::acc::DataClause clause, - const std::string &declareGlobalName, - bool implicit, std::stringstream &asFortran) { - GlobalOp declareGlobalOp = - GlobalOp::create(modBuilder, loc, declareGlobalName); - builder.createBlock(&declareGlobalOp.getRegion(), - declareGlobalOp.getRegion().end(), {}, {}); - builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); - - fir::AddrOfOp addrOp = fir::AddrOfOp::create( - builder, loc, fir::ReferenceType::get(globalOp.getType()), - globalOp.getSymbol()); - addDeclareAttr(builder, addrOp, clause); - - llvm::SmallVector<mlir::Value> bounds; - EntryOp entryOp = createDataEntryOp<EntryOp>( - builder, loc, addrOp.getResTy(), asFortran, bounds, - /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), - /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); - if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) - DeclareOp::create(builder, loc, - mlir::acc::DeclareTokenType::get(entryOp.getContext()), - mlir::ValueRange(entryOp.getAccVar())); - else - DeclareOp::create(builder, loc, mlir::Value{}, - mlir::ValueRange(entryOp.getAccVar())); - if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) { - ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), - entryOp.getBounds(), entryOp.getAsyncOperands(), - entryOp.getAsyncOperandsDeviceTypeAttr(), - entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), - /*structured=*/false, /*implicit=*/false, - builder.getStringAttr(*entryOp.getName())); - } - mlir::acc::TerminatorOp::create(builder, loc); - modBuilder.setInsertionPointAfter(declareGlobalOp); -} - template <typename EntryOp> static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, @@ -4317,6 +4383,66 @@ genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter, dataClause); } +static fir::GlobalOp +lookupGlobalBySymbolOrEquivalence(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::semantics::Symbol &sym) { + const Fortran::semantics::Symbol *commonBlock = + Fortran::semantics::FindCommonBlockContaining(sym); + std::string globalName = commonBlock ? converter.mangleName(*commonBlock) + : converter.mangleName(sym); + if (fir::GlobalOp g = builder.getNamedGlobal(globalName)) { + return g; + } + // Not found: if not a COMMON member, try equivalence members + if (!commonBlock) { + if (const Fortran::semantics::EquivalenceSet *eqSet = + Fortran::semantics::FindEquivalenceSet(sym)) { + for (const Fortran::semantics::EquivalenceObject &eqObj : *eqSet) { + std::string eqName = converter.mangleName(eqObj.symbol); + if (fir::GlobalOp g = builder.getNamedGlobal(eqName)) + return g; + } + } + } + return {}; +} + +template <typename EmitterFn> +static void emitCommonGlobal(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::parser::AccObject &obj, + mlir::acc::DataClause clause, + EmitterFn &&emitCtorDtor) { + Fortran::semantics::Symbol &sym = getSymbolFromAccObject(obj); + if (!(sym.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(sym))) + return; + + fir::GlobalOp globalOp = + lookupGlobalBySymbolOrEquivalence(converter, builder, sym); + if (!globalOp) + llvm::report_fatal_error("could not retrieve global symbol"); + + std::stringstream ctorName; + ctorName << globalOp.getSymName().str() << "_acc_ctor"; + if (builder.getModule().lookupSymbol<mlir::acc::GlobalConstructorOp>( + ctorName.str())) + return; + + mlir::Location operandLocation = genOperandLocation(converter, obj); + addDeclareAttr(builder, globalOp.getOperation(), clause); + mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); + modBuilder.setInsertionPointAfter(globalOp); + std::stringstream asFortran; + asFortran << sym.name().ToString(); + + auto savedIP = builder.saveInsertionPoint(); + emitCtorDtor(modBuilder, operandLocation, globalOp, clause, asFortran, + ctorName.str()); + builder.restoreInsertionPoint(savedIP); +} + static void genDeclareInFunction(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, @@ -4342,11 +4468,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, dataClauseOperands.end()); } else if (const auto *createClause = std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - createClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(createClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_create, @@ -4378,11 +4502,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } else if (const auto *copyoutClause = std::get_if<Fortran::parser::AccClause::Copyout>( &clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - copyoutClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(copyoutClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( accObjectList, converter, semanticsContext, stmtCtx, @@ -4423,6 +4545,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } } + // If no structured operands were generated (all objects were COMMON), + // do not create a declare region. + if (dataClauseOperands.empty()) + return; + mlir::func::FuncOp funcOp = builder.getFunction(); auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>(); mlir::Value declareToken; diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 85398be..1c163e6 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1080,9 +1080,8 @@ bool ClauseProcessor::processHasDeviceAddr( [&](const omp::clause::HasDeviceAddr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::implicit; omp::ObjectList baseObjects; llvm::transform(clause.v, std::back_inserter(baseObjects), [&](const omp::Object &object) { @@ -1217,8 +1216,7 @@ bool ClauseProcessor::processLink( void ClauseProcessor::processMapObjects( lower::StatementContext &stmtCtx, mlir::Location clauseLocation, - const omp::ObjectList &objects, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, + const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits, std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, @@ -1310,10 +1308,7 @@ void ClauseProcessor::processMapObjects( mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp( firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds, - /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapTypeBits), + /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, mapTypeBits, mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(), /*partialMap=*/false, mapperId); @@ -1347,8 +1342,7 @@ bool ClauseProcessor::processMap( objects] = clause.t; if (attachMod) TODO(currentLocation, "ATTACH modifier is not implemented yet"); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + mlir::omp::ClauseMapFlags mapTypeBits = mlir::omp::ClauseMapFlags::none; std::string mapperIdName = "__implicit_mapper"; // If the map type is specified, then process it else set the appropriate // default value @@ -1364,36 +1358,32 @@ bool ClauseProcessor::processMap( switch (type) { case Map::MapType::To: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapTypeBits |= mlir::omp::ClauseMapFlags::to; break; case Map::MapType::From: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= mlir::omp::ClauseMapFlags::from; break; case Map::MapType::Tofrom: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::from; break; case Map::MapType::Storage: - // alloc and release is the default map_type for the Target Data - // Ops, i.e. if no bits for map_type is supplied then alloc/release - // (aka storage in 6.0+) is implicitly assumed based on the target - // directive. Default value for Target Data and Enter Data is alloc - // and for Exit Data it is release. + mapTypeBits |= mlir::omp::ClauseMapFlags::storage; break; } if (typeMods) { // TODO: Still requires "self" modifier, an OpenMP 6.0+ feature if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Always)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + mapTypeBits |= mlir::omp::ClauseMapFlags::always; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Present)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= mlir::omp::ClauseMapFlags::present; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Close)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + mapTypeBits |= mlir::omp::ClauseMapFlags::close; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Delete)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + mapTypeBits |= mlir::omp::ClauseMapFlags::del; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::OmpxHold)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + mapTypeBits |= mlir::omp::ClauseMapFlags::ompx_hold; } if (iterator) { @@ -1437,12 +1427,12 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, TODO(clauseLocation, "Iterator modifier is not supported yet"); } - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + mlir::omp::ClauseMapFlags mapTypeBits = std::is_same_v<llvm::remove_cvref_t<decltype(clause)>, omp::clause::To> - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + ? mlir::omp::ClauseMapFlags::to + : mlir::omp::ClauseMapFlags::from; if (expectation && *expectation == omp::clause::To::Expectation::Present) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= mlir::omp::ClauseMapFlags::present; processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits, parentMemberIndices, result.mapVars, mapSymbols); }; @@ -1568,8 +1558,8 @@ bool ClauseProcessor::processUseDeviceAddr( [&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::return_param; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); @@ -1589,8 +1579,8 @@ bool ClauseProcessor::processUseDevicePtr( [&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::return_param; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 9e352fa..6452e39 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -194,8 +194,7 @@ private: void processMapObjects( lower::StatementContext &stmtCtx, mlir::Location clauseLocation, - const omp::ObjectList &objects, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, + const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits, std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 2a4ebf1..d39f9dd 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -16,8 +16,6 @@ #include "flang/Semantics/openmp-modifiers.h" #include "flang/Semantics/symbol.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" - #include <list> #include <optional> #include <tuple> diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 9495ea6..a49961c 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,7 +45,6 @@ #include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" using namespace Fortran::lower::omp; using namespace Fortran::common::openmp; @@ -945,8 +944,7 @@ getDefaultmapIfPresent(const DefaultMapsTy &defaultMaps, mlir::Type varType) { return DefMap::ImplicitBehavior::Default; } -static std::pair<llvm::omp::OpenMPOffloadMappingFlags, - mlir::omp::VariableCaptureKind> +static std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind> getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, lower::AbstractConverter &converter, const DefaultMapsTy &defaultMaps, mlir::Type varType, @@ -967,8 +965,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, return size <= ptrSize && align <= ptrAlign; }; - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; auto implicitBehaviour = getDefaultmapIfPresent(defaultMaps, varType); if (implicitBehaviour == DefMap::ImplicitBehavior::Default) { @@ -986,8 +983,8 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::DeclareTargetCaptureClause::link && declareTargetOp.getDeclareTargetDeviceType() != mlir::omp::DeclareTargetDeviceType::nohost) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } } else if (fir::isa_trivial(varType) || fir::isa_char(varType)) { // Scalars behave as if they were "firstprivate". @@ -996,18 +993,18 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, if (isLiteralType(varType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapFlag |= mlir::omp::ClauseMapFlags::to; } } else if (!fir::isa_builtin_cptr_type(varType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } return std::make_pair(mapFlag, captureKind); } switch (implicitBehaviour) { case DefMap::ImplicitBehavior::Alloc: - return std::make_pair(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, + return std::make_pair(mlir::omp::ClauseMapFlags::storage, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Firstprivate: @@ -1016,26 +1013,22 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, "behaviour"); break; case DefMap::ImplicitBehavior::From: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Present: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::present, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::To: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::to, (fir::isa_trivial(varType) || fir::isa_char(varType)) ? mlir::omp::VariableCaptureKind::ByCopy : mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Tofrom: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::to, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Default: @@ -1044,9 +1037,8 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, break; } - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::to, mlir::omp::VariableCaptureKind::ByRef); } @@ -2612,18 +2604,14 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType())) eleType = refType.getElementType(); - std::pair<llvm::omp::OpenMPOffloadMappingFlags, - mlir::omp::VariableCaptureKind> + std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind> mapFlagAndKind = getImplicitMapTypeAndKind( firOpBuilder, converter, defaultMaps, eleType, loc, sym); mlir::Value mapOp = createMapInfoOp( firOpBuilder, converter.getCurrentLocation(), baseOp, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - std::get<0>(mapFlagAndKind)), + /*membersIndex=*/mlir::ArrayAttr{}, std::get<0>(mapFlagAndKind), std::get<1>(mapFlagAndKind), baseOp.getType(), /*partialMap=*/false, mapperId); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 37b926e..6487f59 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -273,7 +273,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) { + mlir::omp::ClauseMapFlags mapTypeBits) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); /// Checks if an omp::Object is an array expression with a subscript, e.g. @@ -414,11 +414,10 @@ mlir::Value createParentSymAndGenIntermediateMaps( // be safer to just pass OMP_MAP_NONE as the map type, but we may still // need some of the other map types the mapped member utilises, so for // now it's good to keep an eye on this. - llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; - interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - interimMapType &= - ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags interimMapType = mapTypeBits; + interimMapType &= ~mlir::omp::ClauseMapFlags::to; + interimMapType &= ~mlir::omp::ClauseMapFlags::from; + interimMapType &= ~mlir::omp::ClauseMapFlags::return_param; // Create a map for the intermediate member and insert it and it's // indices into the parentMemberIndices list to track it. @@ -427,10 +426,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( /*varPtrPtr=*/mlir::Value{}, asFortran, /*bounds=*/interimBounds, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - interimMapType), + /*membersIndex=*/mlir::ArrayAttr{}, interimMapType, mlir::omp::VariableCaptureKind::ByRef, curValue.getType()); parentMemberIndices.memberPlacementIndices.push_back(interimIndices); @@ -563,7 +559,8 @@ void insertChildMapInfoIntoParent( // it allows this to work with enter and exit without causing MLIR // verification issues. The more appropriate thing may be to take // the "main" map type clause from the directive being used. - uint64_t mapType = indices.second.memberMap[0].getMapType(); + mlir::omp::ClauseMapFlags mapType = + indices.second.memberMap[0].getMapType(); llvm::SmallVector<mlir::Value> members; members.reserve(indices.second.memberMap.size()); diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 69499f9..ef1f37a 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -134,7 +134,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits); + mlir::omp::ClauseMapFlags mapTypeBits); omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, semantics::SemanticsContext &semaCtx); diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp index 8b99913..817434f 100644 --- a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp +++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp @@ -20,8 +20,6 @@ #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" - namespace flangomp { #define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS #include "flang/Optimizer/OpenMP/Passes.h.inc" @@ -120,12 +118,9 @@ class AutomapToTargetDataPass builder, memOp.getLoc(), memOp.getMemref().getType(), memOp.getMemref(), TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - static_cast<unsigned>( - isa<fir::StoreOp>(memOp) - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)), + builder.getAttr<omp::ClauseMapFlagsAttr>( + isa<fir::StoreOp>(memOp) ? omp::ClauseMapFlags::to + : omp::ClauseMapFlags::del), builder.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByCopy), /*var_ptr_ptr=*/mlir::Value{}, diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 03ff163..65a23be 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -22,7 +22,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" namespace flangomp { #define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS @@ -568,16 +567,15 @@ private: if (auto refType = mlir::dyn_cast<fir::ReferenceType>(liveInType)) eleType = refType.getElementType(); - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else if (!fir::isa_builtin_cptr_type(eleType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } llvm::SmallVector<mlir::Value> boundsOps; @@ -587,11 +585,8 @@ private: builder, liveIn.getLoc(), rawAddr, /*varPtrPtr=*/{}, name.str(), boundsOps, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapFlag), - captureKind, rawAddr.getType()); + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, + rawAddr.getType()); } mlir::omp::TargetOp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 9278e17..8a9b383 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -719,10 +719,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, SmallVector<Value> outerMapInfos; // Create new mapinfo ops for the inner target region for (auto mapInfo : mapInfos) { - auto originalMapType = - (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + mlir::omp::ClauseMapFlags originalMapType = mapInfo.getMapType(); auto originalCaptureType = mapInfo.getMapCaptureType(); - llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::ClauseMapFlags newMapType; mlir::omp::VariableCaptureKind newCaptureType; // For bycopy, we keep the same map type and capture type // For byref, we change the map type to none and keep the capture type @@ -730,7 +729,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, newMapType = originalMapType; newCaptureType = originalCaptureType; } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { - newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newMapType = mlir::omp::ClauseMapFlags::storage; newCaptureType = originalCaptureType; outerMapInfos.push_back(mapInfo); } else { @@ -738,11 +737,8 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, return failure(); } auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo)); - innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( - rewriter.getIntegerType(64, false), - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - newMapType))); + innerMapInfo.setMapTypeAttr( + rewriter.getAttr<omp::ClauseMapFlagsAttr>(newMapType)); innerMapInfo.setMapCaptureType(newCaptureType); innerMapInfos.push_back(innerMapInfo.getResult()); } @@ -834,11 +830,11 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, alloc = rewriter.create<fir::AllocaOp>(loc, allocType); } // Lambda to create mapinfo ops - auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + auto getMapInfo = [&](mlir::omp::ClauseMapFlags mappingFlags, + const char *name) { return rewriter.create<omp::MapInfoOp>( loc, alloc.getType(), alloc, TypeAttr::get(allocType), - rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), - mappingFlags), + rewriter.getAttr<omp::ClauseMapFlagsAttr>(mappingFlags), rewriter.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/Value{}, @@ -849,14 +845,10 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); }; // Create mapinfo ops. - uint64_t mapFrom = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - uint64_t mapTo = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); - auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + auto mapInfoFrom = getMapInfo(mlir::omp::ClauseMapFlags::from, + "__flang_workdistribute_from"); + auto mapInfoTo = + getMapInfo(mlir::omp::ClauseMapFlags::to, "__flang_workdistribute_to"); return TempOmpVar{mapInfoFrom, mapInfoTo}; } diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 2bbd803..566e88b 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -43,7 +43,6 @@ #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cstddef> @@ -350,7 +349,7 @@ class MapInfoFinalizationPass /// the descriptor map onto the base address map. mlir::omp::MapInfoOp genBaseAddrMap(mlir::Value descriptor, mlir::OperandRange bounds, - int64_t mapType, + mlir::omp::ClauseMapFlags mapType, fir::FirOpBuilder &builder) { mlir::Location loc = descriptor.getLoc(); mlir::Value baseAddrAddr = fir::BoxOffsetOp::create( @@ -368,7 +367,7 @@ class MapInfoFinalizationPass return mlir::omp::MapInfoOp::create( builder, loc, baseAddrAddr.getType(), descriptor, mlir::TypeAttr::get(underlyingVarType), - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), builder.getAttr<mlir::omp::VariableCaptureKindAttr>( mlir::omp::VariableCaptureKind::ByRef), baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{}, @@ -428,22 +427,22 @@ class MapInfoFinalizationPass /// allowing `to` mappings, and `target update` not allowing both `to` and /// `from` simultaneously. We currently try to maintain the `implicit` flag /// where necessary, although it does not seem strictly required. - unsigned long getDescriptorMapType(unsigned long mapTypeFlag, - mlir::Operation *target) { - using mapFlags = llvm::omp::OpenMPOffloadMappingFlags; + mlir::omp::ClauseMapFlags + getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag, + mlir::Operation *target) { + using mapFlags = mlir::omp::ClauseMapFlags; if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp, mlir::omp::TargetUpdateOp>(target)) return mapTypeFlag; - mapFlags flags = mapFlags::OMP_MAP_TO | - (mapFlags(mapTypeFlag) & - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); + mapFlags flags = + mapFlags::to | (mapTypeFlag & (mapFlags::implicit | mapFlags::always)); // For unified_shared_memory, we additionally add `CLOSE` on the descriptor // to ensure device-local placement where required by tests relying on USM + // close semantics. if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) - flags |= mapFlags::OMP_MAP_CLOSE; - return llvm::to_underlying(flags); + flags |= mapFlags::close; + return flags; } /// Check if the mapOp is present in the HasDeviceAddr clause on @@ -493,11 +492,6 @@ class MapInfoFinalizationPass mlir::Value boxAddr = fir::BoxOffsetOp::create( builder, loc, op.getVarPtr(), fir::BoxFieldAttr::base_addr); - uint64_t mapTypeToImplicit = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); - mlir::ArrayAttr newMembersAttr; llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); @@ -506,8 +500,9 @@ class MapInfoFinalizationPass mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create( builder, op.getLoc(), varPtr.getType(), varPtr, mlir::TypeAttr::get(boxCharType.getEleTy()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeToImplicit), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::implicit), builder.getAttr<mlir::omp::VariableCaptureKindAttr>( mlir::omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/boxAddr, @@ -568,12 +563,9 @@ class MapInfoFinalizationPass mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); // Force CLOSE in USM paths so the pointer gets device-local placement // when required by tests relying on USM + close semantics. - uint64_t mapTypeVal = - op.getMapType() | - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( - builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); + mlir::omp::ClauseMapFlagsAttr mapTypeAttr = + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + op.getMapType() | mlir::omp::ClauseMapFlags::close); mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( builder, loc, coord.getType(), coord, @@ -683,17 +675,16 @@ class MapInfoFinalizationPass // one place in the code may differ from that address in another place. // The contents of the descriptor (the base address in particular) will // remain unchanged though. - uint64_t mapType = op.getMapType(); + mlir::omp::ClauseMapFlags mapType = op.getMapType(); if (isHasDeviceAddrFlag) { - mapType |= llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); + mapType |= mlir::omp::ClauseMapFlags::always; } mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( builder, op->getLoc(), op.getResult().getType(), descriptor, mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())), - builder.getIntegerAttr(builder.getIntegerType(64, false), - getDescriptorMapType(mapType, target)), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + getDescriptorMapType(mapType, target)), op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers, newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{}, /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), @@ -896,11 +887,9 @@ class MapInfoFinalizationPass builder.create<mlir::omp::MapInfoOp>( op->getLoc(), op.getResult().getType(), op.getVarPtr(), op.getVarTypeAttr(), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::always), op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, /*bounds=*/mlir::SmallVector<mlir::Value>{}, @@ -1240,9 +1229,8 @@ class MapInfoFinalizationPass // we need to change this check for early return OR live with // over-mapping. bool hasImplicitMap = - (llvm::omp::OpenMPOffloadMappingFlags(op.getMapType()) & - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT) == - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + (op.getMapType() & mlir::omp::ClauseMapFlags::implicit) == + mlir::omp::ClauseMapFlags::implicit; if (hasImplicitMap) return; diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 3032857..0972861 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -35,7 +35,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/Debug.h" #include <type_traits> @@ -70,9 +69,6 @@ class MapsForPrivatizedSymbolsPass return size <= ptrSize && align <= ptrAlign; }; - uint64_t mapTypeTo = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); Operation *definingOp = var.getDefiningOp(); Value varPtr = var; @@ -122,8 +118,7 @@ class MapsForPrivatizedSymbolsPass builder, loc, varPtr.getType(), varPtr, TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType()) .getElementType()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeTo), + builder.getAttr<omp::ClauseMapFlagsAttr>(omp::ClauseMapFlags::to), builder.getAttr<omp::VariableCaptureKindAttr>(captureKind), /*varPtrPtr=*/Value{}, /*members=*/SmallVector<Value>{}, diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 759e3a65d..8d00272 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -454,6 +454,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { mlir::LogicalResult matchAndRewrite(fir::DeclareOp op, mlir::PatternRewriter &rewriter) const override { + if (op.getResult().getUsers().empty()) + return success(); if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { if (auto global = symTab.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { @@ -963,6 +965,8 @@ public: } target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) { + if (op.getResult().getUsers().empty()) + return true; if (inDeviceContext(op)) return true; if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp index ad035e6..0dec5652 100644 --- a/flang/lib/Parser/openacc-parsers.cpp +++ b/flang/lib/Parser/openacc-parsers.cpp @@ -75,21 +75,21 @@ TYPE_PARSER( // tile size is one of: // * (represented as an empty std::optional<ScalarIntExpr>) // constant-int-expr -TYPE_PARSER(construct<AccTileExpr>(scalarIntConstantExpr) || +TYPE_PARSER(sourced(construct<AccTileExpr>(scalarIntConstantExpr) || construct<AccTileExpr>( - "*" >> construct<std::optional<ScalarIntConstantExpr>>())) + "*" >> construct<std::optional<ScalarIntConstantExpr>>()))) TYPE_PARSER(construct<AccTileExprList>(nonemptyList(Parser<AccTileExpr>{}))) // 2.9 (1979-1982) gang-arg is one of : // [num:]int-expr // dim:int-expr // static:size-expr -TYPE_PARSER(construct<AccGangArg>(construct<AccGangArg::Static>( - "STATIC: " >> Parser<AccSizeExpr>{})) || +TYPE_PARSER(sourced(construct<AccGangArg>(construct<AccGangArg::Static>( + "STATIC: " >> Parser<AccSizeExpr>{})) || construct<AccGangArg>( construct<AccGangArg::Dim>("DIM: " >> scalarIntExpr)) || construct<AccGangArg>( - construct<AccGangArg::Num>(maybe("NUM: "_tok) >> scalarIntExpr))) + construct<AccGangArg::Num>(maybe("NUM: "_tok) >> scalarIntExpr)))) // 2.9 gang-arg-list TYPE_PARSER( @@ -101,7 +101,7 @@ TYPE_PARSER(construct<AccCollapseArg>( // 2.5.15 Reduction, F'2023 R1131, and CUF reduction-op // Operator for reduction -TYPE_PARSER(sourced(construct<ReductionOperator>( +TYPE_PARSER(construct<ReductionOperator>( first("+" >> pure(ReductionOperator::Operator::Plus), "*" >> pure(ReductionOperator::Operator::Multiply), "MAX" >> pure(ReductionOperator::Operator::Max), @@ -112,32 +112,32 @@ TYPE_PARSER(sourced(construct<ReductionOperator>( ".AND." >> pure(ReductionOperator::Operator::And), ".OR." >> pure(ReductionOperator::Operator::Or), ".EQV." >> pure(ReductionOperator::Operator::Eqv), - ".NEQV." >> pure(ReductionOperator::Operator::Neqv))))) + ".NEQV." >> pure(ReductionOperator::Operator::Neqv)))) // 2.15.1 Bind clause -TYPE_PARSER(sourced(construct<AccBindClause>(name)) || - sourced(construct<AccBindClause>(scalarDefaultCharExpr))) +TYPE_PARSER(sourced(construct<AccBindClause>(name) || + construct<AccBindClause>(scalarDefaultCharExpr))) // 2.5.16 Default clause -TYPE_PARSER(construct<AccDefaultClause>( +TYPE_PARSER(sourced(construct<AccDefaultClause>( first("NONE" >> pure(llvm::acc::DefaultValue::ACC_Default_none), - "PRESENT" >> pure(llvm::acc::DefaultValue::ACC_Default_present)))) + "PRESENT" >> pure(llvm::acc::DefaultValue::ACC_Default_present))))) // SELF clause is either a simple optional condition for compute construct // or a synonym of the HOST clause for the update directive 2.14.4 holding // an object list. -TYPE_PARSER( +TYPE_PARSER(sourced( construct<AccSelfClause>(Parser<AccObjectList>{}) / lookAhead(")"_tok) || - construct<AccSelfClause>(scalarLogicalExpr / lookAhead(")"_tok)) || + construct<AccSelfClause>(scalarLogicalExpr) / lookAhead(")"_tok) || construct<AccSelfClause>( recovery(fail<std::optional<ScalarLogicalExpr>>( "logical expression or object list expected"_err_en_US), - SkipTo<')'>{} >> pure<std::optional<ScalarLogicalExpr>>()))) + SkipTo<')'>{} >> pure<std::optional<ScalarLogicalExpr>>())))) // Modifier for copyin, copyout, cache and create -TYPE_PARSER(construct<AccDataModifier>( +TYPE_PARSER(sourced(construct<AccDataModifier>( first("ZERO:" >> pure(AccDataModifier::Modifier::Zero), - "READONLY:" >> pure(AccDataModifier::Modifier::ReadOnly)))) + "READONLY:" >> pure(AccDataModifier::Modifier::ReadOnly))))) // Combined directives TYPE_PARSER(sourced(construct<AccCombinedDirective>( @@ -166,14 +166,13 @@ TYPE_PARSER(sourced(construct<AccStandaloneDirective>( TYPE_PARSER(sourced(construct<AccLoopDirective>( first("LOOP" >> pure(llvm::acc::Directive::ACCD_loop))))) -TYPE_PARSER(construct<AccBeginLoopDirective>( - sourced(Parser<AccLoopDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<AccBeginLoopDirective>( + Parser<AccLoopDirective>{}, Parser<AccClauseList>{}))) TYPE_PARSER(construct<AccEndLoop>("END LOOP"_tok)) TYPE_PARSER(construct<OpenACCLoopConstruct>( - sourced(Parser<AccBeginLoopDirective>{} / endAccLine), - maybe(Parser<DoConstruct>{}), + Parser<AccBeginLoopDirective>{} / endAccLine, maybe(Parser<DoConstruct>{}), maybe(startAccLine >> Parser<AccEndLoop>{} / endAccLine))) // 2.15.1 Routine directive @@ -186,8 +185,8 @@ TYPE_PARSER(sourced( parenthesized(Parser<AccObjectListWithModifier>{})))) // 2.11 Combined constructs -TYPE_PARSER(construct<AccBeginCombinedDirective>( - sourced(Parser<AccCombinedDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<AccBeginCombinedDirective>( + Parser<AccCombinedDirective>{}, Parser<AccClauseList>{}))) // 2.12 Atomic constructs TYPE_PARSER(construct<AccEndAtomic>(startAccLine >> "END ATOMIC"_tok)) @@ -213,10 +212,10 @@ TYPE_PARSER("ATOMIC" >> statement(assignmentStmt), Parser<AccEndAtomic>{} / endAccLine)) TYPE_PARSER( - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicRead>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicCapture>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicWrite>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicUpdate>{}))) + sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicRead>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicCapture>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicWrite>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicUpdate>{}))) // 2.13 Declare constructs TYPE_PARSER(sourced(construct<AccDeclarativeDirective>( @@ -250,18 +249,18 @@ TYPE_PARSER(construct<OpenACCBlockConstruct>( pure(llvm::acc::Directive::ACCD_data)))))) // Standalone constructs -TYPE_PARSER(construct<OpenACCStandaloneConstruct>( - sourced(Parser<AccStandaloneDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<OpenACCStandaloneConstruct>( + Parser<AccStandaloneDirective>{}, Parser<AccClauseList>{}))) // Standalone declarative constructs -TYPE_PARSER(construct<OpenACCStandaloneDeclarativeConstruct>( - sourced(Parser<AccDeclarativeDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<OpenACCStandaloneDeclarativeConstruct>( + Parser<AccDeclarativeDirective>{}, Parser<AccClauseList>{}))) TYPE_PARSER(startAccLine >> withMessage("expected OpenACC directive"_err_en_US, - first(sourced(construct<OpenACCDeclarativeConstruct>( - Parser<OpenACCStandaloneDeclarativeConstruct>{})), - sourced(construct<OpenACCDeclarativeConstruct>( + sourced(first(construct<OpenACCDeclarativeConstruct>( + Parser<OpenACCStandaloneDeclarativeConstruct>{}), + construct<OpenACCDeclarativeConstruct>( Parser<OpenACCRoutineConstruct>{}))))) TYPE_PARSER(sourced(construct<OpenACCEndConstruct>( @@ -293,9 +292,9 @@ TYPE_PARSER(startAccLine >> "SERIAL"_tok >> maybe("LOOP"_tok) >> pure(llvm::acc::Directive::ACCD_serial_loop)))))) -TYPE_PARSER(construct<OpenACCCombinedConstruct>( - sourced(Parser<AccBeginCombinedDirective>{} / endAccLine), +TYPE_PARSER(sourced(construct<OpenACCCombinedConstruct>( + Parser<AccBeginCombinedDirective>{} / endAccLine, maybe(Parser<DoConstruct>{}), - maybe(Parser<AccEndCombinedDirective>{} / endAccLine))) + maybe(Parser<AccEndCombinedDirective>{} / endAccLine)))) } // namespace Fortran::parser diff --git a/flang/lib/Parser/prescan.cpp b/flang/lib/Parser/prescan.cpp index 66e5b2c..df0372b 100644 --- a/flang/lib/Parser/prescan.cpp +++ b/flang/lib/Parser/prescan.cpp @@ -140,17 +140,9 @@ void Prescanner::Statement() { CHECK(*at_ == '!'); } std::optional<int> condOffset; - if (InOpenMPConditionalLine()) { + if (InOpenMPConditionalLine()) { // !$ condOffset = 2; - } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'c' && - directiveSentinel_[2] == 'u' && directiveSentinel_[3] == 'f' && - directiveSentinel_[4] == '\0') { - // CUDA conditional compilation line. - condOffset = 5; - } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'a' && - directiveSentinel_[2] == 'c' && directiveSentinel_[3] == 'c' && - directiveSentinel_[4] == '\0') { - // OpenACC conditional compilation line. + } else if (InOpenACCOrCUDAConditionalLine()) { // !@acc or !@cuf condOffset = 5; } if (condOffset && !preprocessingOnly_) { @@ -166,7 +158,8 @@ void Prescanner::Statement() { } else { // Compiler directive. Emit normalized sentinel, squash following spaces. // Conditional compilation lines (!$) take this path in -E mode too - // so that -fopenmp only has to appear on the later compilation. + // so that -fopenmp only has to appear on the later compilation + // (ditto for !@cuf and !@acc). EmitChar(tokens, '!'); ++at_, ++column_; for (const char *sp{directiveSentinel_}; *sp != '\0'; @@ -202,7 +195,7 @@ void Prescanner::Statement() { } tokens.CloseToken(); SkipSpaces(); - if (InOpenMPConditionalLine() && inFixedForm_ && !tabInCurrentLine_ && + if (InConditionalLine() && inFixedForm_ && !tabInCurrentLine_ && column_ == 6 && *at_ != '\n') { // !$ 0 - turn '0' into a space // !$ 1 - turn '1' into '&' @@ -347,7 +340,7 @@ void Prescanner::Statement() { while (CompilerDirectiveContinuation(tokens, line.sentinel)) { newlineProvenance = GetCurrentProvenance(); } - if (preprocessingOnly_ && inFixedForm_ && InOpenMPConditionalLine() && + if (preprocessingOnly_ && inFixedForm_ && InConditionalLine() && nextLine_ < limit_) { // In -E mode, when the line after !$ conditional compilation is a // regular fixed form continuation line, append a '&' to the line. @@ -1360,11 +1353,10 @@ const char *Prescanner::FixedFormContinuationLine(bool atNewline) { features_.IsEnabled(LanguageFeature::OldDebugLines))) && nextLine_[1] == ' ' && nextLine_[2] == ' ' && nextLine_[3] == ' ' && nextLine_[4] == ' '}; - if (InCompilerDirective() && - !(InOpenMPConditionalLine() && !preprocessingOnly_)) { + if (InCompilerDirective() && !(InConditionalLine() && !preprocessingOnly_)) { // !$ under -E is not continued, but deferred to later compilation if (IsFixedFormCommentChar(col1) && - !(InOpenMPConditionalLine() && preprocessingOnly_)) { + !(InConditionalLine() && preprocessingOnly_)) { int j{1}; for (; j < 5; ++j) { char ch{directiveSentinel_[j - 1]}; @@ -1443,7 +1435,7 @@ const char *Prescanner::FreeFormContinuationLine(bool ampersand) { } p = SkipWhiteSpaceIncludingEmptyMacros(p); if (InCompilerDirective()) { - if (InOpenMPConditionalLine()) { + if (InConditionalLine()) { if (preprocessingOnly_) { // in -E mode, don't treat !$ as a continuation return nullptr; diff --git a/flang/lib/Parser/prescan.h b/flang/lib/Parser/prescan.h index fc38adb..5e74817 100644 --- a/flang/lib/Parser/prescan.h +++ b/flang/lib/Parser/prescan.h @@ -171,7 +171,17 @@ private: bool InOpenMPConditionalLine() const { return directiveSentinel_ && directiveSentinel_[0] == '$' && !directiveSentinel_[1]; - ; + } + bool InOpenACCOrCUDAConditionalLine() const { + return directiveSentinel_ && directiveSentinel_[0] == '@' && + ((directiveSentinel_[1] == 'a' && directiveSentinel_[2] == 'c' && + directiveSentinel_[3] == 'c') || + (directiveSentinel_[1] == 'c' && directiveSentinel_[2] == 'u' && + directiveSentinel_[3] == 'f')) && + directiveSentinel_[4] == '\0'; + } + bool InConditionalLine() const { + return InOpenMPConditionalLine() || InOpenACCOrCUDAConditionalLine(); } bool InFixedFormSource() const { return inFixedForm_ && !inPreprocessorDirective_ && !InCompilerDirective(); diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index 4cb0b74..b3fd6c8 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -19,7 +19,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/semantics.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" using OmpClauseSet = Fortran::common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>; diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 33e9ea5..c410bd4 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -31,15 +31,17 @@ #include <list> #include <map> +namespace Fortran::semantics { + template <typename T> -static Fortran::semantics::Scope *GetScope( - Fortran::semantics::SemanticsContext &context, const T &x) { - std::optional<Fortran::parser::CharBlock> source{GetLastSource(x)}; - return source ? &context.FindScope(*source) : nullptr; +static Scope *GetScope(SemanticsContext &context, const T &x) { + if (auto source{GetLastSource(x)}) { + return &context.FindScope(*source); + } else { + return nullptr; + } } -namespace Fortran::semantics { - template <typename T> class DirectiveAttributeVisitor { public: explicit DirectiveAttributeVisitor(SemanticsContext &context) @@ -361,7 +363,7 @@ private: void ResolveAccObject(const parser::AccObject &, Symbol::Flag); Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &); Symbol *ResolveAcc(Symbol &, Symbol::Flag, Scope &); - Symbol *ResolveName(const parser::Name &, bool parentScope = false); + Symbol *ResolveName(const parser::Name &); Symbol *ResolveFctName(const parser::Name &); Symbol *ResolveAccCommonBlockName(const parser::Name *); Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag); @@ -1257,31 +1259,22 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCStandaloneConstruct &x) { return true; } -Symbol *AccAttributeVisitor::ResolveName( - const parser::Name &name, bool parentScope) { - Symbol *prev{currScope().FindSymbol(name.source)}; - // Check in parent scope if asked for. - if (!prev && parentScope) { - prev = currScope().parent().FindSymbol(name.source); - } - if (prev != name.symbol) { - name.symbol = prev; - } - return prev; +Symbol *AccAttributeVisitor::ResolveName(const parser::Name &name) { + return name.symbol; } Symbol *AccAttributeVisitor::ResolveFctName(const parser::Name &name) { Symbol *prev{currScope().FindSymbol(name.source)}; - if (!prev || (prev && prev->IsFuncResult())) { + if (prev && prev->IsFuncResult()) { prev = currScope().parent().FindSymbol(name.source); - if (!prev) { - prev = &context_.globalScope().MakeSymbol( - name.source, Attrs{}, ProcEntityDetails{}); - } } - if (prev != name.symbol) { - name.symbol = prev; + if (!prev) { + prev = &*context_.globalScope() + .try_emplace(name.source, ProcEntityDetails{}) + .first->second; } + CHECK(!name.symbol || name.symbol == prev); + name.symbol = prev; return prev; } @@ -1388,9 +1381,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) { } else { PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine); } - const auto &optName{std::get<std::optional<parser::Name>>(x.t)}; - if (optName) { - if (Symbol *sym = ResolveFctName(*optName)) { + if (const auto &optName{std::get<std::optional<parser::Name>>(x.t)}) { + if (Symbol * sym{ResolveFctName(*optName)}) { Symbol &ultimate{sym->GetUltimate()}; AddRoutineInfoToSymbol(ultimate, x); } else { @@ -1425,7 +1417,7 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) { case llvm::acc::Directive::ACCD_kernels_loop: case llvm::acc::Directive::ACCD_parallel_loop: case llvm::acc::Directive::ACCD_serial_loop: - PushContext(combinedDir.source, combinedDir.v); + PushContext(x.source, combinedDir.v); break; default: break; @@ -1706,41 +1698,37 @@ void AccAttributeVisitor::Post(const parser::AccDefaultClause &x) { } } -// For OpenACC constructs, check all the data-refs within the constructs -// and adjust the symbol for each Name if necessary void AccAttributeVisitor::Post(const parser::Name &name) { - auto *symbol{name.symbol}; - if (symbol && WithinConstruct()) { - symbol = &symbol->GetUltimate(); - if (!symbol->owner().IsDerivedType() && !symbol->has<ProcEntityDetails>() && - !symbol->has<SubprogramDetails>() && !IsObjectWithVisibleDSA(*symbol)) { + if (name.symbol && WithinConstruct()) { + const Symbol &symbol{name.symbol->GetUltimate()}; + if (!symbol.owner().IsDerivedType() && !symbol.has<ProcEntityDetails>() && + !symbol.has<SubprogramDetails>() && !IsObjectWithVisibleDSA(symbol)) { if (Symbol * found{currScope().FindSymbol(name.source)}) { - if (symbol != found) { - name.symbol = found; // adjust the symbol within region + if (&symbol != found) { + // adjust the symbol within the region + // TODO: why didn't name resolution set the right name originally? + name.symbol = found; } else if (GetContext().defaultDSA == Symbol::Flag::AccNone) { // 2.5.14. context_.Say(name.source, "The DEFAULT(NONE) clause requires that '%s' must be listed in a data-mapping clause"_err_en_US, - symbol->name()); + symbol.name()); } + } else { + // TODO: assertion here? or clear name.symbol? } } - } // within OpenACC construct + } } Symbol *AccAttributeVisitor::ResolveAccCommonBlockName( const parser::Name *name) { - if (auto *prev{name - ? GetContext().scope.parent().FindCommonBlock(name->source) - : nullptr}) { - name->symbol = prev; - return prev; - } - // Check if the Common Block is declared in the current scope - if (auto *commonBlockSymbol{ - name ? GetContext().scope.FindCommonBlock(name->source) : nullptr}) { - name->symbol = commonBlockSymbol; - return commonBlockSymbol; + if (name) { + if (Symbol * + cb{GetContext().scope.FindCommonBlockInVisibleScopes(name->source)}) { + name->symbol = cb; + return cb; + } } return nullptr; } @@ -1790,8 +1778,8 @@ void AccAttributeVisitor::ResolveAccObject( } } else { context_.Say(name.source, - "COMMON block must be declared in the same scoping unit " - "in which the OpenACC directive or clause appears"_err_en_US); + "Could not find COMMON block '%s' used in OpenACC directive"_err_en_US, + name.ToString()); } }, }, @@ -1810,13 +1798,11 @@ Symbol *AccAttributeVisitor::ResolveAcc( Symbol *AccAttributeVisitor::DeclareOrMarkOtherAccessEntity( const parser::Name &name, Symbol::Flag accFlag) { - Symbol *prev{currScope().FindSymbol(name.source)}; - if (!name.symbol || !prev) { + if (name.symbol) { + return DeclareOrMarkOtherAccessEntity(*name.symbol, accFlag); + } else { return nullptr; - } else if (prev != name.symbol) { - name.symbol = prev; } - return DeclareOrMarkOtherAccessEntity(*prev, accFlag); } Symbol *AccAttributeVisitor::DeclareOrMarkOtherAccessEntity( @@ -2990,6 +2976,7 @@ void OmpAttributeVisitor::Post(const parser::Name &name) { } Symbol *OmpAttributeVisitor::ResolveName(const parser::Name *name) { + // TODO: why is the symbol not properly resolved by name resolution? if (auto *resolvedSymbol{ name ? GetContext().scope.FindSymbol(name->source) : nullptr}) { name->symbol = resolvedSymbol; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 0af1c94..db75437 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1441,6 +1441,30 @@ public: void Post(const parser::AccBeginLoopDirective &x) { messageHandler().set_currStmtSource(std::nullopt); } + bool Pre(const parser::OpenACCStandaloneConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCCacheConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCWaitConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCAtomicConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCEndConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCDeclarativeConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } void CopySymbolWithDevice(const parser::Name *name); @@ -1480,7 +1504,8 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { // symbols are created for the one appearing in the use_device // clause. These new symbols have the CUDA Fortran device // attribute. - if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) { + if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA) && + name->symbol) { name->symbol = currScope().CopySymbol(*name->symbol); if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) { object->set_cudaDataAttr(common::CUDADataAttr::Device); @@ -1490,15 +1515,12 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { for (const auto &accObject : x.v.v) { + Walk(accObject); common::visit( common::visitors{ [&](const parser::Designator &designator) { if (const auto *name{ parser::GetDesignatorNameIfDataRef(designator)}) { - Symbol *prev{currScope().FindSymbol(name->source)}; - if (prev != name->symbol) { - name->symbol = prev; - } CopySymbolWithDevice(name); } else { if (const auto *dataRef{ @@ -1507,13 +1529,8 @@ bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { common::Indirection<parser::ArrayElement>; if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) { const parser::ArrayElement &arrayElement{ind->value()}; - Walk(arrayElement.subscripts); const parser::DataRef &base{arrayElement.base}; if (auto *name{std::get_if<parser::Name>(&base.u)}) { - Symbol *prev{currScope().FindSymbol(name->source)}; - if (prev != name->symbol) { - name->symbol = prev; - } CopySymbolWithDevice(name); } } @@ -1537,6 +1554,7 @@ void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) { bool AccVisitor::Pre(const parser::OpenACCCombinedConstruct &x) { PushScope(Scope::Kind::OpenACCConstruct, nullptr); + currScope().AddSourceRange(x.source); return true; } @@ -3627,6 +3645,20 @@ void ModuleVisitor::Post(const parser::UseStmt &x) { } } } + // Go through the list of COMMON block symbols in the module scope and add + // their USE association to the current scope's USE-associated COMMON blocks. + for (const auto &[name, symbol] : useModuleScope_->commonBlocks()) { + if (!currScope().FindCommonBlockInVisibleScopes(name)) { + currScope().AddCommonBlockUse( + name, symbol->attrs(), symbol->GetUltimate()); + } + } + // Go through the list of USE-associated COMMON block symbols in the module + // scope and add USE associations to their ultimate symbols to the current + // scope's USE-associated COMMON blocks. + for (const auto &[name, symbol] : useModuleScope_->commonBlockUses()) { + currScope().AddCommonBlockUse(name, symbol->attrs(), symbol->GetUltimate()); + } useModuleScope_ = nullptr; } @@ -5433,7 +5465,8 @@ void SubprogramVisitor::PushBlockDataScope(const parser::Name &name) { } } -// If name is a generic, return specific subprogram with the same name. +// If name is a generic in the same scope, return its specific subprogram with +// the same name, if any. Symbol *SubprogramVisitor::GetSpecificFromGeneric(const parser::Name &name) { // Search for the name but don't resolve it if (auto *symbol{currScope().FindSymbol(name.source)}) { @@ -5443,6 +5476,9 @@ Symbol *SubprogramVisitor::GetSpecificFromGeneric(const parser::Name &name) { // symbol doesn't inherit it and ruin the ability to check it. symbol->attrs().reset(Attr::MODULE); } + } else if (&symbol->owner() != &currScope() && inInterfaceBlock() && + !isGeneric()) { + // non-generic interface shadows outer definition } else if (auto *details{symbol->detailsIf<GenericDetails>()}) { // found generic, want specific procedure auto *specific{details->specific()}; diff --git a/flang/lib/Semantics/scope.cpp b/flang/lib/Semantics/scope.cpp index 4af371f..ab75d4c 100644 --- a/flang/lib/Semantics/scope.cpp +++ b/flang/lib/Semantics/scope.cpp @@ -144,9 +144,8 @@ void Scope::add_crayPointer(const SourceName &name, Symbol &pointer) { } Symbol &Scope::MakeCommonBlock(SourceName name, SourceName location) { - const auto it{commonBlocks_.find(name)}; - if (it != commonBlocks_.end()) { - return *it->second; + if (auto *cb{FindCommonBlock(name)}) { + return *cb; } else { Symbol &symbol{MakeSymbol( name, Attrs{}, CommonBlockDetails{name.empty() ? location : name})}; @@ -154,9 +153,25 @@ Symbol &Scope::MakeCommonBlock(SourceName name, SourceName location) { return symbol; } } -Symbol *Scope::FindCommonBlock(const SourceName &name) const { - const auto it{commonBlocks_.find(name)}; - return it != commonBlocks_.end() ? &*it->second : nullptr; + +Symbol *Scope::FindCommonBlockInVisibleScopes(const SourceName &name) const { + if (Symbol * cb{FindCommonBlock(name)}) { + return cb; + } else if (Symbol * cb{FindCommonBlockUse(name)}) { + return &cb->GetUltimate(); + } else if (IsSubmodule()) { + if (const Scope *parent{ + symbol_ ? symbol_->get<ModuleDetails>().parent() : nullptr}) { + if (auto *cb{parent->FindCommonBlockInVisibleScopes(name)}) { + return cb; + } + } + } else if (!IsTopLevel() && parent_) { + if (auto *cb{parent_->FindCommonBlockInVisibleScopes(name)}) { + return cb; + } + } + return nullptr; } Scope *Scope::FindSubmodule(const SourceName &name) const { @@ -167,6 +182,31 @@ Scope *Scope::FindSubmodule(const SourceName &name) const { return &*it->second; } } + +bool Scope::AddCommonBlockUse( + const SourceName &name, Attrs attrs, Symbol &cbUltimate) { + CHECK(cbUltimate.has<CommonBlockDetails>()); + // Make a symbol, but don't add it to the Scope, since it needs to + // be added to the USE-associated COMMON blocks + Symbol &useCB{MakeSymbol(name, attrs, UseDetails{name, cbUltimate})}; + return commonBlockUses_.emplace(name, useCB).second; +} + +Symbol *Scope::FindCommonBlock(const SourceName &name) const { + if (const auto it{commonBlocks_.find(name)}; it != commonBlocks_.end()) { + return &*it->second; + } + return nullptr; +} + +Symbol *Scope::FindCommonBlockUse(const SourceName &name) const { + if (const auto it{commonBlockUses_.find(name)}; + it != commonBlockUses_.end()) { + return &*it->second; + } + return nullptr; +} + bool Scope::AddSubmodule(const SourceName &name, Scope &submodule) { return submodules_.emplace(name, submodule).second; } diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp index bdb5377..2606d99 100644 --- a/flang/lib/Semantics/semantics.cpp +++ b/flang/lib/Semantics/semantics.cpp @@ -452,6 +452,15 @@ void SemanticsContext::UpdateScopeIndex( } } +void SemanticsContext::DumpScopeIndex(llvm::raw_ostream &out) const { + out << "scopeIndex_:\n"; + for (const auto &[source, scope] : scopeIndex_) { + out << "source '" << source.ToString() << "' -> scope " << scope + << "... whose source range is '" << scope.sourceRange().ToString() + << "'\n"; + } +} + bool SemanticsContext::IsInModuleFile(parser::CharBlock source) const { for (const Scope *scope{&FindScope(source)}; !scope->IsGlobal(); scope = &scope->parent()) { diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp index 2261912..15a42c3 100644 --- a/flang/lib/Utils/OpenMP.cpp +++ b/flang/lib/Utils/OpenMP.cpp @@ -22,8 +22,9 @@ mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr, llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex, - uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, - mlir::Type retTy, bool partialMap, mlir::FlatSymbolRefAttr mapperId) { + mlir::omp::ClauseMapFlags mapType, + mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, + bool partialMap, mlir::FlatSymbolRefAttr mapperId) { if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr); @@ -42,7 +43,7 @@ mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, mlir::omp::MapInfoOp op = mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varType, - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), varPtrPtr, members, membersIndex, bounds, mapperId, builder.getStringAttr(name), builder.getBoolAttr(partialMap)); @@ -75,8 +76,7 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, firOpBuilder.setInsertionPoint(targetOp); - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; @@ -88,16 +88,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else if (!fir::isa_builtin_cptr_type(eleType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapFlag |= mlir::omp::ClauseMapFlags::to; } mlir::Value mapOp = createMapInfoOp(firOpBuilder, copyVal.getLoc(), copyVal, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/llvm::SmallVector<mlir::Value>{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapFlag), - captureKind, copyVal.getType()); + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, + copyVal.getType()); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); |