diff options
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 229 |
1 files changed, 178 insertions, 51 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index b3e8b69..af4f420 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -718,6 +718,84 @@ static void genDataOperandOperations( } } +template <typename GlobalCtorOrDtorOp, typename EntryOp, typename DeclareOp, + typename ExitOp> +static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, + const std::string &declareGlobalName, + bool implicit, std::stringstream &asFortran) { + GlobalCtorOrDtorOp declareGlobalOp = + GlobalCtorOrDtorOp::create(modBuilder, loc, declareGlobalName); + builder.createBlock(&declareGlobalOp.getRegion(), + declareGlobalOp.getRegion().end(), {}, {}); + builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); + + fir::AddrOfOp addrOp = fir::AddrOfOp::create( + builder, loc, fir::ReferenceType::get(globalOp.getType()), + globalOp.getSymbol()); + addDeclareAttr(builder, addrOp, clause); + + llvm::SmallVector<mlir::Value> bounds; + EntryOp entryOp = createDataEntryOp<EntryOp>( + builder, loc, addrOp.getResTy(), asFortran, bounds, + /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), + /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) + DeclareOp::create(builder, loc, + mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccVar())); + else + DeclareOp::create(builder, loc, mlir::Value{}, + mlir::ValueRange(entryOp.getAccVar())); + if constexpr (std::is_same_v<GlobalCtorOrDtorOp, + mlir::acc::GlobalDestructorOp>) { + if constexpr (std::is_same_v<ExitOp, mlir::acc::DeclareLinkOp>) { + // No destructor emission for declare link in this path to avoid + // complex var/varType/varPtrPtr signatures. The ctor registers the link. + } else if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || + std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getVar(), entryOp.getVarType(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } else { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } + } + mlir::acc::TerminatorOp::create(builder, loc); + modBuilder.setInsertionPointAfter(declareGlobalOp); +} + +template <typename EntryOp, typename ExitOp> +static void +emitCtorDtorPair(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, + mlir::Location operandLocation, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, std::stringstream &asFortran, + const std::string &ctorName) { + createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp, + mlir::acc::DeclareEnterOp, ExitOp>( + modBuilder, builder, operandLocation, globalOp, clause, ctorName, + /*implicit=*/false, asFortran); + + std::stringstream dtorName; + dtorName << globalOp.getSymName().str() << "_acc_dtor"; + createDeclareGlobalOp<mlir::acc::GlobalDestructorOp, + mlir::acc::GetDevicePtrOp, mlir::acc::DeclareExitOp, + ExitOp>(modBuilder, builder, operandLocation, globalOp, + clause, dtorName.str(), + /*implicit=*/false, asFortran); +} + template <typename EntryOp, typename ExitOp> static void genDeclareDataOperandOperations( const Fortran::parser::AccObjectList &objectList, @@ -733,6 +811,37 @@ static void genDeclareDataOperandOperations( std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + // Handle COMMON/global symbols via module-level ctor/dtor path. + if (symbol.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(symbol)) { + emitCommonGlobal( + converter, builder, accObject, dataClause, + [&](mlir::OpBuilder &modBuilder, mlir::Location loc, + fir::GlobalOp globalOp, mlir::acc::DataClause clause, + std::stringstream &asFortranStr, const std::string &ctorName) { + if constexpr (std::is_same_v<EntryOp, mlir::acc::DeclareLinkOp>) { + createDeclareGlobalOp< + mlir::acc::GlobalConstructorOp, mlir::acc::DeclareLinkOp, + mlir::acc::DeclareEnterOp, mlir::acc::DeclareLinkOp>( + modBuilder, builder, loc, globalOp, clause, ctorName, + /*implicit=*/false, asFortranStr); + } else if constexpr (std::is_same_v<EntryOp, mlir::acc::CreateOp> || + std::is_same_v<EntryOp, mlir::acc::CopyinOp> || + std::is_same_v< + EntryOp, + mlir::acc::DeclareDeviceResidentOp> || + std::is_same_v<ExitOp, mlir::acc::CopyoutOp>) { + emitCtorDtorPair<EntryOp, ExitOp>(modBuilder, builder, loc, + globalOp, clause, asFortranStr, + ctorName); + } else { + // No module-level ctor/dtor for this clause (e.g., deviceptr, + // present). Handled via structured declare region only. + return; + } + }); + continue; + } Fortran::semantics::MaybeExpr designator = Fortran::common::visit( [&](auto &&s) { return ea.Analyze(s); }, accObject.u); fir::factory::AddrAndBoundsInfo info = @@ -4098,49 +4207,6 @@ static void genACC(Fortran::lower::AbstractConverter &converter, waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); } -template <typename GlobalOp, typename EntryOp, typename DeclareOp, - typename ExitOp> -static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, - fir::FirOpBuilder &builder, - mlir::Location loc, fir::GlobalOp globalOp, - mlir::acc::DataClause clause, - const std::string &declareGlobalName, - bool implicit, std::stringstream &asFortran) { - GlobalOp declareGlobalOp = - GlobalOp::create(modBuilder, loc, declareGlobalName); - builder.createBlock(&declareGlobalOp.getRegion(), - declareGlobalOp.getRegion().end(), {}, {}); - builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); - - fir::AddrOfOp addrOp = fir::AddrOfOp::create( - builder, loc, fir::ReferenceType::get(globalOp.getType()), - globalOp.getSymbol()); - addDeclareAttr(builder, addrOp, clause); - - llvm::SmallVector<mlir::Value> bounds; - EntryOp entryOp = createDataEntryOp<EntryOp>( - builder, loc, addrOp.getResTy(), asFortran, bounds, - /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), - /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); - if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) - DeclareOp::create(builder, loc, - mlir::acc::DeclareTokenType::get(entryOp.getContext()), - mlir::ValueRange(entryOp.getAccVar())); - else - DeclareOp::create(builder, loc, mlir::Value{}, - mlir::ValueRange(entryOp.getAccVar())); - if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) { - ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), - entryOp.getBounds(), entryOp.getAsyncOperands(), - entryOp.getAsyncOperandsDeviceTypeAttr(), - entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), - /*structured=*/false, /*implicit=*/false, - builder.getStringAttr(*entryOp.getName())); - } - mlir::acc::TerminatorOp::create(builder, loc); - modBuilder.setInsertionPointAfter(declareGlobalOp); -} - template <typename EntryOp> static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, @@ -4317,6 +4383,66 @@ genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter, dataClause); } +static fir::GlobalOp +lookupGlobalBySymbolOrEquivalence(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::semantics::Symbol &sym) { + const Fortran::semantics::Symbol *commonBlock = + Fortran::semantics::FindCommonBlockContaining(sym); + std::string globalName = commonBlock ? converter.mangleName(*commonBlock) + : converter.mangleName(sym); + if (fir::GlobalOp g = builder.getNamedGlobal(globalName)) { + return g; + } + // Not found: if not a COMMON member, try equivalence members + if (!commonBlock) { + if (const Fortran::semantics::EquivalenceSet *eqSet = + Fortran::semantics::FindEquivalenceSet(sym)) { + for (const Fortran::semantics::EquivalenceObject &eqObj : *eqSet) { + std::string eqName = converter.mangleName(eqObj.symbol); + if (fir::GlobalOp g = builder.getNamedGlobal(eqName)) + return g; + } + } + } + return {}; +} + +template <typename EmitterFn> +static void emitCommonGlobal(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::parser::AccObject &obj, + mlir::acc::DataClause clause, + EmitterFn &&emitCtorDtor) { + Fortran::semantics::Symbol &sym = getSymbolFromAccObject(obj); + if (!(sym.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(sym))) + return; + + fir::GlobalOp globalOp = + lookupGlobalBySymbolOrEquivalence(converter, builder, sym); + if (!globalOp) + llvm::report_fatal_error("could not retrieve global symbol"); + + std::stringstream ctorName; + ctorName << globalOp.getSymName().str() << "_acc_ctor"; + if (builder.getModule().lookupSymbol<mlir::acc::GlobalConstructorOp>( + ctorName.str())) + return; + + mlir::Location operandLocation = genOperandLocation(converter, obj); + addDeclareAttr(builder, globalOp.getOperation(), clause); + mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); + modBuilder.setInsertionPointAfter(globalOp); + std::stringstream asFortran; + asFortran << sym.name().ToString(); + + auto savedIP = builder.saveInsertionPoint(); + emitCtorDtor(modBuilder, operandLocation, globalOp, clause, asFortran, + ctorName.str()); + builder.restoreInsertionPoint(savedIP); +} + static void genDeclareInFunction(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, @@ -4342,11 +4468,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, dataClauseOperands.end()); } else if (const auto *createClause = std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - createClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(createClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_create, @@ -4378,11 +4502,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } else if (const auto *copyoutClause = std::get_if<Fortran::parser::AccClause::Copyout>( &clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - copyoutClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(copyoutClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( accObjectList, converter, semanticsContext, stmtCtx, @@ -4423,6 +4545,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } } + // If no structured operands were generated (all objects were COMMON), + // do not create a declare region. + if (dataClauseOperands.empty()) + return; + mlir::func::FuncOp funcOp = builder.getFunction(); auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>(); mlir::Value declareToken; |