diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP')
| -rw-r--r-- | flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 242 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/ClauseProcessor.h | 8 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Clauses.cpp | 46 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/DataSharingProcessor.cpp | 5 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 453 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.cpp | 191 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.h | 10 |
7 files changed, 783 insertions, 172 deletions
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 1c163e6..a81ba37 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -13,10 +13,12 @@ #include "ClauseProcessor.h" #include "Utils.h" +#include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/OpenMP/Clauses.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/Support/ReductionProcessor.h" +#include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Parser/tools.h" #include "flang/Semantics/tools.h" #include "flang/Utils/OpenMP.h" @@ -42,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { return mlir::omp::ReductionModifier::defaultmod; } -/// Check for unsupported map operand types. -static void checkMapType(mlir::Location location, mlir::Type type) { - if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type)) - type = refType.getElementType(); - if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type)) - if (!mlir::isa<fir::PointerType>(boxType.getElementType())) - TODO(location, "OMPD_target_data MapOperand BoxType"); -} - static mlir::omp::ScheduleModifier translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) { switch (m) { @@ -209,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter, ifVal); } -static void addUseDeviceClause( - lower::AbstractConverter &converter, const omp::ObjectList &objects, - llvm::SmallVectorImpl<mlir::Value> &operands, - llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) { - genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) - checkMapType(operand.getLoc(), operand.getType()); - - for (const omp::Object &object : objects) - useDeviceSyms.push_back(object.sym()); -} - //===----------------------------------------------------------------------===// // ClauseProcessor unique clauses //===----------------------------------------------------------------------===// @@ -401,11 +382,75 @@ bool ClauseProcessor::processInclusive( return false; } +bool ClauseProcessor::processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const { + if (auto *clause = findUniqueClause<omp::clause::Initializer>()) { + genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value ompOrig) { + lower::SymMapScope scope(symMap); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t); + mlir::Value ompPrivVar; + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get<parser::ObjectName>(decl.var.t); + assert(name.symbol && "Name does not have a symbol"); + mlir::Value addr = builder.createTemporary(loc, ompOrig.getType()); + fir::StoreOp::create(builder, loc, ompOrig, addr); + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = hlfir::DeclareOp::create( + builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr, + 0, attributes); + if (name.ToString() == "omp_priv") + ompPrivVar = declareOp.getResult(0); + symMap.addVariableDefinition(*name.symbol, declareOp); + } + // Lower the expression/function call + lower::StatementContext stmtCtx; + mlir::Value result = common::visit( + common::visitors{ + [&](const evaluate::ProcedureRef &procRef) -> mlir::Value { + convertCallToHLFIR(loc, converter, procRef, std::nullopt, + symMap, stmtCtx); + auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar); + return privVal; + }, + [&](const auto &expr) -> mlir::Value { + mlir::Value exprResult = fir::getBase(convertExprToValue( + loc, converter, clause->v, symMap, stmtCtx)); + // Conversion can either give a value or a refrence to a value, + // we need to return the reduction type, so an optional load may + // be generated. + if (auto refType = llvm::dyn_cast<fir::ReferenceType>( + exprResult.getType())) + if (ompPrivVar.getType() == refType) + exprResult = fir::LoadOp::create(builder, loc, exprResult); + return exprResult; + }}, + clause->v.u); + stmtCtx.finalizeAndPop(); + return result; + }; + return true; + } + return false; +} + bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable); } +bool ClauseProcessor::processNogroup( + mlir::omp::NogroupClauseOps &result) const { + return markClauseOccurrence<omp::clause::Nogroup>(result.nogroup); +} + bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const { return markClauseOccurrence<omp::clause::Nowait>(result.nowait); } @@ -1159,14 +1204,26 @@ bool ClauseProcessor::processInReduction( } bool ClauseProcessor::processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const { - return findRepeatableClause<omp::clause::IsDevicePtr>( - [&](const omp::clause::IsDevicePtr &devPtrClause, - const parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceSyms); + std::map<Object, OmpMapParentAndMemberData> parentMemberIndices; + bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>( + [&](const omp::clause::IsDevicePtr &clause, + const parser::CharBlock &source) { + mlir::Location location = converter.genLocation(source); + // Force a map so the descriptor is materialized on the device with the + // device address inside. + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::is_device_ptr | + mlir::omp::ClauseMapFlags::to; + processMapObjects(stmtCtx, location, clause.v, mapTypeBits, + parentMemberIndices, result.isDevicePtrVars, + isDeviceSyms); }); + + insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, + result.isDevicePtrVars, isDeviceSyms); + return clauseFound; } bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { @@ -1175,11 +1232,20 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { omp::clause::Linear>([&](const omp::clause::Linear &clause, const parser::CharBlock &) { auto &objects = std::get<omp::ObjectList>(clause.t); + static std::vector<mlir::Attribute> typeAttrs; + + if (!result.linearVars.size()) + typeAttrs.clear(); + for (const omp::Object &object : objects) { semantics::Symbol *sym = object.sym(); const mlir::Value variable = converter.getSymbolAddress(*sym); result.linearVars.push_back(variable); + mlir::Type ty = converter.genType(*sym); + typeAttrs.push_back(mlir::TypeAttr::get(ty)); } + result.linearVarTypes = + mlir::ArrayAttr::get(&converter.getMLIRContext(), typeAttrs); if (objects.size()) { if (auto &mod = std::get<std::optional<omp::clause::Linear::StepComplexModifier>>( @@ -1223,26 +1289,67 @@ void ClauseProcessor::processMapObjects( llvm::StringRef mapperIdNameRef) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto getDefaultMapperID = [&](const omp::Object &object, - std::string &mapperIdName) { - if (!mlir::isa<mlir::omp::DeclareMapperOp>( - firOpBuilder.getRegion().getParentOp())) { - const semantics::DerivedTypeSpec *typeSpec = nullptr; + auto getSymbolDerivedType = [](const semantics::Symbol &symbol) + -> const semantics::DerivedTypeSpec * { + const semantics::Symbol &ultimate = symbol.GetUltimate(); + if (const semantics::DeclTypeSpec *declType = ultimate.GetType()) + if (const auto *derived = declType->AsDerived()) + return derived; + return nullptr; + }; + + auto addImplicitMapper = [&](const omp::Object &object, + std::string &mapperIdName, + bool allowGenerate) -> mlir::FlatSymbolRefAttr { + if (mapperIdName.empty()) + return mlir::FlatSymbolRefAttr(); - if (object.sym()->owner().IsDerivedType()) - typeSpec = object.sym()->owner().derivedTypeSpec(); - else if (object.sym()->GetType() && - object.sym()->GetType()->category() == - semantics::DeclTypeSpec::TypeDerived) - typeSpec = &object.sym()->GetType()->derivedTypeSpec(); - - if (typeSpec) { - mapperIdName = - typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) - mapperIdName = converter.mangleName(mapperIdName, sym->owner()); - } + if (converter.getModuleOp().lookupSymbol(mapperIdName)) + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperIdName); + + if (!allowGenerate) + return mlir::FlatSymbolRefAttr(); + + const semantics::DerivedTypeSpec *typeSpec = + getSymbolDerivedType(*object.sym()); + if (!typeSpec && object.sym()->owner().IsDerivedType()) + typeSpec = object.sym()->owner().derivedTypeSpec(); + + if (!typeSpec) + return mlir::FlatSymbolRefAttr(); + + mlir::Type type = converter.genType(*typeSpec); + auto recordType = mlir::dyn_cast<fir::RecordType>(type); + if (!recordType) + return mlir::FlatSymbolRefAttr(); + + return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation, + recordType, mapperIdName); + }; + + auto getDefaultMapperID = + [&](const semantics::DerivedTypeSpec *typeSpec) -> std::string { + if (mlir::isa<mlir::omp::DeclareMapperOp>( + firOpBuilder.getRegion().getParentOp()) || + !typeSpec) + return {}; + + std::string mapperIdName = + typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) { + mapperIdName = + converter.mangleName(mapperIdName, sym->GetUltimate().owner()); + } else { + mapperIdName = converter.mangleName(mapperIdName, *typeSpec->GetScope()); } + + // Make sure we don't return a mapper to self. + if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>( + firOpBuilder.getRegion().getParentOp())) + if (mapperIdName == declMapOp.getSymName()) + return {}; + return mapperIdName; }; // Create the mapper symbol from its name, if specified. @@ -1251,8 +1358,13 @@ void ClauseProcessor::processMapObjects( mapperIdNameRef != "__implicit_mapper") { std::string mapperIdName = mapperIdNameRef.str(); const omp::Object &object = objects.front(); - if (mapperIdNameRef == "default") - getDefaultMapperID(object, mapperIdName); + if (mapperIdNameRef == "default") { + const semantics::DerivedTypeSpec *typeSpec = + getSymbolDerivedType(*object.sym()); + if (!typeSpec && object.sym()->owner().IsDerivedType()) + typeSpec = object.sym()->owner().derivedTypeSpec(); + mapperIdName = getDefaultMapperID(typeSpec); + } assert(converter.getModuleOp().lookupSymbol(mapperIdName) && "mapper not found"); mapperId = @@ -1290,13 +1402,25 @@ void ClauseProcessor::processMapObjects( } } + const semantics::DerivedTypeSpec *objectTypeSpec = + getSymbolDerivedType(*object.sym()); + if (mapperIdNameRef == "__implicit_mapper") { - std::string mapperIdName; - getDefaultMapperID(object, mapperIdName); - mapperId = converter.getModuleOp().lookupSymbol(mapperIdName) - ? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperIdName) - : mlir::FlatSymbolRefAttr(); + if (parentObj.has_value()) { + mapperId = mlir::FlatSymbolRefAttr(); + } else if (objectTypeSpec) { + std::string mapperIdName = getDefaultMapperID(objectTypeSpec); + bool needsDefaultMapper = + semantics::IsAllocatableOrObjectPointer(object.sym()) || + requiresImplicitDefaultDeclareMapper(*objectTypeSpec); + if (!mapperIdName.empty()) + mapperId = addImplicitMapper(object, mapperIdName, + /*allowGenerate=*/needsDefaultMapper); + else + mapperId = mlir::FlatSymbolRefAttr(); + } else { + mapperId = mlir::FlatSymbolRefAttr(); + } } // Explicit map captures are captured ByRef by default, @@ -1392,10 +1516,14 @@ bool ClauseProcessor::processMap( } if (mappers) { assert(mappers->size() == 1 && "more than one mapper"); - mapperIdName = mappers->front().v.id().symbol->name().ToString(); - if (mapperIdName != "default") - mapperIdName = converter.mangleName( - mapperIdName, mappers->front().v.id().symbol->owner()); + const semantics::Symbol *mapperSym = mappers->front().v.id().symbol; + mapperIdName = mapperSym->name().ToString(); + if (mapperIdName != "default") { + // Mangle with the ultimate owner so that use-associated mapper + // identifiers resolve to the same symbol as their defining scope. + const semantics::Symbol &ultimate = mapperSym->GetUltimate(); + mapperIdName = converter.mangleName(mapperIdName, ultimate.owner()); + } } processMapObjects(stmtCtx, clauseLocation, diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 6452e39..3485a4e 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -18,8 +18,8 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Optimizer/Builder/Todo.h" -#include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/parse-tree.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -88,7 +88,11 @@ public: bool processHint(mlir::omp::HintClauseOps &result) const; bool processInclusive(mlir::Location currentLocation, mlir::omp::InclusiveClauseOps &result) const; + bool processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; + bool processNogroup(mlir::omp::NogroupClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTasks(lower::StatementContext &stmtCtx, mlir::omp::NumTasksClauseOps &result) const; @@ -130,7 +134,7 @@ public: mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; bool processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; bool processLinear(mlir::omp::LinearClauseOps &result) const; bool diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 0f60b47..61430fc 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -10,7 +10,6 @@ #include "flang/Common/idioms.h" #include "flang/Evaluate/expression.h" -#include "flang/Optimizer/Builder/Todo.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/openmp-modifiers.h" @@ -249,8 +248,10 @@ MAKE_EMPTY_CLASS(Groupprivate, Groupprivate); MAKE_INCOMPLETE_CLASS(AdjustArgs, AdjustArgs); MAKE_INCOMPLETE_CLASS(AppendArgs, AppendArgs); +MAKE_INCOMPLETE_CLASS(Collector, Collector); MAKE_INCOMPLETE_CLASS(GraphId, GraphId); MAKE_INCOMPLETE_CLASS(GraphReset, GraphReset); +MAKE_INCOMPLETE_CLASS(Inductor, Inductor); MAKE_INCOMPLETE_CLASS(Replayable, Replayable); MAKE_INCOMPLETE_CLASS(Transparent, Transparent); @@ -394,8 +395,6 @@ makePrescriptiveness(parser::OmpPrescriptiveness::Value v) { switch (v) { case parser::OmpPrescriptiveness::Value::Strict: return clause::Prescriptiveness::Strict; - case parser::OmpPrescriptiveness::Value::Fallback: - return clause::Prescriptiveness::Fallback; } llvm_unreachable("Unexpected prescriptiveness"); } @@ -797,21 +796,31 @@ DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp, semantics::SemanticsContext &semaCtx) { // imp.v -> OmpDyngroupprivateClause CLAUSET_ENUM_CONVERT( // - convert, parser::OmpAccessGroup::Value, DynGroupprivate::AccessGroup, + makeAccessGroup, parser::OmpAccessGroup::Value, + DynGroupprivate::AccessGroup, // clang-format off MS(Cgroup, Cgroup) // clang-format on ); + CLAUSET_ENUM_CONVERT( // + makeFallback, parser::OmpFallbackModifier::Value, + DynGroupprivate::Fallback, + // clang-format off + MS(Abort, Abort) + MS(Default_Mem, Default_Mem) + MS(Null, Null) + // clang-format on + ); + auto &mods = semantics::OmpGetModifiers(inp.v); auto *m0 = semantics::OmpGetUniqueModifier<parser::OmpAccessGroup>(mods); - auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpPrescriptiveness>(mods); + auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpFallbackModifier>(mods); auto &size = std::get<parser::ScalarIntExpr>(inp.v.t); - return DynGroupprivate{ - {/*AccessGroup=*/maybeApplyToV(convert, m0), - /*Prescriptiveness=*/maybeApplyToV(makePrescriptiveness, m1), - /*Size=*/makeExpr(size, semaCtx)}}; + return DynGroupprivate{{/*AccessGroup=*/maybeApplyToV(makeAccessGroup, m0), + /*Fallback=*/maybeApplyToV(makeFallback, m1), + /*Size=*/makeExpr(size, semaCtx)}}; } Enter make(const parser::OmpClause::Enter &inp, @@ -972,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp, Initializer make(const parser::OmpClause::Initializer &inp, semantics::SemanticsContext &semaCtx) { - llvm_unreachable("Empty: initializer"); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t); + if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) { + auto &expr = std::get<parser::Expr>(as->t); + return Initializer{makeExpr(expr, semaCtx)}; + } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) { + if (call->typedCall) { + const auto &procRef = *call->typedCall; + semantics::SomeExpr evalProcRef{procRef}; + return Initializer{evalProcRef}; + } + } + + llvm_unreachable("Unexpected initializer"); } InReduction make(const parser::OmpClause::InReduction &inp, @@ -1052,7 +1076,7 @@ Link make(const parser::OmpClause::Link &inp, return Link{/*List=*/makeObjects(inp.v, semaCtx)}; } -LoopRange make(const parser::OmpClause::Looprange &inp, +Looprange make(const parser::OmpClause::Looprange &inp, semantics::SemanticsContext &semaCtx) { llvm_unreachable("Unimplemented: looprange"); } diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 146a252..83c2eda 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -342,7 +342,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { if (!hasLastPrivate) return; - if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op)) { + if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op) || + mlir::isa<mlir::omp::TaskloopOp>(op)) { mlir::omp::LoopRelatedClauseOps result; llvm::SmallVector<const semantics::Symbol *> iv; collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval, @@ -408,7 +409,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { } else { TODO(converter.getCurrentLocation(), "lastprivate clause in constructs other than " - "simd/worksharing-loop"); + "simd/worksharing-loop/taskloop"); } } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7106728..9c25c19 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -18,12 +18,17 @@ #include "Decomposer.h" #include "Utils.h" #include "flang/Common/idioms.h" +#include "flang/Evaluate/type.h" #include "flang/Lower/Bridge.h" +#include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExpr.h" +#include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" +#include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -565,14 +570,9 @@ getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { if (collapseValue == 0) return &eval; - lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation(); - for (int i = 1; i < collapseValue; i++) { - // The nested evaluations should be DoConstructs (i.e. they should form - // a loop nest). Each DoConstruct is a tuple <NonLabelDoStmt, Block, - // EndDoStmt>. - assert(curEval->isA<parser::DoConstruct>()); - curEval = &*std::next(curEval->getNestedEvaluations().begin()); - } + lower::pft::Evaluation *curEval = &eval; + for (int i = 0; i < collapseValue; i++) + curEval = getNestedDoConstruct(*curEval); return curEval; } @@ -1008,9 +1008,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Firstprivate: - case DefMap::ImplicitBehavior::None: - TODO(loc, "Firstprivate and None are currently unsupported defaultmap " - "behaviour"); + TODO(loc, "Firstprivate is currently unsupported defaultmap behaviour"); break; case DefMap::ImplicitBehavior::From: return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from, @@ -1032,8 +1030,9 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Default: + case DefMap::ImplicitBehavior::None: llvm_unreachable( - "Implicit None Behaviour Should Have Been Handled Earlier"); + "Implicit None and Default behaviour should have been handled earlier"); break; } @@ -1203,7 +1202,7 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, // Start with privatization, so that the lowering of the nested // code will use the right symbols. bool isLoop = llvm::omp::getDirectiveAssociation(info.dir) == - llvm::omp::Association::Loop; + llvm::omp::Association::LoopNest; bool privatize = info.clauses && info.privatize; firOpBuilder.setInsertionPoint(marker); @@ -1637,8 +1636,7 @@ static void genSimdClauses( cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); - - cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd); + cp.processLinear(clauseOps); } static void genSingleClauses(lower::AbstractConverter &converter, @@ -1673,7 +1671,7 @@ static void genTargetClauses( hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown, &mapSyms); cp.processNowait(clauseOps); @@ -1763,21 +1761,25 @@ static void genTaskgroupClauses( cp.processTaskReduction(loc, clauseOps, taskReductionSyms); } -static void genTaskloopClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, - const List<Clause> &clauses, mlir::Location loc, - mlir::omp::TaskloopOperands &clauseOps) { +static void genTaskloopClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, mlir::omp::TaskloopOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms, + llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processFinal(stmtCtx, clauseOps); cp.processGrainsize(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_taskloop, clauseOps); + cp.processInReduction(loc, clauseOps, inReductionSyms); + cp.processMergeable(clauseOps); + cp.processNogroup(clauseOps); cp.processNumTasks(stmtCtx, clauseOps); - - cp.processTODO<clause::Allocate, clause::Collapse, clause::Default, - clause::Final, clause::If, clause::InReduction, - clause::Lastprivate, clause::Mergeable, clause::Nogroup, - clause::Priority, clause::Reduction, clause::Shared, - clause::Untied>(loc, llvm::omp::Directive::OMPD_taskloop); + cp.processPriority(stmtCtx, clauseOps); + cp.processReduction(loc, clauseOps, reductionSyms); + cp.processUntied(clauseOps); } static void genTaskwaitClauses(lower::AbstractConverter &converter, @@ -1828,9 +1830,9 @@ static void genWsloopClauses( cp.processOrdered(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); + cp.processLinear(clauseOps); - cp.processTODO<clause::Allocate, clause::Linear>( - loc, llvm::omp::Directive::OMPD_do); + cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do); } //===----------------------------------------------------------------------===// @@ -2485,13 +2487,15 @@ static bool isDuplicateMappedSymbol( const semantics::Symbol &sym, const llvm::SetVector<const semantics::Symbol *> &privatizedSyms, const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms, - const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) { + const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms, + const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) { llvm::SmallVector<const semantics::Symbol *> concatSyms; concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() + - mappedSyms.size()); + mappedSyms.size() + isDevicePtrSyms.size()); concatSyms.append(privatizedSyms.begin(), privatizedSyms.end()); concatSyms.append(hasDevSyms.begin(), hasDevSyms.end()); concatSyms.append(mappedSyms.begin(), mappedSyms.end()); + concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end()); auto checkSymbol = [&](const semantics::Symbol &checkSym) { return std::any_of(concatSyms.begin(), concatSyms.end(), @@ -2531,6 +2535,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, loc, clauseOps, defaultMaps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); + if (!isDevicePtrSyms.empty()) { + // is_device_ptr maps get duplicated so the clause and synthesized + // has_device_addr entry each own a unique MapInfoOp user, keeping + // MapInfoFinalization happy while still wiring the symbol into + // has_device_addr when the user didn’t spell it explicitly. + auto insertionPt = firOpBuilder.saveInsertionPoint(); + auto alreadyPresent = [&](const semantics::Symbol *sym) { + return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) { + return s && sym && s->GetUltimate() == sym->GetUltimate(); + }); + }; + + for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) { + mlir::Value mapVal = clauseOps.isDevicePtrVars[idx]; + assert(sym && "expected symbol for is_device_ptr"); + assert(mapVal && "expected map value for is_device_ptr"); + auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>(); + assert(mapInfo && "expected map info op"); + + if (!alreadyPresent(sym)) { + clauseOps.hasDeviceAddrVars.push_back(mapVal); + hasDeviceAddrSyms.push_back(sym); + } + + firOpBuilder.setInsertionPointAfter(mapInfo); + mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation()); + auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp); + clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult(); + } + firOpBuilder.restoreInsertionPoint(insertionPt); + } + DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ lower::omp::isLastItemInQueue(item, queue), @@ -2570,7 +2606,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return; if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(), - hasDeviceAddrSyms, mapSyms)) { + hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) { if (const auto *details = sym.template detailsIf<semantics::HostAssocDetails>()) converter.copySymbolBinding(details->symbol(), sym); @@ -2578,18 +2614,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym); name << sym.name().ToString(); - mlir::FlatSymbolRefAttr mapperId; - if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) { - auto &typeSpec = sym.GetType()->derivedTypeSpec(); - std::string mapperIdName = - typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) - mapperIdName = converter.mangleName(mapperIdName, sym->owner()); - if (converter.getModuleOp().lookupSymbol(mapperIdName)) - mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperIdName); - } - fir::factory::AddrAndBoundsInfo info = Fortran::lower::getDataOperandBaseAddr( converter, firOpBuilder, sym.GetUltimate(), @@ -2609,6 +2633,44 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mapFlagAndKind = getImplicitMapTypeAndKind( firOpBuilder, converter, defaultMaps, eleType, loc, sym); + mlir::FlatSymbolRefAttr mapperId; + if (defaultMaps.empty()) { + // TODO: Honor user-provided defaultmap clauses (aggregates/pointers) + // instead of blanket-disabling implicit mapper generation whenever any + // explicit default map is present. + const semantics::DerivedTypeSpec *typeSpec = + sym.GetType() ? sym.GetType()->AsDerived() : nullptr; + if (typeSpec) { + std::string mapperIdName = + typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; + if (auto *mapperSym = + converter.getCurrentScope().FindSymbol(mapperIdName)) + mapperIdName = converter.mangleName( + mapperIdName, mapperSym->GetUltimate().owner()); + else + mapperIdName = + converter.mangleName(mapperIdName, *typeSpec->GetScope()); + + if (!mapperIdName.empty()) { + bool allowImplicitMapper = + semantics::IsAllocatableOrObjectPointer(&sym); + bool hasDefaultMapper = + converter.getModuleOp().lookupSymbol(mapperIdName); + if (hasDefaultMapper || allowImplicitMapper) { + if (!hasDefaultMapper) { + if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>( + converter.genType(*typeSpec))) + mapperId = getOrGenImplicitDefaultDeclareMapper( + converter, loc, recordType, mapperIdName); + } else { + mapperId = mlir::FlatSymbolRefAttr::get( + &converter.getMLIRContext(), mapperIdName); + } + } + } + } + } + mlir::Value mapOp = createMapInfoOp( firOpBuilder, converter.getCurrentLocation(), baseOp, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{}, @@ -2818,7 +2880,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // TODO: Add private syms and vars. args.reduction.syms = reductionSyms; args.reduction.vars = clauseOps.reductionVars; - return genOpWithBody<mlir::omp::TeamsOp>( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) @@ -2979,8 +3040,11 @@ static mlir::omp::TaskloopOp genStandaloneTaskloop( lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::TaskloopOperands taskloopClauseOps; + llvm::SmallVector<const semantics::Symbol *> reductionSyms; + llvm::SmallVector<const semantics::Symbol *> inReductionSyms; + genTaskloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - taskloopClauseOps); + taskloopClauseOps, reductionSyms, inReductionSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, enableDelayedPrivatization, symTable); @@ -2994,6 +3058,10 @@ static mlir::omp::TaskloopOp genStandaloneTaskloop( EntryBlockArgs taskloopArgs; taskloopArgs.priv.syms = dsp.getDelayedPrivSymbols(); taskloopArgs.priv.vars = taskloopClauseOps.privateVars; + taskloopArgs.reduction.syms = reductionSyms; + taskloopArgs.reduction.vars = taskloopClauseOps.reductionVars; + taskloopArgs.inReduction.syms = inReductionSyms; + taskloopArgs.inReduction.vars = taskloopClauseOps.inReductionVars; auto taskLoopOp = genWrapperOp<mlir::omp::TaskloopOp>( converter, loc, taskloopClauseOps, taskloopArgs); @@ -3246,17 +3314,12 @@ static mlir::omp::WsloopOp genCompositeDoSimd( genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, simdReductionSyms); - DataSharingProcessor wsloopItemDSP( - converter, semaCtx, doItem->clauses, eval, - /*shouldCollectPreDeterminedSymbols=*/false, - /*useDelayedPrivatization=*/true, symTable); + DataSharingProcessor wsloopItemDSP(converter, semaCtx, doItem->clauses, eval, + /*shouldCollectPreDeterminedSymbols=*/true, + /*useDelayedPrivatization=*/true, + symTable); wsloopItemDSP.processStep1(&wsloopClauseOps); - DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval, - /*shouldCollectPreDeterminedSymbols=*/true, - /*useDelayedPrivatization=*/true, symTable); - simdItemDSP.processStep1(&simdClauseOps, simdItem->id); - // Pass the innermost leaf construct's clauses because that's where COLLAPSE // is placed by construct decomposition. mlir::omp::LoopNestOperands loopNestClauseOps; @@ -3275,8 +3338,9 @@ static mlir::omp::WsloopOp genCompositeDoSimd( wsloopOp.setComposite(/*val=*/true); EntryBlockArgs simdArgs; - simdArgs.priv.syms = simdItemDSP.getDelayedPrivSymbols(); - simdArgs.priv.vars = simdClauseOps.privateVars; + // For composite 'do simd', privatization is handled by the wsloop. + // The simd does not create separate private storage for variables already + // privatized by the worksharing construct. simdArgs.reduction.syms = simdReductionSyms; simdArgs.reduction.vars = simdClauseOps.reductionVars; auto simdOp = @@ -3286,7 +3350,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd( genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}}, - llvm::omp::Directive::OMPD_do_simd, simdItemDSP); + llvm::omp::Directive::OMPD_do_simd, wsloopItemDSP); return wsloopOp; } @@ -3362,7 +3426,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, }; bool loopLeaf = llvm::omp::getDirectiveAssociation(item->id) == - llvm::omp::Association::Loop; + llvm::omp::Association::LoopNest; if (loopLeaf) { symTable.pushScope(); if (genOMPCompositeDispatch(converter, symTable, stmtCtx, semaCtx, eval, @@ -3471,6 +3535,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_tile: genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; + case llvm::omp::Directive::OMPD_fuse: { + unsigned version = semaCtx.langOptions().OpenMPVersion; + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "Unhandled loop directive (" + + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + break; + } case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; @@ -3503,12 +3574,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, const parser::OpenMPUtilityConstruct &); -static void -genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - const parser::OpenMPDeclarativeAllocate &declarativeAllocate) { +static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, + const parser::OmpAllocateDirective &allocate) { if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); + TODO(converter.getCurrentLocation(), "OmpAllocateDirective"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, @@ -3527,12 +3598,186 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); } +static ReductionProcessor::GenCombinerCBTy +processReductionCombiner(lower::AbstractConverter &converter, + lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + const parser::OmpReductionSpecifier &specifier) { + ReductionProcessor::GenCombinerCBTy genCombinerCB; + const auto &combinerExpression = + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t); + + std::optional<semantics::SomeExpr> evalExprOpt; + if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) { + auto &expr = std::get<parser::Expr>(as->t); + evalExprOpt = makeExpr(expr, semaCtx); + } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) { + if (call->typedCall) { + const auto &procRef = *call->typedCall; + evalExprOpt = semantics::SomeExpr{procRef}; + } else { + TODO(converter.getCurrentLocation(), + "CallStmt without typedCall is not yet supported"); + } + } else { + TODO(converter.getCurrentLocation(), "Unsupported combiner instance type"); + } + + assert(evalExprOpt.has_value() && "evalExpr must be initialized"); + semantics::SomeExpr evalExpr = *evalExprOpt; + + genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value lhs, + mlir::Value rhs, bool isByRef) { + lower::SymMapScope scope(symTable); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); + mlir::Value ompOutVar; + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get<parser::ObjectName>(decl.var.t); + mlir::Value addr = lhs; + mlir::Type type = lhs.getType(); + bool isRhs = name.ToString() == std::string("omp_in"); + if (isRhs) { + addr = rhs; + type = rhs.getType(); + } + + assert(name.symbol && "Reduction object name does not have a symbol"); + if (!fir::conformsWithPassByRef(type)) { + addr = builder.createTemporary(loc, type); + fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr); + } + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = + hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr, + {}, nullptr, nullptr, 0, attributes); + if (name.ToString() == "omp_out") + ompOutVar = declareOp.getResult(0); + symTable.addVariableDefinition(*name.symbol, declareOp); + } + + lower::StatementContext stmtCtx; + mlir::Value result = common::visit( + common::visitors{ + [&](const evaluate::ProcedureRef &procRef) -> mlir::Value { + convertCallToHLFIR(loc, converter, procRef, std::nullopt, + symTable, stmtCtx); + auto outVal = fir::LoadOp::create(builder, loc, ompOutVar); + return outVal; + }, + [&](const auto &expr) -> mlir::Value { + mlir::Value exprResult = fir::getBase(convertExprToValue( + loc, converter, evalExpr, symTable, stmtCtx)); + // Optional load may be generated if we get a reference to the + // reduction type. + if (auto refType = + llvm::dyn_cast<fir::ReferenceType>(exprResult.getType())) + if (lhs.getType() == refType.getElementType()) + exprResult = fir::LoadOp::create(builder, loc, exprResult); + return exprResult; + }}, + evalExpr.u); + stmtCtx.finalizeAndPop(); + if (isByRef) { + fir::StoreOp::create(builder, loc, result, lhs); + mlir::omp::YieldOp::create(builder, loc, lhs); + } else { + mlir::omp::YieldOp::create(builder, loc, result); + } + }; + return genCombinerCB; +} + +// Checks that the reduction type is either a trivial type or a derived type of +// trivial types. +static bool isSimpleReductionType(mlir::Type reductionType) { + if (fir::isa_trivial(reductionType)) + return true; + if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) { + for (auto [_, fieldType] : recordTy.getTypeList()) { + if (!fir::isa_trivial(fieldType)) + return false; + } + } + return true; +} + +// Getting the type from a symbol compared to a DeclSpec is simpler since we do +// not need to consider derived vs intrinsic types. Semantics is guaranteed to +// generate these symbols. +static mlir::Type +getReductionType(lower::AbstractConverter &converter, + const parser::OmpReductionSpecifier &specifier) { + const auto &combinerExpression = + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); + const parser::OmpStylizedDeclaration &decl = declList.front(); + const auto &name = std::get<parser::ObjectName>(decl.var.t); + const auto &symbol = semantics::SymbolRef(*name.symbol); + mlir::Type reductionType = converter.genType(symbol); + + if (!isSimpleReductionType(reductionType)) + TODO(converter.getCurrentLocation(), + "declare reduction currently only supports trival types or derived " + "types containing trivial types"); + return reductionType; +} + static void genOMP( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); + if (semaCtx.langOptions().OpenMPSimd) + return; + + const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()}; + const parser::OmpArgument &arg{args.v.front()}; + const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u); + + if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1) + TODO(converter.getCurrentLocation(), + "multiple types in declare reduction is not yet supported"); + + mlir::Type reductionType = getReductionType(converter, specifier); + ReductionProcessor::GenCombinerCBTy genCombinerCB = + processReductionCombiner(converter, symTable, semaCtx, specifier); + const parser::OmpClauseList &initializer = + declareReductionConstruct.v.Clauses(); + if (initializer.v.size() > 0) { + List<Clause> clauses = makeClauses(initializer, semaCtx); + ReductionProcessor::GenInitValueCBTy genInitValueCB; + ClauseProcessor cp(converter, semaCtx, clauses); + const parser::OmpClause::Initializer &iclause{ + std::get<parser::OmpClause::Initializer>(initializer.v.front().u)}; + cp.processInitializer(symTable, iclause, genInitValueCB); + const auto &identifier = + std::get<parser::OmpReductionIdentifier>(specifier.t); + const auto &designator = + std::get<parser::ProcedureDesignator>(identifier.u); + const auto &reductionName = std::get<parser::Name>(designator.u); + bool isByRef = ReductionProcessor::doReductionByRef(reductionType); + ReductionProcessor::createDeclareReductionHelper< + mlir::omp::DeclareReductionOp>( + converter, reductionName.ToString(), reductionType, + converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB); + } else { + TODO(converter.getCurrentLocation(), + "declare reduction without an initializer clause is not yet " + "supported"); + } } static void @@ -3543,10 +3788,10 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); } -static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, - const parser::OpenMPDeclareMapperConstruct &construct) { +static void genOpenMPDeclareMapperImpl( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const parser::OpenMPDeclareMapperConstruct &construct, + const semantics::Symbol *mapperSymOpt = nullptr) { mlir::Location loc = converter.genLocation(construct.source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const parser::OmpArgumentList &args = construct.v.Arguments(); @@ -3562,8 +3807,17 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, "Expected derived type"); std::string mapperNameStr = mapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperNameStr)) + if (mapperSymOpt && mapperNameStr != "default") { + mapperNameStr = converter.mangleName(mapperNameStr, mapperSymOpt->owner()); + } else if (auto *sym = + converter.getCurrentScope().FindSymbol(mapperNameStr)) { mapperNameStr = converter.mangleName(mapperNameStr, sym->owner()); + } + + // If the mapper op already exists (e.g., created by regular lowering or by + // materialization of imported mappers), do not recreate it. + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) + return; // Save current insertion point before moving to the module scope to create // the DeclareMapperOp @@ -3586,6 +3840,13 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseOps.mapVars); } +static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, + const parser::OpenMPDeclareMapperConstruct &construct) { + genOpenMPDeclareMapperImpl(converter, semaCtx, construct); +} + static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, @@ -3902,14 +4163,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - const parser::OpenMPExecutableAllocate &execAllocConstruct) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); -} - -static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, const parser::OpenMPLoopConstruct &loopConstruct) { const parser::OmpDirectiveSpecification &beginSpec = loopConstruct.BeginDir(); List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx); @@ -3918,12 +4171,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::Location currentLocation = converter.genLocation(beginSpec.source); - auto &optLoopCons = - std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t); - if (optLoopCons.has_value()) { - if (auto *ompNestedLoopCons{ - std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( - &*optLoopCons)}) { + for (auto &construct : std::get<parser::Block>(loopConstruct.t)) { + if (const parser::OpenMPLoopConstruct *ompNestedLoopCons = + parser::omp::GetOmpLoop(construct)) { llvm::omp::Directive nestedDirective = parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; switch (nestedDirective) { @@ -4229,3 +4479,36 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, offloadMod.setRequires(mlirFlags); } } + +// Walk scopes and materialize omp.declare_mapper ops for mapper declarations +// found in imported modules. If \p scope is null, start from the global scope. +void Fortran::lower::materializeOpenMPDeclareMappers( + Fortran::lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, const semantics::Scope *scope) { + const semantics::Scope &root = scope ? *scope : semaCtx.globalScope(); + + // Recurse into child scopes first (modules, submodules, etc.). + for (const semantics::Scope &child : root.children()) + materializeOpenMPDeclareMappers(converter, semaCtx, &child); + + // Only consider module scopes to avoid duplicating local constructs. + if (!root.IsModule()) + return; + + // Only materialize for modules coming from mod files to avoid duplicates. + if (!root.symbol() || !root.symbol()->test(semantics::Symbol::Flag::ModFile)) + return; + + // Scan symbols in this module scope for MapperDetails. + for (auto &it : root) { + const semantics::Symbol &sym = *it.second; + if (auto *md = sym.detailsIf<semantics::MapperDetails>()) { + for (const auto *decl : md->GetDeclList()) { + if (const auto *mapperDecl = + std::get_if<parser::OpenMPDeclareMapperConstruct>(&decl->u)) { + genOpenMPDeclareMapperImpl(converter, semaCtx, *mapperDecl, &sym); + } + } + } + } +} diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 6487f59..a818d63 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -14,22 +14,28 @@ #include "ClauseFinder.h" #include "flang/Evaluate/fold.h" +#include "flang/Evaluate/tools.h" #include <flang/Lower/AbstractConverter.h> #include <flang/Lower/ConvertType.h> #include <flang/Lower/DirectivesCommon.h> #include <flang/Lower/OpenMP/Clauses.h> #include <flang/Lower/PFTBuilder.h> #include <flang/Lower/Support/PrivateReductionUtils.h> +#include <flang/Optimizer/Builder/BoxValue.h> #include <flang/Optimizer/Builder/FIRBuilder.h> #include <flang/Optimizer/Builder/Todo.h> +#include <flang/Optimizer/HLFIR/HLFIROps.h> #include <flang/Parser/openmp-utils.h> #include <flang/Parser/parse-tree.h> #include <flang/Parser/tools.h> #include <flang/Semantics/tools.h> #include <flang/Semantics/type.h> #include <flang/Utils/OpenMP.h> +#include <llvm/ADT/SmallPtrSet.h> +#include <llvm/ADT/StringRef.h> #include <llvm/Support/CommandLine.h> +#include <functional> #include <iterator> template <typename T> @@ -61,6 +67,142 @@ namespace Fortran { namespace lower { namespace omp { +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + lower::AbstractConverter &converter, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr) { + if (mapperNameStr.empty()) + return {}; + + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperNameStr); + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); + + firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); + auto declMapperOp = mlir::omp::DeclareMapperOp::create( + firOpBuilder, loc, mapperNameStr, recordType); + auto ®ion = declMapperOp.getRegion(); + firOpBuilder.createBlock(®ion); + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); + + auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, + /*uniq_name=*/""); + + const auto genBoundsOps = [&](mlir::Value mapVal, + llvm::SmallVectorImpl<mlir::Value> &bounds) { + fir::ExtendedValue extVal = + hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder, + hlfir::Entity{mapVal}, + /*contiguousHint=*/true) + .first; + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>( + firOpBuilder, info, extVal, + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); + }; + + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, + mlir::Type fieldTy, mlir::Type recType) { + mlir::Value field = fir::FieldIndexOp::create( + firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName, + recType, fir::getTypeParams(rec)); + return fir::CoordinateOp::create( + firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); + }; + + llvm::SmallVector<mlir::Value> clauseMapVars; + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; + llvm::SmallVector<mlir::Value> memberMapOps; + + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::implicit; + mlir::omp::VariableCaptureKind captureKind = + mlir::omp::VariableCaptureKind::ByRef; + + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { + const auto &memberName = entry.value().first; + const auto &memberType = entry.value().second; + mlir::FlatSymbolRefAttr mapperId; + if (auto recType = mlir::dyn_cast<fir::RecordType>( + fir::getFortranElementType(memberType))) { + std::string mapperIdName = + recType.getName().str() + llvm::omp::OmpDefaultMapperName; + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); + else if (auto *memberSym = + converter.getCurrentScope().FindSymbol(memberName)) + mapperIdName = converter.mangleName(mapperIdName, memberSym->owner()); + + mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType, + mapperIdName); + } + + auto ref = + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(ref, bounds); + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", + bounds, + /*members=*/{}, + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), + /*partialMap=*/false, mapperId); + memberMapOps.emplace_back(mapOp); + memberPlacementIndices.emplace_back( + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); + } + + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(declareOp.getOriginalBase(), bounds); + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, declareOp.getOriginalBase(), + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, + captureKind, declareOp.getType(0), + /*partialMap=*/true); + + clauseMapVars.emplace_back(mapOp); + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperNameStr); +} + +bool requiresImplicitDefaultDeclareMapper( + const semantics::DerivedTypeSpec &typeSpec) { + // ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit + // default mappers available so that OpenMP offloading can correctly map them. + if (semantics::IsIsoCType(&typeSpec)) + return true; + + llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited; + + std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper = + [&](const semantics::DerivedTypeSpec &spec) -> bool { + if (!visited.insert(&spec).second) + return false; + + semantics::DirectComponentIterator directComponents{spec}; + for (const semantics::Symbol &component : directComponents) { + if (component.attrs().test(semantics::Attr::ALLOCATABLE)) + return true; + + if (const semantics::DeclTypeSpec *declType = component.GetType()) + if (const auto *nested = declType->AsDerived()) + if (requiresMapper(*nested)) + return true; + } + return false; + }; + + return requiresMapper(typeSpec); +} + int64_t getCollapseValue(const List<Clause> &clauses) { auto iter = llvm::find_if(clauses, [](const Clause &clause) { return clause.id == llvm::omp::Clause::OMPC_collapse; @@ -537,6 +679,12 @@ void insertChildMapInfoIntoParent( mapOperands[std::distance(mapSyms.begin(), parentIter)] .getDefiningOp()); + // Once explicit members are attached to a parent map, do not also invoke + // a declare mapper on it, otherwise the mapper would remap the same + // components leading to duplicate mappings at runtime. + if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr()) + mapOp.setMapperIdAttr(nullptr); + // NOTE: To maintain appropriate SSA ordering, we move the parent map // which will now have references to its children after the last // of its members to be generated. This is necessary when a user @@ -631,17 +779,9 @@ static void processTileSizesFromOpenMPConstruct( if (!ompCons) return; if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) { - const auto &nestedOptional = - std::get<std::optional<parser::NestedConstruct>>(ompLoop->t); - assert(nestedOptional.has_value() && - "Expected a DoConstruct or OpenMPLoopConstruct"); - const auto *innerConstruct = - std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( - &(nestedOptional.value())); - if (innerConstruct) { - const auto &innerLoopDirective = innerConstruct->value(); + if (auto *innerConstruct = ompLoop->GetNestedConstruct()) { const parser::OmpDirectiveSpecification &innerBeginSpec = - innerLoopDirective.BeginDir(); + innerConstruct->BeginDir(); if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) { // Get the size values from parse tree and convert to a vector. for (const auto &clause : innerBeginSpec.Clauses().v) { @@ -656,6 +796,28 @@ static void processTileSizesFromOpenMPConstruct( } } +pft::Evaluation *getNestedDoConstruct(pft::Evaluation &eval) { + for (pft::Evaluation &nested : eval.getNestedEvaluations()) { + // In an OpenMPConstruct there can be compiler directives: + // 1 <<OpenMPConstruct>> + // 2 CompilerDirective: !unroll + // <<DoConstruct>> -> 8 + if (nested.getIf<parser::CompilerDirective>()) + continue; + // Within a DoConstruct, there can be compiler directives, plus + // there is a DoStmt before the body: + // <<DoConstruct>> -> 8 + // 3 NonLabelDoStmt -> 7: do i = 1, n + // <<DoConstruct>> -> 7 + if (nested.getIf<parser::NonLabelDoStmt>()) + continue; + assert(nested.getIf<parser::DoConstruct>() && + "Unexpected construct in the nested evaluations"); + return &nested; + } + llvm_unreachable("Expected do loop to be in the nested evaluations"); +} + /// Populates the sizes vector with values if the given OpenMPConstruct /// contains a loop construct with an inner tiling construct. void collectTileSizesFromOpenMPConstruct( @@ -678,7 +840,7 @@ int64_t collectLoopRelatedInfo( int64_t numCollapse = 1; // Collect the loops to collapse. - lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); + lower::pft::Evaluation *doConstructEval = getNestedDoConstruct(eval); if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); } @@ -704,7 +866,7 @@ void collectLoopRelatedInfo( fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // Collect the loops to collapse. - lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); + lower::pft::Evaluation *doConstructEval = getNestedDoConstruct(eval); if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); } @@ -745,9 +907,8 @@ void collectLoopRelatedInfo( iv.push_back(bounds->name.thing.symbol); loopVarTypeSize = std::max(loopVarTypeSize, bounds->name.thing.symbol->GetUltimate().size()); - collapseValue--; - doConstructEval = - &*std::next(doConstructEval->getNestedEvaluations().begin()); + if (--collapseValue) + doConstructEval = getNestedDoConstruct(*doConstructEval); } while (collapseValue > 0); convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index ef1f37a..8a68ff8 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -20,6 +20,7 @@ extern llvm::cl::opt<bool> treatIndexAsSection; namespace fir { class FirOpBuilder; +class RecordType; } // namespace fir namespace Fortran { @@ -136,6 +137,13 @@ mlir::Value createParentSymAndGenIntermediateMaps( OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, mlir::omp::ClauseMapFlags mapTypeBits); +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + Fortran::lower::AbstractConverter &converter, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr); + +bool requiresImplicitDefaultDeclareMapper( + const semantics::DerivedTypeSpec &typeSpec); + omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, semantics::SemanticsContext &semaCtx); @@ -159,6 +167,8 @@ void genObjectList(const ObjectList &objects, void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp, mlir::Location loc); +pft::Evaluation *getNestedDoConstruct(pft::Evaluation &eval); + int64_t collectLoopRelatedInfo( lower::AbstractConverter &converter, mlir::Location currentLocation, lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses, |
