diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 57 |
1 files changed, 23 insertions, 34 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index fc5fef9..12089d6 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -31,6 +31,7 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Parser/characters.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" @@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, lower::pft::Evaluation &eval, mlir::Location loc); -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) { - return beginStatment.v; -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const parser::OmpBeginLoopDirective &beginLoopDirective) { - return getOpenMPDirectiveEnum( - std::get<parser::OmpLoopDirective>(beginLoopDirective.t)); -} - -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) { - return getOpenMPDirectiveEnum( - std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t)); -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) { - return getOpenMPDirectiveEnum(ompLoopConstruct.value()); -} - namespace { /// Structure holding information that is needed to pass host-evaluated /// information to later lowering stages. @@ -432,8 +411,12 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, std::get<parser::OmpBeginBlockDirective>(ompConstruct.t); beginClauseList = &std::get<parser::OmpClauseList>(beginDirective.t); - endClauseList = &std::get<parser::OmpClauseList>( - std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t); + if (auto &endDirective = + std::get<std::optional<parser::OmpEndBlockDirective>>( + ompConstruct.t)) { + endClauseList = + &std::get<parser::OmpClauseList>(endDirective->t); + } }, [&](const parser::OpenMPLoopConstruct &ompConstruct) { const auto &beginDirective = @@ -443,9 +426,10 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, if (auto &endDirective = std::get<std::optional<parser::OmpEndLoopDirective>>( - ompConstruct.t)) + ompConstruct.t)) { endClauseList = &std::get<parser::OmpClauseList>(endDirective->t); + } }, [&](const auto &) {}}, ompEval->u); @@ -468,7 +452,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, llvm::omp::Directive dir; auto &nested = parent.getFirstNestedEvaluation(); if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) - dir = extractOmpDirective(*ompEval); + dir = parser::omp::GetOmpDirectiveName(*ompEval).v; else return std::nullopt; @@ -508,7 +492,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); assert(hostInfo && "expected HOST_EVAL info structure"); - switch (extractOmpDirective(*ompEval)) { + switch (parser::omp::GetOmpDirectiveName(*ompEval).v) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: cp.processThreadLimit(stmtCtx, hostInfo->ops); @@ -569,7 +553,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); assert(ompEval && - llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + llvm::omp::allTargetSet.test( + parser::omp::GetOmpDirectiveName(*ompEval).v) && "expected TARGET construct evaluation"); (void)ompEval; @@ -3733,16 +3718,19 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, const parser::OpenMPBlockConstruct &blockConstruct) { const auto &beginBlockDirective = std::get<parser::OmpBeginBlockDirective>(blockConstruct.t); - const auto &endBlockDirective = - std::get<parser::OmpEndBlockDirective>(blockConstruct.t); mlir::Location currentLocation = converter.genLocation(beginBlockDirective.source); const auto origDirective = std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v; List<Clause> clauses = makeClauses( std::get<parser::OmpClauseList>(beginBlockDirective.t), semaCtx); - clauses.append(makeClauses( - std::get<parser::OmpClauseList>(endBlockDirective.t), semaCtx)); + + if (const auto &endBlockDirective = + std::get<std::optional<parser::OmpEndBlockDirective>>( + blockConstruct.t)) { + clauses.append(makeClauses( + std::get<parser::OmpClauseList>(endBlockDirective->t), semaCtx)); + } assert(llvm::omp::blockConstructSet.test(origDirective) && "Expected block construct"); @@ -3872,7 +3860,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( &*optLoopCons)}) { llvm::omp::Directive nestedDirective = - getOpenMPDirectiveEnum(*ompNestedLoopCons); + parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; switch (nestedDirective) { case llvm::omp::Directive::OMPD_tile: // Emit the omp.loop_nest with annotation for tiling @@ -3889,7 +3877,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, } } - llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective); + llvm::omp::Directive directive = + parser::omp::GetOmpDirectiveName(beginLoopDirective).v; const parser::CharBlock &source = std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source; ConstructQueue queue{ |