diff options
Diffstat (limited to 'flang/lib')
-rw-r--r-- | flang/lib/Common/OpenMP-utils.cpp | 9 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 448 |
2 files changed, 437 insertions, 20 deletions
diff --git a/flang/lib/Common/OpenMP-utils.cpp b/flang/lib/Common/OpenMP-utils.cpp index f5115f4..47e89fe 100644 --- a/flang/lib/Common/OpenMP-utils.cpp +++ b/flang/lib/Common/OpenMP-utils.cpp @@ -18,10 +18,10 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args, llvm::SmallVector<mlir::Type> types; llvm::SmallVector<mlir::Location> locs; - unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + - args.priv.vars.size() + args.reduction.vars.size() + - args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() + - args.useDevicePtr.vars.size(); + unsigned numVars = args.hostEvalVars.size() + args.inReduction.vars.size() + + args.map.vars.size() + args.priv.vars.size() + + args.reduction.vars.size() + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size(); types.reserve(numVars); locs.reserve(numVars); @@ -34,6 +34,7 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args, // Populate block arguments in clause name alphabetical order to match // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.hostEvalVars); extractTypeLoc(args.inReduction.vars); extractTypeLoc(args.map.vars); extractTypeLoc(args.priv.vars); diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 8a10294..826edf7 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -55,6 +55,149 @@ static void genOMPDispatch(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item); +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc); + +namespace { +/// Structure holding information that is needed to pass host-evaluated +/// information to later lowering stages. +class HostEvalInfo { +public: + // Allow this function access to private members in order to initialize them. + friend void ::processHostEvalClauses(lower::AbstractConverter &, + semantics::SemanticsContext &, + lower::StatementContext &, + lower::pft::Evaluation &, + mlir::Location); + + /// Fill \c vars with values stored in \c ops. + /// + /// The order in which values are stored matches the one expected by \see + /// bindOperands(). + void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const { + vars.append(ops.loopLowerBounds); + vars.append(ops.loopUpperBounds); + vars.append(ops.loopSteps); + + if (ops.numTeamsLower) + vars.push_back(ops.numTeamsLower); + + if (ops.numTeamsUpper) + vars.push_back(ops.numTeamsUpper); + + if (ops.numThreads) + vars.push_back(ops.numThreads); + + if (ops.threadLimit) + vars.push_back(ops.threadLimit); + } + + /// Update \c ops, replacing all values with the corresponding block argument + /// in \c args. + /// + /// The order in which values are stored in \c args is the same as the one + /// used by \see collectValues(). + void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) { + assert(args.size() == + ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.threadLimit ? 1 : 0) && + "invalid block argument list"); + int argIndex = 0; + for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) + ops.loopLowerBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i) + ops.loopUpperBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopSteps.size(); ++i) + ops.loopSteps[i] = args[argIndex++]; + + if (ops.numTeamsLower) + ops.numTeamsLower = args[argIndex++]; + + if (ops.numTeamsUpper) + ops.numTeamsUpper = args[argIndex++]; + + if (ops.numThreads) + ops.numThreads = args[argIndex++]; + + if (ops.threadLimit) + ops.threadLimit = args[argIndex++]; + } + + /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated + /// values and Fortran symbols, respectively, if they have already been + /// initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::LoopNestOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) { + if (iv.empty() || loopNestApplied) { + loopNestApplied = true; + return false; + } + + loopNestApplied = true; + clauseOps.loopLowerBounds = ops.loopLowerBounds; + clauseOps.loopUpperBounds = ops.loopUpperBounds; + clauseOps.loopSteps = ops.loopSteps; + ivOut.append(iv); + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::ParallelOperands &clauseOps) { + if (!ops.numThreads || parallelApplied) { + parallelApplied = true; + return false; + } + + parallelApplied = true; + clauseOps.numThreads = ops.numThreads; + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::TeamsOperands &clauseOps) { + if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) + return false; + + clauseOps.numTeamsLower = ops.numTeamsLower; + clauseOps.numTeamsUpper = ops.numTeamsUpper; + clauseOps.threadLimit = ops.threadLimit; + return true; + } + +private: + mlir::omp::HostEvaluatedOperands ops; + llvm::SmallVector<const semantics::Symbol *> iv; + bool loopNestApplied = false, parallelApplied = false; +}; +} // namespace + +/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target +/// operations being created. +/// +/// The current implementation prevents nested 'target' regions from breaking +/// the handling of the outer region by keeping a stack of information +/// structures, but it will probably still require some further work to support +/// reverse offloading. +static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo; + /// Bind symbols to their corresponding entry block arguments. /// /// The binding will be performed inside of the current block, which does not @@ -176,6 +319,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter, }; // Process in clause name alphabetical order to match block arguments order. + // Do not bind host_eval variables because they cannot be used inside of the + // corresponding region, except for very specific cases handled separately. bindPrivateLike(args.inReduction.syms, args.inReduction.vars, op.getInReductionBlockArgs()); bindMapLike(args.map.syms, op.getMapBlockArgs()); @@ -213,6 +358,256 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars, }); } +/// Get the directive enumeration value corresponding to the given OpenMP +/// construct PFT node. +llvm::omp::Directive +extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { + return common::visit( + common::visitors{ + [](const parser::OpenMPAllocatorsConstruct &c) { + return llvm::omp::OMPD_allocators; + }, + [](const parser::OpenMPAtomicConstruct &c) { + return llvm::omp::OMPD_atomic; + }, + [](const parser::OpenMPBlockConstruct &c) { + return std::get<parser::OmpBlockDirective>( + std::get<parser::OmpBeginBlockDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPCriticalConstruct &c) { + return llvm::omp::OMPD_critical; + }, + [](const parser::OpenMPDeclarativeAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPExecutableAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPLoopConstruct &c) { + return std::get<parser::OmpLoopDirective>( + std::get<parser::OmpBeginLoopDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPSectionConstruct &c) { + return llvm::omp::OMPD_section; + }, + [](const parser::OpenMPSectionsConstruct &c) { + return std::get<parser::OmpSectionsDirective>( + std::get<parser::OmpBeginSectionsDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPStandaloneConstruct &c) { + return common::visit( + common::visitors{ + [](const parser::OpenMPSimpleStandaloneConstruct &c) { + return std::get<parser::OmpSimpleStandaloneDirective>(c.t) + .v; + }, + [](const parser::OpenMPFlushConstruct &c) { + return llvm::omp::OMPD_flush; + }, + [](const parser::OpenMPCancelConstruct &c) { + return llvm::omp::OMPD_cancel; + }, + [](const parser::OpenMPCancellationPointConstruct &c) { + return llvm::omp::OMPD_cancellation_point; + }, + [](const parser::OpenMPDepobjConstruct &c) { + return llvm::omp::OMPD_depobj; + }}, + c.u); + }, + [](const parser::OpenMPUtilityConstruct &c) { + return common::visit( + common::visitors{[](const parser::OmpErrorDirective &c) { + return llvm::omp::OMPD_error; + }, + [](const parser::OmpNothingDirective &c) { + return llvm::omp::OMPD_nothing; + }}, + c.u); + }}, + ompConstruct.u); +} + +/// Populate the global \see hostEvalInfo after processing clauses for the given +/// \p eval OpenMP target construct, or nested constructs, if these must be +/// evaluated outside of the target region per the spec. +/// +/// In particular, this will ensure that in 'target teams' and equivalent nested +/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated +/// in the host. Additionally, loop bounds, steps and the \c num_threads clause +/// will also be evaluated in the host if a target SPMD construct is detected +/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting). +/// +/// The result, stored as a global, is intended to be used to populate the \c +/// host_eval operands of the associated \c omp.target operation, and also to be +/// checked and used by later lowering steps to populate the corresponding +/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest +/// operations. +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc) { + // Obtain the list of clauses of the given OpenMP block or loop construct + // evaluation. Other evaluations passed to this lambda keep `clauses` + // unchanged. + auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval, + List<Clause> &clauses) { + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *endClauseList = nullptr; + common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginBlockDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + endClauseList = &std::get<parser::OmpClauseList>( + std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t); + }, + [&](const parser::OpenMPLoopConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginLoopDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + + if (auto &endDirective = + std::get<std::optional<parser::OmpEndLoopDirective>>( + ompConstruct.t)) + endClauseList = + &std::get<parser::OmpClauseList>(endDirective->t); + }, + [&](const auto &) {}}, + ompEval->u); + + assert(beginClauseList && "expected begin directive"); + clauses.append(makeClauses(*beginClauseList, semaCtx)); + + if (endClauseList) + clauses.append(makeClauses(*endClauseList, semaCtx)); + }; + + // Return the directive that is immediately nested inside of the given + // `parent` evaluation, if it is its only non-end-statement nested evaluation + // and it represents an OpenMP construct. + auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent) + -> std::optional<llvm::omp::Directive> { + if (!parent.hasNestedEvaluations()) + return std::nullopt; + + llvm::omp::Directive dir; + auto &nested = parent.getFirstNestedEvaluation(); + if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) + dir = extractOmpDirective(*ompEval); + else + return std::nullopt; + + for (auto &sibling : parent.getNestedEvaluations()) + if (&sibling != &nested && !sibling.isEndStmt()) + return std::nullopt; + + return dir; + }; + + // Process the given evaluation assuming it's part of a 'target' construct or + // captured by one, and store results in the global `hostEvalInfo`. + std::function<void(lower::pft::Evaluation &, const List<Clause> &)> + processEval; + processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) { + using namespace llvm::omp; + ClauseProcessor cp(converter, semaCtx, clauses); + + // Call `processEval` recursively with the immediately nested evaluation and + // its corresponding clauses if there is a single nested evaluation + // representing an OpenMP directive that passes the given test. + auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) { + std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval); + if (!nestedDir || !test(*nestedDir)) + return; + + lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation(); + List<lower::omp::Clause> nestedClauses; + extractClauses(nestedEval, nestedClauses); + processEval(nestedEval, nestedClauses); + }; + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + HostEvalInfo &hostInfo = hostEvalInfo.back(); + + switch (extractOmpDirective(*ompEval)) { + // Cases where 'teams' and target SPMD clauses might be present. + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: + cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processNumThreads(stmtCtx, hostInfo.ops); + break; + + // Cases where 'teams' clauses might be present, and target SPMD is + // possible by looking at nested evaluations. + case OMPD_teams: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return nestedDir == OMPD_distribute_parallel_do || + nestedDir == OMPD_distribute_parallel_do_simd; + }); + break; + + // Cases where only 'teams' host-evaluated clauses might be present. + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + break; + + // Standalone 'target' case. + case OMPD_target: { + processSingleNestedIf( + [](Directive nestedDir) { return topTeamsSet.test(nestedDir); }); + break; + } + default: + break; + } + }; + + assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + assert(ompEval && + llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + "expected TARGET construct evaluation"); + + // Use the whole list of clauses passed to the construct here, rather than the + // ones only applied to omp.target. + List<lower::omp::Clause> clauses; + extractClauses(eval, clauses); + processEval(eval, clauses); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -913,6 +1308,8 @@ static void genBodyOfTargetOp( mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region); bindEntryBlockArgs(converter, targetOp, args); + if (!hostEvalInfo.empty()) + hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1126,7 +1523,10 @@ genLoopNestClauses(lower::AbstractConverter &converter, mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - cp.processCollapse(loc, eval, clauseOps, iv); + + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) + cp.processCollapse(loc, eval, clauseOps, iv); + clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); } @@ -1168,7 +1568,10 @@ static void genParallelClauses( ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - cp.processNumThreads(stmtCtx, clauseOps); + + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) + cp.processNumThreads(stmtCtx, clauseOps); + cp.processProcBind(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); } @@ -1215,8 +1618,8 @@ static void genSingleClauses(lower::AbstractConverter &converter, static void genTargetClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List<Clause> &clauses, - mlir::Location loc, bool processHostOnlyClauses, + lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TargetOperands &clauseOps, llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, @@ -1226,13 +1629,15 @@ static void genTargetClauses( cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); + if (!hostEvalInfo.empty()) { + // Only process host_eval if compiling for the host device. + processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); + hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); - - if (processHostOnlyClauses) - cp.processNowait(clauseOps); - + cp.processNowait(clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate, @@ -1344,8 +1749,12 @@ static void genTeamsClauses( ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); + + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + } + cp.processReduction(loc, clauseOps, reductionSyms); // TODO Support delayed privatization. } @@ -1720,17 +2129,19 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); lower::StatementContext stmtCtx; + bool isTargetDevice = + llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp()) + .getIsTargetDevice(); - bool processHostOnlyClauses = - !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp()) - .getIsTargetDevice(); + // Introduce a new host_eval information structure for this target region. + if (!isTargetDevice) + hostEvalInfo.emplace_back(); mlir::omp::TargetOperands clauseOps; llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms, hasDeviceAddrSyms; - genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, - isDevicePtrSyms, mapSyms); + genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc, + clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -1840,6 +2251,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); EntryBlockArgs args; + args.hostEvalVars = clauseOps.hostEvalVars; // TODO: Add in_reduction syms and vars. args.map.syms = mapSyms; args.map.vars = mapBaseValues; @@ -1848,6 +2260,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, queue, item, dsp); + + // Remove the host_eval information structure created for this target region. + if (!isTargetDevice) + hostEvalInfo.pop_back(); return targetOp; } |