aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenMP/OpenMP.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp57
1 files changed, 23 insertions, 34 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fc5fef9..12089d6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -31,6 +31,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/characters.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
@@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
lower::pft::Evaluation &eval,
mlir::Location loc);
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) {
- return beginStatment.v;
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const parser::OmpBeginLoopDirective &beginLoopDirective) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpLoopDirective>(beginLoopDirective.t));
-}
-
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t));
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(ompLoopConstruct.value());
-}
-
namespace {
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
@@ -432,8 +411,12 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
beginClauseList =
&std::get<parser::OmpClauseList>(beginDirective.t);
- endClauseList = &std::get<parser::OmpClauseList>(
- std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+ if (auto &endDirective =
+ std::get<std::optional<parser::OmpEndBlockDirective>>(
+ ompConstruct.t)) {
+ endClauseList =
+ &std::get<parser::OmpClauseList>(endDirective->t);
+ }
},
[&](const parser::OpenMPLoopConstruct &ompConstruct) {
const auto &beginDirective =
@@ -443,9 +426,10 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
if (auto &endDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
- ompConstruct.t))
+ ompConstruct.t)) {
endClauseList =
&std::get<parser::OmpClauseList>(endDirective->t);
+ }
},
[&](const auto &) {}},
ompEval->u);
@@ -468,7 +452,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
llvm::omp::Directive dir;
auto &nested = parent.getFirstNestedEvaluation();
if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
- dir = extractOmpDirective(*ompEval);
+ dir = parser::omp::GetOmpDirectiveName(*ompEval).v;
else
return std::nullopt;
@@ -508,7 +492,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
assert(hostInfo && "expected HOST_EVAL info structure");
- switch (extractOmpDirective(*ompEval)) {
+ switch (parser::omp::GetOmpDirectiveName(*ompEval).v) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
cp.processThreadLimit(stmtCtx, hostInfo->ops);
@@ -569,7 +553,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
- llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+ llvm::omp::allTargetSet.test(
+ parser::omp::GetOmpDirectiveName(*ompEval).v) &&
"expected TARGET construct evaluation");
(void)ompEval;
@@ -3733,16 +3718,19 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
const parser::OpenMPBlockConstruct &blockConstruct) {
const auto &beginBlockDirective =
std::get<parser::OmpBeginBlockDirective>(blockConstruct.t);
- const auto &endBlockDirective =
- std::get<parser::OmpEndBlockDirective>(blockConstruct.t);
mlir::Location currentLocation =
converter.genLocation(beginBlockDirective.source);
const auto origDirective =
std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v;
List<Clause> clauses = makeClauses(
std::get<parser::OmpClauseList>(beginBlockDirective.t), semaCtx);
- clauses.append(makeClauses(
- std::get<parser::OmpClauseList>(endBlockDirective.t), semaCtx));
+
+ if (const auto &endBlockDirective =
+ std::get<std::optional<parser::OmpEndBlockDirective>>(
+ blockConstruct.t)) {
+ clauses.append(makeClauses(
+ std::get<parser::OmpClauseList>(endBlockDirective->t), semaCtx));
+ }
assert(llvm::omp::blockConstructSet.test(origDirective) &&
"Expected block construct");
@@ -3872,7 +3860,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
llvm::omp::Directive nestedDirective =
- getOpenMPDirectiveEnum(*ompNestedLoopCons);
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Emit the omp.loop_nest with annotation for tiling
@@ -3889,7 +3877,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
}
}
- llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective);
+ llvm::omp::Directive directive =
+ parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
const parser::CharBlock &source =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
ConstructQueue queue{