diff options
-rw-r--r-- | flang/include/flang/Parser/openmp-utils.h | 161 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/DataSharingProcessor.cpp | 4 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 35 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/Utils.cpp | 84 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/Utils.h | 2 |
5 files changed, 173 insertions, 113 deletions
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h new file mode 100644 index 0000000..579ea7d --- /dev/null +++ b/flang/include/flang/Parser/openmp-utils.h @@ -0,0 +1,161 @@ +//===-- flang/Parser/openmp-utils.h ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Common OpenMP utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_PARSER_OPENMP_UTILS_H +#define FORTRAN_PARSER_OPENMP_UTILS_H + +#include "flang/Common/indirection.h" +#include "flang/Parser/parse-tree.h" +#include "llvm/Frontend/OpenMP/OMP.h" + +#include <cassert> +#include <tuple> +#include <type_traits> +#include <utility> +#include <variant> + +namespace Fortran::parser::omp { + +namespace detail { +using D = llvm::omp::Directive; + +template <typename Construct> // +struct ConstructId { + static constexpr llvm::omp::Directive id{D::OMPD_unknown}; +}; + +#define MAKE_CONSTR_ID(Construct, Id) \ + template <> struct ConstructId<Construct> { \ + static constexpr llvm::omp::Directive id{Id}; \ + } + +MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume); +MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical); +MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant); +MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error); +MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective); +MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate); +MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes); +MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper); +MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction); +MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd); +MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target); +MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate); +MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires); +MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate); + +#undef MAKE_CONSTR_ID + +struct DirectiveNameScope { + static OmpDirectiveName MakeName(CharBlock source = {}, + llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) { + OmpDirectiveName name; + name.source = source; + name.v = id; + return name; + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) { + return MakeName(x.source, llvm::omp::Directive::OMPD_nothing); + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) { + auto &dir{std::get<OmpBlockDirective>(x.t)}; + return MakeName(dir.source, dir.v); + } + + static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) { + auto &dir{std::get<OmpLoopDirective>(x.t)}; + return MakeName(dir.source, dir.v); + } + + static OmpDirectiveName GetOmpDirectiveName( + const OmpBeginSectionsDirective &x) { + auto &dir{std::get<OmpSectionsDirective>(x.t)}; + return MakeName(dir.source, dir.v); + } + + template <typename T> + static OmpDirectiveName GetOmpDirectiveName(const T &x) { + if constexpr (WrapperTrait<T>) { + if constexpr (std::is_same_v<T, OpenMPCancelConstruct> || + std::is_same_v<T, OpenMPCancellationPointConstruct> || + std::is_same_v<T, OpenMPDepobjConstruct> || + std::is_same_v<T, OpenMPFlushConstruct> || + std::is_same_v<T, OpenMPInteropConstruct> || + std::is_same_v<T, OpenMPSimpleStandaloneConstruct>) { + return x.v.DirName(); + } else { + return GetOmpDirectiveName(x.v); + } + } else if constexpr (TupleTrait<T>) { + if constexpr (std::is_same_v<T, OpenMPAllocatorsConstruct> || + std::is_same_v<T, OpenMPAtomicConstruct> || + std::is_same_v<T, OpenMPDispatchConstruct>) { + return std::get<OmpDirectiveSpecification>(x.t).DirName(); + } else if constexpr (std::is_same_v<T, OmpAssumeDirective> || + std::is_same_v<T, OmpCriticalDirective> || + std::is_same_v<T, OmpDeclareVariantDirective> || + std::is_same_v<T, OmpErrorDirective> || + std::is_same_v<T, OmpMetadirectiveDirective> || + std::is_same_v<T, OpenMPDeclarativeAllocate> || + std::is_same_v<T, OpenMPDeclarativeAssumes> || + std::is_same_v<T, OpenMPDeclareMapperConstruct> || + std::is_same_v<T, OpenMPDeclareReductionConstruct> || + std::is_same_v<T, OpenMPDeclareSimdConstruct> || + std::is_same_v<T, OpenMPDeclareTargetConstruct> || + std::is_same_v<T, OpenMPExecutableAllocate> || + std::is_same_v<T, OpenMPRequiresConstruct> || + std::is_same_v<T, OpenMPThreadprivate>) { + return MakeName(std::get<Verbatim>(x.t).source, ConstructId<T>::id); + } else { + return GetFromTuple( + x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{}); + } + } else if constexpr (UnionTrait<T>) { + return common::visit( + [](auto &&s) { return GetOmpDirectiveName(s); }, x.u); + } else { + return MakeName(); + } + } + + template <typename... Ts, size_t... Is> + static OmpDirectiveName GetFromTuple( + const std::tuple<Ts...> &t, std::index_sequence<Is...>) { + OmpDirectiveName name = MakeName(); + auto accumulate = [&](const OmpDirectiveName &n) { + if (name.v == llvm::omp::Directive::OMPD_unknown) { + name = n; + } else { + assert( + n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names"); + } + }; + (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...); + return name; + } + + template <typename T> + static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) { + return GetOmpDirectiveName(x.value()); + } +}; +} // namespace detail + +template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) { + return detail::DirectiveNameScope::GetOmpDirectiveName(x); +} + +} // namespace Fortran::parser::omp + +#endif // FORTRAN_PARSER_OPENMP_UTILS_H diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 11e4883..2ac4d95 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -24,6 +24,7 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Semantics/attr.h" #include "flang/Semantics/tools.h" #include "llvm/ADT/Sequence.h" @@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct( // allow a privatizing clause) are: dispatch, distribute, do, for, loop, // parallel, scope, sections, simd, single, target, target_data, task, // taskgroup, taskloop, and teams. - return llvm::is_contained(privatizing, extractOmpDirective(omp)); + return llvm::is_contained(privatizing, + parser::omp::GetOmpDirectiveName(omp).v); } bool DataSharingProcessor::isOpenMPPrivatizingEvaluation( diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index fc5fef9..4c2d7bad 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -31,6 +31,7 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Parser/characters.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" @@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, lower::pft::Evaluation &eval, mlir::Location loc); -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) { - return beginStatment.v; -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const parser::OmpBeginLoopDirective &beginLoopDirective) { - return getOpenMPDirectiveEnum( - std::get<parser::OmpLoopDirective>(beginLoopDirective.t)); -} - -static llvm::omp::Directive -getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) { - return getOpenMPDirectiveEnum( - std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t)); -} - -static llvm::omp::Directive getOpenMPDirectiveEnum( - const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) { - return getOpenMPDirectiveEnum(ompLoopConstruct.value()); -} - namespace { /// Structure holding information that is needed to pass host-evaluated /// information to later lowering stages. @@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, llvm::omp::Directive dir; auto &nested = parent.getFirstNestedEvaluation(); if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) - dir = extractOmpDirective(*ompEval); + dir = parser::omp::GetOmpDirectiveName(*ompEval).v; else return std::nullopt; @@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); assert(hostInfo && "expected HOST_EVAL info structure"); - switch (extractOmpDirective(*ompEval)) { + switch (parser::omp::GetOmpDirectiveName(*ompEval).v) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: cp.processThreadLimit(stmtCtx, hostInfo->ops); @@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); assert(ompEval && - llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + llvm::omp::allTargetSet.test( + parser::omp::GetOmpDirectiveName(*ompEval).v) && "expected TARGET construct evaluation"); (void)ompEval; @@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( &*optLoopCons)}) { llvm::omp::Directive nestedDirective = - getOpenMPDirectiveEnum(*ompNestedLoopCons); + parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; switch (nestedDirective) { case llvm::omp::Directive::OMPD_tile: // Emit the omp.loop_nest with annotation for tiling @@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, } } - llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective); + llvm::omp::Directive directive = + parser::omp::GetOmpDirectiveName(beginLoopDirective).v; const parser::CharBlock &source = std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source; ConstructQueue queue{ diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b1716d6..13fda97 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -20,6 +20,7 @@ #include <flang/Lower/PFTBuilder.h> #include <flang/Optimizer/Builder/FIRBuilder.h> #include <flang/Optimizer/Builder/Todo.h> +#include <flang/Parser/openmp-utils.h> #include <flang/Parser/parse-tree.h> #include <flang/Parser/tools.h> #include <flang/Semantics/tools.h> @@ -663,89 +664,6 @@ bool collectLoopRelatedInfo( return found; } -/// Get the directive enumeration value corresponding to the given OpenMP -/// construct PFT node. -llvm::omp::Directive -extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { - return common::visit( - common::visitors{ - [](const parser::OpenMPAllocatorsConstruct &c) { - return llvm::omp::OMPD_allocators; - }, - [](const parser::OpenMPAssumeConstruct &c) { - return llvm::omp::OMPD_assume; - }, - [](const parser::OpenMPAtomicConstruct &c) { - return llvm::omp::OMPD_atomic; - }, - [](const parser::OpenMPBlockConstruct &c) { - return std::get<parser::OmpBlockDirective>( - std::get<parser::OmpBeginBlockDirective>(c.t).t) - .v; - }, - [](const parser::OpenMPCriticalConstruct &c) { - return llvm::omp::OMPD_critical; - }, - [](const parser::OpenMPDeclarativeAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPDispatchConstruct &c) { - return llvm::omp::OMPD_dispatch; - }, - [](const parser::OpenMPExecutableAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPLoopConstruct &c) { - return std::get<parser::OmpLoopDirective>( - std::get<parser::OmpBeginLoopDirective>(c.t).t) - .v; - }, - [](const parser::OpenMPSectionConstruct &c) { - return llvm::omp::OMPD_section; - }, - [](const parser::OpenMPSectionsConstruct &c) { - return std::get<parser::OmpSectionsDirective>( - std::get<parser::OmpBeginSectionsDirective>(c.t).t) - .v; - }, - [](const parser::OpenMPStandaloneConstruct &c) { - return common::visit( - common::visitors{ - [](const parser::OpenMPSimpleStandaloneConstruct &c) { - return c.v.DirId(); - }, - [](const parser::OpenMPFlushConstruct &c) { - return llvm::omp::OMPD_flush; - }, - [](const parser::OpenMPCancelConstruct &c) { - return llvm::omp::OMPD_cancel; - }, - [](const parser::OpenMPCancellationPointConstruct &c) { - return llvm::omp::OMPD_cancellation_point; - }, - [](const parser::OmpMetadirectiveDirective &c) { - return llvm::omp::OMPD_metadirective; - }, - [](const parser::OpenMPDepobjConstruct &c) { - return llvm::omp::OMPD_depobj; - }, - [](const parser::OpenMPInteropConstruct &c) { - return llvm::omp::OMPD_interop; - }}, - c.u); - }, - [](const parser::OpenMPUtilityConstruct &c) { - return common::visit( - common::visitors{[](const parser::OmpErrorDirective &c) { - return llvm::omp::OMPD_error; - }, - [](const parser::OmpNothingDirective &c) { - return llvm::omp::OMPD_nothing; - }}, - c.u); - }}, - ompConstruct.u); -} } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 8e3ad5c..11641ba 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -167,8 +167,6 @@ bool collectLoopRelatedInfo( mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &iv); -llvm::omp::Directive -extractOmpDirective(const parser::OpenMPConstruct &ompConstruct); } // namespace omp } // namespace lower } // namespace Fortran |