aboutsummaryrefslogtreecommitdiff
path: root/flang/lib
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib')
-rw-r--r--flang/lib/CMakeLists.txt1
-rw-r--r--flang/lib/Evaluate/characteristics.cpp4
-rw-r--r--flang/lib/Evaluate/check-expression.cpp296
-rw-r--r--flang/lib/Evaluate/common.cpp32
-rw-r--r--flang/lib/Evaluate/fold-character.cpp20
-rw-r--r--flang/lib/Evaluate/fold-complex.cpp22
-rw-r--r--flang/lib/Evaluate/fold-implementation.h121
-rw-r--r--flang/lib/Evaluate/fold-integer.cpp127
-rw-r--r--flang/lib/Evaluate/fold-logical.cpp10
-rw-r--r--flang/lib/Evaluate/fold-matmul.h6
-rw-r--r--flang/lib/Evaluate/fold-real.cpp129
-rw-r--r--flang/lib/Evaluate/fold-reduction.h18
-rw-r--r--flang/lib/Evaluate/fold.cpp7
-rw-r--r--flang/lib/Evaluate/formatting.cpp16
-rw-r--r--flang/lib/Evaluate/host.cpp9
-rw-r--r--flang/lib/Evaluate/intrinsics.cpp97
-rw-r--r--flang/lib/Evaluate/real.cpp8
-rw-r--r--flang/lib/Evaluate/shape.cpp12
-rw-r--r--flang/lib/Evaluate/tools.cpp83
-rw-r--r--flang/lib/Evaluate/variable.cpp18
-rw-r--r--flang/lib/Frontend/CompilerInstance.cpp15
-rw-r--r--flang/lib/Frontend/CompilerInvocation.cpp61
-rw-r--r--flang/lib/Frontend/FrontendActions.cpp25
-rw-r--r--flang/lib/Lower/Allocatable.cpp24
-rw-r--r--flang/lib/Lower/Bridge.cpp78
-rw-r--r--flang/lib/Lower/CMakeLists.txt2
-rw-r--r--flang/lib/Lower/CUDA.cpp167
-rw-r--r--flang/lib/Lower/ConvertCall.cpp61
-rw-r--r--flang/lib/Lower/ConvertConstant.cpp22
-rw-r--r--flang/lib/Lower/ConvertExpr.cpp2
-rw-r--r--flang/lib/Lower/ConvertExprToHLFIR.cpp13
-rw-r--r--flang/lib/Lower/ConvertVariable.cpp86
-rw-r--r--flang/lib/Lower/HlfirIntrinsics.cpp78
-rw-r--r--flang/lib/Lower/HostAssociations.cpp4
-rw-r--r--flang/lib/Lower/OpenACC.cpp14
-rw-r--r--flang/lib/Lower/OpenMP/Atomic.cpp271
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.cpp29
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.h20
-rw-r--r--flang/lib/Lower/OpenMP/Clauses.cpp23
-rw-r--r--flang/lib/Lower/OpenMP/DataSharingProcessor.cpp97
-rw-r--r--flang/lib/Lower/OpenMP/DataSharingProcessor.h35
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp172
-rw-r--r--flang/lib/Lower/OpenMP/Utils.cpp42
-rw-r--r--flang/lib/Lower/OpenMP/Utils.h24
-rw-r--r--flang/lib/Lower/PFTBuilder.cpp6
-rw-r--r--flang/lib/Lower/Runtime.cpp3
-rw-r--r--flang/lib/Lower/Support/PrivateReductionUtils.cpp2
-rw-r--r--flang/lib/Lower/Support/Utils.cpp22
-rw-r--r--flang/lib/Optimizer/Builder/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/Builder/FIRBuilder.cpp33
-rw-r--r--flang/lib/Optimizer/Builder/HLFIRTools.cpp5
-rw-r--r--flang/lib/Optimizer/Builder/IntrinsicCall.cpp248
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Coarray.cpp86
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp17
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Main.cpp7
-rw-r--r--flang/lib/Optimizer/CodeGen/CodeGen.cpp303
-rw-r--r--flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp49
-rw-r--r--flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp23
-rw-r--r--flang/lib/Optimizer/Dialect/FIROps.cpp1
-rw-r--r--flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp28
-rw-r--r--flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp203
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp6
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp33
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp3
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp116
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp2
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp2
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp876
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp10
-rw-r--r--flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp159
-rw-r--r--flang/lib/Optimizer/OpenMP/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp70
-rw-r--r--flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp5
-rw-r--r--flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp1
-rw-r--r--flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp21
-rw-r--r--flang/lib/Optimizer/OpenMP/SimdOnly.cpp209
-rw-r--r--flang/lib/Optimizer/Passes/Pipelines.cpp34
-rw-r--r--flang/lib/Optimizer/Support/Utils.cpp71
-rw-r--r--flang/lib/Optimizer/Transforms/AffineDemotion.cpp5
-rw-r--r--flang/lib/Optimizer/Transforms/AffinePromotion.cpp33
-rw-r--r--flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp3
-rw-r--r--flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp17
-rw-r--r--flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp32
-rw-r--r--flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp37
-rw-r--r--flang/lib/Optimizer/Transforms/FIRToSCF.cpp155
-rw-r--r--flang/lib/Optimizer/Transforms/FunctionAttr.cpp4
-rw-r--r--flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp19
-rw-r--r--flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp24
-rw-r--r--flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp8
-rw-r--r--flang/lib/Optimizer/Transforms/StackArrays.cpp9
-rw-r--r--flang/lib/Parser/CMakeLists.txt1
-rw-r--r--flang/lib/Parser/characters.cpp3
-rw-r--r--flang/lib/Parser/openmp-parsers.cpp120
-rw-r--r--flang/lib/Parser/openmp-utils.cpp64
-rw-r--r--flang/lib/Parser/parsing.cpp3
-rw-r--r--flang/lib/Parser/preprocessor.cpp53
-rw-r--r--flang/lib/Parser/prescan.h9
-rw-r--r--flang/lib/Parser/unparse.cpp52
-rw-r--r--flang/lib/Semantics/check-acc-structure.cpp38
-rw-r--r--flang/lib/Semantics/check-allocate.cpp2
-rw-r--r--flang/lib/Semantics/check-call.cpp224
-rw-r--r--flang/lib/Semantics/check-declarations.cpp58
-rw-r--r--flang/lib/Semantics/check-omp-atomic.cpp600
-rw-r--r--flang/lib/Semantics/check-omp-loop.cpp4
-rw-r--r--flang/lib/Semantics/check-omp-metadirective.cpp3
-rw-r--r--flang/lib/Semantics/check-omp-structure.cpp661
-rw-r--r--flang/lib/Semantics/check-omp-structure.h21
-rw-r--r--flang/lib/Semantics/check-select-rank.cpp2
-rw-r--r--flang/lib/Semantics/check-select-type.cpp2
-rw-r--r--flang/lib/Semantics/compute-offsets.cpp4
-rw-r--r--flang/lib/Semantics/data-to-inits.cpp9
-rw-r--r--flang/lib/Semantics/expression.cpp66
-rw-r--r--flang/lib/Semantics/openmp-utils.cpp2
-rw-r--r--flang/lib/Semantics/openmp-utils.h81
-rw-r--r--flang/lib/Semantics/pointer-assignment.cpp2
-rw-r--r--flang/lib/Semantics/resolve-directives.cpp461
-rw-r--r--flang/lib/Semantics/resolve-names.cpp151
-rw-r--r--flang/lib/Semantics/rewrite-parse-tree.cpp268
-rw-r--r--flang/lib/Semantics/symbol.cpp2
-rw-r--r--flang/lib/Semantics/tools.cpp24
-rw-r--r--flang/lib/Semantics/unparse-with-symbols.cpp19
-rw-r--r--flang/lib/Support/Fortran-features.cpp2
-rw-r--r--flang/lib/Support/Fortran.cpp7
-rw-r--r--flang/lib/Utils/CMakeLists.txt20
-rw-r--r--flang/lib/Utils/OpenMP.cpp47
126 files changed, 5718 insertions, 2774 deletions
diff --git a/flang/lib/CMakeLists.txt b/flang/lib/CMakeLists.txt
index 8b201d9..528e7b5 100644
--- a/flang/lib/CMakeLists.txt
+++ b/flang/lib/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(Semantics)
add_subdirectory(Support)
add_subdirectory(Frontend)
add_subdirectory(FrontendTool)
+add_subdirectory(Utils)
add_subdirectory(Optimizer)
diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp
index 8954773..37c62c9 100644
--- a/flang/lib/Evaluate/characteristics.cpp
+++ b/flang/lib/Evaluate/characteristics.cpp
@@ -400,7 +400,7 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual,
}
if (!attrs.test(Attr::Value) &&
!common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr,
- ignoreTKR, warning,
+ ignoreTKR,
/*allowUnifiedMatchingRule=*/false,
/*=isHostDeviceProcedure*/ false)) {
if (whyNot) {
@@ -1816,7 +1816,7 @@ bool DistinguishUtils::Distinguishable(
x.intent != common::Intent::In) {
return true;
} else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr,
- x.ignoreTKR | y.ignoreTKR, nullptr,
+ x.ignoreTKR | y.ignoreTKR,
/*allowUnifiedMatchingRule=*/false,
/*=isHostDeviceProcedure*/ false)) {
return true;
diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp
index 3d7f01d..394a033 100644
--- a/flang/lib/Evaluate/check-expression.cpp
+++ b/flang/lib/Evaluate/check-expression.cpp
@@ -405,6 +405,88 @@ bool IsInitialProcedureTarget(const Expr<SomeType> &expr) {
}
}
+class SuspiciousRealLiteralFinder
+ : public AnyTraverse<SuspiciousRealLiteralFinder> {
+public:
+ using Base = AnyTraverse<SuspiciousRealLiteralFinder>;
+ SuspiciousRealLiteralFinder(int kind, FoldingContext &c)
+ : Base{*this}, kind_{kind}, context_{c} {}
+ using Base::operator();
+ template <int KIND>
+ bool operator()(const Constant<Type<TypeCategory::Real, KIND>> &x) const {
+ if (kind_ > KIND && x.result().isFromInexactLiteralConversion()) {
+ context_.Warn(common::UsageWarning::RealConstantWidening,
+ "Default real literal in REAL(%d) context might need a kind suffix, as its rounded value %s is inexact"_warn_en_US,
+ kind_, x.AsFortran());
+ return true;
+ } else {
+ return false;
+ }
+ }
+ template <int KIND>
+ bool operator()(const Constant<Type<TypeCategory::Complex, KIND>> &x) const {
+ if (kind_ > KIND && x.result().isFromInexactLiteralConversion()) {
+ context_.Warn(common::UsageWarning::RealConstantWidening,
+ "Default real literal in COMPLEX(%d) context might need a kind suffix, as its rounded value %s is inexact"_warn_en_US,
+ kind_, x.AsFortran());
+ return true;
+ } else {
+ return false;
+ }
+ }
+ template <TypeCategory TOCAT, int TOKIND, TypeCategory FROMCAT>
+ bool operator()(const Convert<Type<TOCAT, TOKIND>, FROMCAT> &x) const {
+ if constexpr ((TOCAT == TypeCategory::Real ||
+ TOCAT == TypeCategory::Complex) &&
+ (FROMCAT == TypeCategory::Real || FROMCAT == TypeCategory::Complex)) {
+ auto fromType{x.left().GetType()};
+ if (!fromType || fromType->kind() < TOKIND) {
+ return false;
+ }
+ }
+ return (*this)(x.left());
+ }
+
+private:
+ int kind_;
+ FoldingContext &context_;
+};
+
+void CheckRealWidening(const Expr<SomeType> &expr, const DynamicType &toType,
+ FoldingContext &context) {
+ if (toType.category() == TypeCategory::Real ||
+ toType.category() == TypeCategory::Complex) {
+ if (auto fromType{expr.GetType()}) {
+ if ((fromType->category() == TypeCategory::Real ||
+ fromType->category() == TypeCategory::Complex) &&
+ toType.kind() > fromType->kind()) {
+ SuspiciousRealLiteralFinder{toType.kind(), context}(expr);
+ }
+ }
+ }
+}
+
+void CheckRealWidening(const Expr<SomeType> &expr,
+ const std::optional<DynamicType> &toType, FoldingContext &context) {
+ if (toType) {
+ CheckRealWidening(expr, *toType, context);
+ }
+}
+
+class InexactLiteralConversionFlagClearer
+ : public AnyTraverse<InexactLiteralConversionFlagClearer> {
+public:
+ using Base = AnyTraverse<InexactLiteralConversionFlagClearer>;
+ InexactLiteralConversionFlagClearer() : Base(*this) {}
+ using Base::operator();
+ template <int KIND>
+ bool operator()(const Constant<Type<TypeCategory::Real, KIND>> &x) const {
+ auto &mut{const_cast<Type<TypeCategory::Real, KIND> &>(x.result())};
+ mut.set_isFromInexactLiteralConversion(false);
+ return false;
+ }
+};
+
// Converts, folds, and then checks type, rank, and shape of an
// initialization expression for a named constant, a non-pointer
// variable static initialization, a component default initializer,
@@ -416,16 +498,14 @@ std::optional<Expr<SomeType>> NonPointerInitializationExpr(const Symbol &symbol,
if (auto symTS{
characteristics::TypeAndShape::Characterize(symbol, context)}) {
auto xType{x.GetType()};
+ CheckRealWidening(x, symTS->type(), context);
auto converted{ConvertToType(symTS->type(), Expr<SomeType>{x})};
if (!converted &&
symbol.owner().context().IsEnabled(
common::LanguageFeature::LogicalIntegerAssignment)) {
converted = DataConstantConversionExtension(context, symTS->type(), x);
- if (converted &&
- symbol.owner().context().ShouldWarn(
- common::LanguageFeature::LogicalIntegerAssignment)) {
- context.messages().Say(
- common::LanguageFeature::LogicalIntegerAssignment,
+ if (converted) {
+ context.Warn(common::LanguageFeature::LogicalIntegerAssignment,
"nonstandard usage: initialization of %s with %s"_port_en_US,
symTS->type().AsFortran(), x.GetType().value().AsFortran());
}
@@ -433,6 +513,7 @@ std::optional<Expr<SomeType>> NonPointerInitializationExpr(const Symbol &symbol,
if (converted) {
auto folded{Fold(context, std::move(*converted))};
if (IsActuallyConstant(folded)) {
+ InexactLiteralConversionFlagClearer{}(folded);
int symRank{symTS->Rank()};
if (IsImpliedShape(symbol)) {
if (folded.Rank() == symRank) {
@@ -579,10 +660,8 @@ public:
// host-associated dummy argument, and that doesn't seem like a
// good idea.
if (!inInquiry_ && hasHostAssociation &&
- ultimate.attrs().test(semantics::Attr::INTENT_OUT) &&
- context_.languageFeatures().ShouldWarn(
- common::UsageWarning::HostAssociatedIntentOutInSpecExpr)) {
- context_.messages().Say(
+ ultimate.attrs().test(semantics::Attr::INTENT_OUT)) {
+ context_.Warn(common::UsageWarning::HostAssociatedIntentOutInSpecExpr,
"specification expression refers to host-associated INTENT(OUT) dummy argument '%s'"_port_en_US,
ultimate.name());
}
@@ -593,13 +672,9 @@ public:
} else if (isInitialized &&
context_.languageFeatures().IsEnabled(
common::LanguageFeature::SavedLocalInSpecExpr)) {
- if (!scope_.IsModuleFile() &&
- context_.languageFeatures().ShouldWarn(
- common::LanguageFeature::SavedLocalInSpecExpr)) {
- context_.messages().Say(common::LanguageFeature::SavedLocalInSpecExpr,
- "specification expression refers to local object '%s' (initialized and saved)"_port_en_US,
- ultimate.name());
- }
+ context_.Warn(common::LanguageFeature::SavedLocalInSpecExpr,
+ "specification expression refers to local object '%s' (initialized and saved)"_port_en_US,
+ ultimate.name());
return std::nullopt;
} else if (const auto *object{
ultimate.detailsIf<semantics::ObjectEntityDetails>()}) {
@@ -917,8 +992,8 @@ public:
} else {
return Base::operator()(ultimate); // use expr
}
- } else if (semantics::IsPointer(ultimate) ||
- semantics::IsAssumedShape(ultimate) || IsAssumedRank(ultimate)) {
+ } else if (semantics::IsPointer(ultimate) || IsAssumedShape(ultimate) ||
+ IsAssumedRank(ultimate)) {
return std::nullopt;
} else if (ultimate.has<semantics::ObjectEntityDetails>()) {
return true;
@@ -1198,9 +1273,21 @@ std::optional<bool> IsContiguous(const A &x, FoldingContext &context,
}
}
+std::optional<bool> IsContiguous(const ActualArgument &actual,
+ FoldingContext &fc, bool namedConstantSectionsAreContiguous,
+ bool firstDimensionStride1) {
+ auto *expr{actual.UnwrapExpr()};
+ return expr &&
+ IsContiguous(
+ *expr, fc, namedConstantSectionsAreContiguous, firstDimensionStride1);
+}
+
template std::optional<bool> IsContiguous(const Expr<SomeType> &,
FoldingContext &, bool namedConstantSectionsAreContiguous,
bool firstDimensionStride1);
+template std::optional<bool> IsContiguous(const ActualArgument &,
+ FoldingContext &, bool namedConstantSectionsAreContiguous,
+ bool firstDimensionStride1);
template std::optional<bool> IsContiguous(const ArrayRef &, FoldingContext &,
bool namedConstantSectionsAreContiguous, bool firstDimensionStride1);
template std::optional<bool> IsContiguous(const Substring &, FoldingContext &,
@@ -1350,4 +1437,177 @@ std::optional<parser::Message> CheckStatementFunction(
return StmtFunctionChecker{sf, context}(expr);
}
+// Helper class for checking differences between actual and dummy arguments
+class CopyInOutExplicitInterface {
+public:
+ explicit CopyInOutExplicitInterface(FoldingContext &fc,
+ const ActualArgument &actual,
+ const characteristics::DummyDataObject &dummyObj)
+ : fc_{fc}, actual_{actual}, dummyObj_{dummyObj} {}
+
+ // Returns true, if actual and dummy have different contiguity requirements
+ bool HaveContiguityDifferences() const {
+ // Check actual contiguity, unless dummy doesn't care
+ bool dummyTreatAsArray{dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank)};
+ bool actualTreatAsContiguous{
+ dummyObj_.ignoreTKR.test(common::IgnoreTKR::Contiguous) ||
+ IsSimplyContiguous(actual_, fc_)};
+ bool dummyIsExplicitShape{dummyObj_.type.IsExplicitShape()};
+ bool dummyIsAssumedSize{dummyObj_.type.attrs().test(
+ characteristics::TypeAndShape::Attr::AssumedSize)};
+ bool dummyIsPolymorphic{dummyObj_.type.type().IsPolymorphic()};
+ // type(*) with IGNORE_TKR(tkr) is often used to interface with C "void*".
+ // Since the other languages don't know about Fortran's discontiguity
+ // handling, such cases should require contiguity.
+ bool dummyIsVoidStar{dummyObj_.type.type().IsAssumedType() &&
+ dummyObj_.ignoreTKR.test(common::IgnoreTKR::Type) &&
+ dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank) &&
+ dummyObj_.ignoreTKR.test(common::IgnoreTKR::Kind)};
+ // Explicit shape and assumed size arrays must be contiguous
+ bool dummyNeedsContiguity{dummyIsExplicitShape || dummyIsAssumedSize ||
+ (dummyTreatAsArray && !dummyIsPolymorphic) || dummyIsVoidStar ||
+ dummyObj_.attrs.test(
+ characteristics::DummyDataObject::Attr::Contiguous)};
+ return !actualTreatAsContiguous && dummyNeedsContiguity;
+ }
+
+ // Returns true, if actual and dummy have polymorphic differences
+ bool HavePolymorphicDifferences() const {
+ bool dummyIsAssumedRank{dummyObj_.type.attrs().test(
+ characteristics::TypeAndShape::Attr::AssumedRank)};
+ bool actualIsAssumedRank{semantics::IsAssumedRank(actual_)};
+ bool dummyIsAssumedShape{dummyObj_.type.attrs().test(
+ characteristics::TypeAndShape::Attr::AssumedShape)};
+ bool actualIsAssumedShape{semantics::IsAssumedShape(actual_)};
+ if ((actualIsAssumedRank && dummyIsAssumedRank) ||
+ (actualIsAssumedShape && dummyIsAssumedShape)) {
+ // Assumed-rank and assumed-shape arrays are represented by descriptors,
+ // so don't need to do polymorphic check.
+ } else if (!dummyObj_.ignoreTKR.test(common::IgnoreTKR::Type)) {
+ // flang supports limited cases of passing polymorphic to non-polimorphic.
+ // These cases require temporary of non-polymorphic type. (For example,
+ // the actual argument could be polymorphic array of child type,
+ // while the dummy argument could be non-polymorphic array of parent
+ // type.)
+ bool dummyIsPolymorphic{dummyObj_.type.type().IsPolymorphic()};
+ auto actualType{
+ characteristics::TypeAndShape::Characterize(actual_, fc_)};
+ bool actualIsPolymorphic{
+ actualType && actualType->type().IsPolymorphic()};
+ if (actualIsPolymorphic && !dummyIsPolymorphic) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool HaveArrayOrAssumedRankArgs() const {
+ bool dummyTreatAsArray{dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank)};
+ return IsArrayOrAssumedRank(actual_) &&
+ (IsArrayOrAssumedRank(dummyObj_) || dummyTreatAsArray);
+ }
+
+ bool PassByValue() const {
+ return dummyObj_.attrs.test(characteristics::DummyDataObject::Attr::Value);
+ }
+
+ bool HaveCoarrayDifferences() const {
+ return ExtractCoarrayRef(actual_) && dummyObj_.type.corank() == 0;
+ }
+
+ bool HasIntentOut() const { return dummyObj_.intent == common::Intent::Out; }
+
+ bool HasIntentIn() const { return dummyObj_.intent == common::Intent::In; }
+
+ static bool IsArrayOrAssumedRank(const ActualArgument &actual) {
+ return semantics::IsAssumedRank(actual) || actual.Rank() > 0;
+ }
+
+ static bool IsArrayOrAssumedRank(
+ const characteristics::DummyDataObject &dummy) {
+ return dummy.type.attrs().test(
+ characteristics::TypeAndShape::Attr::AssumedRank) ||
+ dummy.type.Rank() > 0;
+ }
+
+private:
+ FoldingContext &fc_;
+ const ActualArgument &actual_;
+ const characteristics::DummyDataObject &dummyObj_;
+};
+
+// If forCopyOut is false, returns if a particular actual/dummy argument
+// combination may need a temporary creation with copy-in operation. If
+// forCopyOut is true, returns the same for copy-out operation. For
+// procedures with explicit interface, it's expected that "dummy" is not null.
+// For procedures with implicit interface dummy may be null.
+//
+// Note that these copy-in and copy-out checks are done from the caller's
+// perspective, meaning that for copy-in the caller need to do the copy
+// before calling the callee. Similarly, for copy-out the caller is expected
+// to do the copy after the callee returns.
+bool MayNeedCopy(const ActualArgument *actual,
+ const characteristics::DummyArgument *dummy, FoldingContext &fc,
+ bool forCopyOut) {
+ if (!actual) {
+ return false;
+ }
+ if (actual->isAlternateReturn()) {
+ return false;
+ }
+ const auto *dummyObj{dummy
+ ? std::get_if<characteristics::DummyDataObject>(&dummy->u)
+ : nullptr};
+ const bool forCopyIn = !forCopyOut;
+ if (!evaluate::IsVariable(*actual)) {
+ // Actual argument expressions that aren’t variables are copy-in, but
+ // not copy-out.
+ return forCopyIn;
+ }
+ if (dummyObj) { // Explict interface
+ CopyInOutExplicitInterface check{fc, *actual, *dummyObj};
+ if (forCopyOut && check.HasIntentIn()) {
+ // INTENT(IN) dummy args never need copy-out
+ return false;
+ }
+ if (forCopyIn && check.HasIntentOut()) {
+ // INTENT(OUT) dummy args never need copy-in
+ return false;
+ }
+ if (check.PassByValue()) {
+ // Pass by value, always copy-in, never copy-out
+ return forCopyIn;
+ }
+ if (check.HaveCoarrayDifferences()) {
+ return true;
+ }
+ // Note: contiguity and polymorphic checks deal with array or assumed rank
+ // arguments
+ if (!check.HaveArrayOrAssumedRankArgs()) {
+ return false;
+ }
+ if (check.HaveContiguityDifferences()) {
+ return true;
+ }
+ if (check.HavePolymorphicDifferences()) {
+ return true;
+ }
+ } else { // Implicit interface
+ if (ExtractCoarrayRef(*actual)) {
+ // Coindexed actual args may need copy-in and copy-out with implicit
+ // interface
+ return true;
+ }
+ if (!IsSimplyContiguous(*actual, fc)) {
+ // Copy-in: actual arguments that are variables are copy-in when
+ // non-contiguous.
+ // Copy-out: vector subscripts could refer to duplicate elements, can't
+ // copy out.
+ return !(forCopyOut && HasVectorSubscript(*actual));
+ }
+ }
+ // For everything else, no copy-in or copy-out
+ return false;
+}
+
} // namespace Fortran::evaluate
diff --git a/flang/lib/Evaluate/common.cpp b/flang/lib/Evaluate/common.cpp
index 6a960d4..46c75a5 100644
--- a/flang/lib/Evaluate/common.cpp
+++ b/flang/lib/Evaluate/common.cpp
@@ -16,26 +16,22 @@ namespace Fortran::evaluate {
void RealFlagWarnings(
FoldingContext &context, const RealFlags &flags, const char *operation) {
static constexpr auto warning{common::UsageWarning::FoldingException};
- if (context.languageFeatures().ShouldWarn(warning)) {
- if (flags.test(RealFlag::Overflow)) {
- context.messages().Say(warning, "overflow on %s"_warn_en_US, operation);
- }
- if (flags.test(RealFlag::DivideByZero)) {
- if (std::strcmp(operation, "division") == 0) {
- context.messages().Say(warning, "division by zero"_warn_en_US);
- } else {
- context.messages().Say(
- warning, "division by zero on %s"_warn_en_US, operation);
- }
- }
- if (flags.test(RealFlag::InvalidArgument)) {
- context.messages().Say(
- warning, "invalid argument on %s"_warn_en_US, operation);
- }
- if (flags.test(RealFlag::Underflow)) {
- context.messages().Say(warning, "underflow on %s"_warn_en_US, operation);
+ if (flags.test(RealFlag::Overflow)) {
+ context.Warn(warning, "overflow on %s"_warn_en_US, operation);
+ }
+ if (flags.test(RealFlag::DivideByZero)) {
+ if (std::strcmp(operation, "division") == 0) {
+ context.Warn(warning, "division by zero"_warn_en_US);
+ } else {
+ context.Warn(warning, "division by zero on %s"_warn_en_US, operation);
}
}
+ if (flags.test(RealFlag::InvalidArgument)) {
+ context.Warn(warning, "invalid argument on %s"_warn_en_US, operation);
+ }
+ if (flags.test(RealFlag::Underflow)) {
+ context.Warn(warning, "underflow on %s"_warn_en_US, operation);
+ }
}
ConstantSubscript &FoldingContext::StartImpliedDo(
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 76ac497..a43742a 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -58,13 +58,10 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, IntT>(context, std::move(funcRef),
ScalarFunc<T, IntT>([&](const Scalar<IntT> &i) {
if (i.IsNegative() || i.BGE(Scalar<IntT>{0}.IBSET(8 * KIND))) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) {
- context.messages().Say(common::UsageWarning::FoldingValueChecks,
- "%s(I=%jd) is out of range for CHARACTER(KIND=%d)"_warn_en_US,
- parser::ToUpperCaseLetters(name),
- static_cast<std::intmax_t>(i.ToInt64()), KIND);
- }
+ context.Warn(common::UsageWarning::FoldingValueChecks,
+ "%s(I=%jd) is out of range for CHARACTER(KIND=%d)"_warn_en_US,
+ parser::ToUpperCaseLetters(name),
+ static_cast<std::intmax_t>(i.ToInt64()), KIND);
}
return CharacterUtils<KIND>::CHAR(i.ToUInt64());
}));
@@ -106,12 +103,9 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
static_cast<std::intmax_t>(n));
} else if (static_cast<double>(n) * str.size() >
(1 << 20)) { // sanity limit of 1MiB
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingLimit)) {
- context.messages().Say(common::UsageWarning::FoldingLimit,
- "Result of REPEAT() is too large to compute at compilation time (%g characters)"_port_en_US,
- static_cast<double>(n) * str.size());
- }
+ context.Warn(common::UsageWarning::FoldingLimit,
+ "Result of REPEAT() is too large to compute at compilation time (%g characters)"_port_en_US,
+ static_cast<double>(n) * str.size());
} else {
return Expr<T>{Constant<T>{CharacterUtils<KIND>::REPEAT(str, n)}};
}
diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index 3eb8e1f..84066ee 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -29,9 +29,8 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) {
return FoldElementalIntrinsic<T, T>(
context, std::move(funcRef), *callable);
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"%s(complex(kind=%d)) cannot be folded on host"_warn_en_US, name,
KIND);
}
@@ -83,12 +82,21 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(
if (auto array{ApplyElementwise(context, x)}) {
return *array;
}
- using Result = Type<TypeCategory::Complex, KIND>;
+ using ComplexType = Type<TypeCategory::Complex, KIND>;
if (auto folded{OperandsAreConstants(x)}) {
- return Expr<Result>{
- Constant<Result>{Scalar<Result>{folded->first, folded->second}}};
+ using RealType = typename ComplexType::Part;
+ Constant<ComplexType> result{
+ Scalar<ComplexType>{folded->first, folded->second}};
+ if (const auto *re{UnwrapConstantValue<RealType>(x.left())};
+ re && re->result().isFromInexactLiteralConversion()) {
+ result.result().set_isFromInexactLiteralConversion();
+ } else if (const auto *im{UnwrapConstantValue<RealType>(x.right())};
+ im && im->result().isFromInexactLiteralConversion()) {
+ result.result().set_isFromInexactLiteralConversion();
+ }
+ return Expr<ComplexType>{std::move(result)};
}
- return Expr<Result>{std::move(x)};
+ return Expr<ComplexType>{std::move(x)};
}
#ifdef _MSC_VER // disable bogus warning about missing definitions
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 52e954d..3fdf3a6 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -1321,8 +1321,8 @@ public:
*charLength_, std::move(elements_), ConstantSubscripts{n}}};
}
} else {
- return Expr<T>{
- Constant<T>{std::move(elements_), ConstantSubscripts{n}}};
+ return Expr<T>{Constant<T>{
+ std::move(elements_), ConstantSubscripts{n}, resultInfo_}};
}
}
return Expr<T>{std::move(array)};
@@ -1343,6 +1343,11 @@ private:
if (!knownCharLength_) {
charLength_ = std::max(c->LEN(), charLength_.value_or(-1));
}
+ } else if constexpr (T::category == TypeCategory::Real ||
+ T::category == TypeCategory::Complex) {
+ if (c->result().isFromInexactLiteralConversion()) {
+ resultInfo_.set_isFromInexactLiteralConversion();
+ }
}
return true;
} else {
@@ -1395,6 +1400,7 @@ private:
std::vector<Scalar<T>> elements_;
std::optional<ConstantSubscript> charLength_;
bool knownCharLength_{false};
+ typename Constant<T>::Result resultInfo_;
};
template <typename T>
@@ -1779,7 +1785,7 @@ common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) {
if (static_cast<std::uint64_t>(*iter) > 127) {
return std::nullopt;
}
- str.push_back(*iter);
+ str.push_back(static_cast<typename TO::value_type>(*iter));
}
return std::make_optional<TO>(std::move(str));
}
@@ -1808,10 +1814,8 @@ Expr<TO> FoldOperation(
if constexpr (TO::category == TypeCategory::Integer) {
if constexpr (FromCat == TypeCategory::Integer) {
auto converted{Scalar<TO>::ConvertSigned(*value)};
- if (converted.overflow &&
- msvcWorkaround.context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- ctx.messages().Say(common::UsageWarning::FoldingException,
+ if (converted.overflow) {
+ ctx.Warn(common::UsageWarning::FoldingException,
"conversion of %s_%d to INTEGER(%d) overflowed; result is %s"_warn_en_US,
value->SignedDecimal(), Operand::kind, TO::kind,
converted.value.SignedDecimal());
@@ -1819,10 +1823,8 @@ Expr<TO> FoldOperation(
return ScalarConstantToExpr(std::move(converted.value));
} else if constexpr (FromCat == TypeCategory::Unsigned) {
auto converted{Scalar<TO>::ConvertUnsigned(*value)};
- if ((converted.overflow || converted.value.IsNegative()) &&
- msvcWorkaround.context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- ctx.messages().Say(common::UsageWarning::FoldingException,
+ if ((converted.overflow || converted.value.IsNegative())) {
+ ctx.Warn(common::UsageWarning::FoldingException,
"conversion of %s_U%d to INTEGER(%d) overflowed; result is %s"_warn_en_US,
value->UnsignedDecimal(), Operand::kind, TO::kind,
converted.value.SignedDecimal());
@@ -1830,17 +1832,14 @@ Expr<TO> FoldOperation(
return ScalarConstantToExpr(std::move(converted.value));
} else if constexpr (FromCat == TypeCategory::Real) {
auto converted{value->template ToInteger<Scalar<TO>>()};
- if (msvcWorkaround.context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- if (converted.flags.test(RealFlag::InvalidArgument)) {
- ctx.messages().Say(common::UsageWarning::FoldingException,
- "REAL(%d) to INTEGER(%d) conversion: invalid argument"_warn_en_US,
- Operand::kind, TO::kind);
- } else if (converted.flags.test(RealFlag::Overflow)) {
- ctx.messages().Say(
- "REAL(%d) to INTEGER(%d) conversion overflowed"_warn_en_US,
- Operand::kind, TO::kind);
- }
+ if (converted.flags.test(RealFlag::InvalidArgument)) {
+ ctx.Warn(common::UsageWarning::FoldingException,
+ "REAL(%d) to INTEGER(%d) conversion: invalid argument"_warn_en_US,
+ Operand::kind, TO::kind);
+ } else if (converted.flags.test(RealFlag::Overflow)) {
+ ctx.Warn(common::UsageWarning::FoldingException,
+ "REAL(%d) to INTEGER(%d) conversion overflowed"_warn_en_US,
+ Operand::kind, TO::kind);
}
return ScalarConstantToExpr(std::move(converted.value));
}
@@ -1960,10 +1959,8 @@ Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) {
} else if (auto value{GetScalarConstantValue<T>(operand)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto negated{value->Negate()};
- if (negated.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (negated.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"INTEGER(%d) negation overflowed"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{std::move(negated.value)}};
@@ -2004,10 +2001,8 @@ Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto sum{folded->first.AddSigned(folded->second)};
- if (sum.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (sum.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"INTEGER(%d) addition overflowed"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{sum.value}};
@@ -2035,10 +2030,8 @@ Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto difference{folded->first.SubtractSigned(folded->second)};
- if (difference.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (difference.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"INTEGER(%d) subtraction overflowed"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{difference.value}};
@@ -2066,10 +2059,8 @@ Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto product{folded->first.MultiplySigned(folded->second)};
- if (product.SignedMultiplicationOverflowed() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (product.SignedMultiplicationOverflowed()) {
+ context.Warn(common::UsageWarning::FoldingException,
"INTEGER(%d) multiplication overflowed"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{product.lower}};
@@ -2116,28 +2107,20 @@ Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
if constexpr (T::category == TypeCategory::Integer) {
auto quotAndRem{folded->first.DivideSigned(folded->second)};
if (quotAndRem.divisionByZero) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "INTEGER(%d) division by zero"_warn_en_US, T::kind);
- }
+ context.Warn(common::UsageWarning::FoldingException,
+ "INTEGER(%d) division by zero"_warn_en_US, T::kind);
return Expr<T>{std::move(x)};
}
- if (quotAndRem.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (quotAndRem.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"INTEGER(%d) division overflowed"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{quotAndRem.quotient}};
} else if constexpr (T::category == TypeCategory::Unsigned) {
auto quotAndRem{folded->first.DivideUnsigned(folded->second)};
if (quotAndRem.divisionByZero) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "UNSIGNED(%d) division by zero"_warn_en_US, T::kind);
- }
+ context.Warn(common::UsageWarning::FoldingException,
+ "UNSIGNED(%d) division by zero"_warn_en_US, T::kind);
return Expr<T>{std::move(x)};
}
return Expr<T>{Constant<T>{quotAndRem.quotient}};
@@ -2177,24 +2160,21 @@ Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
if (auto folded{OperandsAreConstants(x)}) {
if constexpr (T::category == TypeCategory::Integer) {
auto power{folded->first.Power(folded->second)};
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- if (power.divisionByZero) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "INTEGER(%d) zero to negative power"_warn_en_US, T::kind);
- } else if (power.overflow) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "INTEGER(%d) power overflowed"_warn_en_US, T::kind);
- } else if (power.zeroToZero) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "INTEGER(%d) 0**0 is not defined"_warn_en_US, T::kind);
- }
+ if (power.divisionByZero) {
+ context.Warn(common::UsageWarning::FoldingException,
+ "INTEGER(%d) zero to negative power"_warn_en_US, T::kind);
+ } else if (power.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
+ "INTEGER(%d) power overflowed"_warn_en_US, T::kind);
+ } else if (power.zeroToZero) {
+ context.Warn(common::UsageWarning::FoldingException,
+ "INTEGER(%d) 0**0 is not defined"_warn_en_US, T::kind);
}
return Expr<T>{Constant<T>{power.power}};
} else {
if (folded->first.IsZero()) {
if (folded->second.IsZero()) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ context.Warn(common::UsageWarning::FoldingException,
"REAL/COMPLEX 0**0 is not defined"_warn_en_US);
} else {
return Expr<T>(Constant<T>{folded->first}); // 0. ** nonzero -> 0.
@@ -2202,9 +2182,8 @@ Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
} else if (auto callable{GetHostRuntimeWrapper<T, T, T>("pow")}) {
return Expr<T>{
Constant<T>{(*callable)(context, folded->first, folded->second)}};
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"Power for %s cannot be folded on host"_warn_en_US,
T{}.AsFortran());
}
@@ -2291,10 +2270,8 @@ Expr<Type<TypeCategory::Real, KIND>> ToReal(
CHECK(constant);
Scalar<Result> real{constant->GetScalarValue().value()};
From converted{From::ConvertUnsigned(real.RawBits()).value};
- if (original != converted &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) { // C1601
- context.messages().Say(common::UsageWarning::FoldingValueChecks,
+ if (original != converted) { // C1601
+ context.Warn(common::UsageWarning::FoldingValueChecks,
"Nonzero bits truncated from BOZ literal constant in REAL intrinsic"_warn_en_US);
}
} else if constexpr (IsNumericCategoryExpr<From>()) {
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 352dec4..3628497 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -38,13 +38,13 @@ static bool CheckDimArg(const std::optional<ActualArgument> &dimArg,
const Expr<SomeType> &array, parser::ContextualMessages &messages,
bool isLBound, std::optional<int> &dimVal) {
dimVal.reset();
- if (int rank{array.Rank()}; rank > 0 || IsAssumedRank(array)) {
+ if (int rank{array.Rank()}; rank > 0 || semantics::IsAssumedRank(array)) {
auto named{ExtractNamedEntity(array)};
if (auto dim64{ToInt64(dimArg)}) {
if (*dim64 < 1) {
messages.Say("DIM=%jd dimension must be positive"_err_en_US, *dim64);
return false;
- } else if (!IsAssumedRank(array) && *dim64 > rank) {
+ } else if (!semantics::IsAssumedRank(array) && *dim64 > rank) {
messages.Say(
"DIM=%jd dimension is out of range for rank-%d array"_err_en_US,
*dim64, rank);
@@ -56,7 +56,7 @@ static bool CheckDimArg(const std::optional<ActualArgument> &dimArg,
"DIM=%jd dimension is out of range for rank-%d assumed-size array"_err_en_US,
*dim64, rank);
return false;
- } else if (IsAssumedRank(array)) {
+ } else if (semantics::IsAssumedRank(array)) {
if (*dim64 > common::maxRank) {
messages.Say(
"DIM=%jd dimension is too large for any array (maximum rank %d)"_err_en_US,
@@ -189,7 +189,7 @@ Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
return Expr<T>{std::move(funcRef)};
}
}
- if (IsAssumedRank(*array)) {
+ if (semantics::IsAssumedRank(*array)) {
// Would like to return 1 if DIM=.. is present, but that would be
// hiding a runtime error if the DIM= were too large (including
// the case of an assumed-rank argument that's scalar).
@@ -240,7 +240,7 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
return Expr<T>{std::move(funcRef)};
}
}
- if (IsAssumedRank(*array)) {
+ if (semantics::IsAssumedRank(*array)) {
} else if (int rank{array->Rank()}; rank > 0) {
bool takeBoundsFromShape{true};
if (auto named{ExtractNamedEntity(*array)}) {
@@ -350,10 +350,8 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
CountAccumulator<T, maskKind> accumulator{arrayAndMask->array};
Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
dim, Scalar<T>{}, accumulator)};
- if (accumulator.overflow() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (accumulator.overflow()) {
+ context.Warn(common::UsageWarning::FoldingException,
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
}
return Expr<T>{std::move(result)};
@@ -965,10 +963,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
auto FromInt64{[&name, &context](std::int64_t n) {
Scalar<T> result{n};
- if (result.ToInt64() != n &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (result.ToInt64() != n) {
+ context.Warn(common::UsageWarning::FoldingException,
"Result of intrinsic function '%s' (%jd) overflows its result type"_warn_en_US,
name, std::intmax_t{n});
}
@@ -979,10 +975,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
typename Scalar<T>::ValueWithOverflow j{i.ABS()};
- if (j.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (j.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"abs(integer(kind=%d)) folding overflowed"_warn_en_US, KIND);
}
return j.value;
@@ -999,11 +993,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
auto y{x.template ToInteger<Scalar<T>>(mode)};
- if (y.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(
- common::UsageWarning::FoldingException,
+ if (y.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"%s intrinsic folding overflow"_warn_en_US, name);
}
return y.value;
@@ -1029,10 +1020,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>(
[&context](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
auto result{x.DIM(y)};
- if (result.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (result.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"DIM intrinsic folding overflow"_warn_en_US);
}
return result.value;
@@ -1061,14 +1050,13 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
context.messages().Say(
"Character in intrinsic function %s must have length one"_err_en_US,
name);
- } else if (len.value() > 1 &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::Portability)) {
- // Do not die, this was not checked before
- context.messages().Say(common::UsageWarning::Portability,
- "Character in intrinsic function %s should have length one"_port_en_US,
- name);
} else {
+ // Do not die, this was not checked before
+ if (len.value() > 1) {
+ context.Warn(common::UsageWarning::Portability,
+ "Character in intrinsic function %s should have length one"_port_en_US,
+ name);
+ }
return common::visit(
[&funcRef, &context, &FromInt64](const auto &str) -> Expr<T> {
using Char = typename std::decay_t<decltype(str)>::Result;
@@ -1256,11 +1244,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
bool badPConst{false};
if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
*pExpr = Fold(context, std::move(*pExpr));
- if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
- pConst->IsZero() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (auto pConst{GetScalarConstantValue<T>(*pExpr)};
+ pConst && pConst->IsZero()) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"MOD: P argument is zero"_warn_en_US);
badPConst = true;
}
@@ -1270,17 +1256,12 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
[badPConst](FoldingContext &context, const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
auto quotRem{x.DivideSigned(y)};
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- if (!badPConst && quotRem.divisionByZero) {
- context.messages().Say(
- common::UsageWarning::FoldingAvoidsRuntimeCrash,
- "mod() by zero"_warn_en_US);
- } else if (quotRem.overflow) {
- context.messages().Say(
- common::UsageWarning::FoldingAvoidsRuntimeCrash,
- "mod() folding overflowed"_warn_en_US);
- }
+ if (!badPConst && quotRem.divisionByZero) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ "mod() by zero"_warn_en_US);
+ } else if (quotRem.overflow) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ "mod() folding overflowed"_warn_en_US);
}
return quotRem.remainder;
}));
@@ -1288,11 +1269,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
bool badPConst{false};
if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
*pExpr = Fold(context, std::move(*pExpr));
- if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
- pConst->IsZero() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (auto pConst{GetScalarConstantValue<T>(*pExpr)};
+ pConst && pConst->IsZero()) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"MODULO: P argument is zero"_warn_en_US);
badPConst = true;
}
@@ -1302,10 +1281,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
auto result{x.MODULO(y)};
- if (!badPConst && result.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (!badPConst && result.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"modulo() folding overflowed"_warn_en_US);
}
return result.value;
@@ -1405,10 +1382,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>([&context](const Scalar<T> &j,
const Scalar<T> &k) -> Scalar<T> {
typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)};
- if (result.overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (result.overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"sign(integer(kind=%d)) folding overflowed"_warn_en_US, KIND);
}
return result.value;
@@ -1465,11 +1440,11 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
auto realBytes{
context.targetCharacteristics().GetByteSize(TypeCategory::Real,
context.defaults().GetDefaultKind(TypeCategory::Real))};
- if (intBytes != realBytes &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) {
- context.messages().Say(common::UsageWarning::FoldingValueChecks,
- *context.moduleFileName(),
+ if (intBytes != realBytes) {
+ // Using the low-level API to bypass the module file check in this case.
+ context.messages().Warn(
+ /*isInModuleFile=*/false, context.languageFeatures(),
+ common::UsageWarning::FoldingValueChecks, *context.moduleFileName(),
"NUMERIC_STORAGE_SIZE from ISO_FORTRAN_ENV is not well-defined when default INTEGER and REAL are not consistent due to compiler options"_warn_en_US);
}
return Expr<T>{8 * std::min(intBytes, realBytes)};
@@ -1496,11 +1471,9 @@ Expr<Type<TypeCategory::Unsigned, KIND>> FoldIntrinsicFunction(
bool badPConst{false};
if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
*pExpr = Fold(context, std::move(*pExpr));
- if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
- pConst->IsZero() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (auto pConst{GetScalarConstantValue<T>(*pExpr)};
+ pConst && pConst->IsZero()) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"%s: P argument is zero"_warn_en_US, name);
badPConst = true;
}
@@ -1510,13 +1483,9 @@ Expr<Type<TypeCategory::Unsigned, KIND>> FoldIntrinsicFunction(
[badPConst, &name](FoldingContext &context, const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
auto quotRem{x.DivideUnsigned(y)};
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- if (!badPConst && quotRem.divisionByZero) {
- context.messages().Say(
- common::UsageWarning::FoldingAvoidsRuntimeCrash,
- "%s() by zero"_warn_en_US, name);
- }
+ if (!badPConst && quotRem.divisionByZero) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ "%s() by zero"_warn_en_US, name);
}
return quotRem.remainder;
}));
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 6950caf..c64f79e 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -530,13 +530,11 @@ static Expr<Type<TypeCategory::Logical, KIND>> RewriteOutOfRange(
if (args.size() >= 3) {
// Bounds depend on round= value
if (auto *round{UnwrapExpr<Expr<SomeType>>(args[2])}) {
- if (const Symbol * whole{UnwrapWholeSymbolDataRef(*round)};
- whole && semantics::IsOptional(whole->GetUltimate()) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::OptionalMustBePresent)) {
+ if (const Symbol *whole{UnwrapWholeSymbolDataRef(*round)};
+ whole && semantics::IsOptional(whole->GetUltimate())) {
if (auto source{args[2]->sourceLocation()}) {
- context.messages().Say(
- common::UsageWarning::OptionalMustBePresent, *source,
+ context.Warn(common::UsageWarning::OptionalMustBePresent,
+ *source,
"ROUND= argument to OUT_OF_RANGE() is an optional dummy argument that must be present at execution"_warn_en_US);
}
}
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index 9237d6e..ae9221f 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -92,10 +92,8 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
elements.push_back(sum);
}
}
- if (overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"MATMUL of %s data overflowed during computation"_warn_en_US,
T::AsFortran());
}
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6fb5249..225e340 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -35,9 +35,8 @@ static Expr<T> FoldTransformationalBessel(
}
return Expr<T>{Constant<T>{
std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}};
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
name, T::kind);
}
@@ -131,10 +130,8 @@ static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
context.targetCharacteristics().roundingMode()};
Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
dim, identity, norm2Accumulator)};
- if (norm2Accumulator.overflow() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (norm2Accumulator.overflow()) {
+ context.Warn(common::UsageWarning::FoldingException,
"NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
}
return Expr<T>{std::move(result)};
@@ -165,9 +162,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) {
return FoldElementalIntrinsic<T, T>(
context, std::move(funcRef), *callable);
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"%s(real(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND);
}
} else if (name == "amax0" || name == "amin0" || name == "amin1" ||
@@ -179,9 +175,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) {
return FoldElementalIntrinsic<T, T, T>(
context, std::move(funcRef), *callable);
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"%s(real(kind=%d), real(kind%d)) cannot be folded on host"_warn_en_US,
name, KIND, KIND);
}
@@ -191,9 +186,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
if (auto callable{GetHostRuntimeWrapper<T, Int4, T>(name)}) {
return FoldElementalIntrinsic<T, Int4, T>(
context, std::move(funcRef), *callable);
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
+ } else {
+ context.Warn(common::UsageWarning::FoldingFailure,
"%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
name, KIND);
}
@@ -210,10 +204,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, ComplexT>([&name, &context](
const Scalar<ComplexT> &z) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> y{z.ABS()};
- if (y.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (y.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"complex ABS intrinsic folding overflow"_warn_en_US, name);
}
return y.value;
@@ -234,10 +226,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T>(
[&name, &context, mode](const Scalar<T> &x) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> y{x.ToWholeNumber(mode)};
- if (y.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (y.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"%s intrinsic folding overflow"_warn_en_US, name);
}
return y.value;
@@ -247,10 +237,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>([&context](const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> result{x.DIM(y)};
- if (result.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (result.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"DIM intrinsic folding overflow"_warn_en_US);
}
return result.value;
@@ -282,10 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>(
[&](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
ValueWithRealFlags<Scalar<T>> result{x.HYPOT(y)};
- if (result.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (result.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"HYPOT intrinsic folding overflow"_warn_en_US);
}
return result.value;
@@ -307,11 +293,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
bool badPConst{false};
if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
*pExpr = Fold(context, std::move(*pExpr));
- if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
- pConst->IsZero() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (auto pConst{GetScalarConstantValue<T>(*pExpr)};
+ pConst && pConst->IsZero()) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"MOD: P argument is zero"_warn_en_US);
badPConst = true;
}
@@ -320,11 +304,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
auto result{x.MOD(y)};
- if (!badPConst && result.flags.test(RealFlag::DivideByZero) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(
- common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (!badPConst && result.flags.test(RealFlag::DivideByZero)) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"second argument to MOD must not be zero"_warn_en_US);
}
return result.value;
@@ -334,11 +315,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
bool badPConst{false};
if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
*pExpr = Fold(context, std::move(*pExpr));
- if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
- pConst->IsZero() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (auto pConst{GetScalarConstantValue<T>(*pExpr)};
+ pConst && pConst->IsZero()) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"MODULO: P argument is zero"_warn_en_US);
badPConst = true;
}
@@ -347,11 +326,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x,
const Scalar<T> &y) -> Scalar<T> {
auto result{x.MODULO(y)};
- if (!badPConst && result.flags.test(RealFlag::DivideByZero) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
- context.messages().Say(
- common::UsageWarning::FoldingAvoidsRuntimeCrash,
+ if (!badPConst && result.flags.test(RealFlag::DivideByZero)) {
+ context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash,
"second argument to MODULO must not be zero"_warn_en_US);
}
return result.value;
@@ -363,11 +339,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
[&](const auto &sVal) {
using TS = ResultType<decltype(sVal)>;
bool badSConst{false};
- if (auto sConst{GetScalarConstantValue<TS>(sVal)}; sConst &&
- (sConst->IsZero() || sConst->IsNotANumber()) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) {
- context.messages().Say(common::UsageWarning::FoldingValueChecks,
+ if (auto sConst{GetScalarConstantValue<TS>(sVal)};
+ sConst && (sConst->IsZero() || sConst->IsNotANumber())) {
+ context.Warn(common::UsageWarning::FoldingValueChecks,
"NEAREST: S argument is %s"_warn_en_US,
sConst->IsZero() ? "zero" : "NaN");
badSConst = true;
@@ -375,22 +349,15 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, T, TS>(context, std::move(funcRef),
ScalarFunc<T, T, TS>([&](const Scalar<T> &x,
const Scalar<TS> &s) -> Scalar<T> {
- if (!badSConst && (s.IsZero() || s.IsNotANumber()) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) {
- context.messages().Say(
- common::UsageWarning::FoldingValueChecks,
+ if (!badSConst && (s.IsZero() || s.IsNotANumber())) {
+ context.Warn(common::UsageWarning::FoldingValueChecks,
"NEAREST: S argument is %s"_warn_en_US,
s.IsZero() ? "zero" : "NaN");
}
auto result{x.NEAREST(!s.IsNegative())};
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- if (result.flags.test(RealFlag::InvalidArgument)) {
- context.messages().Say(
- common::UsageWarning::FoldingException,
- "NEAREST intrinsic folding: bad argument"_warn_en_US);
- }
+ if (result.flags.test(RealFlag::InvalidArgument)) {
+ context.Warn(common::UsageWarning::FoldingException,
+ "NEAREST intrinsic folding: bad argument"_warn_en_US);
}
return result.value;
}));
@@ -427,11 +394,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
template
#endif
SCALE<Scalar<TBY>>(y)};
- if (result.flags.test(RealFlag::Overflow) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(
- common::UsageWarning::FoldingException,
+ if (result.flags.test(RealFlag::Overflow)) {
+ context.Warn(common::UsageWarning::FoldingException,
"SCALE/IEEE_SCALB intrinsic folding overflow"_warn_en_US);
}
return result.value;
@@ -481,12 +445,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
auto yBig{Scalar<LargestReal>::Convert(y).value};
switch (xBig.Compare(yBig)) {
case Relation::Unordered:
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingValueChecks)) {
- context.messages().Say(
- common::UsageWarning::FoldingValueChecks,
- "IEEE_NEXT_AFTER intrinsic folding: arguments are unordered"_warn_en_US);
- }
+ context.Warn(common::UsageWarning::FoldingValueChecks,
+ "IEEE_NEXT_AFTER intrinsic folding: arguments are unordered"_warn_en_US);
return x.NotANumber();
case Relation::Equal:
break;
@@ -507,12 +467,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>([&](const Scalar<T> &x) -> Scalar<T> {
auto result{x.NEAREST(upward)};
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- if (result.flags.test(RealFlag::InvalidArgument)) {
- context.messages().Say(common::UsageWarning::FoldingException,
- "%s intrinsic folding: argument is NaN"_warn_en_US, iName);
- }
+ if (result.flags.test(RealFlag::InvalidArgument)) {
+ context.Warn(common::UsageWarning::FoldingException,
+ "%s intrinsic folding: argument is NaN"_warn_en_US, iName);
}
return result.value;
}));
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index b6f2d21..fe89739 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -112,10 +112,8 @@ static Expr<T> FoldDotProduct(
}
}
}
- if (overflow &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (overflow) {
+ context.Warn(common::UsageWarning::FoldingException,
"DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
T::AsFortran());
}
@@ -334,10 +332,8 @@ static Expr<T> FoldProduct(
ProductAccumulator accumulator{arrayAndMask->array};
auto result{Expr<T>{DoReduction<T>(
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
- if (accumulator.overflow() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (accumulator.overflow()) {
+ context.Warn(common::UsageWarning::FoldingException,
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
}
return result;
@@ -406,10 +402,8 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
arrayAndMask->array, context.targetCharacteristics().roundingMode()};
auto result{Expr<T>{DoReduction<T>(
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
- if (accumulator.overflow() &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingException)) {
- context.messages().Say(common::UsageWarning::FoldingException,
+ if (accumulator.overflow()) {
+ context.Warn(common::UsageWarning::FoldingException,
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
}
return result;
diff --git a/flang/lib/Evaluate/fold.cpp b/flang/lib/Evaluate/fold.cpp
index 71ead1b..1fbbbba 100644
--- a/flang/lib/Evaluate/fold.cpp
+++ b/flang/lib/Evaluate/fold.cpp
@@ -290,11 +290,8 @@ std::optional<Expr<SomeType>> FoldTransfer(
} else if (source && moldType) {
if (const auto *boz{std::get_if<BOZLiteralConstant>(&source->u)}) {
// TRANSFER(BOZ, MOLD=integer or real) extension
- if (context.languageFeatures().ShouldWarn(
- common::LanguageFeature::TransferBOZ)) {
- context.messages().Say(common::LanguageFeature::TransferBOZ,
- "TRANSFER(BOZ literal) is not standard"_port_en_US);
- }
+ context.Warn(common::LanguageFeature::TransferBOZ,
+ "TRANSFER(BOZ literal) is not standard"_port_en_US);
return Fold(context, ConvertToType(*moldType, Expr<SomeType>{*boz}));
}
}
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 121afc6..ec5dc0b 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -98,6 +98,14 @@ llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran(
return o;
}
+template <typename RESULT, typename VALUE>
+std::string ConstantBase<RESULT, VALUE>::AsFortran() const {
+ std::string result;
+ llvm::raw_string_ostream sstream(result);
+ AsFortran(sstream);
+ return result;
+}
+
template <int KIND>
llvm::raw_ostream &Constant<Type<TypeCategory::Character, KIND>>::AsFortran(
llvm::raw_ostream &o) const {
@@ -126,6 +134,14 @@ llvm::raw_ostream &Constant<Type<TypeCategory::Character, KIND>>::AsFortran(
return o;
}
+template <int KIND>
+std::string Constant<Type<TypeCategory::Character, KIND>>::AsFortran() const {
+ std::string result;
+ llvm::raw_string_ostream sstream(result);
+ AsFortran(sstream);
+ return result;
+}
+
llvm::raw_ostream &EmitVar(llvm::raw_ostream &o, const Symbol &symbol,
std::optional<parser::CharBlock> name = std::nullopt) {
const auto &renamings{symbol.owner().context().moduleFileOutputRenamings()};
diff --git a/flang/lib/Evaluate/host.cpp b/flang/lib/Evaluate/host.cpp
index 187bb2f..25409ac 100644
--- a/flang/lib/Evaluate/host.cpp
+++ b/flang/lib/Evaluate/host.cpp
@@ -100,13 +100,8 @@ void HostFloatingPointEnvironment::SetUpHostFloatingPointEnvironment(
break;
case common::RoundingMode::TiesAwayFromZero:
fesetround(FE_TONEAREST);
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::FoldingFailure)) {
- context.messages().Say(common::UsageWarning::FoldingFailure,
- "TiesAwayFromZero rounding mode is not available when folding "
- "constants"
- " with host runtime; using TiesToEven instead"_warn_en_US);
- }
+ context.Warn(common::UsageWarning::FoldingFailure,
+ "TiesAwayFromZero rounding mode is not available when folding constants with host runtime; using TiesToEven instead"_warn_en_US);
break;
}
flags_.clear();
diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp
index c37a7f90..abe53c3 100644
--- a/flang/lib/Evaluate/intrinsics.cpp
+++ b/flang/lib/Evaluate/intrinsics.cpp
@@ -666,7 +666,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
{ArgFlag::canBeMoldNull, ArgFlag::onlyConstantInquiry}}},
DefaultInt, Rank::elemental, IntrinsicClass::inquiryFunction},
{"lbound",
- {{"array", AnyData, Rank::anyOrAssumedRank}, RequiredDIM,
+ {{"array", AnyData, Rank::arrayOrAssumedRank}, RequiredDIM,
SizeDefaultKIND},
KINDInt, Rank::scalar, IntrinsicClass::inquiryFunction},
{"lbound", {{"array", AnyData, Rank::arrayOrAssumedRank}, SizeDefaultKIND},
@@ -921,6 +921,10 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
{"back", AnyLogical, Rank::elemental, Optionality::optional},
DefaultingKIND},
KINDInt},
+ {"secnds",
+ {{"refTime", TypePattern{RealType, KindCode::exactKind, 4},
+ Rank::scalar}},
+ TypePattern{RealType, KindCode::exactKind, 4}, Rank::scalar},
{"second", {}, DefaultReal, Rank::scalar},
{"selected_char_kind", {{"name", DefaultChar, Rank::scalar}}, DefaultInt,
Rank::scalar, IntrinsicClass::transformationalFunction},
@@ -1034,7 +1038,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
{"trim", {{"string", SameCharNoLen, Rank::scalar}}, SameCharNoLen,
Rank::scalar, IntrinsicClass::transformationalFunction},
{"ubound",
- {{"array", AnyData, Rank::anyOrAssumedRank}, RequiredDIM,
+ {{"array", AnyData, Rank::arrayOrAssumedRank}, RequiredDIM,
SizeDefaultKIND},
KINDInt, Rank::scalar, IntrinsicClass::inquiryFunction},
{"ubound", {{"array", AnyData, Rank::arrayOrAssumedRank}, SizeDefaultKIND},
@@ -2256,7 +2260,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
for (std::size_t j{0}; j < dummies; ++j) {
const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]};
if (const ActualArgument *arg{actualForDummy[j]}) {
- bool isAssumedRank{IsAssumedRank(*arg)};
+ bool isAssumedRank{semantics::IsAssumedRank(*arg)};
if (isAssumedRank && d.rank != Rank::anyOrAssumedRank &&
d.rank != Rank::arrayOrAssumedRank) {
messages.Say(arg->sourceLocation(),
@@ -2617,15 +2621,12 @@ std::optional<SpecificCall> IntrinsicInterface::Match(
if (const Symbol *whole{
UnwrapWholeSymbolOrComponentDataRef(actualForDummy[*dimArg])}) {
if (IsOptional(*whole) || IsAllocatableOrObjectPointer(whole)) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::OptionalMustBePresent)) {
- if (rank == Rank::scalarIfDim || arrayRank.value_or(-1) == 1) {
- messages.Say(common::UsageWarning::OptionalMustBePresent,
- "The actual argument for DIM= is optional, pointer, or allocatable, and it is assumed to be present and equal to 1 at execution time"_warn_en_US);
- } else {
- messages.Say(common::UsageWarning::OptionalMustBePresent,
- "The actual argument for DIM= is optional, pointer, or allocatable, and may not be absent during execution; parenthesize to silence this warning"_warn_en_US);
- }
+ if (rank == Rank::scalarIfDim || arrayRank.value_or(-1) == 1) {
+ context.Warn(common::UsageWarning::OptionalMustBePresent,
+ "The actual argument for DIM= is optional, pointer, or allocatable, and it is assumed to be present and equal to 1 at execution time"_warn_en_US);
+ } else {
+ context.Warn(common::UsageWarning::OptionalMustBePresent,
+ "The actual argument for DIM= is optional, pointer, or allocatable, and may not be absent during execution; parenthesize to silence this warning"_warn_en_US);
}
}
}
@@ -3002,7 +3003,7 @@ SpecificCall IntrinsicProcTable::Implementation::HandleNull(
mold = nullptr;
}
if (mold) {
- if (IsAssumedRank(*arguments[0])) {
+ if (semantics::IsAssumedRank(*arguments[0])) {
context.messages().Say(arguments[0]->sourceLocation(),
"MOLD= argument to NULL() must not be assumed-rank"_err_en_US);
}
@@ -3109,16 +3110,12 @@ IntrinsicProcTable::Implementation::HandleC_F_Pointer(
context.messages().Say(at,
"FPTR= argument to C_F_POINTER() may not have a deferred type parameter"_err_en_US);
} else if (type->category() == TypeCategory::Derived) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::Interoperability) &&
- type->IsUnlimitedPolymorphic()) {
- context.messages().Say(common::UsageWarning::Interoperability, at,
+ if (type->IsUnlimitedPolymorphic()) {
+ context.Warn(common::UsageWarning::Interoperability, at,
"FPTR= argument to C_F_POINTER() should not be unlimited polymorphic"_warn_en_US);
} else if (!type->GetDerivedTypeSpec().typeSymbol().attrs().test(
- semantics::Attr::BIND_C) &&
- context.languageFeatures().ShouldWarn(
- common::UsageWarning::Portability)) {
- context.messages().Say(common::UsageWarning::Portability, at,
+ semantics::Attr::BIND_C)) {
+ context.Warn(common::UsageWarning::Portability, at,
"FPTR= argument to C_F_POINTER() should not have a derived type that is not BIND(C)"_port_en_US);
}
} else if (!IsInteroperableIntrinsicType(
@@ -3126,16 +3123,11 @@ IntrinsicProcTable::Implementation::HandleC_F_Pointer(
.value_or(true)) {
if (type->category() == TypeCategory::Character &&
type->kind() == 1) {
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::CharacterInteroperability)) {
- context.messages().Say(
- common::UsageWarning::CharacterInteroperability, at,
- "FPTR= argument to C_F_POINTER() should not have the non-interoperable character length %s"_warn_en_US,
- type->AsFortran());
- }
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::Interoperability)) {
- context.messages().Say(common::UsageWarning::Interoperability, at,
+ context.Warn(common::UsageWarning::CharacterInteroperability, at,
+ "FPTR= argument to C_F_POINTER() should not have the non-interoperable character length %s"_warn_en_US,
+ type->AsFortran());
+ } else {
+ context.Warn(common::UsageWarning::Interoperability, at,
"FPTR= argument to C_F_POINTER() should not have the non-interoperable intrinsic type or kind %s"_warn_en_US,
type->AsFortran());
}
@@ -3274,16 +3266,11 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::HandleC_Loc(
if (typeAndShape->type().category() == TypeCategory::Character &&
typeAndShape->type().kind() == 1) {
// Default character kind, but length is not known to be 1
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::CharacterInteroperability)) {
- context.messages().Say(
- common::UsageWarning::CharacterInteroperability,
- arguments[0]->sourceLocation(),
- "C_LOC() argument has non-interoperable character length"_warn_en_US);
- }
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::Interoperability)) {
- context.messages().Say(common::UsageWarning::Interoperability,
+ context.Warn(common::UsageWarning::CharacterInteroperability,
+ arguments[0]->sourceLocation(),
+ "C_LOC() argument has non-interoperable character length"_warn_en_US);
+ } else {
+ context.Warn(common::UsageWarning::Interoperability,
arguments[0]->sourceLocation(),
"C_LOC() argument has non-interoperable intrinsic type or kind"_warn_en_US);
}
@@ -3341,16 +3328,11 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::HandleC_Devloc(
if (typeAndShape->type().category() == TypeCategory::Character &&
typeAndShape->type().kind() == 1) {
// Default character kind, but length is not known to be 1
- if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::CharacterInteroperability)) {
- context.messages().Say(
- common::UsageWarning::CharacterInteroperability,
- arguments[0]->sourceLocation(),
- "C_DEVLOC() argument has non-interoperable character length"_warn_en_US);
- }
- } else if (context.languageFeatures().ShouldWarn(
- common::UsageWarning::Interoperability)) {
- context.messages().Say(common::UsageWarning::Interoperability,
+ context.Warn(common::UsageWarning::CharacterInteroperability,
+ arguments[0]->sourceLocation(),
+ "C_DEVLOC() argument has non-interoperable character length"_warn_en_US);
+ } else {
+ context.Warn(common::UsageWarning::Interoperability,
arguments[0]->sourceLocation(),
"C_DEVLOC() argument has non-interoperable intrinsic type or kind"_warn_en_US);
}
@@ -3673,15 +3655,10 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
genericType.category() == TypeCategory::Real) &&
(newType.category() == TypeCategory::Integer ||
newType.category() == TypeCategory::Real))) {
- if (context.languageFeatures().ShouldWarn(
- common::LanguageFeature::
- UseGenericIntrinsicWhenSpecificDoesntMatch)) {
- context.messages().Say(
- common::LanguageFeature::
- UseGenericIntrinsicWhenSpecificDoesntMatch,
- "Argument types do not match specific intrinsic '%s' requirements; using '%s' generic instead and converting the result to %s if needed"_port_en_US,
- call.name, genericName, newType.AsFortran());
- }
+ context.Warn(common::LanguageFeature::
+ UseGenericIntrinsicWhenSpecificDoesntMatch,
+ "Argument types do not match specific intrinsic '%s' requirements; using '%s' generic instead and converting the result to %s if needed"_port_en_US,
+ call.name, genericName, newType.AsFortran());
specificCall->specificIntrinsic.name = call.name;
specificCall->specificIntrinsic.characteristics.value()
.functionResult.value()
diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 2c0f283..6e6b9f3 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -750,6 +750,14 @@ llvm::raw_ostream &Real<W, P>::AsFortran(
return o;
}
+template <typename W, int P>
+std::string Real<W, P>::AsFortran(int kind, bool minimal) const {
+ std::string result;
+ llvm::raw_string_ostream sstream(result);
+ AsFortran(sstream, kind, minimal);
+ return result;
+}
+
// 16.9.180
template <typename W, int P> Real<W, P> Real<W, P>::RRSPACING() const {
if (IsNotANumber()) {
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index 776866d..07bff10 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -623,7 +623,7 @@ MaybeExtentExpr GetRawUpperBound(
} else if (semantics::IsAssumedSizeArray(symbol) &&
dimension + 1 == symbol.Rank()) {
return std::nullopt;
- } else {
+ } else if (IsSafelyCopyable(base, /*admitPureCall=*/true)) {
return ComputeUpperBound(
GetRawLowerBound(base, dimension), GetExtent(base, dimension));
}
@@ -678,9 +678,11 @@ static MaybeExtentExpr GetUBOUND(FoldingContext *context,
} else if (semantics::IsAssumedSizeArray(symbol) &&
dimension + 1 == symbol.Rank()) {
return std::nullopt; // UBOUND() folding replaces with -1
- } else if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
- return ComputeUpperBound(
- std::move(*lb), GetExtent(base, dimension, invariantOnly));
+ } else if (IsSafelyCopyable(base, /*admitPureCall=*/true)) {
+ if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
+ return ComputeUpperBound(
+ std::move(*lb), GetExtent(base, dimension, invariantOnly));
+ }
}
}
} else if (const auto *assoc{
@@ -947,7 +949,7 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
intrinsic->name == "ubound") {
// For LBOUND/UBOUND, these are the array-valued cases (no DIM=)
if (!call.arguments().empty() && call.arguments().front()) {
- if (IsAssumedRank(*call.arguments().front())) {
+ if (semantics::IsAssumedRank(*call.arguments().front())) {
return Shape{MaybeExtentExpr{}};
} else {
return Shape{
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 9c059b0..1f3cbbf 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -495,7 +495,7 @@ Expr<SomeComplex> PromoteMixedComplexReal(
// N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
// the operands to a dyadic operation where one is permitted, it assumes the
// type and kind of the other operand.
-template <template <typename> class OPR, bool CAN_BE_UNSIGNED>
+template <template <typename> class OPR>
std::optional<Expr<SomeType>> NumericOperation(
parser::ContextualMessages &messages, Expr<SomeType> &&x,
Expr<SomeType> &&y, int defaultRealKind) {
@@ -510,13 +510,8 @@ std::optional<Expr<SomeType>> NumericOperation(
std::move(rx), std::move(ry)));
},
[&](Expr<SomeUnsigned> &&ix, Expr<SomeUnsigned> &&iy) {
- if constexpr (CAN_BE_UNSIGNED) {
- return Package(PromoteAndCombine<OPR, TypeCategory::Unsigned>(
- std::move(ix), std::move(iy)));
- } else {
- messages.Say("Operands must not be UNSIGNED"_err_en_US);
- return NoExpr();
- }
+ return Package(PromoteAndCombine<OPR, TypeCategory::Unsigned>(
+ std::move(ix), std::move(iy)));
},
// Mixed REAL/INTEGER operations
[](Expr<SomeReal> &&rx, Expr<SomeInteger> &&iy) {
@@ -575,34 +570,31 @@ std::optional<Expr<SomeType>> NumericOperation(
},
// Operations with one typeless operand
[&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
+ return NumericOperation<OPR>(messages,
AsGenericExpr(ConvertTo(iy, std::move(bx))), std::move(y),
defaultRealKind);
},
[&](BOZLiteralConstant &&bx, Expr<SomeUnsigned> &&iy) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
+ return NumericOperation<OPR>(messages,
AsGenericExpr(ConvertTo(iy, std::move(bx))), std::move(y),
defaultRealKind);
},
[&](BOZLiteralConstant &&bx, Expr<SomeReal> &&ry) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
+ return NumericOperation<OPR>(messages,
AsGenericExpr(ConvertTo(ry, std::move(bx))), std::move(y),
defaultRealKind);
},
[&](Expr<SomeInteger> &&ix, BOZLiteralConstant &&by) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
- std::move(x), AsGenericExpr(ConvertTo(ix, std::move(by))),
- defaultRealKind);
+ return NumericOperation<OPR>(messages, std::move(x),
+ AsGenericExpr(ConvertTo(ix, std::move(by))), defaultRealKind);
},
[&](Expr<SomeUnsigned> &&ix, BOZLiteralConstant &&by) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
- std::move(x), AsGenericExpr(ConvertTo(ix, std::move(by))),
- defaultRealKind);
+ return NumericOperation<OPR>(messages, std::move(x),
+ AsGenericExpr(ConvertTo(ix, std::move(by))), defaultRealKind);
},
[&](Expr<SomeReal> &&rx, BOZLiteralConstant &&by) {
- return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages,
- std::move(x), AsGenericExpr(ConvertTo(rx, std::move(by))),
- defaultRealKind);
+ return NumericOperation<OPR>(messages, std::move(x),
+ AsGenericExpr(ConvertTo(rx, std::move(by))), defaultRealKind);
},
// Error cases
[&](Expr<SomeUnsigned> &&, auto &&) {
@@ -621,7 +613,7 @@ std::optional<Expr<SomeType>> NumericOperation(
std::move(x.u), std::move(y.u));
}
-template std::optional<Expr<SomeType>> NumericOperation<Power, false>(
+template std::optional<Expr<SomeType>> NumericOperation<Power>(
parser::ContextualMessages &, Expr<SomeType> &&, Expr<SomeType> &&,
int defaultRealKind);
template std::optional<Expr<SomeType>> NumericOperation<Multiply>(
@@ -890,29 +882,6 @@ std::optional<Expr<SomeType>> ConvertToType(
}
}
-bool IsAssumedRank(const Symbol &original) {
- if (const auto *assoc{original.detailsIf<semantics::AssocEntityDetails>()}) {
- if (assoc->rank()) {
- return false; // in RANK(n) or RANK(*)
- } else if (assoc->IsAssumedRank()) {
- return true; // RANK DEFAULT
- }
- }
- const Symbol &symbol{semantics::ResolveAssociations(original)};
- const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()};
- return object && object->IsAssumedRank();
-}
-
-bool IsAssumedRank(const ActualArgument &arg) {
- if (const auto *expr{arg.UnwrapExpr()}) {
- return IsAssumedRank(*expr);
- } else {
- const Symbol *assumedTypeDummy{arg.GetAssumedTypeDummy()};
- CHECK(assumedTypeDummy);
- return IsAssumedRank(*assumedTypeDummy);
- }
-}
-
int GetCorank(const ActualArgument &arg) {
const auto *expr{arg.UnwrapExpr()};
return GetCorank(*expr);
@@ -1129,7 +1098,7 @@ struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
CollectCudaSymbolsHelper() : Base{*this} {}
using Base::operator();
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
- return {symbol};
+ return {symbol.GetUltimate()};
}
// Overload some of the operator() to filter out the symbols that are not
// of interest for CUDA data transfer logic.
@@ -1203,6 +1172,15 @@ bool HasVectorSubscript(const Expr<SomeType> &expr) {
return HasVectorSubscriptHelper{}(expr);
}
+bool HasVectorSubscript(const ActualArgument &actual) {
+ auto expr{actual.UnwrapExpr()};
+ return expr && HasVectorSubscript(*expr);
+}
+
+bool IsArraySection(const Expr<SomeType> &expr) {
+ return expr.Rank() > 0 && IsVariable(expr) && !UnwrapWholeSymbolDataRef(expr);
+}
+
// HasConstant()
struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool,
/*TraverseAssocEntityDetails=*/false> {
@@ -2312,9 +2290,22 @@ bool IsDummy(const Symbol &symbol) {
ResolveAssociations(symbol).details());
}
+bool IsAssumedRank(const Symbol &original) {
+ if (const auto *assoc{original.detailsIf<semantics::AssocEntityDetails>()}) {
+ if (assoc->rank()) {
+ return false; // in RANK(n) or RANK(*)
+ } else if (assoc->IsAssumedRank()) {
+ return true; // RANK DEFAULT
+ }
+ }
+ const Symbol &symbol{semantics::ResolveAssociations(original)};
+ const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()};
+ return object && object->IsAssumedRank();
+}
+
bool IsAssumedShape(const Symbol &symbol) {
const Symbol &ultimate{ResolveAssociations(symbol)};
- const auto *object{ultimate.detailsIf<ObjectEntityDetails>()};
+ const auto *object{ultimate.detailsIf<semantics::ObjectEntityDetails>()};
return object && object->IsAssumedShape() &&
!semantics::IsAllocatableOrObjectPointer(&ultimate);
}
diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp
index d1bff03..b9b34d4 100644
--- a/flang/lib/Evaluate/variable.cpp
+++ b/flang/lib/Evaluate/variable.cpp
@@ -212,21 +212,17 @@ std::optional<Expr<SomeCharacter>> Substring::Fold(FoldingContext &context) {
}
if (!result) { // error cases
if (*lbi < 1) {
- if (context.languageFeatures().ShouldWarn(common::UsageWarning::Bounds)) {
- context.messages().Say(common::UsageWarning::Bounds,
- "Lower bound (%jd) on substring is less than one"_warn_en_US,
- static_cast<std::intmax_t>(*lbi));
- }
+ context.Warn(common::UsageWarning::Bounds,
+ "Lower bound (%jd) on substring is less than one"_warn_en_US,
+ static_cast<std::intmax_t>(*lbi));
*lbi = 1;
lower_ = AsExpr(Constant<SubscriptInteger>{1});
}
if (length && *ubi > *length) {
- if (context.languageFeatures().ShouldWarn(common::UsageWarning::Bounds)) {
- context.messages().Say(common::UsageWarning::Bounds,
- "Upper bound (%jd) on substring is greater than character length (%jd)"_warn_en_US,
- static_cast<std::intmax_t>(*ubi),
- static_cast<std::intmax_t>(*length));
- }
+ context.Warn(common::UsageWarning::Bounds,
+ "Upper bound (%jd) on substring is greater than character length (%jd)"_warn_en_US,
+ static_cast<std::intmax_t>(*ubi),
+ static_cast<std::intmax_t>(*length));
*ubi = *length;
upper_ = AsExpr(Constant<SubscriptInteger>{*ubi});
}
diff --git a/flang/lib/Frontend/CompilerInstance.cpp b/flang/lib/Frontend/CompilerInstance.cpp
index cd8ddda..d97b4b8 100644
--- a/flang/lib/Frontend/CompilerInstance.cpp
+++ b/flang/lib/Frontend/CompilerInstance.cpp
@@ -253,18 +253,15 @@ getExplicitAndImplicitAMDGPUTargetFeatures(clang::DiagnosticsEngine &diags,
const TargetOptions &targetOpts,
const llvm::Triple triple) {
llvm::StringRef cpu = targetOpts.cpu;
- llvm::StringMap<bool> implicitFeaturesMap;
- // Get the set of implicit target features
- llvm::AMDGPU::fillAMDGPUFeatureMap(cpu, triple, implicitFeaturesMap);
+ llvm::StringMap<bool> FeaturesMap;
// Add target features specified by the user
for (auto &userFeature : targetOpts.featuresAsWritten) {
std::string userKeyString = userFeature.substr(1);
- implicitFeaturesMap[userKeyString] = (userFeature[0] == '+');
+ FeaturesMap[userKeyString] = (userFeature[0] == '+');
}
- auto HasError =
- llvm::AMDGPU::insertWaveSizeFeature(cpu, triple, implicitFeaturesMap);
+ auto HasError = llvm::AMDGPU::fillAMDGPUFeatureMap(cpu, triple, FeaturesMap);
if (HasError.first) {
unsigned diagID = diags.getCustomDiagID(clang::DiagnosticsEngine::Error,
"Unsupported feature ID: %0");
@@ -273,9 +270,9 @@ getExplicitAndImplicitAMDGPUTargetFeatures(clang::DiagnosticsEngine &diags,
}
llvm::SmallVector<std::string> featuresVec;
- for (auto &implicitFeatureItem : implicitFeaturesMap) {
- featuresVec.push_back((llvm::Twine(implicitFeatureItem.second ? "+" : "-") +
- implicitFeatureItem.first().str())
+ for (auto &FeatureItem : FeaturesMap) {
+ featuresVec.push_back((llvm::Twine(FeatureItem.second ? "+" : "-") +
+ FeatureItem.first().str())
.str());
}
llvm::sort(featuresVec);
diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp
index 111c5aa4..6295a58 100644
--- a/flang/lib/Frontend/CompilerInvocation.cpp
+++ b/flang/lib/Frontend/CompilerInvocation.cpp
@@ -22,6 +22,7 @@
#include "flang/Tools/TargetSetup.h"
#include "flang/Version.inc"
#include "clang/Basic/DiagnosticDriver.h"
+#include "clang/Basic/DiagnosticFrontend.h"
#include "clang/Basic/DiagnosticOptions.h"
#include "clang/Driver/CommonArgs.h"
#include "clang/Driver/Driver.h"
@@ -1152,6 +1153,17 @@ static bool parseDialectArgs(CompilerInvocation &res, llvm::opt::ArgList &args,
diags.Report(diagID);
}
}
+ // -fcoarray
+ if (args.hasArg(clang::driver::options::OPT_fcoarray)) {
+ res.getFrontendOpts().features.Enable(
+ Fortran::common::LanguageFeature::Coarray);
+ const unsigned diagID =
+ diags.getCustomDiagID(clang::DiagnosticsEngine::Warning,
+ "Support for multi image Fortran features is "
+ "still experimental and in development.");
+ diags.Report(diagID);
+ }
+
return diags.getNumErrors() == numErrorsBefore;
}
@@ -1162,13 +1174,21 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args,
clang::DiagnosticsEngine &diags) {
llvm::opt::Arg *arg = args.getLastArg(clang::driver::options::OPT_fopenmp,
clang::driver::options::OPT_fno_openmp);
- if (!arg || arg->getOption().matches(clang::driver::options::OPT_fno_openmp))
- return true;
+ if (!arg ||
+ arg->getOption().matches(clang::driver::options::OPT_fno_openmp)) {
+ bool isSimdSpecified = args.hasFlag(
+ clang::driver::options::OPT_fopenmp_simd,
+ clang::driver::options::OPT_fno_openmp_simd, /*Default=*/false);
+ if (!isSimdSpecified)
+ return true;
+ res.getLangOpts().OpenMPSimd = 1;
+ }
unsigned numErrorsBefore = diags.getNumErrors();
llvm::Triple t(res.getTargetOpts().triple);
constexpr unsigned newestFullySupported = 31;
+ constexpr unsigned latestFinalized = 60;
// By default OpenMP is set to the most recent fully supported version
res.getLangOpts().OpenMPVersion = newestFullySupported;
res.getFrontendOpts().features.Enable(
@@ -1191,12 +1211,26 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args,
diags.Report(diagID) << value << arg->getAsString(args) << versions.str();
};
+ auto reportFutureVersion = [&](llvm::StringRef value) {
+ const unsigned diagID = diags.getCustomDiagID(
+ clang::DiagnosticsEngine::Warning,
+ "The specification for OpenMP version %0 is still under development; "
+ "the syntax and semantics of new features may be subject to change");
+ std::string buffer;
+ llvm::raw_string_ostream versions(buffer);
+ llvm::interleaveComma(ompVersions, versions);
+
+ diags.Report(diagID) << value;
+ };
+
llvm::StringRef value = arg->getValue();
if (!value.getAsInteger(/*radix=*/10, version)) {
if (llvm::is_contained(ompVersions, version)) {
res.getLangOpts().OpenMPVersion = version;
- if (version > newestFullySupported)
+ if (version > latestFinalized)
+ reportFutureVersion(value);
+ else if (version > newestFullySupported)
diags.Report(clang::diag::warn_openmp_incomplete) << version;
} else if (llvm::is_contained(oldVersions, version)) {
const unsigned diagID =
@@ -1225,7 +1259,7 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args,
clang::driver::options::OPT_fopenmp_host_ir_file_path)) {
res.getLangOpts().OMPHostIRFile = arg->getValue();
if (!llvm::sys::fs::exists(res.getLangOpts().OMPHostIRFile))
- diags.Report(clang::diag::err_drv_omp_host_ir_file_not_found)
+ diags.Report(clang::diag::err_omp_host_ir_file_not_found)
<< res.getLangOpts().OMPHostIRFile;
}
@@ -1696,6 +1730,20 @@ void CompilerInvocation::setDefaultPredefinitions() {
fortranOptions.predefinitions.emplace_back("__flang_patchlevel__",
FLANG_VERSION_PATCHLEVEL_STRING);
+ // Add predefinitions based on the relocation model
+ if (unsigned PICLevel = getCodeGenOpts().PICLevel) {
+ fortranOptions.predefinitions.emplace_back("__PIC__",
+ std::to_string(PICLevel));
+ fortranOptions.predefinitions.emplace_back("__pic__",
+ std::to_string(PICLevel));
+ if (getCodeGenOpts().IsPIE) {
+ fortranOptions.predefinitions.emplace_back("__PIE__",
+ std::to_string(PICLevel));
+ fortranOptions.predefinitions.emplace_back("__pie__",
+ std::to_string(PICLevel));
+ }
+ }
+
// Add predefinitions based on extensions enabled
if (frontendOptions.features.IsEnabled(
Fortran::common::LanguageFeature::OpenACC)) {
@@ -1707,6 +1755,11 @@ void CompilerInvocation::setDefaultPredefinitions() {
fortranOptions.predefinitions);
}
+ if (frontendOptions.features.IsEnabled(
+ Fortran::common::LanguageFeature::CUDA)) {
+ fortranOptions.predefinitions.emplace_back("_CUDA", "1");
+ }
+
llvm::Triple targetTriple{llvm::Triple(this->targetOpts.triple)};
if (targetTriple.isOSLinux()) {
fortranOptions.predefinitions.emplace_back("__linux__", "1");
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 5c66ecf..3bef6b1 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -298,6 +298,7 @@ bool CodeGenAction::beginSourceFileAction() {
bool isOpenMPEnabled =
ci.getInvocation().getFrontendOpts().features.IsEnabled(
Fortran::common::LanguageFeature::OpenMP);
+ bool isOpenMPSimd = ci.getInvocation().getLangOpts().OpenMPSimd;
fir::OpenMPFIRPassPipelineOpts opts;
@@ -329,12 +330,13 @@ bool CodeGenAction::beginSourceFileAction() {
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
mlirModule->getOperation()))
opts.isTargetDevice = offloadMod.getIsTargetDevice();
+ }
- // WARNING: This pipeline must be run immediately after the lowering to
- // ensure that the FIR is correct with respect to OpenMP operations/
- // attributes.
+ // WARNING: This pipeline must be run immediately after the lowering to
+ // ensure that the FIR is correct with respect to OpenMP operations/
+ // attributes.
+ if (isOpenMPEnabled || isOpenMPSimd)
fir::createOpenMPFIRPassPipeline(pm, opts);
- }
pm.enableVerifier(/*verifyPasses=*/true);
pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
@@ -617,12 +619,14 @@ void CodeGenAction::lowerHLFIRToFIR() {
pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
pm.enableVerifier(/*verifyPasses=*/true);
+ fir::EnableOpenMP enableOpenMP = fir::EnableOpenMP::None;
+ if (ci.getInvocation().getFrontendOpts().features.IsEnabled(
+ Fortran::common::LanguageFeature::OpenMP))
+ enableOpenMP = fir::EnableOpenMP::Full;
+ if (ci.getInvocation().getLangOpts().OpenMPSimd)
+ enableOpenMP = fir::EnableOpenMP::Simd;
// Create the pass pipeline
- fir::createHLFIRToFIRPassPipeline(
- pm,
- ci.getInvocation().getFrontendOpts().features.IsEnabled(
- Fortran::common::LanguageFeature::OpenMP),
- level);
+ fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP, level);
(void)mlir::applyPassManagerCLOptions(pm);
mlir::TimingScope timingScopeMLIRPasses = timingScopeRoot.nest(
@@ -748,6 +752,9 @@ void CodeGenAction::generateLLVMIR() {
Fortran::common::LanguageFeature::OpenMP))
config.EnableOpenMP = true;
+ if (ci.getInvocation().getLangOpts().OpenMPSimd)
+ config.EnableOpenMPSimd = true;
+
if (ci.getInvocation().getLoweringOpts().getIntegerWrapAround())
config.NSWOnLoopVarInc = false;
diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp
index 219f920..444b5b6 100644
--- a/flang/lib/Lower/Allocatable.cpp
+++ b/flang/lib/Lower/Allocatable.cpp
@@ -13,9 +13,9 @@
#include "flang/Lower/Allocatable.h"
#include "flang/Evaluate/tools.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/CUDA.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
-#include "flang/Lower/Cuda.h"
#include "flang/Lower/IterationSpace.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/OpenACC.h"
@@ -445,10 +445,14 @@ private:
/*mustBeHeap=*/true);
}
- void postAllocationAction(const Allocation &alloc) {
+ void postAllocationAction(const Allocation &alloc,
+ const fir::MutableBoxValue &box) {
if (alloc.getSymbol().test(Fortran::semantics::Symbol::Flag::AccDeclare))
Fortran::lower::attachDeclarePostAllocAction(converter, builder,
alloc.getSymbol());
+ if (Fortran::semantics::HasCUDAComponent(alloc.getSymbol()))
+ Fortran::lower::initializeDeviceComponentAllocator(
+ converter, alloc.getSymbol(), box);
}
void setPinnedToFalse() {
@@ -481,11 +485,21 @@ private:
// Pointers must use PointerAllocate so that their deallocations
// can be validated.
genInlinedAllocation(alloc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
setPinnedToFalse();
return;
}
+ // Preserve characters' dynamic length.
+ if (lenParams.empty() && box.isCharacter() &&
+ !box.hasNonDeferredLenParams()) {
+ auto charTy = mlir::dyn_cast<fir::CharacterType>(box.getEleTy());
+ if (charTy && charTy.hasDynamicLen()) {
+ fir::ExtendedValue exv{box};
+ lenParams.push_back(fir::factory::readCharLen(builder, loc, exv));
+ }
+ }
+
// Generate a sequence of runtime calls.
errorManager.genStatCheck(builder, loc);
genAllocateObjectInit(box, allocatorIdx);
@@ -504,7 +518,7 @@ private:
genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol());
}
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
errorManager.assignStat(builder, loc, stat);
}
@@ -647,7 +661,7 @@ private:
setPinnedToFalse();
}
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
errorManager.assignStat(builder, loc, stat);
}
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 6b7efe6..c003a5b 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -13,6 +13,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/Allocatable.h"
+#include "flang/Lower/CUDA.h"
#include "flang/Lower/CallInterface.h"
#include "flang/Lower/Coarray.h"
#include "flang/Lower/ConvertCall.h"
@@ -20,7 +21,6 @@
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
-#include "flang/Lower/Cuda.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
@@ -475,7 +475,9 @@ public:
fir::runtime::genMain(*builder, toLocation(),
bridge.getEnvironmentDefaults(),
getFoldingContext().languageFeatures().IsEnabled(
- Fortran::common::LanguageFeature::CUDA));
+ Fortran::common::LanguageFeature::CUDA),
+ getFoldingContext().languageFeatures().IsEnabled(
+ Fortran::common::LanguageFeature::Coarray));
});
finalizeOpenMPLowering(globalOmpRequiresSymbol);
@@ -1400,21 +1402,23 @@ private:
mlir::Value genLoopVariableAddress(mlir::Location loc,
const Fortran::semantics::Symbol &sym,
bool isUnordered) {
- if (isUnordered || sym.has<Fortran::semantics::HostAssocDetails>() ||
- sym.has<Fortran::semantics::UseDetails>()) {
- if (!shallowLookupSymbol(sym) &&
- !GetSymbolDSA(sym).test(
- Fortran::semantics::Symbol::Flag::OmpShared)) {
- // Do concurrent loop variables are not mapped yet since they are local
- // to the Do concurrent scope (same for OpenMP loops).
- mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
- builder->setInsertionPointToStart(builder->getAllocaBlock());
- mlir::Type tempTy = genType(sym);
- mlir::Value temp =
- builder->createTemporaryAlloc(loc, tempTy, toStringRef(sym.name()));
- bindIfNewSymbol(sym, temp);
- builder->restoreInsertionPoint(insPt);
- }
+ if (!shallowLookupSymbol(sym) &&
+ (isUnordered ||
+ GetSymbolDSA(sym).test(Fortran::semantics::Symbol::Flag::OmpPrivate) ||
+ GetSymbolDSA(sym).test(
+ Fortran::semantics::Symbol::Flag::OmpFirstPrivate) ||
+ GetSymbolDSA(sym).test(
+ Fortran::semantics::Symbol::Flag::OmpLastPrivate) ||
+ GetSymbolDSA(sym).test(Fortran::semantics::Symbol::Flag::OmpLinear))) {
+ // Do concurrent loop variables are not mapped yet since they are
+ // local to the Do concurrent scope (same for OpenMP loops).
+ mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+ builder->setInsertionPointToStart(builder->getAllocaBlock());
+ mlir::Type tempTy = genType(sym);
+ mlir::Value temp =
+ builder->createTemporaryAlloc(loc, tempTy, toStringRef(sym.name()));
+ bindIfNewSymbol(sym, temp);
+ builder->restoreInsertionPoint(insPt);
}
auto entry = lookupSymbol(sym);
(void)entry;
@@ -2060,10 +2064,10 @@ private:
// TODO Promote to using `enableDelayedPrivatization` (which is enabled by
// default unlike the staging flag) once the implementation of this is more
// complete.
- bool useDelayedPriv =
- enableDelayedPrivatizationStaging && doConcurrentLoopOp;
+ bool useDelayedPriv = enableDelayedPrivatization && doConcurrentLoopOp;
llvm::SetVector<const Fortran::semantics::Symbol *> allPrivatizedSymbols;
- llvm::SmallSet<const Fortran::semantics::Symbol *, 16> mightHaveReadHostSym;
+ llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 16>
+ mightHaveReadHostSym;
for (const Fortran::semantics::Symbol *symToPrivatize : info.localSymList) {
if (useDelayedPriv) {
@@ -2122,6 +2126,9 @@ private:
}
}
+ if (!doConcurrentLoopOp)
+ return;
+
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<mlir::Attribute> nestReduceAttrs;
@@ -4824,7 +4831,9 @@ private:
void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
const Fortran::evaluate::Assignment &assign,
- hlfir::Entity &lhs, hlfir::Entity &rhs) {
+ hlfir::Entity &lhs, hlfir::Entity &rhs,
+ bool isWholeAllocatableAssignment,
+ bool keepLhsLengthInAllocatableAssignment) {
bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs);
bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs);
@@ -4889,6 +4898,28 @@ private:
// host = device
if (!lhsIsDevice && rhsIsDevice) {
+ if (Fortran::lower::isTransferWithConversion(rhs)) {
+ mlir::OpBuilder::InsertionGuard insertionGuard(builder);
+ auto elementalOp =
+ mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp());
+ assert(elementalOp && "expect elemental op");
+ auto designateOp =
+ *elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
+ builder.setInsertionPoint(elementalOp);
+ // Create a temp to transfer the rhs before applying the conversion.
+ hlfir::Entity entity{designateOp.getMemref()};
+ auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, entity);
+ auto transferKindAttr = cuf::DataTransferKindAttr::get(
+ builder.getContext(), cuf::DataTransferKind::DeviceHost);
+ cuf::DataTransferOp::create(builder, loc, designateOp.getMemref(), temp,
+ /*shape=*/mlir::Value{}, transferKindAttr);
+ designateOp.getMemrefMutable().assign(temp);
+ builder.setInsertionPointAfter(elementalOp);
+ hlfir::AssignOp::create(builder, loc, elementalOp, lhs,
+ isWholeAllocatableAssignment,
+ keepLhsLengthInAllocatableAssignment);
+ return;
+ }
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceHost);
cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape,
@@ -4898,7 +4929,6 @@ private:
// device = device
if (lhsIsDevice && rhsIsDevice) {
- assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape,
@@ -5037,7 +5067,9 @@ private:
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
if (isCUDATransfer && !hasCUDAImplicitTransfer)
- genCUDADataTransfer(builder, loc, assign, lhs, rhs);
+ genCUDADataTransfer(builder, loc, assign, lhs, rhs,
+ isWholeAllocatableAssignment,
+ keepLhsLengthInAllocatableAssignment);
else
hlfir::AssignOp::create(builder, loc, rhs, lhs,
isWholeAllocatableAssignment,
diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt
index 8e20abf..eb4d57d 100644
--- a/flang/lib/Lower/CMakeLists.txt
+++ b/flang/lib/Lower/CMakeLists.txt
@@ -15,6 +15,7 @@ add_flang_library(FortranLower
ConvertProcedureDesignator.cpp
ConvertType.cpp
ConvertVariable.cpp
+ CUDA.cpp
CustomIntrinsicCall.cpp
HlfirIntrinsics.cpp
HostAssociations.cpp
@@ -59,6 +60,7 @@ add_flang_library(FortranLower
FortranParser
FortranEvaluate
FortranSemantics
+ FortranUtils
LINK_COMPONENTS
Support
diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp
new file mode 100644
index 0000000..1293d2c
--- /dev/null
+++ b/flang/lib/Lower/CUDA.cpp
@@ -0,0 +1,167 @@
+//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Lower/CUDA.h"
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+
+#define DEBUG_TYPE "flang-lower-cuda"
+
+void Fortran::lower::initializeDeviceComponentAllocator(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) {
+ if (const auto *details{
+ sym.GetUltimate()
+ .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
+ const Fortran::semantics::DeclTypeSpec *type{details->type()};
+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
+ : nullptr};
+ if (derived) {
+ if (!FindCUDADeviceAllocatableUltimateComponent(*derived))
+ return; // No device components.
+
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::Location loc = converter.getCurrentLocation();
+
+ mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType());
+
+ // Only pointer and allocatable needs post allocation initialization
+ // of components descriptors.
+ if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy))
+ return;
+
+ // Extract the derived type.
+ mlir::Type ty = fir::getDerivedType(baseTy);
+ auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
+ assert(recTy && "expected fir::RecordType");
+
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy))
+ baseTy = boxTy.getEleTy();
+ baseTy = fir::unwrapRefType(baseTy);
+
+ Fortran::semantics::UltimateComponentIterator components{*derived};
+ mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr());
+ mlir::Value addr;
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) {
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+ mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+ llvm::SmallVector<fir::DoLoopOp> loops;
+ llvm::SmallVector<mlir::Value> indices;
+ llvm::SmallVector<mlir::Value> extents;
+ for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
+ mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
+ auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy,
+ idxTy, loadedBox, dim);
+ mlir::Value lbub = mlir::arith::AddIOp::create(
+ builder, loc, dimInfo.getResult(0), dimInfo.getResult(1));
+ mlir::Value ext =
+ mlir::arith::SubIOp::create(builder, loc, lbub, one);
+ mlir::Value cmp = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero);
+ ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero);
+ extents.push_back(ext);
+
+ auto loop = fir::DoLoopOp::create(
+ builder, loc, dimInfo.getResult(0), dimInfo.getResult(1),
+ dimInfo.getResult(2), /*isUnordered=*/true,
+ /*finalCount=*/false, mlir::ValueRange{});
+ loops.push_back(loop);
+ indices.push_back(loop.getInductionVar());
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+ mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox);
+ auto shape = fir::ShapeOp::create(builder, loc, extents);
+ addr = fir::ArrayCoorOp::create(
+ builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape,
+ /*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{});
+ } else {
+ addr = fir::BoxAddrOp::create(builder, loc, loadedBox);
+ }
+ for (const auto &compSym : components) {
+ if (Fortran::semantics::IsDeviceAllocatable(compSym)) {
+ llvm::SmallVector<mlir::Value> coord;
+ mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType(
+ builder, loc, compSym, recTy, coord);
+ assert(coord.size() == 1 && "expect one coordinate");
+ mlir::Value comp = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(fieldTy), addr, coord[0]);
+ cuf::DataAttributeAttr dataAttr =
+ Fortran::lower::translateSymbolCUFDataAttribute(
+ builder.getContext(), compSym);
+ cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr);
+ }
+ }
+ }
+ }
+}
+
+mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
+ fir::FirOpBuilder &builder, mlir::Location loc,
+ const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
+ llvm::SmallVector<mlir::Value> &coordinates) {
+ unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
+ mlir::Type fieldTy;
+ if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+ // Field found in the base record type.
+ auto fieldName = recTy.getTypeList()[fieldIdx].first;
+ fieldTy = recTy.getTypeList()[fieldIdx].second;
+ mlir::Value fieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
+ recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(fieldIndex);
+ } else {
+ // Field not found in base record type, search in potential
+ // record type components.
+ for (auto component : recTy.getTypeList()) {
+ if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
+ fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
+ if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+ mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(childRecTy.getContext()),
+ component.first, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(parentFieldIndex);
+ auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
+ fieldTy = childRecTy.getTypeList()[fieldIdx].second;
+ mlir::Value childFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(fieldTy.getContext()),
+ fieldName, childRecTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(childFieldIndex);
+ break;
+ }
+ }
+ }
+ }
+ if (coordinates.empty())
+ TODO(loc, "device resident component in complex derived-type hierarchy");
+ return fieldTy;
+}
+
+cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
+ mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
+ std::optional<Fortran::common::CUDADataAttr> cudaAttr =
+ Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
+ return cuf::getDataAttribute(mlirContext, cudaAttr);
+}
+
+bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
+ if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
+ if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
+ llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
+ llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1)
+ return true;
+ return false;
+}
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index bf713f5..04dcc92 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -880,9 +880,10 @@ struct CallContext {
std::optional<mlir::Type> resultType, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap,
- Fortran::lower::StatementContext &stmtCtx)
+ Fortran::lower::StatementContext &stmtCtx, bool doCopyIn = true)
: procRef{procRef}, converter{converter}, symMap{symMap},
- stmtCtx{stmtCtx}, resultType{resultType}, loc{loc} {}
+ stmtCtx{stmtCtx}, resultType{resultType}, loc{loc}, doCopyIn{doCopyIn} {
+ }
fir::FirOpBuilder &getBuilder() { return converter.getFirOpBuilder(); }
@@ -924,6 +925,7 @@ struct CallContext {
Fortran::lower::StatementContext &stmtCtx;
std::optional<mlir::Type> resultType;
mlir::Location loc;
+ bool doCopyIn;
};
using ExvAndCleanup =
@@ -1161,18 +1163,6 @@ mlir::Value static getZeroLowerBounds(mlir::Location loc,
return builder.genShift(loc, lowerBounds);
}
-static bool
-isSimplyContiguous(const Fortran::evaluate::ActualArgument &arg,
- Fortran::evaluate::FoldingContext &foldingContext) {
- if (const auto *expr = arg.UnwrapExpr())
- return Fortran::evaluate::IsSimplyContiguous(*expr, foldingContext);
- const Fortran::semantics::Symbol *sym = arg.GetAssumedTypeDummy();
- assert(sym &&
- "expect ActualArguments to be expression or assumed-type symbols");
- return sym->Rank() == 0 ||
- Fortran::evaluate::IsSimplyContiguous(*sym, foldingContext);
-}
-
static bool isParameterObjectOrSubObject(hlfir::Entity entity) {
mlir::Value base = entity;
bool foundParameter = false;
@@ -1204,6 +1194,10 @@ static bool isParameterObjectOrSubObject(hlfir::Entity entity) {
/// fir.box_char...).
/// This function should only be called with an actual that is present.
/// The optional aspects must be handled by this function user.
+///
+/// Note: while Fortran::lower::CallerInterface::PassedEntity (the type of arg)
+/// is technically a template type, in the prepare*ActualArgument() calls
+/// it resolves to Fortran::evaluate::ActualArgument *
static PreparedDummyArgument preparePresentUserCallActualArgument(
mlir::Location loc, fir::FirOpBuilder &builder,
const Fortran::lower::PreparedActualArgument &preparedActual,
@@ -1211,9 +1205,6 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
const Fortran::lower::CallerInterface::PassedEntity &arg,
CallContext &callContext) {
- Fortran::evaluate::FoldingContext &foldingContext =
- callContext.converter.getFoldingContext();
-
// Step 1: get the actual argument, which includes addressing the
// element if this is an array in an elemental call.
hlfir::Entity actual = preparedActual.getActual(loc, builder);
@@ -1254,13 +1245,20 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
passingPolymorphicToNonPolymorphic &&
(actual.isArray() || mlir::isa<fir::BaseBoxType>(dummyType));
- // The simple contiguity of the actual is "lost" when passing a polymorphic
- // to a non polymorphic entity because the dummy dynamic type matters for
- // the contiguity.
- const bool mustDoCopyInOut =
- actual.isArray() && arg.mustBeMadeContiguous() &&
- (passingPolymorphicToNonPolymorphic ||
- !isSimplyContiguous(*arg.entity, foldingContext));
+ bool mustDoCopyIn{false};
+ bool mustDoCopyOut{false};
+
+ if (callContext.doCopyIn) {
+ Fortran::evaluate::FoldingContext &foldingContext{
+ callContext.converter.getFoldingContext()};
+
+ bool suggestCopyIn = Fortran::evaluate::MayNeedCopy(
+ arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/false);
+ bool suggestCopyOut = Fortran::evaluate::MayNeedCopy(
+ arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/true);
+ mustDoCopyIn = actual.isArray() && suggestCopyIn;
+ mustDoCopyOut = actual.isArray() && suggestCopyOut;
+ }
const bool actualIsAssumedRank = actual.isAssumedRank();
// Create dummy type with actual argument rank when the dummy is an assumed
@@ -1370,8 +1368,14 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
entity = hlfir::Entity{associate.getBase()};
// Register the temporary destruction after the call.
preparedDummy.pushExprAssociateCleanUp(associate);
- } else if (mustDoCopyInOut) {
+ } else if (mustDoCopyIn || mustDoCopyOut) {
// Copy-in non contiguous variables.
+ //
+ // TODO: copy-in and copy-out are now determined separately, in order
+ // to allow more fine grained copying. While currently both copy-in
+ // and copy-out are must be done together, these copy operations could
+ // be separated in the future. (This is related to TODO comment below.)
+ //
// TODO: for non-finalizable monomorphic derived type actual
// arguments associated with INTENT(OUT) dummy arguments
// we may avoid doing the copy and only allocate the temporary.
@@ -1379,7 +1383,7 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
// allocation for the temp in this case. We can communicate
// this to the codegen via some CopyInOp flag.
// This is a performance concern.
- entity = genCopyIn(entity, arg.mayBeModifiedByCall());
+ entity = genCopyIn(entity, mustDoCopyOut);
}
} else {
const Fortran::lower::SomeExpr *expr = arg.entity->UnwrapExpr();
@@ -2966,8 +2970,11 @@ void Fortran::lower::convertUserDefinedAssignmentToHLFIR(
const evaluate::ProcedureRef &procRef, hlfir::Entity lhs, hlfir::Entity rhs,
Fortran::lower::SymMap &symMap) {
Fortran::lower::StatementContext definedAssignmentContext;
+ // For defined assignment, don't use regular copy-in/copy-out mechanism:
+ // defined assignment generates hlfir.region_assign construct, and this
+ // construct automatically handles any copy-in.
CallContext callContext(procRef, /*resultType=*/std::nullopt, loc, converter,
- symMap, definedAssignmentContext);
+ symMap, definedAssignmentContext, /*doCopyIn=*/false);
Fortran::lower::CallerInterface caller(procRef, converter);
mlir::FunctionType callSiteType = caller.genFunctionType();
PreparedActualArgument preparedLhs{lhs, /*isPresent=*/std::nullopt};
diff --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp
index 768a237..376ec12 100644
--- a/flang/lib/Lower/ConvertConstant.cpp
+++ b/flang/lib/Lower/ConvertConstant.cpp
@@ -145,6 +145,9 @@ private:
fir::FirOpBuilder &builder,
const Fortran::evaluate::Constant<Fortran::evaluate::Type<TC, KIND>>
&constant) {
+ using Element =
+ Fortran::evaluate::Scalar<Fortran::evaluate::Type<TC, KIND>>;
+
static_assert(TC != Fortran::common::TypeCategory::Character,
"must be numerical or logical");
auto attrTc = TC == Fortran::common::TypeCategory::Logical
@@ -152,7 +155,24 @@ private:
: TC;
attributeElementType =
Fortran::lower::getFIRType(builder.getContext(), attrTc, KIND, {});
- for (auto element : constant.values())
+
+ const std::vector<Element> &values = constant.values();
+ auto sameElements = [&]() -> bool {
+ if (values.empty())
+ return false;
+
+ return std::all_of(values.begin(), values.end(),
+ [&](const auto &v) { return v == values.front(); });
+ };
+
+ if (sameElements()) {
+ auto attr = convertToAttribute<TC, KIND>(builder, values.front(),
+ attributeElementType);
+ attributes.assign(values.size(), attr);
+ return;
+ }
+
+ for (auto element : values)
attributes.push_back(
convertToAttribute<TC, KIND>(builder, element, attributeElementType));
}
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 5588f62..d7f94e1 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2750,7 +2750,7 @@ public:
fir::unwrapSequenceType(fir::unwrapPassByRefType(argTy))))
TODO(loc, "passing to an OPTIONAL CONTIGUOUS derived type argument "
"with length parameters");
- if (Fortran::evaluate::IsAssumedRank(*expr))
+ if (Fortran::semantics::IsAssumedRank(*expr))
TODO(loc, "passing an assumed rank entity to an OPTIONAL "
"CONTIGUOUS argument");
// Assumed shape VALUE are currently TODO in the call interface
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 9930dd6..81e09a1 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -26,7 +26,6 @@
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/MutableBox.h"
-#include "flang/Optimizer/Builder/Runtime/Character.h"
#include "flang/Optimizer/Builder/Runtime/Derived.h"
#include "flang/Optimizer/Builder/Runtime/Pointer.h"
#include "flang/Optimizer/Builder/Todo.h"
@@ -1286,16 +1285,8 @@ struct BinaryOp<Fortran::evaluate::Relational<
fir::FirOpBuilder &builder,
const Op &op, hlfir::Entity lhs,
hlfir::Entity rhs) {
- auto [lhsExv, lhsCleanUp] =
- hlfir::translateToExtendedValue(loc, builder, lhs);
- auto [rhsExv, rhsCleanUp] =
- hlfir::translateToExtendedValue(loc, builder, rhs);
- auto cmp = fir::runtime::genCharCompare(
- builder, loc, translateSignedRelational(op.opr), lhsExv, rhsExv);
- if (lhsCleanUp)
- (*lhsCleanUp)();
- if (rhsCleanUp)
- (*rhsCleanUp)();
+ auto cmp = hlfir::CmpCharOp::create(
+ builder, loc, translateSignedRelational(op.opr), lhs, rhs);
return hlfir::EntityWithAttributes{cmp};
}
};
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index a4a8a69..80af7f4 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -14,12 +14,12 @@
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/BoxAnalyzer.h"
+#include "flang/Lower/CUDA.h"
#include "flang/Lower/CallInterface.h"
#include "flang/Lower/ConvertConstant.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertProcedureDesignator.h"
-#include "flang/Lower/Cuda.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
@@ -814,81 +814,24 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
baseTy = boxTy.getEleTy();
baseTy = fir::unwrapRefType(baseTy);
- if (mlir::isa<fir::SequenceType>(baseTy) &&
- (fir::isAllocatableType(fir::getBase(exv).getType()) ||
- fir::isPointerType(fir::getBase(exv).getType())))
+ if (fir::isAllocatableType(fir::getBase(exv).getType()) ||
+ fir::isPointerType(fir::getBase(exv).getType()))
return; // Allocator index need to be set after allocation.
auto recTy =
mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy));
assert(recTy && "expected fir::RecordType");
- llvm::SmallVector<mlir::Value> coordinates;
Fortran::semantics::UltimateComponentIterator components{*derived};
for (const auto &sym : components) {
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
- unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
- mlir::Type fieldTy;
- llvm::SmallVector<mlir::Value> coordinates;
-
- if (fieldIdx != std::numeric_limits<unsigned>::max()) {
- // Field found in the base record type.
- auto fieldName = recTy.getTypeList()[fieldIdx].first;
- fieldTy = recTy.getTypeList()[fieldIdx].second;
- mlir::Value fieldIndex = fir::FieldIndexOp::create(
- builder, loc, fir::FieldType::get(fieldTy.getContext()),
- fieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- coordinates.push_back(fieldIndex);
- } else {
- // Field not found in base record type, search in potential
- // record type components.
- for (auto component : recTy.getTypeList()) {
- if (auto childRecTy =
- mlir::dyn_cast<fir::RecordType>(component.second)) {
- fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
- if (fieldIdx != std::numeric_limits<unsigned>::max()) {
- mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
- builder, loc,
- fir::FieldType::get(childRecTy.getContext()),
- component.first, recTy,
- /*typeParams=*/mlir::ValueRange{});
- coordinates.push_back(parentFieldIndex);
- auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
- fieldTy = childRecTy.getTypeList()[fieldIdx].second;
- mlir::Value childFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fir::FieldType::get(fieldTy.getContext()),
- fieldName, childRecTy,
- /*typeParams=*/mlir::ValueRange{});
- coordinates.push_back(childFieldIndex);
- break;
- }
- }
- }
- }
-
- if (coordinates.empty())
- TODO(loc, "device resident component in complex derived-type "
- "hierarchy");
-
+ llvm::SmallVector<mlir::Value> coord;
+ mlir::Type fieldTy =
+ Fortran::lower::gatherDeviceComponentCoordinatesAndType(
+ builder, loc, sym, recTy, coord);
mlir::Value base = fir::getBase(exv);
- mlir::Value comp;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) {
- mlir::Value box = fir::LoadOp::create(builder, loc, base);
- mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box);
- llvm::SmallVector<mlir::Value> lenParams;
- assert(coordinates.size() == 1 && "expect one coordinate");
- auto field = mlir::dyn_cast<fir::FieldIndexOp>(
- coordinates[0].getDefiningOp());
- comp = hlfir::DesignateOp::create(
- builder, loc, builder.getRefType(fieldTy), addr,
- /*component=*/field.getFieldName(),
- /*componentShape=*/mlir::Value{},
- hlfir::DesignateOp::Subscripts{});
- } else {
- comp = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(fieldTy), base, coordinates);
- }
+ mlir::Value comp = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(fieldTy), base, coord);
cuf::DataAttributeAttr dataAttr =
Fortran::lower::translateSymbolCUFDataAttribute(
builder.getContext(), sym);
@@ -1777,7 +1720,7 @@ static bool lowerToBoxValue(const Fortran::semantics::Symbol &sym,
return true;
// Assumed rank and optional fir.box cannot yet be read while lowering the
// specifications.
- if (Fortran::evaluate::IsAssumedRank(sym) ||
+ if (Fortran::semantics::IsAssumedRank(sym) ||
Fortran::semantics::IsOptional(sym))
return true;
// Polymorphic entity should be tracked through a fir.box that has the
@@ -1950,13 +1893,6 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
}
-cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
- mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
- std::optional<Fortran::common::CUDADataAttr> cudaAttr =
- Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
- return cuf::getDataAttribute(mlirContext, cudaAttr);
-}
-
static bool
isCapturedInInternalProcedure(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
@@ -2236,7 +2172,7 @@ void Fortran::lower::mapSymbolAttributes(
return;
}
- const bool isAssumedRank = Fortran::evaluate::IsAssumedRank(sym);
+ const bool isAssumedRank = Fortran::semantics::IsAssumedRank(sym);
if (isAssumedRank && !allowAssumedRank)
TODO(loc, "assumed-rank variable in procedure implemented in Fortran");
diff --git a/flang/lib/Lower/HlfirIntrinsics.cpp b/flang/lib/Lower/HlfirIntrinsics.cpp
index 6e1d06a..b9731e9 100644
--- a/flang/lib/Lower/HlfirIntrinsics.cpp
+++ b/flang/lib/Lower/HlfirIntrinsics.cpp
@@ -159,6 +159,18 @@ protected:
hlfir::CharExtremumPredicate pred;
};
+class HlfirCharTrimLowering : public HlfirTransformationalIntrinsic {
+public:
+ HlfirCharTrimLowering(fir::FirOpBuilder &builder, mlir::Location loc)
+ : HlfirTransformationalIntrinsic(builder, loc) {}
+
+protected:
+ mlir::Value
+ lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+ const fir::IntrinsicArgumentLoweringRules *argLowering,
+ mlir::Type stmtResultType) override;
+};
+
class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
public:
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
@@ -170,6 +182,17 @@ protected:
mlir::Type stmtResultType) override;
};
+class HlfirEOShiftLowering : public HlfirTransformationalIntrinsic {
+public:
+ using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+ mlir::Value
+ lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+ const fir::IntrinsicArgumentLoweringRules *argLowering,
+ mlir::Type stmtResultType) override;
+};
+
class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
public:
using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
@@ -410,6 +433,15 @@ mlir::Value HlfirCharExtremumLowering::lowerImpl(
return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
}
+mlir::Value HlfirCharTrimLowering::lowerImpl(
+ const Fortran::lower::PreparedActualArguments &loweredActuals,
+ const fir::IntrinsicArgumentLoweringRules *argLowering,
+ mlir::Type stmtResultType) {
+ auto operands = getOperandVector(loweredActuals, argLowering);
+ assert(operands.size() == 1);
+ return createOp<hlfir::CharTrimOp>(operands[0]);
+}
+
mlir::Value HlfirCShiftLowering::lowerImpl(
const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
@@ -430,6 +462,46 @@ mlir::Value HlfirCShiftLowering::lowerImpl(
return createOp<hlfir::CShiftOp>(resultType, operands);
}
+mlir::Value HlfirEOShiftLowering::lowerImpl(
+ const Fortran::lower::PreparedActualArguments &loweredActuals,
+ const fir::IntrinsicArgumentLoweringRules *argLowering,
+ mlir::Type stmtResultType) {
+ auto operands = getOperandVector(loweredActuals, argLowering);
+ assert(operands.size() == 4);
+ mlir::Value array = operands[0];
+ mlir::Value shift = operands[1];
+ mlir::Value boundary = operands[2];
+ mlir::Value dim = operands[3];
+ // If DIM is present, then dereference it if it is a ref.
+ if (dim)
+ dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+
+ mlir::Type resultType = computeResultType(array, stmtResultType);
+
+ if (boundary && fir::isa_trivial(boundary.getType())) {
+ mlir::Type elementType = hlfir::getFortranElementType(resultType);
+ if (auto logicalTy = mlir::dyn_cast<fir::LogicalType>(elementType)) {
+ // Scalar logical constant boundary might be represented using i1, i2, ...
+ // type. We need to cast it to fir.logical type of the ARRAY/result.
+ if (boundary.getType() != logicalTy)
+ boundary = builder.createConvert(loc, logicalTy, boundary);
+ } else {
+ // When the boundary is a constant like '1u', the lowering converts
+ // it into a signless arith.constant value (which is a requirement
+ // of the Arith dialect). If the ARRAY/RESULT is also UNSIGNED,
+ // we have to cast the boundary to the same unsigned type.
+ auto resultIntTy = mlir::dyn_cast<mlir::IntegerType>(elementType);
+ auto boundaryIntTy =
+ mlir::dyn_cast<mlir::IntegerType>(boundary.getType());
+ if (resultIntTy && boundaryIntTy &&
+ resultIntTy.getSignedness() != boundaryIntTy.getSignedness())
+ boundary = builder.createConvert(loc, resultIntTy, boundary);
+ }
+ }
+
+ return createOp<hlfir::EOShiftOp>(resultType, array, shift, boundary, dim);
+}
+
mlir::Value HlfirReshapeLowering::lowerImpl(
const Fortran::lower::PreparedActualArguments &loweredActuals,
const fir::IntrinsicArgumentLoweringRules *argLowering,
@@ -489,6 +561,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
if (name == "cshift")
return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
+ if (name == "eoshift")
+ return HlfirEOShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
+ stmtResultType);
if (name == "reshape")
return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
@@ -501,6 +576,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
return HlfirCharExtremumLowering{builder, loc,
hlfir::CharExtremumPredicate::max}
.lower(loweredActuals, argLowering, stmtResultType);
+ if (name == "trim")
+ return HlfirCharTrimLowering{builder, loc}.lower(
+ loweredActuals, argLowering, stmtResultType);
}
return std::nullopt;
}
diff --git a/flang/lib/Lower/HostAssociations.cpp b/flang/lib/Lower/HostAssociations.cpp
index 2a330cc..ad6aba1 100644
--- a/flang/lib/Lower/HostAssociations.cpp
+++ b/flang/lib/Lower/HostAssociations.cpp
@@ -431,7 +431,7 @@ public:
mlir::Value box = args.valueInTuple;
mlir::IndexType idxTy = builder.getIndexType();
llvm::SmallVector<mlir::Value> lbounds;
- if (!ba.lboundIsAllOnes() && !Fortran::evaluate::IsAssumedRank(sym)) {
+ if (!ba.lboundIsAllOnes() && !Fortran::semantics::IsAssumedRank(sym)) {
if (ba.isStaticArray()) {
for (std::int64_t lb : ba.staticLBound())
lbounds.emplace_back(builder.createIntegerConstant(loc, idxTy, lb));
@@ -490,7 +490,7 @@ private:
bool isPolymorphic = type && type->IsPolymorphic();
return isScalarOrContiguous && !isPolymorphic &&
!isDerivedWithLenParameters(sym) &&
- !Fortran::evaluate::IsAssumedRank(sym);
+ !Fortran::semantics::IsAssumedRank(sym);
}
};
} // namespace
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 35edcb0..7a84b21 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1575,7 +1575,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
if (bounds.empty()) {
llvm::SmallVector<mlir::Value> extents;
mlir::Type idxTy = builder.getIndexType();
- for (auto extent : seqTy.getShape()) {
+ for (auto extent : llvm::reverse(seqTy.getShape())) {
mlir::Value lb = mlir::arith::ConstantOp::create(
builder, loc, idxTy, builder.getIntegerAttr(idxTy, 0));
mlir::Value ub = mlir::arith::ConstantOp::create(
@@ -1607,12 +1607,11 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
}
} else {
// Lowerbound, upperbound and step are passed as block arguments.
- [[maybe_unused]] unsigned nbRangeArgs =
+ unsigned nbRangeArgs =
recipe.getCombinerRegion().getArguments().size() - 2;
assert((nbRangeArgs / 3 == seqTy.getDimension()) &&
"Expect 3 block arguments per dimension");
- for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
- i += 3) {
+ for (int i = nbRangeArgs - 1; i >= 2; i -= 3) {
mlir::Value lb = recipe.getCombinerRegion().getArgument(i);
mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1);
mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2);
@@ -1623,8 +1622,11 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ivs.push_back(loop.getInductionVar());
}
}
- auto addr1 = fir::CoordinateOp::create(builder, loc, refTy, value1, ivs);
- auto addr2 = fir::CoordinateOp::create(builder, loc, refTy, value2, ivs);
+ llvm::SmallVector<mlir::Value> reversedIvs(ivs.rbegin(), ivs.rend());
+ auto addr1 =
+ fir::CoordinateOp::create(builder, loc, refTy, value1, reversedIvs);
+ auto addr2 =
+ fir::CoordinateOp::create(builder, loc, refTy, value2, reversedIvs);
auto load1 = fir::LoadOp::create(builder, loc, addr1);
auto load2 = fir::LoadOp::create(builder, loc, addr2);
mlir::Value res =
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index ed0bff0..ff82a36 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -43,179 +43,6 @@ namespace omp {
using namespace Fortran::lower::omp;
}
-namespace {
-// An example of a type that can be used to get the return value from
-// the visitor:
-// visitor(type_identity<Xyz>) -> result_type
-using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
-
-struct GetProc
- : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
- false> {
- using Result = const evaluate::ProcedureDesignator *;
- using Base = evaluate::Traverse<GetProc, Result, false>;
- GetProc() : Base(*this) {}
-
- using Base::operator();
-
- static Result Default() { return nullptr; }
-
- Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
- static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
-};
-
-struct WithType {
- WithType(const evaluate::DynamicType &t) : type(t) {
- assert(type.category() != common::TypeCategory::Derived &&
- "Type cannot be a derived type");
- }
-
- template <typename VisitorTy> //
- auto visit(VisitorTy &&visitor) const
- -> std::invoke_result_t<VisitorTy, SomeArgType> {
- switch (type.category()) {
- case common::TypeCategory::Integer:
- switch (type.kind()) {
- case 1:
- return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
- case 8:
- return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
- case 16:
- return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
- }
- break;
- case common::TypeCategory::Unsigned:
- switch (type.kind()) {
- case 1:
- return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
- case 8:
- return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
- case 16:
- return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
- }
- break;
- case common::TypeCategory::Real:
- switch (type.kind()) {
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
- case 3:
- return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
- case 8:
- return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
- case 10:
- return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
- case 16:
- return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
- }
- break;
- case common::TypeCategory::Complex:
- switch (type.kind()) {
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
- case 3:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
- case 8:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
- case 10:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
- case 16:
- return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
- }
- break;
- case common::TypeCategory::Logical:
- switch (type.kind()) {
- case 1:
- return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
- case 8:
- return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
- }
- break;
- case common::TypeCategory::Character:
- switch (type.kind()) {
- case 1:
- return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
- case 2:
- return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
- case 4:
- return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
- }
- break;
- case common::TypeCategory::Derived:
- (void)Derived;
- break;
- }
- llvm_unreachable("Unhandled type");
- }
-
- const evaluate::DynamicType &type;
-
-private:
- // Shorter names.
- static constexpr auto Character = common::TypeCategory::Character;
- static constexpr auto Complex = common::TypeCategory::Complex;
- static constexpr auto Derived = common::TypeCategory::Derived;
- static constexpr auto Integer = common::TypeCategory::Integer;
- static constexpr auto Logical = common::TypeCategory::Logical;
- static constexpr auto Real = common::TypeCategory::Real;
- static constexpr auto Unsigned = common::TypeCategory::Unsigned;
-};
-
-template <typename T, typename U = std::remove_const_t<T>>
-U AsRvalue(T &t) {
- U copy{t};
- return std::move(copy);
-}
-
-template <typename T>
-T &&AsRvalue(T &&t) {
- return std::move(t);
-}
-
-struct ArgumentReplacer
- : public evaluate::Traverse<ArgumentReplacer, bool, false> {
- using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
- using Result = bool;
-
- Result Default() const { return false; }
-
- ArgumentReplacer(evaluate::ActualArguments &&newArgs)
- : Base(*this), args_(std::move(newArgs)) {}
-
- using Base::operator();
-
- template <typename T>
- Result operator()(const evaluate::FunctionRef<T> &x) {
- assert(!done_);
- auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
- mut.arguments() = args_;
- done_ = true;
- return true;
- }
-
- Result Combine(Result &&a, Result &&b) { return a || b; }
-
-private:
- bool done_{false};
- evaluate::ActualArguments &&args_;
-};
-} // namespace
-
[[maybe_unused]] static void
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
auto whatStr = [](int k) {
@@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
return nullptr;
}
-static bool replaceArgs(semantics::SomeExpr &expr,
- evaluate::ActualArguments &&newArgs) {
- return ArgumentReplacer(std::move(newArgs))(expr);
-}
-
-static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
- const evaluate::ProcedureDesignator &proc,
- const evaluate::ActualArguments &args) {
- return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
- using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
- return evaluate::AsGenericExpr(
- evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
- });
-}
-
-static const evaluate::ProcedureDesignator &
-getProcedureDesignator(const semantics::SomeExpr &call) {
- const evaluate::ProcedureDesignator *proc = GetProc{}(call);
- assert(proc && "Call has no procedure designator");
- return *proc;
-}
-
-static semantics::SomeExpr //
-genReducedMinMax(const semantics::SomeExpr &orig,
- const semantics::SomeExpr *atomArg,
- const std::vector<semantics::SomeExpr> &args) {
- // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
- // One of the a_i's, say a_t, must be atomArg.
- // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
- // call = min/max(a_t, tmp).
- // Return "call".
-
- // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
- // Make sure that the "tmp = min/max(...)" doesn't promote an optional
- // argument to a non-optional position. This could happen if a_t is at
- // position 0 or 1.
- if (args.size() <= 2)
- return orig;
-
- evaluate::ActualArguments nonAtoms;
-
- auto AsActual = [](const semantics::SomeExpr &x) {
- semantics::SomeExpr copy = x;
- return evaluate::ActualArgument(std::move(copy));
- };
- // Semantic checks guarantee that the "atom" shows exactly once in the
- // argument list (with potential conversions around it).
- // For the first two (non-optional) arguments, if "atom" is among them,
- // replace it with another occurrence of the other non-optional argument.
- if (atomArg == &args[0]) {
- // (atom, x, y...) -> (x, x, y...)
- nonAtoms.push_back(AsActual(args[1]));
- nonAtoms.push_back(AsActual(args[1]));
- } else if (atomArg == &args[1]) {
- // (x, atom, y...) -> (x, x, y...)
- nonAtoms.push_back(AsActual(args[0]));
- nonAtoms.push_back(AsActual(args[0]));
- } else {
- // (x, y, z...) -> unchanged
- nonAtoms.push_back(AsActual(args[0]));
- nonAtoms.push_back(AsActual(args[1]));
- }
-
- // The rest of arguments are optional, so we can just skip "atom".
- for (size_t i = 2, e = args.size(); i != e; ++i) {
- if (atomArg != &args[i])
- nonAtoms.push_back(AsActual(args[i]));
- }
-
- // The type of the intermediate min/max is the same as the type of its
- // arguments, which may be different from the type of the original
- // expression. The original expression may have additional coverts.
- auto tmp =
- makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
- semantics::SomeExpr call = orig;
- replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
- return call;
-}
-
static mlir::Operation * //
genAtomicRead(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter,
auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input);
assert(!args.empty() && "Update operation without arguments");
- // Pass args as an argument to avoid capturing a structured binding.
- const semantics::SomeExpr *atomArg = [&](auto &args) {
- for (const semantics::SomeExpr &e : args) {
- if (evaluate::IsSameOrConvertOf(e, atom))
- return &e;
- }
- llvm_unreachable("Atomic variable not in argument list");
- }(args);
-
- if (opcode == evaluate::operation::Operator::Min ||
- opcode == evaluate::operation::Operator::Max) {
- // Min and max operations are expanded inline, so reduce them to
- // operations with exactly two (non-optional) arguments.
- rhs = genReducedMinMax(rhs, atomArg, args);
- input = *evaluate::GetConvertInput(rhs);
- std::tie(opcode, args) =
- evaluate::GetTopLevelOperationIgnoreResizing(input);
- atomArg = nullptr; // No longer valid.
- }
for (auto &arg : args) {
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b98ad3c..6b9bd66 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -19,6 +19,7 @@
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
+#include "flang/Utils/OpenMP.h"
#include "llvm/Frontend/OpenMP/OMP.h.inc"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
@@ -647,10 +648,8 @@ addAlignedClause(lower::AbstractConverter &converter,
// The default alignment for some targets is equal to 0.
// Do not generate alignment assumption if alignment is less than or equal to
- // 0.
- if (alignment > 0) {
- // alignment value must be power of 2
- assert((alignment & (alignment - 1)) == 0 && "alignment is not power of 2");
+ // 0 or not a power of two
+ if (alignment > 0 && ((alignment & (alignment - 1)) == 0)) {
auto &objects = std::get<omp::ObjectList>(clause.t);
if (!objects.empty())
genObjectList(objects, converter, alignedVars);
@@ -1179,12 +1178,13 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
}
bool ClauseProcessor::processLink(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::Link>(
[&](const omp::clause::Link &clause, const parser::CharBlock &) {
// Case: declare target link(var1, var2)...
gatherFuncAndVarSyms(
- clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
+ clause.v, mlir::omp::DeclareTargetCaptureClause::link, result,
+ /*automap=*/false);
});
}
@@ -1280,7 +1280,7 @@ void ClauseProcessor::processMapObjects(
auto location = mlir::NameLoc::get(
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
baseOp.getLoc());
- mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+ mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp(
firOpBuilder, location, baseOp,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
@@ -1507,26 +1507,27 @@ bool ClauseProcessor::processTaskReduction(
}
bool ClauseProcessor::processTo(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::To>(
[&](const omp::clause::To &clause, const parser::CharBlock &) {
// Case: declare target to(func, var1, var2)...
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
- mlir::omp::DeclareTargetCaptureClause::to, result);
+ mlir::omp::DeclareTargetCaptureClause::to, result,
+ /*automap=*/false);
});
}
bool ClauseProcessor::processEnter(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::Enter>(
[&](const omp::clause::Enter &clause, const parser::CharBlock &source) {
- mlir::Location currentLocation = converter.genLocation(source);
- if (std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t))
- TODO(currentLocation, "Declare target enter AUTOMAP modifier");
+ bool automap =
+ std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t)
+ .has_value();
// Case: declare target enter(func, var1, var2)...
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
mlir::omp::DeclareTargetCaptureClause::enter,
- result);
+ result, automap);
});
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index f8a1f79..c46bdb3 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -118,7 +118,7 @@ public:
bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
mlir::omp::DependClauseOps &result) const;
bool
- processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+ processEnter(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::omp::IfClauseOps &result) const;
bool processInReduction(
@@ -129,7 +129,7 @@ public:
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
- processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+ processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
// This method is used to process a map clause.
// The optional parameter mapSyms is used to store the original Fortran symbol
@@ -150,7 +150,7 @@ public:
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
- bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+ bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDeviceAddrClauseOps &result,
@@ -208,11 +208,15 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation,
if (!x)
return;
unsigned version = semaCtx.langOptions().OpenMPVersion;
- TODO(currentLocation,
- "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
- " in " +
- llvm::omp::getOpenMPDirectiveName(directive, version).upper() +
- " construct");
+ bool isSimdDirective = llvm::omp::getOpenMPDirectiveName(directive, version)
+ .upper()
+ .find("SIMD") != llvm::StringRef::npos;
+ if (!semaCtx.langOptions().OpenMPSimd || isSimdDirective)
+ TODO(currentLocation,
+ "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
+ " in " +
+ llvm::omp::getOpenMPDirectiveName(directive, version).upper() +
+ " construct");
};
for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 7f75aae..1a16e1c 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -396,6 +396,8 @@ makePrescriptiveness(parser::OmpPrescriptiveness::Value v) {
switch (v) {
case parser::OmpPrescriptiveness::Value::Strict:
return clause::Prescriptiveness::Strict;
+ case parser::OmpPrescriptiveness::Value::Fallback:
+ return clause::Prescriptiveness::Fallback;
}
llvm_unreachable("Unexpected prescriptiveness");
}
@@ -770,6 +772,27 @@ Doacross make(const parser::OmpClause::Doacross &inp,
// DynamicAllocators: empty
+DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // imp.v -> OmpDyngroupprivateClause
+ CLAUSET_ENUM_CONVERT( //
+ convert, parser::OmpAccessGroup::Value, DynGroupprivate::AccessGroup,
+ // clang-format off
+ MS(Cgroup, Cgroup)
+ // clang-format on
+ );
+
+ auto &mods = semantics::OmpGetModifiers(inp.v);
+ auto *m0 = semantics::OmpGetUniqueModifier<parser::OmpAccessGroup>(mods);
+ auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpPrescriptiveness>(mods);
+ auto &size = std::get<parser::ScalarIntExpr>(inp.v.t);
+
+ return DynGroupprivate{
+ {/*AccessGroup=*/maybeApplyToV(convert, m0),
+ /*Prescriptiveness=*/maybeApplyToV(makePrescriptiveness, m1),
+ /*Size=*/makeExpr(size, semaCtx)}};
+}
+
Enter make(const parser::OmpClause::Enter &inp,
semantics::SemanticsContext &semaCtx) {
// inp.v -> parser::OmpEnterClause
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 67a9a46..146a252 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -30,18 +30,27 @@
#include "flang/Semantics/tools.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include <variant>
namespace Fortran {
namespace lower {
namespace omp {
bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const {
- return eval.visit(
- common::visitors{[&](const parser::OpenMPConstruct &functionParserNode) {
- return symDefMap.count(symbol) &&
- symDefMap.at(symbol) == &functionParserNode;
- },
- [](const auto &functionParserNode) { return false; }});
+ return eval.visit(common::visitors{
+ [&](const parser::OpenMPConstruct &functionParserNode) {
+ return symDefMap.count(symbol) &&
+ symDefMap.at(symbol) == ConstructPtr(&functionParserNode);
+ },
+ [](const auto &functionParserNode) { return false; }});
+}
+
+bool DataSharingProcessor::OMPConstructSymbolVisitor::
+ isSymbolDefineByNestedDeclaration(const semantics::Symbol *symbol) const {
+ return symDefMap.count(symbol) &&
+ std::holds_alternative<const parser::DeclarationConstruct *>(
+ symDefMap.at(symbol));
}
static bool isConstructWithTopLevelTarget(lower::pft::Evaluation &eval) {
@@ -81,13 +90,14 @@ DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter,
isTargetPrivatization) {}
void DataSharingProcessor::processStep1(
- mlir::omp::PrivateClauseOps *clauseOps) {
+ mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
collectImplicitSymbols();
collectPreDeterminedSymbols();
- privatize(clauseOps);
+ privatize(clauseOps, dir);
insertBarrier(clauseOps);
}
@@ -414,47 +424,10 @@ static parser::CharBlock getSource(const semantics::SemanticsContext &semaCtx,
});
}
-static void collectPrivatizingConstructs(
- llvm::SmallSet<llvm::omp::Directive, 16> &constructs, unsigned version) {
- using Clause = llvm::omp::Clause;
- using Directive = llvm::omp::Directive;
-
- static const Clause privatizingClauses[] = {
- Clause::OMPC_private,
- Clause::OMPC_lastprivate,
- Clause::OMPC_firstprivate,
- Clause::OMPC_in_reduction,
- Clause::OMPC_reduction,
- Clause::OMPC_linear,
- // TODO: Clause::OMPC_induction,
- Clause::OMPC_task_reduction,
- Clause::OMPC_detach,
- Clause::OMPC_use_device_ptr,
- Clause::OMPC_is_device_ptr,
- };
-
- for (auto dir : llvm::enum_seq_inclusive<Directive>(Directive::First_,
- Directive::Last_)) {
- bool allowsPrivatizing = llvm::any_of(privatizingClauses, [&](Clause cls) {
- return llvm::omp::isAllowedClauseForDirective(dir, cls, version);
- });
- if (allowsPrivatizing)
- constructs.insert(dir);
- }
-}
-
bool DataSharingProcessor::isOpenMPPrivatizingConstruct(
const parser::OpenMPConstruct &omp, unsigned version) {
- static llvm::SmallSet<llvm::omp::Directive, 16> privatizing;
- [[maybe_unused]] static bool init =
- (collectPrivatizingConstructs(privatizing, version), true);
-
- // As of OpenMP 6.0, privatizing constructs (with the test being if they
- // allow a privatizing clause) are: dispatch, distribute, do, for, loop,
- // parallel, scope, sections, simd, single, target, target_data, task,
- // taskgroup, taskloop, and teams.
- return llvm::is_contained(privatizing,
- parser::omp::GetOmpDirectiveName(omp).v);
+ return llvm::omp::isPrivatizingConstruct(
+ parser::omp::GetOmpDirectiveName(omp).v, version);
}
bool DataSharingProcessor::isOpenMPPrivatizingEvaluation(
@@ -550,11 +523,23 @@ void DataSharingProcessor::collectSymbols(
return false;
}
- return sym->test(semantics::Symbol::Flag::OmpImplicit);
+ // Collect implicit symbols only if they are not defined by a nested
+ // `DeclarationConstruct`. If `sym` is not defined by the current OpenMP
+ // evaluation then it is defined by a block nested within the OpenMP
+ // construct. This, in turn, means that the private allocation for the
+ // symbol will be emitted as part of the nested block and there is no need
+ // to privatize it within the OpenMP construct.
+ return !visitor.isSymbolDefineByNestedDeclaration(sym) &&
+ sym->test(semantics::Symbol::Flag::OmpImplicit);
}
- if (collectPreDetermined)
- return sym->test(semantics::Symbol::Flag::OmpPreDetermined);
+ if (collectPreDetermined) {
+ // Similar to implicit symbols, collect pre-determined symbols only if
+ // they are not defined by a nested `DeclarationConstruct`
+ return visitor.isSymbolDefineBy(sym, eval) &&
+ !visitor.isSymbolDefineByNestedDeclaration(sym) &&
+ sym->test(semantics::Symbol::Flag::OmpPreDetermined);
+ }
return !sym->test(semantics::Symbol::Flag::OmpImplicit) &&
!sym->test(semantics::Symbol::Flag::OmpPreDetermined);
@@ -597,14 +582,15 @@ void DataSharingProcessor::collectPreDeterminedSymbols() {
preDeterminedSymbols);
}
-void DataSharingProcessor::privatize(mlir::omp::PrivateClauseOps *clauseOps) {
+void DataSharingProcessor::privatize(mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir) {
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
if (const auto *commonDet =
sym->detailsIf<semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects())
- privatizeSymbol(&*mem, clauseOps);
+ privatizeSymbol(&*mem, clauseOps, dir);
} else
- privatizeSymbol(sym, clauseOps);
+ privatizeSymbol(sym, clauseOps, dir);
}
}
@@ -623,7 +609,8 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
void DataSharingProcessor::privatizeSymbol(
const semantics::Symbol *symToPrivatize,
- mlir::omp::PrivateClauseOps *clauseOps) {
+ mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir) {
if (!useDelayedPrivatization) {
cloneSymbol(symToPrivatize);
copyFirstPrivateSymbol(symToPrivatize);
@@ -633,7 +620,7 @@ void DataSharingProcessor::privatizeSymbol(
Fortran::lower::privatizeSymbol<mlir::omp::PrivateClauseOp,
mlir::omp::PrivateClauseOps>(
converter, firOpBuilder, symTable, allPrivatizedSymbols,
- mightHaveReadHostSym, symToPrivatize, clauseOps);
+ mightHaveReadHostSym, symToPrivatize, clauseOps, dir);
}
} // namespace omp
} // namespace lower
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 96e7fa6..f6aa865 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include <variant>
namespace mlir {
namespace omp {
@@ -58,20 +59,35 @@ private:
}
void Post(const parser::Name &name) {
- auto *current = !constructs.empty() ? constructs.back() : nullptr;
+ auto current = !constructs.empty() ? constructs.back() : ConstructPtr();
symDefMap.try_emplace(name.symbol, current);
}
- llvm::SmallVector<const parser::OpenMPConstruct *> constructs;
- llvm::DenseMap<semantics::Symbol *, const parser::OpenMPConstruct *>
- symDefMap;
+ bool Pre(const parser::DeclarationConstruct &decl) {
+ constructs.push_back(&decl);
+ return true;
+ }
+
+ void Post(const parser::DeclarationConstruct &decl) {
+ constructs.pop_back();
+ }
/// Given a \p symbol and an \p eval, returns true if eval is the OMP
/// construct that defines symbol.
bool isSymbolDefineBy(const semantics::Symbol *symbol,
lower::pft::Evaluation &eval) const;
+ // Given a \p symbol, returns true if it is defined by a nested
+ // `DeclarationConstruct`.
+ bool
+ isSymbolDefineByNestedDeclaration(const semantics::Symbol *symbol) const;
+
private:
+ using ConstructPtr = std::variant<const parser::OpenMPConstruct *,
+ const parser::DeclarationConstruct *>;
+ llvm::SmallVector<ConstructPtr> constructs;
+ llvm::DenseMap<semantics::Symbol *, ConstructPtr> symDefMap;
+
unsigned version;
};
@@ -91,7 +107,7 @@ private:
lower::pft::Evaluation &eval;
bool shouldCollectPreDeterminedSymbols;
bool useDelayedPrivatization;
- llvm::SmallSet<const semantics::Symbol *, 16> mightHaveReadHostSym;
+ llvm::SmallPtrSet<const semantics::Symbol *, 16> mightHaveReadHostSym;
lower::SymMap &symTable;
bool isTargetPrivatization;
OMPConstructSymbolVisitor visitor;
@@ -110,7 +126,8 @@ private:
void collectDefaultSymbols();
void collectImplicitSymbols();
void collectPreDeterminedSymbols();
- void privatize(mlir::omp::PrivateClauseOps *clauseOps);
+ void privatize(mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir = std::nullopt);
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const semantics::Symbol *sym);
@@ -151,7 +168,8 @@ public:
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
- void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr);
+ void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr,
+ std::optional<llvm::omp::Directive> dir = std::nullopt);
void processStep2(mlir::Operation *op, bool isLoop);
void pushLoopIV(mlir::Value iv) { loopIVs.push_back(iv); }
@@ -168,7 +186,8 @@ public:
}
void privatizeSymbol(const semantics::Symbol *symToPrivatize,
- mlir::omp::PrivateClauseOps *clauseOps);
+ mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir = std::nullopt);
};
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index db6a0e2..574c322 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -34,9 +34,11 @@
#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/tools.h"
#include "flang/Support/Flags.h"
#include "flang/Support/OpenMP-utils.h"
+#include "flang/Utils/OpenMP.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Support/StateStack.h"
@@ -46,6 +48,7 @@
using namespace Fortran::lower::omp;
using namespace Fortran::common::openmp;
+using namespace Fortran::utils::openmp;
//===----------------------------------------------------------------------===//
// Code generation helper functions
@@ -406,7 +409,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
const parser::OmpClauseList *endClauseList = nullptr;
common::visit(
common::visitors{
- [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+ [&](const parser::OmpBlockConstruct &ompConstruct) {
beginClauseList = &ompConstruct.BeginDir().Clauses();
if (auto &endSpec = ompConstruct.EndDir())
endClauseList = &endSpec->Clauses();
@@ -533,6 +536,13 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
+ case OMPD_teams_workdistribute:
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
+ [[fallthrough]];
+ case OMPD_target_teams_workdistribute:
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
+ break;
+
// Standalone 'target' case.
case OMPD_target: {
processSingleNestedIf(
@@ -764,14 +774,14 @@ static void getDeclareTargetInfo(
lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
mlir::omp::DeclareTargetOperands &clauseOps,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause) {
const auto &spec =
std::get<parser::OmpDeclareTargetSpecifier>(declareTargetConstruct.t);
if (const auto *objectList{parser::Unwrap<parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeObjects(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
- symbolAndClause);
+ symbolAndClause, /*automap=*/false);
} else if (const auto *clauseList{
parser::Unwrap<parser::OmpClauseList>(spec.u)}) {
List<Clause> clauses = makeClauses(*clauseList, semaCtx);
@@ -804,21 +814,20 @@ static void collectDeferredDeclareTargets(
llvm::SmallVectorImpl<lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
mlir::omp::DeclareTargetOperands clauseOps;
- llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
+ llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
- mlir::Operation *op = mod.lookupSymbol(
- converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
+ for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
+ mlir::Operation *op =
+ mod.lookupSymbol(converter.mangleName(symClause.symbol));
if (!op) {
- deferredDeclareTarget.push_back({std::get<0>(symClause),
- clauseOps.deviceType,
- std::get<1>(symClause)});
+ deferredDeclareTarget.push_back({symClause.clause, clauseOps.deviceType,
+ symClause.automap, symClause.symbol});
}
}
}
@@ -829,16 +838,16 @@ getDeclareTargetFunctionDevice(
lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
mlir::omp::DeclareTargetOperands clauseOps;
- llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
+ llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
- mlir::Operation *op = mod.lookupSymbol(
- converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
+ for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
+ mlir::Operation *op =
+ mod.lookupSymbol(converter.mangleName(symClause.symbol));
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
return clauseOps.deviceType;
@@ -1055,7 +1064,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder,
static void
markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
mlir::omp::DeclareTargetCaptureClause captureClause,
- mlir::omp::DeclareTargetDeviceType deviceType) {
+ mlir::omp::DeclareTargetDeviceType deviceType, bool automap) {
// TODO: Add support for program local variables with declare target applied
auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
if (!declareTargetOp)
@@ -1070,11 +1079,11 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
if (declareTargetOp.isDeclareTarget()) {
if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
- captureClause);
+ captureClause, automap);
return;
}
- declareTargetOp.setDeclareTarget(deviceType, captureClause);
+ declareTargetOp.setDeclareTarget(deviceType, captureClause, automap);
}
//===----------------------------------------------------------------------===//
@@ -2262,7 +2271,8 @@ genOrderedOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
- TODO(loc, "OMPD_ordered");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(loc, "OMPD_ordered");
return nullptr;
}
@@ -2449,7 +2459,8 @@ genScopeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
- TODO(loc, "Scope construct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(loc, "Scope construct");
return nullptr;
}
@@ -2818,6 +2829,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);
}
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ return genOpWithBody<mlir::omp::WorkdistributeOp>(
+ OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+ llvm::omp::Directive::OMPD_workdistribute),
+ queue, item);
+}
+
//===----------------------------------------------------------------------===//
// Code generation functions for the standalone version of constructs that can
// also be a leaf of a composite construct
@@ -3235,7 +3257,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd(
DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
/*useDelayedPrivatization=*/true, symTable);
- simdItemDSP.processStep1(&simdClauseOps);
+ simdItemDSP.processStep1(&simdClauseOps, simdItem->id);
// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
@@ -3276,7 +3298,8 @@ static mlir::omp::TaskloopOp genCompositeTaskloopSimd(
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
- TODO(loc, "Composite TASKLOOP SIMD");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(loc, "Composite TASKLOOP SIMD");
return nullptr;
}
@@ -3448,13 +3471,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_tile: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
- TODO(loc, "Unhandled loop directive (" +
- llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(loc, "Unhandled loop directive (" +
+ llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
+ break;
}
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
- // case llvm::omp::Directive::OMPD_workdistribute:
+ case llvm::omp::Directive::OMPD_workdistribute:
+ newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+ item);
+ break;
case llvm::omp::Directive::OMPD_workshare:
newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
queue, item);
@@ -3484,35 +3512,40 @@ static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclarativeAllocate &declarativeAllocate) {
- TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPDeclarativeAssumes &assumesConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration");
}
static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OmpDeclareVariantDirective &declareVariantDirective) {
- TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
}
static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareSimdConstruct &declareSimdConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
}
static void
@@ -3563,14 +3596,14 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
mlir::omp::DeclareTargetOperands clauseOps;
- llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
+ llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
- for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
- mlir::Operation *op = mod.lookupSymbol(
- converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
+ for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
+ mlir::Operation *op =
+ mod.lookupSymbol(converter.mangleName(symClause.symbol));
// Some symbols are deferred until later in the module, these are handled
// upon finalization of the module for OpenMP inside of Bridge, so we simply
@@ -3578,16 +3611,21 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (!op)
continue;
- markDeclareTarget(
- op, converter,
- std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
- clauseOps.deviceType);
+ markDeclareTarget(op, converter, symClause.clause, clauseOps.deviceType,
+ symClause.automap);
}
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
+ const parser::OpenMPGroupprivate &directive) {
+ TODO(converter.getCurrentLocation(), "GROUPPRIVATE");
+}
+
+static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval,
const parser::OpenMPRequiresConstruct &requiresConstruct) {
// Requires directives are gathered and processed in semantics and
// then combined in the lowering bridge before triggering codegen
@@ -3708,14 +3746,16 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
(void)objects;
(void)clauses;
- TODO(converter.getCurrentLocation(), "OpenMPDepobjConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPDepobjConstruct");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPInteropConstruct &interopConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPInteropConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPInteropConstruct");
}
static void
@@ -3731,7 +3771,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPAllocatorsConstruct &allocsConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct");
}
//===----------------------------------------------------------------------===//
@@ -3748,7 +3789,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
- const parser::OpenMPBlockConstruct &blockConstruct) {
+ const parser::OmpBlockConstruct &blockConstruct) {
const parser::OmpDirectiveSpecification &beginSpec =
blockConstruct.BeginDir();
List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx);
@@ -3797,7 +3838,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Detach>(clause.u)) {
std::string name =
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
- TODO(clauseLocation, name + " clause is not implemented yet");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(clauseLocation, name + " clause is not implemented yet");
}
}
@@ -3813,46 +3855,61 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::pft::Evaluation &eval,
const parser::OpenMPAssumeConstruct &assumeConstruct) {
mlir::Location clauseLocation = converter.genLocation(assumeConstruct.source);
- TODO(clauseLocation, "OpenMP ASSUME construct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(clauseLocation, "OpenMP ASSUME construct");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPCriticalConstruct &criticalConstruct) {
- const auto &cd = std::get<parser::OmpCriticalDirective>(criticalConstruct.t);
- List<Clause> clauses =
- makeClauses(std::get<parser::OmpClauseList>(cd.t), semaCtx);
+ const parser::OmpDirectiveSpecification &beginSpec =
+ criticalConstruct.BeginDir();
+ List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx);
ConstructQueue queue{buildConstructQueue(
- converter.getFirOpBuilder().getModule(), semaCtx, eval, cd.source,
+ converter.getFirOpBuilder().getModule(), semaCtx, eval, beginSpec.source,
llvm::omp::Directive::OMPD_critical, clauses)};
- const auto &name = std::get<std::optional<parser::Name>>(cd.t);
+ std::optional<parser::Name> critName;
+ const parser::OmpArgumentList &args = beginSpec.Arguments();
+ if (!args.v.empty()) {
+ // All of these things should be guaranteed to exist after semantic checks.
+ auto *object = parser::Unwrap<parser::OmpObject>(args.v.front());
+ assert(object && "Expecting object as argument");
+ auto *designator = semantics::omp::GetDesignatorFromObj(*object);
+ assert(designator && "Expecting desginator in argument");
+ auto *name = semantics::getDesignatorNameIfDataRef(*designator);
+ assert(name && "Expecting dataref in designator");
+ critName = *name;
+ }
mlir::Location currentLocation = converter.getCurrentLocation();
genCriticalOp(converter, symTable, semaCtx, eval, currentLocation, queue,
- queue.begin(), name);
+ queue.begin(), critName);
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPUtilityConstruct &) {
- TODO(converter.getCurrentLocation(), "OpenMPUtilityConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPUtilityConstruct");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPDispatchConstruct &) {
- TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPExecutableAllocate &execAllocConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
+ if (!semaCtx.langOptions().OpenMPSimd)
+ TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
}
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
@@ -3924,9 +3981,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
List<Clause> clauses = makeClauses(
std::get<parser::OmpClauseList>(beginSectionsDirective.t), semaCtx);
const auto &endSectionsDirective =
- std::get<parser::OmpEndSectionsDirective>(sectionsConstruct.t);
+ std::get<std::optional<parser::OmpEndSectionsDirective>>(
+ sectionsConstruct.t);
+ assert(endSectionsDirective &&
+ "Missing end section directive should have been handled in semantics");
clauses.append(makeClauses(
- std::get<parser::OmpClauseList>(endSectionsDirective.t), semaCtx));
+ std::get<parser::OmpClauseList>(endSectionsDirective->t), semaCtx));
mlir::Location currentLocation = converter.getCurrentLocation();
llvm::omp::Directive directive =
@@ -4090,7 +4150,7 @@ void Fortran::lower::genDeclareTargetIntGlobal(
bool Fortran::lower::isOpenMPTargetConstruct(
const parser::OpenMPConstruct &omp) {
llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
- if (const auto *block = std::get_if<parser::OpenMPBlockConstruct>(&omp.u)) {
+ if (const auto *block = std::get_if<parser::OmpBlockConstruct>(&omp.u)) {
dir = block->BeginDir().DirId();
} else if (const auto *loop =
std::get_if<parser::OpenMPLoopConstruct>(&omp.u)) {
@@ -4164,7 +4224,7 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
deviceCodeFound = true;
markDeclareTarget(op, converter, declTar.declareTargetCaptureClause,
- devType);
+ devType, declTar.automap);
}
return deviceCodeFound;
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 13fda97..cb6dd57 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -24,6 +24,7 @@
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
+#include <flang/Utils/OpenMP.h>
#include <llvm/Support/CommandLine.h>
#include <iterator>
@@ -102,41 +103,10 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval) {
void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
+ bool automap) {
for (const Object &object : objects)
- symbolAndClause.emplace_back(clause, *object.sym());
-}
-
-mlir::omp::MapInfoOp
-createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value baseAddr, mlir::Value varPtrPtr,
- llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds,
- llvm::ArrayRef<mlir::Value> members,
- mlir::ArrayAttr membersIndex, uint64_t mapType,
- mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
- bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
- if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
- baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr);
- retTy = baseAddr.getType();
- }
-
- mlir::TypeAttr varType = mlir::TypeAttr::get(
- llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
-
- // For types with unknown extents such as <2x?xi32> we discard the incomplete
- // type info and only retain the base type. The correct dimensions are later
- // recovered through the bounds info.
- if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue()))
- if (seqType.hasDynamicExtents())
- varType = mlir::TypeAttr::get(seqType.getEleTy());
-
- mlir::omp::MapInfoOp op = mlir::omp::MapInfoOp::create(
- builder, loc, retTy, baseAddr, varType,
- builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
- builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
- varPtrPtr, members, membersIndex, bounds, mapperId,
- builder.getStringAttr(name), builder.getBoolAttr(partialMap));
- return op;
+ symbolAndClause.emplace_back(clause, *object.sym(), automap);
}
// This function gathers the individual omp::Object's that make up a
@@ -402,7 +372,7 @@ mlir::Value createParentSymAndGenIntermediateMaps(
// Create a map for the intermediate member and insert it and it's
// indices into the parentMemberIndices list to track it.
- mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+ mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp(
firOpBuilder, clauseLocation, curValue,
/*varPtrPtr=*/mlir::Value{}, asFortran,
/*bounds=*/interimBounds,
@@ -562,7 +532,7 @@ void insertChildMapInfoIntoParent(
converter.getCurrentLocation(), asFortran, bounds,
treatIndexAsSection);
- mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+ mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp(
firOpBuilder, info.rawInput.getLoc(), info.rawInput,
/*varPtrPtr=*/mlir::Value(), asFortran.str(), bounds, members,
firOpBuilder.create2DI64ArrayAttr(
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 11641ba..88371ab 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -42,8 +42,15 @@ class AbstractConverter;
namespace omp {
-using DeclareTargetCapturePair =
- std::pair<mlir::omp::DeclareTargetCaptureClause, const semantics::Symbol &>;
+struct DeclareTargetCaptureInfo {
+ mlir::omp::DeclareTargetCaptureClause clause;
+ bool automap = false;
+ const semantics::Symbol &symbol;
+
+ DeclareTargetCaptureInfo(mlir::omp::DeclareTargetCaptureClause c,
+ const semantics::Symbol &s, bool a = false)
+ : clause(c), automap(a), symbol(s) {}
+};
// A small helper structure for keeping track of a component members MapInfoOp
// and index data when lowering OpenMP map clauses. Keeps track of the
@@ -107,16 +114,6 @@ struct OmpMapParentAndMemberData {
semantics::SemanticsContext &semaCtx);
};
-mlir::omp::MapInfoOp
-createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value baseAddr, mlir::Value varPtrPtr,
- llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds,
- llvm::ArrayRef<mlir::Value> members,
- mlir::ArrayAttr membersIndex, uint64_t mapType,
- mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
- bool partialMap = false,
- mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());
-
void insertChildMapInfoIntoParent(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
@@ -150,7 +147,8 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval);
void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
+ llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
+ bool automap = false);
int64_t getCollapseValue(const List<Clause> &clauses);
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index a28cc01..80f31c2 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -1742,11 +1742,11 @@ private:
layeredVarList[i].end());
}
- llvm::SmallSet<const semantics::Symbol *, 32> seen;
+ llvm::SmallPtrSet<const semantics::Symbol *, 32> seen;
std::vector<Fortran::lower::pft::VariableList> layeredVarList;
- llvm::SmallSet<const semantics::Symbol *, 32> aliasSyms;
+ llvm::SmallPtrSet<const semantics::Symbol *, 32> aliasSyms;
/// Set of scopes that have been analyzed for aliases.
- llvm::SmallSet<const semantics::Scope *, 4> analyzedScopes;
+ llvm::SmallPtrSet<const semantics::Scope *, 4> analyzedScopes;
std::vector<Fortran::lower::pft::Variable::AggregateStore> stores;
};
} // namespace
diff --git a/flang/lib/Lower/Runtime.cpp b/flang/lib/Lower/Runtime.cpp
index fc59a24..494dd49 100644
--- a/flang/lib/Lower/Runtime.cpp
+++ b/flang/lib/Lower/Runtime.cpp
@@ -39,8 +39,7 @@ static void genUnreachable(fir::FirOpBuilder &builder, mlir::Location loc) {
if (parentOp->getDialect()->getNamespace() ==
mlir::omp::OpenMPDialect::getDialectNamespace())
Fortran::lower::genOpenMPTerminator(builder, parentOp, loc);
- else if (parentOp->getDialect()->getNamespace() ==
- mlir::acc::OpenACCDialect::getDialectNamespace())
+ else if (Fortran::lower::isInsideOpenACCComputeConstruct(builder))
Fortran::lower::genOpenACCTerminator(builder, parentOp, loc);
else
fir::UnreachableOp::create(builder, loc);
diff --git a/flang/lib/Lower/Support/PrivateReductionUtils.cpp b/flang/lib/Lower/Support/PrivateReductionUtils.cpp
index fff060b..1b09801 100644
--- a/flang/lib/Lower/Support/PrivateReductionUtils.cpp
+++ b/flang/lib/Lower/Support/PrivateReductionUtils.cpp
@@ -616,6 +616,8 @@ void PopulateInitAndCleanupRegionsHelper::populateByRefInitAndCleanupRegions() {
assert(sym && "Symbol information is required to privatize derived types");
assert(!scalarInitValue && "ScalarInitvalue is unused for privatization");
}
+ if (hlfir::Entity{moldArg}.isAssumedRank())
+ TODO(loc, "Privatization of assumed rank variable");
mlir::Type valTy = fir::unwrapRefType(argType);
if (fir::isa_trivial(valTy)) {
diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp
index 881401e..1b4d37e 100644
--- a/flang/lib/Lower/Support/Utils.cpp
+++ b/flang/lib/Lower/Support/Utils.cpp
@@ -654,8 +654,9 @@ void privatizeSymbol(
lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder,
lower::SymMap &symTable,
llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols,
- llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
- const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps) {
+ llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
+ const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps,
+ std::optional<llvm::omp::Directive> dir) {
constexpr bool isDoConcurrent =
std::is_same_v<OpType, fir::LocalitySpecifierOp>;
mlir::OpBuilder::InsertPoint dcIP;
@@ -676,6 +677,13 @@ void privatizeSymbol(
bool emitCopyRegion =
symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
symToPrivatize->test(semantics::Symbol::Flag::LocalityLocalInit);
+ // A symbol attached to the simd directive can have the firstprivate flag set
+ // on it when it is also used in a non-firstprivate privatization clause.
+ // For instance: $omp do simd lastprivate(a) firstprivate(a)
+ // We cannot apply the firstprivate privatizer to simd, so make sure we do
+ // not emit the copy region when dealing with the SIMD directive.
+ if (dir && dir == llvm::omp::Directive::OMPD_simd)
+ emitCopyRegion = false;
mlir::Value privVal = hsb.getAddr();
mlir::Type allocType = privVal.getType();
@@ -846,17 +854,19 @@ privatizeSymbol<mlir::omp::PrivateClauseOp, mlir::omp::PrivateClauseOps>(
lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder,
lower::SymMap &symTable,
llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols,
- llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
+ llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
const semantics::Symbol *symToPrivatize,
- mlir::omp::PrivateClauseOps *clauseOps);
+ mlir::omp::PrivateClauseOps *clauseOps,
+ std::optional<llvm::omp::Directive> dir);
template void
privatizeSymbol<fir::LocalitySpecifierOp, fir::LocalitySpecifierOperands>(
lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder,
lower::SymMap &symTable,
llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols,
- llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
+ llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym,
const semantics::Symbol *symToPrivatize,
- fir::LocalitySpecifierOperands *clauseOps);
+ fir::LocalitySpecifierOperands *clauseOps,
+ std::optional<llvm::omp::Directive> dir);
} // end namespace Fortran::lower
diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt
index 31ae395..404afd1 100644
--- a/flang/lib/Optimizer/Builder/CMakeLists.txt
+++ b/flang/lib/Optimizer/Builder/CMakeLists.txt
@@ -16,6 +16,7 @@ add_flang_library(FIRBuilder
Runtime/Allocatable.cpp
Runtime/ArrayConstructor.cpp
Runtime/Assign.cpp
+ Runtime/Coarray.cpp
Runtime/Character.cpp
Runtime/Command.cpp
Runtime/CUDA/Descriptor.cpp
@@ -49,6 +50,7 @@ add_flang_library(FIRBuilder
FIRDialectSupport
FIRSupport
FortranEvaluate
+ FortranSupport
HLFIRDialect
MLIR_DEPS
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 87a52ff..b6501fd 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -147,8 +147,20 @@ mlir::Value fir::FirOpBuilder::createIntegerConstant(mlir::Location loc,
assert((cst >= 0 || mlir::isa<mlir::IndexType>(ty) ||
mlir::cast<mlir::IntegerType>(ty).getWidth() <= 64) &&
"must use APint");
- return mlir::arith::ConstantOp::create(*this, loc, ty,
- getIntegerAttr(ty, cst));
+
+ mlir::Type cstType = ty;
+ if (auto intType = mlir::dyn_cast<mlir::IntegerType>(ty)) {
+ // Signed and unsigned constants must be encoded as signless
+ // arith.constant followed by fir.convert cast.
+ if (intType.isUnsigned())
+ cstType = mlir::IntegerType::get(getContext(), intType.getWidth());
+ else if (intType.isSigned())
+ TODO(loc, "signed integer constant");
+ }
+
+ mlir::Value cstValue = mlir::arith::ConstantOp::create(
+ *this, loc, cstType, getIntegerAttr(cstType, cst));
+ return createConvert(loc, ty, cstValue);
}
mlir::Value fir::FirOpBuilder::createAllOnesInteger(mlir::Location loc,
@@ -411,10 +423,11 @@ mlir::Value fir::FirOpBuilder::genTempDeclareOp(
llvm::ArrayRef<mlir::Value> typeParams,
fir::FortranVariableFlagsAttr fortranAttrs) {
auto nameAttr = mlir::StringAttr::get(builder.getContext(), name);
- return fir::DeclareOp::create(builder, loc, memref.getType(), memref, shape,
- typeParams,
- /*dummy_scope=*/nullptr, nameAttr, fortranAttrs,
- cuf::DataAttributeAttr{});
+ return fir::DeclareOp::create(
+ builder, loc, memref.getType(), memref, shape, typeParams,
+ /*dummy_scope=*/nullptr,
+ /*storage=*/nullptr,
+ /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{});
}
mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) {
@@ -1947,17 +1960,17 @@ void fir::factory::genDimInfoFromBox(
mlir::Value fir::factory::genLifetimeStart(mlir::OpBuilder &builder,
mlir::Location loc,
- fir::AllocaOp alloc, int64_t size,
+ fir::AllocaOp alloc,
const mlir::DataLayout *dl) {
mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(
alloc.getContext(), getAllocaAddressSpace(dl));
mlir::Value cast =
fir::ConvertOp::create(builder, loc, ptrTy, alloc.getResult());
- mlir::LLVM::LifetimeStartOp::create(builder, loc, size, cast);
+ mlir::LLVM::LifetimeStartOp::create(builder, loc, cast);
return cast;
}
void fir::factory::genLifetimeEnd(mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value cast, int64_t size) {
- mlir::LLVM::LifetimeEndOp::create(builder, loc, size, cast);
+ mlir::Value cast) {
+ mlir::LLVM::LifetimeEndOp::create(builder, loc, cast);
}
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index b6d692a..086dd66 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -416,7 +416,10 @@ hlfir::Entity hlfir::loadTrivialScalar(mlir::Location loc,
entity = derefPointersAndAllocatables(loc, builder, entity);
if (entity.isVariable() && entity.isScalar() &&
fir::isa_trivial(entity.getFortranElementType())) {
- return Entity{fir::LoadOp::create(builder, loc, entity)};
+ // Optional entities may be represented with !fir.box<i32/f32/...>.
+ // We need to take the data pointer before loading the scalar.
+ mlir::Value base = genVariableRawAddress(loc, builder, entity);
+ return Entity{fir::LoadOp::create(builder, loc, base)};
}
return entity;
}
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index bfa470d..e1c9520 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -25,6 +25,7 @@
#include "flang/Optimizer/Builder/Runtime/Allocatable.h"
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
#include "flang/Optimizer/Builder/Runtime/Character.h"
+#include "flang/Optimizer/Builder/Runtime/Coarray.h"
#include "flang/Optimizer/Builder/Runtime/Command.h"
#include "flang/Optimizer/Builder/Runtime/Derived.h"
#include "flang/Optimizer/Builder/Runtime/Exceptions.h"
@@ -137,7 +138,7 @@ static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
/// Table that drives the fir generation depending on the intrinsic or intrinsic
/// module procedure one to one mapping with Fortran arguments. If no mapping is
/// defined here for a generic intrinsic, genRuntimeCall will be called
-/// to look for a match in the runtime a emit a call. Note that the argument
+/// to look for a match in the runtime and emit a call. Note that the argument
/// lowering rules for an intrinsic need to be provided only if at least one
/// argument must not be lowered by value. In which case, the lowering rules
/// should be provided for all the intrinsic arguments for completeness.
@@ -778,6 +779,10 @@ static constexpr IntrinsicHandler handlers[]{
/*isElemental=*/false},
{"not", &I::genNot},
{"null", &I::genNull, {{{"mold", asInquired}}}, /*isElemental=*/false},
+ {"num_images",
+ &I::genNumImages,
+ {{{"team", asAddr}, {"team_number", asAddr}}},
+ /*isElemental*/ false},
{"pack",
&I::genPack,
{{{"array", asBox},
@@ -864,6 +869,10 @@ static constexpr IntrinsicHandler handlers[]{
{"back", asValue, handleDynamicOptional},
{"kind", asValue}}},
/*isElemental=*/true},
+ {"secnds",
+ &I::genSecnds,
+ {{{"refTime", asAddr}}},
+ /*isElemental=*/false},
{"second",
&I::genSecond,
{{{"time", asAddr}}},
@@ -947,6 +956,12 @@ static constexpr IntrinsicHandler handlers[]{
{"tand", &I::genTand},
{"tanpi", &I::genTanpi},
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
+ {"this_image",
+ &I::genThisImage,
+ {{{"coarray", asBox},
+ {"dim", asAddr},
+ {"team", asBox, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false},
{"this_warp", &I::genThisWarp, {}, /*isElemental=*/false},
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
@@ -1047,7 +1062,7 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::StringRef suffix, mlir::FunctionType funcType) {
std::string output = prefix.str();
llvm::raw_string_ostream sstream(output);
- if (name == "pow") {
+ if (name == "pow" || name == "pow-unsigned") {
assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
std::string displayName{" ** "};
sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
@@ -1276,6 +1291,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
+mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
+ const MathOperation &mathOp,
+ mlir::FunctionType mathLibFuncType,
+ llvm::ArrayRef<mlir::Value> args) {
+ bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+ if (!isAMDGPU)
+ return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+ auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ mlir::Value complexExp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result =
+ builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+ return result;
+}
+
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1625,17 +1660,29 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+ genComplexPow},
{"pow", RTNAME_STRING(zpowi),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+ genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+ genComplexPow},
{"pow", RTNAME_STRING(zpowk),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+ genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
+ {"pow-unsigned", RTNAME_STRING(UPow1),
+ genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
+ {"pow-unsigned", RTNAME_STRING(UPow2),
+ genFuncType<Ty::Integer<2>, Ty::Integer<2>, Ty::Integer<2>>, genLibCall},
+ {"pow-unsigned", RTNAME_STRING(UPow4),
+ genFuncType<Ty::Integer<4>, Ty::Integer<4>, Ty::Integer<4>>, genLibCall},
+ {"pow-unsigned", RTNAME_STRING(UPow8),
+ genFuncType<Ty::Integer<8>, Ty::Integer<8>, Ty::Integer<8>>, genLibCall},
{"remainder", "remainderf",
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Real<4>>, genLibCall},
{"remainder", "remainder",
@@ -2672,10 +2719,11 @@ mlir::Value IntrinsicLibrary::genAcosd(mlir::Type resultType,
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
mlir::Value result =
getRuntimeCallGenerator("acos", ftype)(builder, loc, {args[0]});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi);
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, llvm::APFloat(fltSem, "180.0") / pi);
return mlir::arith::MulFOp::create(builder, loc, result, factor);
}
@@ -2687,10 +2735,10 @@ mlir::Value IntrinsicLibrary::genAcospi(mlir::Type resultType,
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
mlir::Value acos = getRuntimeCallGenerator("acos", ftype)(builder, loc, args);
- llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi);
- mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
+ llvm::APFloat inv_pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::inv_pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi);
return mlir::arith::MulFOp::create(builder, loc, acos, factor);
}
@@ -2840,10 +2888,11 @@ mlir::Value IntrinsicLibrary::genAsind(mlir::Type resultType,
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
mlir::Value result =
getRuntimeCallGenerator("asin", ftype)(builder, loc, {args[0]});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi);
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, llvm::APFloat(fltSem, "180.0") / pi);
return mlir::arith::MulFOp::create(builder, loc, result, factor);
}
@@ -2855,10 +2904,10 @@ mlir::Value IntrinsicLibrary::genAsinpi(mlir::Type resultType,
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
mlir::Value asin = getRuntimeCallGenerator("asin", ftype)(builder, loc, args);
- llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi);
- mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
+ llvm::APFloat inv_pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::inv_pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi);
return mlir::arith::MulFOp::create(builder, loc, asin, factor);
}
@@ -2880,10 +2929,11 @@ mlir::Value IntrinsicLibrary::genAtand(mlir::Type resultType,
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
atan = getRuntimeCallGenerator("atan", ftype)(builder, loc, args);
}
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi);
- mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, llvm::APFloat(fltSem, "180.0") / pi);
return mlir::arith::MulFOp::create(builder, loc, atan, factor);
}
@@ -2905,10 +2955,10 @@ mlir::Value IntrinsicLibrary::genAtanpi(mlir::Type resultType,
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
atan = getRuntimeCallGenerator("atan", ftype)(builder, loc, args);
}
- llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi);
- mlir::Value factor = builder.createConvert(loc, resultType, dfactor);
+ llvm::APFloat inv_pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::inv_pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi);
return mlir::arith::MulFOp::create(builder, loc, atan, factor);
}
@@ -3669,10 +3719,11 @@ mlir::Value IntrinsicLibrary::genCosd(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, pi / llvm::APFloat(fltSem, "180.0"));
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg});
}
@@ -3684,10 +3735,10 @@ mlir::Value IntrinsicLibrary::genCospi(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), pi);
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ llvm::APFloat pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, pi);
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg});
}
@@ -4031,21 +4082,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
mlir::Value waitAddr = fir::getBase(wait);
mlir::Value waitIsPresentAtRuntime =
builder.genIsNotNullAddr(loc, waitAddr);
- waitBool = builder
- .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
- /*withElseRegion=*/true)
- .genThen([&]() {
- auto waitLoad =
- fir::LoadOp::create(builder, loc, waitAddr);
- mlir::Value cast =
- builder.createConvert(loc, i1Ty, waitLoad);
- fir::ResultOp::create(builder, loc, cast);
- })
- .genElse([&]() {
- mlir::Value trueVal = builder.createBool(loc, true);
- fir::ResultOp::create(builder, loc, trueVal);
- })
- .getResults()[0];
+ waitBool =
+ builder
+ .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
+ mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
+ fir::ResultOp::create(builder, loc, cast);
+ })
+ .genElse([&]() {
+ mlir::Value trueVal = builder.createBool(loc, true);
+ fir::ResultOp::create(builder, loc, trueVal);
+ })
+ .getResults()[0];
}
mlir::Value exitstatBox =
@@ -7277,6 +7327,19 @@ IntrinsicLibrary::genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) {
return fir::MutableBoxValue(boxStorage, mold->nonDeferredLenParams(), {});
}
+// NUM_IMAGES
+fir::ExtendedValue
+IntrinsicLibrary::genNumImages(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ checkCoarrayEnabled();
+ assert(args.size() == 0 || args.size() == 1);
+
+ if (args.size())
+ return fir::runtime::getNumImagesWithTeam(builder, loc,
+ fir::getBase(args[0]));
+ return fir::runtime::getNumImages(builder, loc);
+}
+
// CLOCK, CLOCK64, GLOBALTIMER
template <typename OpTy>
mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType,
@@ -7813,6 +7876,22 @@ IntrinsicLibrary::genScan(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "SCAN");
}
+// SECNDS
+fir::ExtendedValue
+IntrinsicLibrary::genSecnds(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1 && "SECNDS expects one argument");
+
+ mlir::Value refTime = fir::getBase(args[0]);
+
+ if (!refTime)
+ fir::emitFatalError(loc, "expected REFERENCE TIME parameter");
+
+ mlir::Value result = fir::runtime::genSecnds(builder, loc, refTime);
+
+ return builder.createConvert(loc, resultType, result);
+}
+
// SECOND
fir::ExtendedValue
IntrinsicLibrary::genSecond(std::optional<mlir::Type> resultType,
@@ -8121,10 +8200,11 @@ mlir::Value IntrinsicLibrary::genSind(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, pi / llvm::APFloat(fltSem, "180.0"));
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg});
}
@@ -8136,10 +8216,10 @@ mlir::Value IntrinsicLibrary::genSinpi(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), pi);
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ llvm::APFloat pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, pi);
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg});
}
@@ -8218,10 +8298,11 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor = builder.createRealConstant(
- loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0));
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ const llvm::fltSemantics &fltSem =
+ llvm::cast<mlir::FloatType>(resultType).getFloatSemantics();
+ llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(
+ loc, resultType, pi / llvm::APFloat(fltSem, "180.0"));
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
}
@@ -8233,10 +8314,10 @@ mlir::Value IntrinsicLibrary::genTanpi(mlir::Type resultType,
mlir::MLIRContext *context = builder.getContext();
mlir::FunctionType ftype =
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi);
- mlir::Value dfactor =
- builder.createRealConstant(loc, mlir::Float64Type::get(context), pi);
- mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor);
+ llvm::APFloat pi =
+ llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(),
+ llvm::numbers::pis);
+ mlir::Value factor = builder.createRealConstant(loc, resultType, pi);
mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor);
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
}
@@ -8327,6 +8408,27 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
return res;
}
+// THIS_IMAGE
+fir::ExtendedValue
+IntrinsicLibrary::genThisImage(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ checkCoarrayEnabled();
+ assert(args.size() >= 1 && args.size() <= 3);
+ const bool coarrayIsAbsent = args.size() == 1;
+ mlir::Value team =
+ !isStaticallyAbsent(args, args.size() - 1)
+ ? fir::getBase(args[args.size() - 1])
+ : builder
+ .create<fir::AbsentOp>(loc,
+ fir::BoxType::get(builder.getNoneType()))
+ .getResult();
+
+ if (!coarrayIsAbsent)
+ TODO(loc, "this_image with coarray argument.");
+ mlir::Value res = fir::runtime::getThisImage(builder, loc, team);
+ return builder.createConvert(loc, resultType, res);
+}
+
// THIS_THREAD_BLOCK
mlir::Value
IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType,
@@ -9347,6 +9449,14 @@ mlir::Value genPow(fir::FirOpBuilder &builder, mlir::Location loc,
// implementation and mark it 'strictfp'.
// Another option is to implement it in Fortran runtime library
// (just like matmul).
+ if (type.isUnsignedInteger()) {
+ assert(x.getType().isUnsignedInteger() && y.getType().isUnsignedInteger() &&
+ "unsigned pow requires unsigned arguments");
+ return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow-unsigned", type,
+ {x, y});
+ }
+ assert(!x.getType().isUnsignedInteger() && !y.getType().isUnsignedInteger() &&
+ "non-unsigned pow requires non-unsigned arguments");
return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow", type, {x, y});
}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp
new file mode 100644
index 0000000..fb72fc2
--- /dev/null
+++ b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp
@@ -0,0 +1,86 @@
+//===-- Coarray.cpp -- runtime API for coarray intrinsics -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/Runtime/Coarray.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+using namespace Fortran::runtime;
+using namespace Fortran::semantics;
+
+/// Generate Call to runtime prif_init
+mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder,
+ mlir::Location loc) {
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Value result = builder.createTemporary(loc, i32Ty);
+ mlir::FunctionType ftype = PRIF_FUNCTYPE(builder.getRefType(i32Ty));
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, PRIFNAME_SUB("init"), ftype);
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, result);
+ builder.create<fir::CallOp>(loc, funcOp, args);
+ return builder.create<fir::LoadOp>(loc, result);
+}
+
+/// Generate Call to runtime prif_num_images
+mlir::Value fir::runtime::getNumImages(fir::FirOpBuilder &builder,
+ mlir::Location loc) {
+ mlir::Value result = builder.createTemporary(loc, builder.getI32Type());
+ mlir::FunctionType ftype =
+ PRIF_FUNCTYPE(builder.getRefType(builder.getI32Type()));
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, PRIFNAME_SUB("num_images"), ftype);
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, result);
+ builder.create<fir::CallOp>(loc, funcOp, args);
+ return builder.create<fir::LoadOp>(loc, result);
+}
+
+/// Generate Call to runtime prif_num_images_with_{team|team_number}
+mlir::Value fir::runtime::getNumImagesWithTeam(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value team) {
+ bool isTeamNumber = fir::unwrapPassByRefType(team.getType()).isInteger();
+ std::string numImagesName = isTeamNumber
+ ? PRIFNAME_SUB("num_images_with_team_number")
+ : PRIFNAME_SUB("num_images_with_team");
+
+ mlir::Value result = builder.createTemporary(loc, builder.getI32Type());
+ mlir::Type refTy = builder.getRefType(builder.getI32Type());
+ mlir::FunctionType ftype =
+ isTeamNumber
+ ? PRIF_FUNCTYPE(builder.getRefType(builder.getI64Type()), refTy)
+ : PRIF_FUNCTYPE(fir::BoxType::get(builder.getNoneType()), refTy);
+ mlir::func::FuncOp funcOp = builder.createFunction(loc, numImagesName, ftype);
+
+ if (!isTeamNumber)
+ team = builder.createBox(loc, team);
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, team, result);
+ builder.create<fir::CallOp>(loc, funcOp, args);
+ return builder.create<fir::LoadOp>(loc, result);
+}
+
+/// Generate Call to runtime prif_this_image_no_coarray
+mlir::Value fir::runtime::getThisImage(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value team) {
+ mlir::Type refTy = builder.getRefType(builder.getI32Type());
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::FunctionType ftype = PRIF_FUNCTYPE(boxTy, refTy);
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, PRIFNAME_SUB("this_image_no_coarray"), ftype);
+
+ mlir::Value result = builder.createTemporary(loc, builder.getI32Type());
+ mlir::Value teamArg =
+ !team ? builder.create<fir::AbsentOp>(loc, boxTy) : team;
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, teamArg, result);
+ builder.create<fir::CallOp>(loc, funcOp, args);
+ return builder.create<fir::LoadOp>(loc, result);
+}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
index ee15157..dc61903 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
@@ -276,6 +276,23 @@ void fir::runtime::genRename(fir::FirOpBuilder &builder, mlir::Location loc,
fir::CallOp::create(builder, loc, runtimeFunc, args);
}
+mlir::Value fir::runtime::genSecnds(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value refTime) {
+ auto runtimeFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(Secnds)>(loc, builder);
+
+ mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType();
+
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, runtimeFuncTy.getInput(2));
+
+ llvm::SmallVector<mlir::Value> args = {refTime, sourceFile, sourceLine};
+ args = fir::runtime::createArguments(builder, loc, runtimeFuncTy, args);
+
+ return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0);
+}
+
/// generate runtime call to time intrinsic
mlir::Value fir::runtime::genTime(fir::FirOpBuilder &builder,
mlir::Location loc) {
diff --git a/flang/lib/Optimizer/Builder/Runtime/Main.cpp b/flang/lib/Optimizer/Builder/Runtime/Main.cpp
index d35f687..d303e0a 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Main.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Main.cpp
@@ -10,6 +10,7 @@
#include "flang/Lower/EnvironmentDefault.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/Runtime/Coarray.h"
#include "flang/Optimizer/Builder/Runtime/EnvironmentDefaults.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Dialect/FIROps.h"
@@ -23,8 +24,8 @@ using namespace Fortran::runtime;
/// Create a `int main(...)` that calls the Fortran entry point
void fir::runtime::genMain(
fir::FirOpBuilder &builder, mlir::Location loc,
- const std::vector<Fortran::lower::EnvironmentDefault> &defs,
- bool initCuda) {
+ const std::vector<Fortran::lower::EnvironmentDefault> &defs, bool initCuda,
+ bool initCoarrayEnv) {
auto *context = builder.getContext();
auto argcTy = builder.getDefaultIntegerType();
auto ptrTy = mlir::LLVM::LLVMPointerType::get(context);
@@ -69,6 +70,8 @@ void fir::runtime::genMain(
loc, RTNAME_STRING(CUFInit), mlir::FunctionType::get(context, {}, {}));
fir::CallOp::create(builder, loc, initFn);
}
+ if (initCoarrayEnv)
+ fir::runtime::genInitCoarray(builder, loc);
fir::CallOp::create(builder, loc, qqMainFn);
fir::CallOp::create(builder, loc, stopFn);
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1b289ae..76f3cbd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -87,14 +87,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
return mlir::IntegerType::get(context, 8);
}
-static mlir::LLVM::ConstantOp
-genConstantIndex(mlir::Location loc, mlir::Type ity,
- mlir::ConversionPatternRewriter &rewriter,
- std::int64_t offset) {
- auto cattr = rewriter.getI64IntegerAttr(offset);
- return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr);
-}
-
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
mlir::Block *insertBefore) {
assert(insertBefore && "expected valid insertion block");
@@ -208,39 +200,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
TODO(op.getLoc(), "did not find allocation function");
}
-// Compute the alloc scale size (constant factors encoded in the array type).
-// We do this for arrays without a constant interior or arrays of character with
-// dynamic length arrays, since those are the only ones that get decayed to a
-// pointer to the element type.
-template <typename OP>
-static mlir::Value
-genAllocationScaleSize(OP op, mlir::Type ity,
- mlir::ConversionPatternRewriter &rewriter) {
- mlir::Location loc = op.getLoc();
- mlir::Type dataTy = op.getInType();
- auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
- fir::SequenceType::Extent constSize = 1;
- if (seqTy) {
- int constRows = seqTy.getConstantRows();
- const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
- if (constRows != static_cast<int>(shape.size())) {
- for (auto extent : shape) {
- if (constRows-- > 0)
- continue;
- if (extent != fir::SequenceType::getUnknownExtent())
- constSize *= extent;
- }
- }
- }
-
- if (constSize != 1) {
- mlir::Value constVal{
- genConstantIndex(loc, ity, rewriter, constSize).getResult()};
- return constVal;
- }
- return nullptr;
-}
-
namespace {
struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> {
public:
@@ -275,7 +234,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
auto loc = alloc.getLoc();
mlir::Type ity = lowerTy().indexType();
unsigned i = 0;
- mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
+ mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult();
mlir::Type firObjType = fir::unwrapRefType(alloc.getType());
mlir::Type llvmObjectType = convertObjectType(firObjType);
if (alloc.hasLenParams()) {
@@ -307,7 +266,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
<< scalarType << " with type parameters";
}
}
- if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
+ if (auto scaleSize = fir::genAllocationScaleSize(
+ alloc.getLoc(), alloc.getInType(), ity, rewriter))
size =
rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
if (alloc.hasShapeOperands()) {
@@ -484,7 +444,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> {
auto loc = boxisarray.getLoc();
TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType());
mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter);
- mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0);
+ mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
return mlir::success();
@@ -820,7 +780,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// Do folding for constant inputs.
if (auto constVal = fir::getIntIfConstant(op0)) {
mlir::Value normVal =
- genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
+ fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
rewriter.replaceOp(convert, normVal);
return mlir::success();
}
@@ -833,7 +793,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
}
// Compare the input with zero.
- mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
+ mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0);
auto isTrue = mlir::LLVM::ICmpOp::create(
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
@@ -1082,21 +1042,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
return getMallocInModule(mod, op, rewriter, indexType);
}
-/// Helper function for generating the LLVM IR that computes the distance
-/// in bytes between adjacent elements pointed to by a pointer
-/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
-/// type.
-static mlir::Value
-computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
- mlir::Type idxTy,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::DataLayout &dataLayout) {
- llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
- unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
- std::int64_t distance = llvm::alignTo(size, alignment);
- return genConstantIndex(loc, idxTy, rewriter, distance);
-}
-
/// Return value of the stride in bytes between adjacent elements
/// of LLVM type \p llTy. The result is returned as a value of
/// \p idxTy integer type.
@@ -1105,7 +1050,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
const mlir::DataLayout &dataLayout) {
// Create a pointer type and use computeElementDistance().
- return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
+ return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
}
namespace {
@@ -1124,8 +1069,9 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
- if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
- size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
+ if (auto scaleSize =
+ fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands())
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size,
integerCast(loc, rewriter, ity, opnd));
@@ -1133,8 +1079,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
// As the return value of malloc(0) is implementation defined, allocate one
// byte to ensure the allocation status being true. This behavior aligns to
// what the runtime has.
- mlir::Value zero = genConstantIndex(loc, ity, rewriter, 0);
- mlir::Value one = genConstantIndex(loc, ity, rewriter, 1);
+ mlir::Value zero = fir::genConstantIndex(loc, ity, rewriter, 0);
+ mlir::Value one = fir::genConstantIndex(loc, ity, rewriter, 1);
mlir::Value cmp = mlir::LLVM::ICmpOp::create(
rewriter, loc, mlir::LLVM::ICmpPredicate::sgt, size, zero);
size = mlir::LLVM::SelectOp::create(rewriter, loc, cmp, size, one);
@@ -1157,7 +1103,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) const {
- return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
+ return fir::computeElementDistance(loc, llTy, idxTy, rewriter,
+ getDataLayout());
}
};
} // namespace
@@ -1344,7 +1291,7 @@ genCUFAllocDescriptor(mlir::Location loc,
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
mlir::Value sizeInBytes =
- genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
+ fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
return mlir::LLVM::CallOp::create(rewriter, loc, fctTy,
RTNAME_STRING(CUFAllocDescriptor), args)
@@ -1599,7 +1546,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
// representation of derived types with pointer/allocatable components.
// This has been seen in hashing algorithms using TRANSFER.
mlir::Value zero =
- genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
+ fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
descriptor = insertField(rewriter, loc, descriptor,
{getLenParamFieldId(boxTy), 0}, zero);
}
@@ -1944,8 +1891,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
bool hasSlice = !xbox.getSlice().empty();
unsigned sliceOffset = xbox.getSliceOperandIndex();
mlir::Location loc = xbox.getLoc();
- mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
- mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
+ mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0);
+ mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1);
mlir::Value prevPtrOff = one;
mlir::Type eleTy = boxTy.getEleTy();
const unsigned rank = xbox.getRank();
@@ -1994,7 +1941,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
prevDimByteStride =
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
} else {
- prevDimByteStride = genConstantIndex(
+ prevDimByteStride = fir::genConstantIndex(
loc, i64Ty, rewriter,
charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
}
@@ -2152,7 +2099,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) {
if (charTy.hasConstantLen()) {
mlir::Value len =
- genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
+ fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
lenParams.emplace_back(len);
} else {
mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair,
@@ -2161,7 +2108,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
assert(!isInGlobalOp(rewriter) &&
"character target in global op must have constant length");
mlir::Value width =
- genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
+ fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width);
}
lenParams.emplace_back(len);
@@ -2215,8 +2162,9 @@ private:
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
mlir::Value zero =
- genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
- mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
+ fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
+ mlir::Value one =
+ fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
mlir::Value extent = std::get<0>(iter.value());
unsigned dim = iter.index();
@@ -2249,7 +2197,7 @@ private:
mlir::Location loc = rebox.getLoc();
mlir::Type byteTy = ::getI8Type(rebox.getContext());
mlir::Type idxTy = lowerTy().indexType();
- mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0);
// Apply subcomponent and substring shift on base address.
if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
// Cast to inputEleTy* so that a GEP can be used.
@@ -2277,7 +2225,7 @@ private:
// and strides.
llvm::SmallVector<mlir::Value> slicedExtents;
llvm::SmallVector<mlir::Value> slicedStrides;
- mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
+ mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
const bool sliceHasOrigins = !rebox.getShift().empty();
unsigned sliceOps = rebox.getSliceOperandIndex();
unsigned shiftOps = rebox.getShiftOperandIndex();
@@ -2350,7 +2298,7 @@ private:
// which may be OK if all new extents are ones, the stride does not
// matter, use one.
mlir::Value stride = inputStrides.empty()
- ? genConstantIndex(loc, idxTy, rewriter, 1)
+ ? fir::genConstantIndex(loc, idxTy, rewriter, 1)
: inputStrides[0];
for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i];
@@ -2585,9 +2533,9 @@ struct XArrayCoorOpConversion
unsigned shiftOffset = coor.getShiftOperandIndex();
unsigned sliceOffset = coor.getSliceOperandIndex();
auto sliceOps = coor.getSlice().begin();
- mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
+ mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
mlir::Value prevExt = one;
- mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0);
const bool isShifted = !coor.getShift().empty();
const bool isSliced = !coor.getSlice().empty();
const bool baseIsBoxed =
@@ -2918,7 +2866,7 @@ private:
// of lower bound aspects. This both accounts for dynamically sized
// types and non contiguous arrays.
auto idxTy = lowerTy().indexType();
- mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
+ mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0);
unsigned arrayDim = arrTy.getDimension();
for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) {
mlir::Value stride =
@@ -3525,114 +3473,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};
-/// Helper function for converting select ops. This function converts the
-/// signature of the given block. If the new block signature is different from
-/// `expectedTypes`, returns "failure".
-static llvm::FailureOr<mlir::Block *>
-getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter,
- mlir::Operation *branchOp, mlir::Block *block,
- mlir::TypeRange expectedTypes) {
- assert(converter && "expected non-null type converter");
- assert(!block->isEntryBlock() && "entry blocks have no predecessors");
-
- // There is nothing to do if the types already match.
- if (block->getArgumentTypes() == expectedTypes)
- return block;
-
- // Compute the new block argument types and convert the block.
- std::optional<mlir::TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion)
- return rewriter.notifyMatchFailure(branchOp,
- "could not compute block signature");
- if (expectedTypes != conversion->getConvertedTypes())
- return rewriter.notifyMatchFailure(
- branchOp,
- "mismatch between adaptor operand types and computed block signature");
- return rewriter.applySignatureConversion(block, *conversion, converter);
-}
-
+/// Base class for SelectOpConversion and SelectRankOpConversion.
template <typename OP>
-static llvm::LogicalResult
-selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
- typename OP::Adaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter,
- const mlir::TypeConverter *converter) {
- unsigned conds = select.getNumConditions();
- auto cases = select.getCases().getValue();
- mlir::Value selector = adaptor.getSelector();
- auto loc = select.getLoc();
- assert(conds > 0 && "select must have cases");
-
- llvm::SmallVector<mlir::Block *> destinations;
- llvm::SmallVector<mlir::ValueRange> destinationsOperands;
- mlir::Block *defaultDestination;
- mlir::ValueRange defaultOperands;
- llvm::SmallVector<int32_t> caseValues;
-
- for (unsigned t = 0; t != conds; ++t) {
- mlir::Block *dest = select.getSuccessor(t);
- auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
- const mlir::Attribute &attr = cases[t];
- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
- destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
- auto convertedBlock =
- getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(destinationsOperands.back()));
+struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
+ using fir::FIROpConversion<OP>::FIROpConversion;
+
+private:
+ /// Helper function for converting select ops. This function converts the
+ /// signature of the given block. If the new block signature is different from
+ /// `expectedTypes`, returns "failure".
+ llvm::FailureOr<mlir::Block *>
+ getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
+ mlir::Operation *branchOp, mlir::Block *block,
+ mlir::TypeRange expectedTypes) const {
+ const mlir::TypeConverter *converter = this->getTypeConverter();
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<mlir::TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(branchOp,
+ "mismatch between adaptor operand "
+ "types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+ }
+
+protected:
+ llvm::LogicalResult
+ selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ unsigned conds = select.getNumConditions();
+ auto cases = select.getCases().getValue();
+ mlir::Value selector = adaptor.getSelector();
+ auto loc = select.getLoc();
+ assert(conds > 0 && "select must have cases");
+
+ llvm::SmallVector<mlir::Block *> destinations;
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+ mlir::Block *defaultDestination;
+ mlir::ValueRange defaultOperands;
+ // LLVM::SwitchOp selector type and the case values types
+ // must have the same bit width, so cast the selector to i64,
+ // and use i64 for the case values. It is hard to imagine
+ // a computed GO TO with the number of labels in the label-list
+ // bigger than INT_MAX, but let's use i64 to be on the safe side.
+ // Moreover, fir.select operation is more relaxed than
+ // a Fortran computed GO TO, so it may specify such a case value
+ // even if there is just a single label/case.
+ llvm::SmallVector<int64_t> caseValues;
+
+ for (unsigned t = 0; t != conds; ++t) {
+ mlir::Block *dest = select.getSuccessor(t);
+ auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+ const mlir::Attribute &attr = cases[t];
+ if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
+ destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
+ auto convertedBlock =
+ getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(destinationsOperands.back()));
+ if (mlir::failed(convertedBlock))
+ return mlir::failure();
+ destinations.push_back(*convertedBlock);
+ caseValues.push_back(intAttr.getInt());
+ continue;
+ }
+ assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
+ assert((t + 1 == conds) && "unit must be last");
+ defaultOperands = destOps ? *destOps : mlir::ValueRange{};
+ auto convertedBlock = getConvertedBlock(rewriter, select, dest,
+ mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
- destinations.push_back(*convertedBlock);
- caseValues.push_back(intAttr.getInt());
- continue;
+ defaultDestination = *convertedBlock;
}
- assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
- assert((t + 1 == conds) && "unit must be last");
- defaultOperands = destOps ? *destOps : mlir::ValueRange{};
- auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
- mlir::TypeRange(defaultOperands));
- if (mlir::failed(convertedBlock))
- return mlir::failure();
- defaultDestination = *convertedBlock;
- }
-
- // LLVM::SwitchOp takes a i32 type for the selector.
- if (select.getSelector().getType() != rewriter.getI32Type())
- selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
- selector);
-
- rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
- select, selector,
- /*defaultDestination=*/defaultDestination,
- /*defaultOperands=*/defaultOperands,
- /*caseValues=*/caseValues,
- /*caseDestinations=*/destinations,
- /*caseOperands=*/destinationsOperands,
- /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
- return mlir::success();
-}
+ selector =
+ this->integerCast(loc, rewriter, rewriter.getI64Type(), selector);
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+ select, selector,
+ /*defaultDestination=*/defaultDestination,
+ /*defaultOperands=*/defaultOperands,
+ /*caseValues=*/rewriter.getI64VectorAttr(caseValues),
+ /*caseDestinations=*/destinations,
+ /*caseOperands=*/destinationsOperands,
+ /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
+ return mlir::success();
+ }
+};
/// conversion of fir::SelectOp to an if-then-else ladder
-struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
- rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
-struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
- using FIROpConversion::FIROpConversion;
+struct SelectRankOpConversion
+ : public SelectOpConversionBase<fir::SelectRankOp> {
+ using SelectOpConversionBase::SelectOpConversionBase;
llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- return selectMatchAndRewrite<fir::SelectRankOp>(
- lowerTy(), op, adaptor, rewriter, getTypeConverter());
+ return this->selectMatchAndRewrite(op, adaptor, rewriter);
}
};
@@ -3837,7 +3794,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> {
ptr = mlir::LLVM::ExtractValueOp::create(rewriter, loc, ptr, 0);
}
mlir::LLVM::ConstantOp c0 =
- genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
+ fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 37f1c9f..97912bd 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -21,6 +21,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -125,10 +126,58 @@ struct PrivateClauseOpConversion
return mlir::success();
}
};
+
+// Convert FIR type to LLVM without turning fir.box<T> into memory
+// reference.
+static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
+ mlir::Type firType) {
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
+ return converter.convertBoxTypeAsStruct(boxTy);
+ return converter.convertType(firType);
+}
+
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+ : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Type heapTy = allocmemOp.getAllocatedType();
+ mlir::Location loc = allocmemOp.getLoc();
+ auto ity = lowerTy().indexType();
+ mlir::Type dataTy = fir::unwrapRefType(heapTy);
+ mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+ if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
+ TODO(loc, "omp.target_allocmem codegen of derived type with length "
+ "parameters");
+ mlir::Value size = fir::computeElementDistance(
+ loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+ if (auto scaleSize = fir::genAllocationScaleSize(
+ loc, allocmemOp.getInType(), ity, rewriter))
+ size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ for (mlir::Value opnd : adaptor.getOperands().drop_front())
+ size = rewriter.create<mlir::LLVM::MulOp>(
+ loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
+ auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+ auto mallocTy =
+ mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth())
+ size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+ rewriter.modifyOpInPlace(allocmemOp, [&]() {
+ allocmemOp.setInType(rewriter.getI8Type());
+ allocmemOp.getTypeparamsMutable().clear();
+ allocmemOp.getTypeparamsMutable().append(size);
+ });
+ return mlir::success();
+ }
+};
} // namespace
void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
+ patterns.add<TargetAllocMemOpConversion>(converter);
}
diff --git a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
index 52c733d..bd0499f 100644
--- a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFEnumAttr.cpp.inc"
@@ -29,4 +30,26 @@ void CUFDialect::registerAttributes() {
LaunchBoundsAttr, ProcAttributeAttr>();
}
+cuf::DataAttributeAttr getDataAttr(mlir::Operation *op) {
+ if (!op)
+ return {};
+
+ if (auto dataAttr =
+ op->getAttrOfType<cuf::DataAttributeAttr>(cuf::getDataAttrName()))
+ return dataAttr;
+
+ // When the attribute is declared on the operation, it doesn't have a prefix.
+ if (auto dataAttr =
+ op->getAttrOfType<cuf::DataAttributeAttr>(cuf::dataAttrName))
+ return dataAttr;
+
+ return {};
+}
+
+bool hasDataAttr(mlir::Operation *op, cuf::DataAttribute value) {
+ if (auto dataAttr = getDataAttr(op))
+ return dataAttr.getValue() == value;
+ return false;
+}
+
} // namespace cuf
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 01975f3..87f9899 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) {
}
/// Parser shared by Alloca and Allocmem
-///
/// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type
/// ( `(` $typeparams `)` )? ( `,` $shape )?
/// attr-dict-without-keyword
diff --git a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
index 034f8c7..f16072a 100644
--- a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
+++ b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
@@ -68,3 +68,31 @@ fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) {
}
return mlir::success();
}
+
+mlir::LogicalResult
+fir::detail::verifyFortranVariableStorageOpInterface(mlir::Operation *op) {
+ auto storageIface = mlir::cast<fir::FortranVariableStorageOpInterface>(op);
+ mlir::Value storage = storageIface.getStorage();
+ std::uint64_t storageOffset = storageIface.getStorageOffset();
+ if (!storage) {
+ if (storageOffset != 0)
+ return op->emitOpError(
+ "storage offset specified without the storage reference");
+ return mlir::success();
+ }
+
+ auto storageType =
+ mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(storage.getType()));
+ if (!storageType || storageType.getDimension() != 1)
+ return op->emitOpError("storage must be a vector");
+ if (storageType.hasDynamicExtents())
+ return op->emitOpError("storage must have known extent");
+ if (storageType.getEleTy() != mlir::IntegerType::get(op->getContext(), 8))
+ return op->emitOpError("storage must be an array of i8 elements");
+ if (storageOffset > storageType.getConstantArraySize())
+ return op->emitOpError("storage offset exceeds the storage size");
+ // TODO: we should probably verify that the (offset + sizeof(var))
+ // is within the storage object, but this requires mlir::DataLayout.
+ // Can we make it available during the verification?
+ return mlir::success();
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index ed102db..2971a72 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -279,7 +279,8 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
auto [hlfirVariableType, firVarType] =
getDeclareOutputTypes(inputType, hasExplicitLbs);
build(builder, result, {hlfirVariableType, firVarType}, memref, shape,
- typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr);
+ typeparams, dummy_scope, /*storage=*/nullptr, /*storage_offset=*/0,
+ nameAttr, fortran_attrs, data_attr);
}
llvm::LogicalResult hlfir::DeclareOp::verify() {
@@ -821,6 +822,62 @@ void hlfir::ConcatOp::getEffects(
}
//===----------------------------------------------------------------------===//
+// CmpCharOp
+//===----------------------------------------------------------------------===//
+
+llvm::LogicalResult hlfir::CmpCharOp::verify() {
+ mlir::Value lchr = getLchr();
+ mlir::Value rchr = getRchr();
+
+ unsigned kind = getCharacterKind(lchr.getType());
+ if (kind != getCharacterKind(rchr.getType()))
+ return emitOpError("character arguments must have the same KIND");
+
+ switch (getPredicate()) {
+ case mlir::arith::CmpIPredicate::slt:
+ case mlir::arith::CmpIPredicate::sle:
+ case mlir::arith::CmpIPredicate::eq:
+ case mlir::arith::CmpIPredicate::ne:
+ case mlir::arith::CmpIPredicate::sgt:
+ case mlir::arith::CmpIPredicate::sge:
+ break;
+ default:
+ return emitOpError("expected signed predicate");
+ }
+
+ return mlir::success();
+}
+
+void hlfir::CmpCharOp::getEffects(
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
+ &effects) {
+ getIntrinsicEffects(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
+// CharTrimOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::CharTrimOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result, mlir::Value chr) {
+ unsigned kind = getCharacterKind(chr.getType());
+ auto resultType = hlfir::ExprType::get(
+ builder.getContext(), hlfir::ExprType::Shape{},
+ fir::CharacterType::get(builder.getContext(), kind,
+ fir::CharacterType::unknownLen()),
+ /*polymorphic=*/false);
+ build(builder, result, resultType, chr);
+}
+
+void hlfir::CharTrimOp::getEffects(
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
+ &effects) {
+ getIntrinsicEffects(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
// NumericalReductionOp
//===----------------------------------------------------------------------===//
@@ -1440,44 +1497,46 @@ void hlfir::MatmulTransposeOp::getEffects(
}
//===----------------------------------------------------------------------===//
-// CShiftOp
+// Array shifts: CShiftOp/EOShiftOp
//===----------------------------------------------------------------------===//
-llvm::LogicalResult hlfir::CShiftOp::verify() {
- mlir::Value array = getArray();
+template <typename Op>
+static llvm::LogicalResult verifyArrayShift(Op op) {
+ mlir::Value array = op.getArray();
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(array.getType()));
llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
std::size_t arrayRank = inShape.size();
mlir::Type eleTy = arrayTy.getEleTy();
- hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType());
+ hlfir::ExprType resultTy =
+ mlir::cast<hlfir::ExprType>(op.getResult().getType());
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
std::size_t resultRank = resultShape.size();
mlir::Type resultEleTy = resultTy.getEleTy();
- mlir::Value shift = getShift();
+ mlir::Value shift = op.getShift();
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());
- // TODO: turn allowCharacterLenMismatch into true.
- if (auto match = areMatchingTypes(*this, eleTy, resultEleTy,
- /*allowCharacterLenMismatch=*/false);
+ if (auto match = areMatchingTypes(
+ op, eleTy, resultEleTy,
+ /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier);
match.failed())
- return emitOpError(
+ return op.emitOpError(
"input and output arrays should have the same element type");
if (arrayRank != resultRank)
- return emitOpError("input and output arrays should have the same rank");
+ return op.emitOpError("input and output arrays should have the same rank");
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape))
if (inDim != unknownExtent && resultDim != unknownExtent &&
inDim != resultDim)
- return emitOpError(
+ return op.emitOpError(
"output array's shape conflicts with the input array's shape");
int64_t dimVal = -1;
- if (!getDim())
+ if (!op.getDim())
dimVal = 1;
- else if (auto dim = fir::getIntIfConstant(getDim()))
+ else if (auto dim = fir::getIntIfConstant(op.getDim()))
dimVal = *dim;
// The DIM argument may be statically invalid (e.g. exceed the
@@ -1485,44 +1544,79 @@ llvm::LogicalResult hlfir::CShiftOp::verify() {
// so avoid some checks unless useStrictIntrinsicVerifier is true.
if (useStrictIntrinsicVerifier && dimVal != -1) {
if (dimVal < 1)
- return emitOpError("DIM must be >= 1");
+ return op.emitOpError("DIM must be >= 1");
if (dimVal > static_cast<int64_t>(arrayRank))
- return emitOpError("DIM must be <= input array's rank");
+ return op.emitOpError("DIM must be <= input array's rank");
}
- if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) {
- // SHIFT is an array. Verify the rank and the shape (if DIM is constant).
- llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape();
- std::size_t shiftRank = shiftShape.size();
- if (shiftRank != arrayRank - 1)
- return emitOpError(
- "SHIFT's rank must be 1 less than the input array's rank");
-
- if (useStrictIntrinsicVerifier && dimVal != -1) {
- // SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
- // where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
- int64_t arrayDimIdx = 0;
- int64_t shiftDimIdx = 0;
- for (auto shiftDim : shiftShape) {
- if (arrayDimIdx == dimVal - 1)
+ // A helper lambda to verify the shape of the array types of
+ // certain operands of the array shift (e.g. the SHIFT and BOUNDARY operands).
+ auto verifyOperandTypeShape = [&](mlir::Type type,
+ llvm::Twine name) -> llvm::LogicalResult {
+ if (auto opndSeqTy = mlir::dyn_cast<fir::SequenceType>(type)) {
+ // The operand is an array. Verify the rank and the shape (if DIM is
+ // constant).
+ llvm::ArrayRef<int64_t> opndShape = opndSeqTy.getShape();
+ std::size_t opndRank = opndShape.size();
+ if (opndRank != arrayRank - 1)
+ return op.emitOpError(
+ name + "'s rank must be 1 less than the input array's rank");
+
+ if (useStrictIntrinsicVerifier && dimVal != -1) {
+ // The operand's shape must be
+ // [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
+ // where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
+ int64_t arrayDimIdx = 0;
+ int64_t opndDimIdx = 0;
+ for (auto opndDim : opndShape) {
+ if (arrayDimIdx == dimVal - 1)
+ ++arrayDimIdx;
+
+ if (inShape[arrayDimIdx] != unknownExtent &&
+ opndDim != unknownExtent && inShape[arrayDimIdx] != opndDim)
+ return op.emitOpError("SHAPE(ARRAY)(" +
+ llvm::Twine(arrayDimIdx + 1) +
+ ") must be equal to SHAPE(" + name + ")(" +
+ llvm::Twine(opndDimIdx + 1) +
+ "): " + llvm::Twine(inShape[arrayDimIdx]) +
+ " != " + llvm::Twine(opndDim));
++arrayDimIdx;
-
- if (inShape[arrayDimIdx] != unknownExtent &&
- shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim)
- return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) +
- ") must be equal to SHAPE(SHIFT)(" +
- llvm::Twine(shiftDimIdx + 1) +
- "): " + llvm::Twine(inShape[arrayDimIdx]) +
- " != " + llvm::Twine(shiftDim));
- ++arrayDimIdx;
- ++shiftDimIdx;
+ ++opndDimIdx;
+ }
}
}
+ return mlir::success();
+ };
+
+ if (failed(verifyOperandTypeShape(shiftTy, "SHIFT")))
+ return mlir::failure();
+
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) {
+ if (mlir::Value boundary = op.getBoundary()) {
+ mlir::Type boundaryTy =
+ hlfir::getFortranElementOrSequenceType(boundary.getType());
+ if (auto match = areMatchingTypes(
+ op, eleTy, hlfir::getFortranElementType(boundaryTy),
+ /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier);
+ match.failed())
+ return op.emitOpError(
+ "ARRAY and BOUNDARY operands must have the same element type");
+ if (failed(verifyOperandTypeShape(boundaryTy, "BOUNDARY")))
+ return mlir::failure();
+ }
}
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// CShiftOp
+//===----------------------------------------------------------------------===//
+
+llvm::LogicalResult hlfir::CShiftOp::verify() {
+ return verifyArrayShift(*this);
+}
+
void hlfir::CShiftOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
@@ -1531,6 +1625,21 @@ void hlfir::CShiftOp::getEffects(
}
//===----------------------------------------------------------------------===//
+// EOShiftOp
+//===----------------------------------------------------------------------===//
+
+llvm::LogicalResult hlfir::EOShiftOp::verify() {
+ return verifyArrayShift(*this);
+}
+
+void hlfir::EOShiftOp::getEffects(
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
+ &effects) {
+ getIntrinsicEffects(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@@ -1543,7 +1652,8 @@ llvm::LogicalResult hlfir::ReshapeOp::verify() {
hlfir::getFortranElementOrSequenceType(array.getType()));
if (auto match = areMatchingTypes(
*this, hlfir::getFortranElementType(resultType),
- arrayType.getElementType(), /*allowCharacterLenMismatch=*/true);
+ arrayType.getElementType(),
+ /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier);
match.failed())
return emitOpError("ARRAY and the result must have the same element type");
if (hlfir::isPolymorphicType(resultType) !=
@@ -1565,9 +1675,9 @@ llvm::LogicalResult hlfir::ReshapeOp::verify() {
if (mlir::Value pad = getPad()) {
auto padArrayType = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(pad.getType()));
- if (auto match = areMatchingTypes(*this, arrayType.getElementType(),
- padArrayType.getElementType(),
- /*allowCharacterLenMismatch=*/true);
+ if (auto match = areMatchingTypes(
+ *this, arrayType.getElementType(), padArrayType.getElementType(),
+ /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier);
match.failed())
return emitOpError("ARRAY and PAD must be of the same type");
}
@@ -1847,8 +1957,7 @@ hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf,
// shape information is not available at compile time
return llvm::LogicalResult::failure();
- rewriter.replaceAllUsesWith(shapeOf.getResult(), shape);
- rewriter.eraseOp(shapeOf);
+ rewriter.replaceOp(shapeOf, shape);
return llvm::LogicalResult::success();
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 9109f2b..886a8a5 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -455,12 +455,8 @@ struct AssociateOpConversion
mlir::Type associateHlfirVarType = associate.getResultTypes()[0];
hlfirVar = adjustVar(hlfirVar, associateHlfirVarType);
- associate.getResult(0).replaceAllUsesWith(hlfirVar);
-
mlir::Type associateFirVarType = associate.getResultTypes()[1];
firVar = adjustVar(firVar, associateFirVarType);
- associate.getResult(1).replaceAllUsesWith(firVar);
- associate.getResult(2).replaceAllUsesWith(flag);
// FIXME: note that the AssociateOp that is being erased
// here will continue to be a user of the original Source
// operand (e.g. a result of hlfir.elemental), because
@@ -472,7 +468,7 @@ struct AssociateOpConversion
// the conversions, so that we can analyze HLFIR in its
// original form and decide which of the AssociateOp
// users of hlfir.expr can reuse the buffer (if it can).
- rewriter.eraseOp(associate);
+ rewriter.replaceOp(associate, {hlfirVar, firVar, flag});
};
// If this is the last use of the expression value and this is an hlfir.expr
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
index cc74273..3775a13 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
@@ -27,6 +27,8 @@ add_flang_library(HLFIRTransforms
FIRSupport
FIRTransforms
FlangOpenMPTransforms
+ FortranEvaluate
+ FortranSupport
HLFIRDialect
LINK_COMPONENTS
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
index 2e27324..8104e53 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
@@ -305,6 +305,8 @@ public:
auto firDeclareOp = fir::DeclareOp::create(
rewriter, loc, memref.getType(), memref, declareOp.getShape(),
declareOp.getTypeparams(), declareOp.getDummyScope(),
+ /*storage=*/declareOp.getStorage(),
+ /*storage_offset=*/declareOp.getStorageOffset(),
declareOp.getUniqName(), fortranAttrs, dataAttr);
// Propagate other attributes from hlfir.declare to fir.declare.
@@ -490,15 +492,18 @@ public:
}
baseEleTy = hlfir::getFortranElementType(componentType);
shape = designate.getComponentShape();
- } else {
- // array%component[(indices) substring|complex part] cases.
- // Component ref of array bases are dealt with below in embox/rebox.
- assert(mlir::isa<fir::BaseBoxType>(designateResultType));
}
}
- if (mlir::isa<fir::BaseBoxType>(designateResultType)) {
- // Generate embox or rebox.
+ if (mlir::isa<fir::BaseBoxType>(designateResultType) ||
+ // Convert the component array slices using embox/rebox
+ // even if the result is a contiguous array section, e.g.:
+ // hlfir.designate %base{"i"} shape %shape :
+ // (!fir.box<!fir.array<2x!fir.type<_QMtypesTt{i:i32}>>>,
+ // !fir.shape<1>) -> !fir.ref<!fir.array<2xi32>>
+ // fir.coordinate_of should probably be a better option, though.
+ (fieldIndex && baseEntity.isArray())) {
+ // Generate embox or rebox for slicing.
mlir::Type eleTy = fir::unwrapPassByRefType(designateResultType);
bool isScalarDesignator = !mlir::isa<fir::SequenceType>(eleTy);
mlir::Value sourceBox;
@@ -575,8 +580,13 @@ public:
else
assert(sliceFields.empty() && substring.empty());
- llvm::SmallVector<mlir::Type> resultType{
- fir::updateTypeWithVolatility(designateResultType, isVolatile)};
+ // If the designate's result type is not a box, then create
+ // a box type to be used for the result of the embox/rebox.
+ mlir::Type resultType = designateResultType;
+ if (!mlir::isa<fir::BaseBoxType>(resultType))
+ resultType = fir::wrapInClassOrBoxType(resultType);
+
+ resultType = fir::updateTypeWithVolatility(resultType, isVolatile);
mlir::Value resultBox;
if (mlir::isa<fir::BaseBoxType>(base.getType())) {
@@ -587,6 +597,13 @@ public:
fir::EmboxOp::create(builder, loc, resultType, base, shape, slice,
firBaseTypeParameters, sourceBox);
}
+
+ if (!mlir::isa<fir::BaseBoxType>(designateResultType)) {
+ // If the designate's result is not a box, use the raw address
+ // as the new result.
+ resultBox = fir::BoxAddrOp::create(rewriter, loc, resultBox);
+ resultBox = builder.createConvert(loc, designateResultType, resultBox);
+ }
rewriter.replaceOp(designate, resultBox);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index c42b895..ff84a3c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -101,9 +101,8 @@ public:
elemental.getLoc(), builder, elemental, apply.getIndices());
// remove the old elemental and all of the bookkeeping
- rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
+ rewriter.replaceOp(apply, {yield.getElementValue()});
rewriter.eraseOp(yield);
- rewriter.eraseOp(apply);
rewriter.eraseOp(destroy);
rewriter.eraseOp(elemental);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 3c29d68..a913cfa 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -469,33 +469,49 @@ struct MatmulTransposeOpConversion
}
};
-class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
- using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
+// A converter for hlfir.cshift and hlfir.eoshift.
+template <typename T>
+class ArrayShiftOpConversion : public HlfirIntrinsicConversion<T> {
+ using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion;
+ using HlfirIntrinsicConversion<T>::lowerArguments;
+ using HlfirIntrinsicConversion<T>::processReturnValue;
+ using typename HlfirIntrinsicConversion<T>::IntrinsicArgument;
llvm::LogicalResult
- matchAndRewrite(hlfir::CShiftOp cshift,
- mlir::PatternRewriter &rewriter) const override {
- fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
- const mlir::Location &loc = cshift->getLoc();
+ matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override {
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
+ const mlir::Location &loc = op->getLoc();
- llvm::SmallVector<IntrinsicArgument, 3> inArgs;
- mlir::Value array = cshift.getArray();
+ llvm::SmallVector<IntrinsicArgument, 4> inArgs;
+ llvm::StringRef intrinsicName{[]() {
+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>)
+ return "eoshift";
+ else if constexpr (std::is_same_v<T, hlfir::CShiftOp>)
+ return "cshift";
+ else
+ llvm_unreachable("unsupported array shift");
+ }()};
+
+ mlir::Value array = op.getArray();
inArgs.push_back({array, array.getType()});
- mlir::Value shift = cshift.getShift();
+ mlir::Value shift = op.getShift();
inArgs.push_back({shift, shift.getType()});
- inArgs.push_back({cshift.getDim(), builder.getI32Type()});
+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) {
+ mlir::Value boundary = op.getBoundary();
+ inArgs.push_back({boundary, boundary ? boundary.getType() : nullptr});
+ }
+ inArgs.push_back({op.getDim(), builder.getI32Type()});
- auto *argLowering = fir::getIntrinsicArgumentLowering("cshift");
+ auto *argLowering = fir::getIntrinsicArgumentLowering(intrinsicName);
llvm::SmallVector<fir::ExtendedValue, 3> args =
- lowerArguments(cshift, inArgs, rewriter, argLowering);
+ lowerArguments(op, inArgs, rewriter, argLowering);
- mlir::Type scalarResultType =
- hlfir::getFortranElementType(cshift.getType());
+ mlir::Type scalarResultType = hlfir::getFortranElementType(op.getType());
- auto [resultExv, mustBeFreed] =
- fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args);
+ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+ builder, loc, intrinsicName, scalarResultType, args);
- processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter);
+ processReturnValue(op, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
@@ -535,6 +551,68 @@ class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
}
};
+class CmpCharOpConversion : public HlfirIntrinsicConversion<hlfir::CmpCharOp> {
+ using HlfirIntrinsicConversion<hlfir::CmpCharOp>::HlfirIntrinsicConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::CmpCharOp cmp,
+ mlir::PatternRewriter &rewriter) const override {
+ fir::FirOpBuilder builder{rewriter, cmp.getOperation()};
+ const mlir::Location &loc = cmp->getLoc();
+ hlfir::Entity lhs{cmp.getLchr()};
+ hlfir::Entity rhs{cmp.getRchr()};
+
+ auto [lhsExv, lhsCleanUp] =
+ hlfir::translateToExtendedValue(loc, builder, lhs);
+ auto [rhsExv, rhsCleanUp] =
+ hlfir::translateToExtendedValue(loc, builder, rhs);
+
+ auto resultVal = fir::runtime::genCharCompare(
+ builder, loc, cmp.getPredicate(), lhsExv, rhsExv);
+ if (lhsCleanUp || rhsCleanUp) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(cmp);
+ if (lhsCleanUp)
+ (*lhsCleanUp)();
+ if (rhsCleanUp)
+ (*rhsCleanUp)();
+ }
+ auto resultEntity = hlfir::EntityWithAttributes{resultVal};
+
+ processReturnValue(cmp, resultEntity, /*mustBeFreed=*/false, builder,
+ rewriter);
+ return mlir::success();
+ }
+};
+
+class CharTrimOpConversion
+ : public HlfirIntrinsicConversion<hlfir::CharTrimOp> {
+ using HlfirIntrinsicConversion<hlfir::CharTrimOp>::HlfirIntrinsicConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::CharTrimOp trim,
+ mlir::PatternRewriter &rewriter) const override {
+ fir::FirOpBuilder builder{rewriter, trim.getOperation()};
+ const mlir::Location &loc = trim->getLoc();
+
+ llvm::SmallVector<IntrinsicArgument, 1> inArgs;
+ mlir::Value chr = trim.getChr();
+ inArgs.push_back({chr, chr.getType()});
+
+ auto *argLowering = fir::getIntrinsicArgumentLowering("trim");
+ llvm::SmallVector<fir::ExtendedValue, 1> args =
+ lowerArguments(trim, inArgs, rewriter, argLowering);
+
+ mlir::Type resultType = hlfir::getFortranElementType(trim.getType());
+
+ auto [resultExv, mustBeFreed] =
+ fir::genIntrinsicCall(builder, loc, "trim", resultType, args);
+
+ processReturnValue(trim, resultExv, mustBeFreed, builder, rewriter);
+ return mlir::success();
+ }
+};
+
class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
@@ -547,7 +625,9 @@ public:
AnyOpConversion, SumOpConversion, ProductOpConversion,
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
- MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
+ MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>,
+ ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion,
+ CmpCharOpConversion, CharTrimOpConversion>(context);
// While conceptually this pass is performing dialect conversion, we use
// pattern rewrites here instead of dialect conversion because this pass
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 8e25298..32998ab 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -96,7 +96,7 @@ struct MaskedArrayExpr {
/// hlfir.elemental_addr that form the elemental tree producing
/// the expression value. hlfir.elemental that produce values
/// used inside transformational operations are not part of this set.
- llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
+ llvm::SmallPtrSet<mlir::Operation *, 4> elementalParts{};
/// Was generateNoneElementalPart called?
bool noneElementalPartWasGenerated = false;
/// Is this expression the mask expression of the outer where statement?
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
index 722cd8a..a48b7ba 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
@@ -137,7 +137,7 @@ private:
// Schedule being built.
hlfir::Schedule schedule;
/// Leaf regions that have been saved so far.
- llvm::SmallSet<mlir::Region *, 16> savedRegions;
+ llvm::SmallPtrSet<mlir::Region *, 16> savedRegions;
/// Is schedule.back() a schedule that is only saving region with read
/// effects?
bool currentRunIsReadOnly = false;
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index b27c3a8..d8e36ea 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -10,6 +10,7 @@
// into the calling function.
//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
@@ -1269,64 +1270,91 @@ public:
}
};
-class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> {
+template <typename Op>
+class ArrayShiftConversion : public mlir::OpRewritePattern<Op> {
public:
- using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
+ // The implementation below only support CShiftOp and EOShiftOp.
+ static_assert(std::is_same_v<Op, hlfir::CShiftOp> ||
+ std::is_same_v<Op, hlfir::EOShiftOp>);
+
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
llvm::LogicalResult
- matchAndRewrite(hlfir::CShiftOp cshift,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
- hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(op.getType());
assert(expr &&
- "expected an expression type for the result of hlfir.cshift");
+ "expected an expression type for the result of the array shift");
unsigned arrayRank = expr.getRank();
- // When it is a 1D CSHIFT, we may assume that the DIM argument
+ // When it is a 1D CSHIFT/EOSHIFT, we may assume that the DIM argument
// (whether it is present or absent) is equal to 1, otherwise,
// the program is illegal.
int64_t dimVal = 1;
if (arrayRank != 1)
- if (mlir::Value dim = cshift.getDim()) {
+ if (mlir::Value dim = op.getDim()) {
auto constDim = fir::getIntIfConstant(dim);
if (!constDim)
- return rewriter.notifyMatchFailure(cshift,
- "Nonconstant DIM for CSHIFT");
+ return rewriter.notifyMatchFailure(
+ op, "Nonconstant DIM for CSHIFT/EOSHIFT");
dimVal = *constDim;
}
if (dimVal <= 0 || dimVal > arrayRank)
- return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
+ return rewriter.notifyMatchFailure(op, "Invalid DIM for CSHIFT/EOSHIFT");
+
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) {
+ // TODO: the EOSHIFT inlining code is not ready to produce
+ // fir.if selecting between ARRAY and BOUNDARY (or the default
+ // boundary value), when they are expressions of type CHARACTER.
+ // This needs more work.
+ if (mlir::isa<fir::CharacterType>(expr.getEleTy())) {
+ if (!hlfir::Entity{op.getArray()}.isVariable())
+ return rewriter.notifyMatchFailure(
+ op, "EOSHIFT with ARRAY being CHARACTER expression");
+ if (op.getBoundary() && !hlfir::Entity{op.getBoundary()}.isVariable())
+ return rewriter.notifyMatchFailure(
+ op, "EOSHIFT with BOUNDARY being CHARACTER expression");
+ }
+ // TODO: selecting between ARRAY and BOUNDARY values with derived types
+ // need more work.
+ if (fir::isa_derived(expr.getEleTy()))
+ return rewriter.notifyMatchFailure(op, "EOSHIFT of derived type");
+ }
// When DIM==1 and the contiguity of the input array is not statically
// known, try to exploit the fact that the leading dimension might be
// contiguous. We can do this now using hlfir.eval_in_mem with
// a dynamic check for the leading dimension contiguity.
- // Otherwise, convert hlfir.cshift to hlfir.elemental.
+ // Otherwise, convert hlfir.cshift/eoshift to hlfir.elemental.
//
// Note that the hlfir.elemental can be inlined into other hlfir.elemental,
// while hlfir.eval_in_mem prevents this, and we will end up creating
// a temporary array for the result. We may need to come up with
// a more sophisticated logic for picking the most efficient
// representation.
- hlfir::Entity array = hlfir::Entity{cshift.getArray()};
+ hlfir::Entity array = hlfir::Entity{op.getArray()};
mlir::Type elementType = array.getFortranElementType();
if (dimVal == 1 && fir::isa_trivial(elementType) &&
- // genInMemCShift() only works for variables currently.
+ // genInMemArrayShift() only works for variables currently.
array.isVariable())
- rewriter.replaceOp(cshift, genInMemCShift(rewriter, cshift, dimVal));
+ rewriter.replaceOp(op, genInMemArrayShift(rewriter, op, dimVal));
else
- rewriter.replaceOp(cshift, genElementalCShift(rewriter, cshift, dimVal));
+ rewriter.replaceOp(op, genElementalArrayShift(rewriter, op, dimVal));
return mlir::success();
}
private:
- /// Generate MODULO(\p shiftVal, \p extent).
+ /// For CSHIFT, generate MODULO(\p shiftVal, \p extent).
+ /// For EOSHIFT, return \p shiftVal casted to \p calcType.
static mlir::Value normalizeShiftValue(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value shiftVal,
mlir::Value extent,
mlir::Type calcType) {
shiftVal = builder.createConvert(loc, calcType, shiftVal);
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
+ return shiftVal;
+
extent = builder.createConvert(loc, calcType, extent);
// Make sure that we do not divide by zero. When the dimension
// has zero size, turn the extent into 1. Note that the computed
@@ -1342,24 +1370,227 @@ private:
return builder.createConvert(loc, calcType, shiftVal);
}
- /// Convert \p cshift into an hlfir.elemental using
+ /// The indices computations for the array shifts are done using I64 type.
+ /// For CSHIFT, all computations do not overflow signed and unsigned I64.
+ /// For EOSHIFT, some computations may involve negative shift values,
+ /// so using no-unsigned wrap flag would be incorrect.
+ static void setArithOverflowFlags(Op op, fir::FirOpBuilder &builder) {
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
+ else
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
+ mlir::arith::IntegerOverflowFlags::nuw);
+ }
+
+ /// Return the element type of the EOSHIFT boundary that may be omitted
+ /// statically or dynamically. This element type might be used
+ /// to generate MLIR where we have to select between the default
+ /// boundary value and the dynamically absent/present boundary value.
+ /// If the boundary has a type not defined in Table 16.4 in 16.9.77
+ /// of F2023, then the return value is nullptr.
+ static mlir::Type getDefaultBoundaryValueType(mlir::Type elementType) {
+ // To be able to generate a "select" between the default boundary value
+ // and the dynamic boundary value, use BoxCharType for the CHARACTER
+ // cases. This might be a little bit inefficient, because we may
+ // create unnecessary tuples, but it simplifies the inlining code.
+ if (auto charTy = mlir::dyn_cast<fir::CharacterType>(elementType))
+ return fir::BoxCharType::get(charTy.getContext(), charTy.getFKind());
+
+ if (mlir::isa<fir::LogicalType>(elementType) ||
+ fir::isa_integer(elementType) || fir::isa_real(elementType) ||
+ fir::isa_complex(elementType))
+ return elementType;
+
+ return nullptr;
+ }
+
+ /// Generate the default boundary value as defined in Table 16.4 in 16.9.77
+ /// of F2023.
+ static mlir::Value genDefaultBoundary(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Type elementType) {
+ assert(getDefaultBoundaryValueType(elementType) &&
+ "default boundary value cannot be computed for the given type");
+ if (mlir::isa<fir::CharacterType>(elementType)) {
+ // Create an empty CHARACTER of the same kind. The assignment
+ // of this empty CHARACTER into the result will add the padding
+ // if necessary.
+ fir::factory::CharacterExprHelper charHelper{builder, loc};
+ mlir::Value zeroLen = builder.createIntegerConstant(
+ loc, builder.getCharacterLengthType(), 0);
+ fir::CharBoxValue emptyCharTemp =
+ charHelper.createCharacterTemp(elementType, zeroLen);
+ return charHelper.createEmbox(emptyCharTemp);
+ }
+
+ return fir::factory::createZeroValue(builder, loc, elementType);
+ }
+
+ /// \p entity represents the boundary operand of hlfir.eoshift.
+ /// This method generates a scalar boundary value fetched
+ /// from the boundary entity using \p indices (which may be empty,
+ /// if the boundary operand is scalar).
+ static mlir::Value loadEoshiftVal(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity entity,
+ mlir::ValueRange indices = {}) {
+ hlfir::Entity boundaryVal =
+ hlfir::loadElementAt(loc, builder, entity, indices);
+
+ mlir::Type boundaryValTy =
+ getDefaultBoundaryValueType(entity.getFortranElementType());
+
+ // Boxed !fir.char<KIND,LEN> with known LEN are loaded
+ // as raw references to !fir.char<KIND,LEN>.
+ // We need to wrap them into the !fir.boxchar.
+ if (boundaryVal.isVariable() && boundaryValTy &&
+ mlir::isa<fir::BoxCharType>(boundaryValTy))
+ return hlfir::genVariableBoxChar(loc, builder, boundaryVal);
+ return boundaryVal;
+ }
+
+ /// This method generates a scalar boundary value for the given hlfir.eoshift
+ /// \p op that can be used to initialize cells of the result
+ /// if the scalar/array boundary operand is statically or dynamically
+ /// absent. The first result is the scalar boundary value. The second result
+ /// is a dynamic predicate indicating whether the scalar boundary value
+ /// should actually be used.
+ [[maybe_unused]] static std::pair<mlir::Value, mlir::Value>
+ genScalarBoundaryForEOShift(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::EOShiftOp op) {
+ hlfir::Entity array{op.getArray()};
+ mlir::Type elementType = array.getFortranElementType();
+
+ if (!op.getBoundary()) {
+ // Boundary operand is statically absent.
+ mlir::Value defaultVal = genDefaultBoundary(loc, builder, elementType);
+ mlir::Value boundaryIsScalarPred = builder.createBool(loc, true);
+ return {defaultVal, boundaryIsScalarPred};
+ }
+
+ hlfir::Entity boundary{op.getBoundary()};
+ mlir::Type boundaryValTy = getDefaultBoundaryValueType(elementType);
+
+ if (boundary.isScalar()) {
+ if (!boundaryValTy || !boundary.mayBeOptional()) {
+ // The boundary must be present.
+ mlir::Value boundaryVal = loadEoshiftVal(loc, builder, boundary);
+ mlir::Value boundaryIsScalarPred = builder.createBool(loc, true);
+ return {boundaryVal, boundaryIsScalarPred};
+ }
+
+ // Boundary is a scalar that may be dynamically absent.
+ // If boundary is not present dynamically, we must use the default
+ // value.
+ assert(mlir::isa<fir::BaseBoxType>(boundary.getType()));
+ mlir::Value isPresentPred =
+ fir::IsPresentOp::create(builder, loc, builder.getI1Type(), boundary);
+ mlir::Value boundaryVal =
+ builder
+ .genIfOp(loc, {boundaryValTy}, isPresentPred,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value boundaryVal =
+ loadEoshiftVal(loc, builder, boundary);
+ fir::ResultOp::create(builder, loc, boundaryVal);
+ })
+ .genElse([&]() {
+ mlir::Value defaultVal =
+ genDefaultBoundary(loc, builder, elementType);
+ fir::ResultOp::create(builder, loc, defaultVal);
+ })
+ .getResults()[0];
+ mlir::Value boundaryIsScalarPred = builder.createBool(loc, true);
+ return {boundaryVal, boundaryIsScalarPred};
+ }
+ if (!boundaryValTy || !boundary.mayBeOptional()) {
+ // The boundary must be present
+ mlir::Value boundaryIsScalarPred = builder.createBool(loc, false);
+ return {nullptr, boundaryIsScalarPred};
+ }
+
+ // Boundary is an array that may be dynamically absent.
+ mlir::Value defaultVal = genDefaultBoundary(loc, builder, elementType);
+ mlir::Value isPresentPred =
+ fir::IsPresentOp::create(builder, loc, builder.getI1Type(), boundary);
+ // If the array is present, then boundaryIsScalarPred must be equal
+ // to false, otherwise, it should be true.
+ mlir::Value trueVal = builder.createBool(loc, true);
+ mlir::Value falseVal = builder.createBool(loc, false);
+ mlir::Value boundaryIsScalarPred = mlir::arith::SelectOp::create(
+ builder, loc, isPresentPred, falseVal, trueVal);
+ return {defaultVal, boundaryIsScalarPred};
+ }
+
+ /// Generate code that produces the final boundary value to be assigned
+ /// to the result of hlfir.eoshift \p op. \p precomputedScalarBoundary
+ /// specifies the scalar boundary value pre-computed before the elemental
+ /// or the assignment loop. If it is nullptr, then the boundary operand
+ /// of \p op must be a present array. \p boundaryIsScalarPred is a dynamic
+ /// predicate that is true, when the pre-computed scalar value must be used.
+ /// \p oneBasedIndices specify the indices to address into the boundary
+ /// array - they may be empty, if the boundary is scalar.
+ [[maybe_unused]] static mlir::Value selectBoundaryValue(
+ mlir::Location loc, fir::FirOpBuilder &builder, hlfir::EOShiftOp op,
+ mlir::Value precomputedScalarBoundary, mlir::Value boundaryIsScalarPred,
+ mlir::ValueRange oneBasedIndices) {
+ // Boundary is statically absent: a default value has been precomputed.
+ if (!op.getBoundary())
+ return precomputedScalarBoundary;
+
+ // Boundary is statically present and is a scalar: boundary does not depend
+ // upon the indices and so it has been precomputed.
+ hlfir::Entity boundary{op.getBoundary()};
+ if (boundary.isScalar())
+ return precomputedScalarBoundary;
+
+ // Boundary is statically present and is an array: if the scalar
+ // boundary has not been precomputed, this means that the data type
+ // of the shifted values does not provide a way to compute
+ // the default boundary value, so the array boundary must be dynamically
+ // present, and we can load the boundary values from it.
+ bool mustBePresent = !precomputedScalarBoundary;
+ if (mustBePresent)
+ return loadEoshiftVal(loc, builder, boundary, oneBasedIndices);
+
+ // The array boundary may be dynamically absent.
+ // In this case, precomputedScalarBoundary is a pre-computed scalar
+ // boundary value that has to be used if boundaryIsScalarPred
+ // is true, otherwise, the boundary value has to be loaded
+ // from the boundary array.
+ mlir::Type boundaryValTy = precomputedScalarBoundary.getType();
+ mlir::Value newBoundaryVal =
+ builder
+ .genIfOp(loc, {boundaryValTy}, boundaryIsScalarPred,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ fir::ResultOp::create(builder, loc, precomputedScalarBoundary);
+ })
+ .genElse([&]() {
+ mlir::Value elem =
+ loadEoshiftVal(loc, builder, boundary, oneBasedIndices);
+ fir::ResultOp::create(builder, loc, elem);
+ })
+ .getResults()[0];
+ return newBoundaryVal;
+ }
+
+ /// Convert \p op into an hlfir.elemental using
/// the pre-computed constant \p dimVal.
- static mlir::Operation *genElementalCShift(mlir::PatternRewriter &rewriter,
- hlfir::CShiftOp cshift,
- int64_t dimVal) {
+ static mlir::Operation *
+ genElementalArrayShift(mlir::PatternRewriter &rewriter, Op op,
+ int64_t dimVal) {
using Fortran::common::maxRank;
- hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
- hlfir::Entity array = hlfir::Entity{cshift.getArray()};
+ hlfir::Entity shift = hlfir::Entity{op.getShift()};
+ hlfir::Entity array = hlfir::Entity{op.getArray()};
- mlir::Location loc = cshift.getLoc();
- fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
+ mlir::Location loc = op.getLoc();
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
// The new index computation involves MODULO, which is not implemented
// for IndexType, so use I64 instead.
mlir::Type calcType = builder.getI64Type();
- // All the indices arithmetic used below does not overflow
- // signed and unsigned I64.
- builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
- mlir::arith::IntegerOverflowFlags::nuw);
+ // Set the indices arithmetic overflow flags.
+ setArithOverflowFlags(op, builder);
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value, maxRank> arrayExtents =
@@ -1374,6 +1605,17 @@ private:
shiftVal =
normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType);
}
+ // The boundary operand of hlfir.eoshift may be statically or
+ // dynamically absent.
+ // In both cases, it is assumed to be a scalar with the value
+ // corresponding to the array element type.
+ // boundaryIsScalarPred is a dynamic predicate that identifies
+ // these cases. If boundaryIsScalarPred is dynamicaly false,
+ // then the boundary operand must be a present array.
+ mlir::Value boundaryVal, boundaryIsScalarPred;
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
+ std::tie(boundaryVal, boundaryIsScalarPred) =
+ genScalarBoundaryForEOShift(loc, builder, op);
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
@@ -1394,34 +1636,84 @@ private:
shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent,
calcType);
}
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) {
+ llvm::SmallVector<mlir::Value, maxRank> boundaryIndices{indices};
+ boundaryIndices.erase(boundaryIndices.begin() + dimVal - 1);
+ boundaryVal =
+ selectBoundaryValue(loc, builder, op, boundaryVal,
+ boundaryIsScalarPred, boundaryIndices);
+ }
- // Element i of the result (1-based) is element
- // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original
- // ARRAY (or its section, when ARRAY is not a vector).
-
- // Compute the index into the original array using the normalized
- // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)):
- // newIndex =
- // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM))
- //
- // Such index computation allows for further loop vectorization
- // in LLVM.
- mlir::Value wrapBound =
- mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
- mlir::Value adjustedShiftVal =
- mlir::arith::SubIOp::create(builder, loc, shiftVal, shiftDimExtent);
- mlir::Value index =
- builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
- mlir::Value wrapCheck = mlir::arith::CmpIOp::create(
- builder, loc, mlir::arith::CmpIPredicate::sle, index, wrapBound);
- mlir::Value actualShift = mlir::arith::SelectOp::create(
- builder, loc, wrapCheck, shiftVal, adjustedShiftVal);
- mlir::Value newIndex =
- mlir::arith::AddIOp::create(builder, loc, index, actualShift);
- newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
- indices[dimVal - 1] = newIndex;
- hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
- return hlfir::loadTrivialScalar(loc, builder, element);
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) {
+ // EOSHIFT:
+ // Element i of the result (1-based) is the element of the original
+ // array (or its section, when ARRAY is not a vector) with index
+ // (i + SH), if (1 <= i + SH <= SIZE(ARRAY,DIM)), otherwise
+ // it is the BOUNDARY value.
+ mlir::Value index =
+ builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
+ mlir::arith::IntegerOverflowFlags savedFlags =
+ builder.getIntegerOverflowFlags();
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
+ mlir::Value indexPlusShift =
+ mlir::arith::AddIOp::create(builder, loc, index, shiftVal);
+ builder.setIntegerOverflowFlags(savedFlags);
+ mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
+ mlir::Value cmp1 = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sge, indexPlusShift, one);
+ mlir::Value cmp2 = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sle, indexPlusShift,
+ shiftDimExtent);
+ mlir::Value loadFromArray =
+ mlir::arith::AndIOp::create(builder, loc, cmp1, cmp2);
+ mlir::Type boundaryValTy = boundaryVal.getType();
+ mlir::Value result =
+ builder
+ .genIfOp(loc, {boundaryValTy}, loadFromArray,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ indices[dimVal - 1] = builder.createConvert(
+ loc, builder.getIndexType(), indexPlusShift);
+ ;
+ mlir::Value elem =
+ loadEoshiftVal(loc, builder, array, indices);
+ fir::ResultOp::create(builder, loc, elem);
+ })
+ .genElse(
+ [&]() { fir::ResultOp::create(builder, loc, boundaryVal); })
+ .getResults()[0];
+ return hlfir::Entity{result};
+ } else {
+ // CSHIFT:
+ // Element i of the result (1-based) is element
+ // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original
+ // ARRAY (or its section, when ARRAY is not a vector).
+
+ // Compute the index into the original array using the normalized
+ // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)):
+ // newIndex =
+ // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM))
+ //
+ // Such index computation allows for further loop vectorization
+ // in LLVM.
+ mlir::Value wrapBound =
+ mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
+ mlir::Value adjustedShiftVal =
+ mlir::arith::SubIOp::create(builder, loc, shiftVal, shiftDimExtent);
+ mlir::Value index =
+ builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
+ mlir::Value wrapCheck = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sle, index, wrapBound);
+ mlir::Value actualShift = mlir::arith::SelectOp::create(
+ builder, loc, wrapCheck, shiftVal, adjustedShiftVal);
+ mlir::Value newIndex =
+ mlir::arith::AddIOp::create(builder, loc, index, actualShift);
+ newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
+ indices[dimVal - 1] = newIndex;
+ hlfir::Entity element =
+ hlfir::getElementAt(loc, builder, array, indices);
+ return hlfir::loadTrivialScalar(loc, builder, element);
+ }
};
mlir::Type elementType = array.getFortranElementType();
@@ -1429,19 +1721,42 @@ private:
loc, builder, elementType, arrayShape, typeParams, genKernel,
/*isUnordered=*/true,
array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
- cshift.getResult().getType());
+ op.getResult().getType());
return elementalOp.getOperation();
}
- /// Convert \p cshift into an hlfir.eval_in_mem using the pre-computed
+ /// Convert \p op into an hlfir.eval_in_mem using the pre-computed
/// constant \p dimVal.
- /// The converted code looks like this:
- /// do i=1,SH
- /// result(i + (SIZE(ARRAY,DIM) - SH)) = array(i)
+ /// The converted code for CSHIFT looks like this:
+ /// DEST_OFFSET = SIZE(ARRAY,DIM) - SH
+ /// COPY_END1 = SH
+ /// do i=1,COPY_END1
+ /// result(i + DEST_OFFSET) = array(i)
/// end
- /// do i=1,SIZE(ARRAY,DIM) - SH
- /// result(i) = array(i + SH)
+ /// SOURCE_OFFSET = SH
+ /// COPY_END2 = SIZE(ARRAY,DIM) - SH
+ /// do i=1,COPY_END2
+ /// result(i) = array(i + SOURCE_OFFSET)
/// end
+ /// Where SH is the normalized shift value, which satisfies
+ /// (SH >= 0 && SH < SIZE(ARRAY,DIM)).
+ ///
+ /// The converted code for EOSHIFT looks like this:
+ /// EXTENT = SIZE(ARRAY,DIM)
+ /// DEST_OFFSET = SH < 0 ? -SH : 0
+ /// SOURCE_OFFSET = SH < 0 ? 0 : SH
+ /// COPY_END = SH < 0 ?
+ /// (-EXTENT > SH ? 0 : EXTENT + SH) :
+ /// (EXTENT < SH ? 0 : EXTENT - SH)
+ /// do i=1,COPY_END
+ /// result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET)
+ /// end
+ /// INIT_END = EXTENT - COPY_END
+ /// INIT_OFFSET = SH < 0 ? 0 : COPY_END
+ /// do i=1,INIT_END
+ /// result(i + INIT_OFFSET) = BOUNDARY
+ /// end
+ /// Where SH is the original shift value.
///
/// When \p dimVal is 1, we generate the same code twice
/// under a dynamic check for the contiguity of the leading
@@ -1450,24 +1765,21 @@ private:
/// as a contiguous slice of the original array.
/// This allows recognizing the above two loops as memcpy
/// loop idioms in LLVM.
- static mlir::Operation *genInMemCShift(mlir::PatternRewriter &rewriter,
- hlfir::CShiftOp cshift,
- int64_t dimVal) {
+ static mlir::Operation *genInMemArrayShift(mlir::PatternRewriter &rewriter,
+ Op op, int64_t dimVal) {
using Fortran::common::maxRank;
- hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
- hlfir::Entity array = hlfir::Entity{cshift.getArray()};
+ hlfir::Entity shift = hlfir::Entity{op.getShift()};
+ hlfir::Entity array = hlfir::Entity{op.getArray()};
assert(array.isVariable() && "array must be a variable");
assert(!array.isPolymorphic() &&
- "genInMemCShift does not support polymorphic types");
- mlir::Location loc = cshift.getLoc();
- fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
+ "genInMemArrayShift does not support polymorphic types");
+ mlir::Location loc = op.getLoc();
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
// The new index computation involves MODULO, which is not implemented
// for IndexType, so use I64 instead.
mlir::Type calcType = builder.getI64Type();
- // All the indices arithmetic used below does not overflow
- // signed and unsigned I64.
- builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
- mlir::arith::IntegerOverflowFlags::nuw);
+ // Set the indices arithmetic overflow flags.
+ setArithOverflowFlags(op, builder);
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value, maxRank> arrayExtents =
@@ -1482,10 +1794,20 @@ private:
shiftVal =
normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType);
}
+ // The boundary operand of hlfir.eoshift may be statically or
+ // dynamically absent.
+ // In both cases, it is assumed to be a scalar with the value
+ // corresponding to the array element type.
+ // boundaryIsScalarPred is a dynamic predicate that identifies
+ // these cases. If boundaryIsScalarPred is dynamicaly false,
+ // then the boundary operand must be a present array.
+ mlir::Value boundaryVal, boundaryIsScalarPred;
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
+ std::tie(boundaryVal, boundaryIsScalarPred) =
+ genScalarBoundaryForEOShift(loc, builder, op);
hlfir::EvaluateInMemoryOp evalOp = hlfir::EvaluateInMemoryOp::create(
- builder, loc, mlir::cast<hlfir::ExprType>(cshift.getType()),
- arrayShape);
+ builder, loc, mlir::cast<hlfir::ExprType>(op.getType()), arrayShape);
builder.setInsertionPointToStart(&evalOp.getBody().front());
mlir::Value resultArray = evalOp.getMemory();
@@ -1499,11 +1821,12 @@ private:
// (if any). If exposeContiguity is true, the array's section
// array(s(1), ..., s(dim-1), :, s(dim+1), ..., s(n)) is represented
// as a contiguous 1D array.
- // shiftVal is the normalized shift value that satisfies (SH >= 0 && SH <
- // SIZE(ARRAY,DIM)).
+ // For CSHIFT, shiftVal is the normalized shift value that satisfies
+ // (SH >= 0 && SH < SIZE(ARRAY,DIM)).
//
auto genDimensionShift = [&](mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::Value shiftVal, bool exposeContiguity,
+ mlir::Value shiftVal, mlir::Value boundary,
+ bool exposeContiguity,
mlir::ValueRange oneBasedIndices)
-> llvm::SmallVector<mlir::Value, 0> {
// Create a vector of indices (s(1), ..., s(dim-1), nullptr, s(dim+1),
@@ -1536,63 +1859,143 @@ private:
srcIndices.resize(1);
}
- // Copy first portion of the array:
- // do i=1,SH
- // result(i + (SIZE(ARRAY,DIM) - SH)) = array(i)
- // end
- auto genAssign1 = [&](mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange index,
- mlir::ValueRange reductionArgs)
+ // genCopy labda generates the body of a generic copy loop.
+ // do i=1,COPY_END
+ // result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET)
+ // end
+ //
+ // It is parameterized by DEST_OFFSET and SOURCE_OFFSET.
+ mlir::Value dstOffset, srcOffset;
+ auto genCopy = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index, mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 0> {
assert(index.size() == 1 && "expected single loop");
mlir::Value srcIndex = builder.createConvert(loc, calcType, index[0]);
+ mlir::Value dstIndex = srcIndex;
+ if (srcOffset)
+ srcIndex =
+ mlir::arith::AddIOp::create(builder, loc, srcIndex, srcOffset);
srcIndices[dimVal - 1] = srcIndex;
hlfir::Entity srcElementValue =
hlfir::loadElementAt(loc, builder, srcArray, srcIndices);
- mlir::Value dstIndex = mlir::arith::AddIOp::create(
- builder, loc, srcIndex,
- mlir::arith::SubIOp::create(builder, loc, shiftDimExtent,
- shiftVal));
+ if (dstOffset)
+ dstIndex =
+ mlir::arith::AddIOp::create(builder, loc, dstIndex, dstOffset);
dstIndices[dimVal - 1] = dstIndex;
hlfir::Entity dstElement = hlfir::getElementAt(
loc, builder, hlfir::Entity{resultArray}, dstIndices);
hlfir::AssignOp::create(builder, loc, srcElementValue, dstElement);
+ // Reset the external parameters' values to make sure
+ // they are properly updated between the labda calls.
+ // WARNING: if genLoopNestWithReductions() calls the lambda
+ // multiple times, this is going to be a problem.
+ dstOffset = nullptr;
+ srcOffset = nullptr;
return {};
};
- // Generate the first loop.
- hlfir::genLoopNestWithReductions(loc, builder, {shiftVal},
- /*reductionInits=*/{}, genAssign1,
- /*isUnordered=*/true);
-
- // Copy second portion of the array:
- // do i=1,SIZE(ARRAY,DIM)-SH
- // result(i) = array(i + SH)
- // end
- auto genAssign2 = [&](mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange index,
- mlir::ValueRange reductionArgs)
- -> llvm::SmallVector<mlir::Value, 0> {
- assert(index.size() == 1 && "expected single loop");
- mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]);
- mlir::Value srcIndex =
- mlir::arith::AddIOp::create(builder, loc, dstIndex, shiftVal);
- srcIndices[dimVal - 1] = srcIndex;
- hlfir::Entity srcElementValue =
- hlfir::loadElementAt(loc, builder, srcArray, srcIndices);
- dstIndices[dimVal - 1] = dstIndex;
- hlfir::Entity dstElement = hlfir::getElementAt(
- loc, builder, hlfir::Entity{resultArray}, dstIndices);
- hlfir::AssignOp::create(builder, loc, srcElementValue, dstElement);
- return {};
- };
-
- // Generate the second loop.
- mlir::Value bound =
- mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
- hlfir::genLoopNestWithReductions(loc, builder, {bound},
- /*reductionInits=*/{}, genAssign2,
- /*isUnordered=*/true);
+ if constexpr (std::is_same_v<Op, hlfir::CShiftOp>) {
+ // Copy first portion of the array:
+ // DEST_OFFSET = SIZE(ARRAY,DIM) - SH
+ // COPY_END1 = SH
+ // do i=1,COPY_END1
+ // result(i + DEST_OFFSET) = array(i)
+ // end
+ dstOffset =
+ mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
+ srcOffset = nullptr;
+ hlfir::genLoopNestWithReductions(loc, builder, {shiftVal},
+ /*reductionInits=*/{}, genCopy,
+ /*isUnordered=*/true);
+
+ // Copy second portion of the array:
+ // SOURCE_OFFSET = SH
+ // COPY_END2 = SIZE(ARRAY,DIM) - SH
+ // do i=1,COPY_END2
+ // result(i) = array(i + SOURCE_OFFSET)
+ // end
+ mlir::Value bound =
+ mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
+ dstOffset = nullptr;
+ srcOffset = shiftVal;
+ hlfir::genLoopNestWithReductions(loc, builder, {bound},
+ /*reductionInits=*/{}, genCopy,
+ /*isUnordered=*/true);
+ } else {
+ // Do the copy:
+ // EXTENT = SIZE(ARRAY,DIM)
+ // DEST_OFFSET = SH < 0 ? -SH : 0
+ // SOURCE_OFFSET = SH < 0 ? 0 : SH
+ // COPY_END = SH < 0 ?
+ // (-EXTENT > SH ? 0 : EXTENT + SH) :
+ // (EXTENT < SH ? 0 : EXTENT - SH)
+ // do i=1,COPY_END
+ // result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET)
+ // end
+ mlir::arith::IntegerOverflowFlags savedFlags =
+ builder.getIntegerOverflowFlags();
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
+
+ mlir::Value zero = builder.createIntegerConstant(loc, calcType, 0);
+ mlir::Value isNegativeShift = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::slt, shiftVal, zero);
+ mlir::Value shiftNeg =
+ mlir::arith::SubIOp::create(builder, loc, zero, shiftVal);
+ dstOffset = mlir::arith::SelectOp::create(builder, loc, isNegativeShift,
+ shiftNeg, zero);
+ srcOffset = mlir::arith::SelectOp::create(builder, loc, isNegativeShift,
+ zero, shiftVal);
+ mlir::Value extentNeg =
+ mlir::arith::SubIOp::create(builder, loc, zero, shiftDimExtent);
+ mlir::Value extentPlusShift =
+ mlir::arith::AddIOp::create(builder, loc, shiftDimExtent, shiftVal);
+ mlir::Value extentNegShiftCmp = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, extentNeg, shiftVal);
+ mlir::Value negativeShiftBound = mlir::arith::SelectOp::create(
+ builder, loc, extentNegShiftCmp, zero, extentPlusShift);
+ mlir::Value extentMinusShift =
+ mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal);
+ mlir::Value extentShiftCmp = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::slt, shiftDimExtent,
+ shiftVal);
+ mlir::Value positiveShiftBound = mlir::arith::SelectOp::create(
+ builder, loc, extentShiftCmp, zero, extentMinusShift);
+ mlir::Value copyEnd = mlir::arith::SelectOp::create(
+ builder, loc, isNegativeShift, negativeShiftBound,
+ positiveShiftBound);
+ hlfir::genLoopNestWithReductions(loc, builder, {copyEnd},
+ /*reductionInits=*/{}, genCopy,
+ /*isUnordered=*/true);
+
+ // Do the init:
+ // INIT_END = EXTENT - COPY_END
+ // INIT_OFFSET = SH < 0 ? 0 : COPY_END
+ // do i=1,INIT_END
+ // result(i + INIT_OFFSET) = BOUNDARY
+ // end
+ assert(boundary && "boundary cannot be null");
+ mlir::Value initEnd =
+ mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, copyEnd);
+ mlir::Value initOffset = mlir::arith::SelectOp::create(
+ builder, loc, isNegativeShift, zero, copyEnd);
+ auto genInit = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 0> {
+ mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]);
+ dstIndex =
+ mlir::arith::AddIOp::create(builder, loc, dstIndex, initOffset);
+ dstIndices[dimVal - 1] = dstIndex;
+ hlfir::Entity dstElement = hlfir::getElementAt(
+ loc, builder, hlfir::Entity{resultArray}, dstIndices);
+ hlfir::AssignOp::create(builder, loc, boundary, dstElement);
+ return {};
+ };
+ hlfir::genLoopNestWithReductions(loc, builder, {initEnd},
+ /*reductionInits=*/{}, genInit,
+ /*isUnordered=*/true);
+ builder.setIntegerOverflowFlags(savedFlags);
+ }
return {};
};
@@ -1614,6 +2017,10 @@ private:
shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent,
calcType);
}
+ if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
+ boundaryVal =
+ selectBoundaryValue(loc, builder, op, boundaryVal,
+ boundaryIsScalarPred, oneBasedIndices);
// If we can fetch the byte stride of the leading dimension,
// and the byte size of the element, then we can generate
@@ -1635,8 +2042,8 @@ private:
}
if (array.isSimplyContiguous() || !elemSize || !stride) {
- genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/false,
- oneBasedIndices);
+ genDimensionShift(loc, builder, shiftVal, boundaryVal,
+ /*exposeContiguity=*/false, oneBasedIndices);
return {};
}
@@ -1644,11 +2051,11 @@ private:
builder, loc, mlir::arith::CmpIPredicate::eq, elemSize, stride);
builder.genIfOp(loc, {}, isContiguous, /*withElseRegion=*/true)
.genThen([&]() {
- genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/true,
- oneBasedIndices);
+ genDimensionShift(loc, builder, shiftVal, boundaryVal,
+ /*exposeContiguity=*/true, oneBasedIndices);
})
.genElse([&]() {
- genDimensionShift(loc, builder, shiftVal,
+ genDimensionShift(loc, builder, shiftVal, boundaryVal,
/*exposeContiguity=*/false, oneBasedIndices);
});
@@ -1671,6 +2078,212 @@ private:
}
};
+class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::CmpCharOp cmp,
+ mlir::PatternRewriter &rewriter) const override {
+
+ fir::FirOpBuilder builder{rewriter, cmp.getOperation()};
+ const mlir::Location &loc = cmp->getLoc();
+
+ auto toVariable =
+ [&builder,
+ &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> {
+ mlir::Value opnd;
+ hlfir::AssociateOp associate;
+ if (mlir::isa<hlfir::ExprType>(val.getType())) {
+ hlfir::Entity entity{val};
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
+ associate = hlfir::genAssociateExpr(loc, builder, entity,
+ entity.getType(), "", byRefAttr);
+ opnd = associate.getBase();
+ } else {
+ opnd = val;
+ }
+ return {opnd, associate};
+ };
+
+ auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr());
+ auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr());
+
+ hlfir::Entity lhs{lhsOpnd};
+ hlfir::Entity rhs{rhsOpnd};
+
+ auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType());
+ unsigned kind = charTy.getFKind();
+
+ auto bits = builder.getKindMap().getCharacterBitsize(kind);
+ auto intTy = builder.getIntegerType(bits);
+
+ auto idxTy = builder.getIndexType();
+ auto charLen1Ty =
+ fir::CharacterType::getSingleton(builder.getContext(), kind);
+ mlir::Type designatorType =
+ fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
+ auto idxAttr = builder.getIntegerAttr(idxTy, 0);
+
+ auto genExtractAndConvertToInt =
+ [&idxAttr, &intTy, &designatorType](
+ mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity &charStr, mlir::Value index, mlir::Value length) {
+ auto singleChr = hlfir::DesignateOp::create(
+ builder, loc, designatorType, charStr, /*component=*/{},
+ /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
+ /*substring=*/mlir::ValueRange{index, index},
+ /*complexPart=*/std::nullopt,
+ /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length},
+ fir::FortranVariableFlagsAttr{});
+ auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
+ mlir::Value intVal = fir::ExtractValueOp::create(
+ builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
+ return intVal;
+ };
+
+ mlir::arith::CmpIPredicate predicate = cmp.getPredicate();
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ mlir::Value lhsLen = builder.createConvert(
+ loc, idxTy, hlfir::genCharLength(loc, builder, lhs));
+ mlir::Value rhsLen = builder.createConvert(
+ loc, idxTy, hlfir::genCharLength(loc, builder, rhs));
+
+ enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight };
+
+ mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0);
+ mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1);
+ mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1);
+ mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' ');
+
+ auto step = GenCmp::LeftToRight;
+ auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index, mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ assert(index.size() == 1 && "expected single loop");
+ assert(reductionArgs.size() == 1 && "expected single reduction value");
+ mlir::Value inRes = reductionArgs[0];
+ auto accEQzero = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt);
+
+ mlir::Value res =
+ builder
+ .genIfOp(loc, {intTy}, accEQzero,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value offset =
+ builder.createConvert(loc, idxTy, index[0]);
+ mlir::Value lhsInt;
+ mlir::Value rhsInt;
+ if (step == GenCmp::LeftToRight) {
+ lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
+ oneIdx);
+ rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
+ oneIdx);
+ } else if (step == GenCmp::LeftToBlank) {
+ // lhsLen > rhsLen
+ offset =
+ mlir::arith::AddIOp::create(builder, loc, rhsLen, offset);
+
+ lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
+ oneIdx);
+ rhsInt = blankInt;
+ } else if (step == GenCmp::BlankToRight) {
+ // rhsLen > lhsLen
+ offset =
+ mlir::arith::AddIOp::create(builder, loc, lhsLen, offset);
+
+ lhsInt = blankInt;
+ rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
+ oneIdx);
+ } else {
+ llvm_unreachable(
+ "unknown compare step for CmpCharOp lowering");
+ }
+
+ mlir::Value newVal = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(builder, loc,
+ mlir::arith::CmpIPredicate::ult,
+ lhsInt, rhsInt),
+ negOneInt, inRes);
+ newVal = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(builder, loc,
+ mlir::arith::CmpIPredicate::ugt,
+ lhsInt, rhsInt),
+ oneInt, newVal);
+ fir::ResultOp::create(builder, loc, newVal);
+ })
+ .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
+ .getResults()[0];
+
+ return {res};
+ };
+
+ // First generate comparison of two strings for the legth of the shorter
+ // one.
+ mlir::Value minLen = mlir::arith::SelectOp::create(
+ builder, loc,
+ mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen),
+ lhsLen, rhsLen);
+
+ llvm::SmallVector<mlir::Value, 1> loopOut =
+ hlfir::genLoopNestWithReductions(loc, builder, {minLen},
+ /*reductionInits=*/{zeroInt}, genCmp,
+ /*isUnordered=*/false);
+ mlir::Value partRes = loopOut[0];
+
+ auto lhsLonger = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen);
+ mlir::Value tempRes =
+ builder
+ .genIfOp(loc, {intTy}, lhsLonger,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ // If left is the longer string generate compare left to blank.
+ step = GenCmp::LeftToBlank;
+ auto lenDiff =
+ mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen);
+
+ llvm::SmallVector<mlir::Value, 1> output =
+ hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
+ /*reductionInits=*/{partRes},
+ genCmp,
+ /*isUnordered=*/false);
+ mlir::Value res = output[0];
+ fir::ResultOp::create(builder, loc, res);
+ })
+ .genElse([&]() {
+ // If right is the longer string generate compare blank to
+ // right.
+ step = GenCmp::BlankToRight;
+ auto lenDiff =
+ mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen);
+ llvm::SmallVector<mlir::Value, 1> output =
+ hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
+ /*reductionInits=*/{partRes},
+ genCmp,
+ /*isUnordered=*/false);
+
+ mlir::Value res = output[0];
+ fir::ResultOp::create(builder, loc, res);
+ })
+ .getResults()[0];
+ if (lhsAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, lhsAssociate);
+ if (rhsAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, rhsAssociate);
+
+ auto finalCmpResult =
+ mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt);
+ rewriter.replaceOp(cmp, finalCmpResult);
+ return mlir::success();
+ }
+};
+
template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
@@ -2339,9 +2952,10 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
- patterns.insert<CShiftConversion>(context);
+ patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
+ patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
+ patterns.insert<CmpCharOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
-
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index e5fd19d..c9aff59 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -271,8 +271,6 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var,
mlir::Value extent = val;
mlir::Value upperbound =
mlir::arith::SubIOp::create(builder, loc, extent, one);
- upperbound = mlir::arith::AddIOp::create(builder, loc, lowerbound,
- upperbound);
mlir::Value stride = one;
if (strideIncludeLowerExtent) {
stride = cummulativeExtent;
@@ -591,7 +589,8 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
hlfir::AssignOp::create(firBuilder, loc, initVal,
declareOp.getBase());
} else {
- for (auto ext : seqTy.getShape()) {
+ // Generate loop nest from slowest to fastest running dimension
+ for (auto ext : llvm::reverse(seqTy.getShape())) {
auto lb = firBuilder.createIntegerConstant(loc, idxTy, 0);
auto ub = firBuilder.createIntegerConstant(loc, idxTy, ext - 1);
auto step = firBuilder.createIntegerConstant(loc, idxTy, 1);
@@ -614,6 +613,11 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
if (fir::isa_trivial(innerTy)) {
retVal = getDeclareOpForType(unwrappedTy).getBase();
+ mlir::Value allocatedScalar =
+ fir::AllocMemOp::create(builder, loc, innerTy);
+ mlir::Value firClass =
+ fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar);
+ fir::StoreOp::create(builder, loc, firClass, retVal);
} else if (mlir::isa<fir::SequenceType>(innerTy)) {
hlfir::Entity source = hlfir::Entity{var};
auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source);
diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp
new file mode 100644
index 0000000..8b99913
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp
@@ -0,0 +1,159 @@
+//===- AutomapToTargetData.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/DirectivesCommon.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+using namespace mlir;
+
+namespace {
+class AutomapToTargetDataPass
+ : public flangomp::impl::AutomapToTargetDataPassBase<
+ AutomapToTargetDataPass> {
+
+ // Returns true if the variable has a dynamic size and therefore requires
+ // bounds operations to describe its extents.
+ inline bool needsBoundsOps(mlir::Value var) {
+ assert(mlir::isa<mlir::omp::PointerLikeType>(var.getType()) &&
+ "only pointer like types expected");
+ mlir::Type t = fir::unwrapRefType(var.getType());
+ if (mlir::Type inner = fir::dyn_cast_ptrOrBoxEleTy(t))
+ return fir::hasDynamicSize(inner);
+ return fir::hasDynamicSize(t);
+ }
+
+ // Generate MapBoundsOp operations for the variable if required.
+ inline void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value var,
+ llvm::SmallVectorImpl<mlir::Value> &boundsOps) {
+ mlir::Location loc = var.getLoc();
+ fir::factory::AddrAndBoundsInfo info =
+ fir::factory::getDataOperandBaseAddr(builder, var,
+ /*isOptional=*/false, loc);
+ fir::ExtendedValue exv =
+ hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{info.addr},
+ /*contiguousHint=*/true)
+ .first;
+ llvm::SmallVector<mlir::Value> tmp =
+ fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info, exv, /*dataExvIsAssumedSize=*/false, loc);
+ llvm::append_range(boundsOps, tmp);
+ }
+
+ void findRelatedAllocmemFreemem(fir::AddrOfOp addressOfOp,
+ llvm::DenseSet<fir::StoreOp> &allocmems,
+ llvm::DenseSet<fir::LoadOp> &freemems) {
+ assert(addressOfOp->hasOneUse() && "op must have single use");
+
+ auto declaredRef =
+ cast<hlfir::DeclareOp>(*addressOfOp->getUsers().begin())->getResult(0);
+
+ for (Operation *refUser : declaredRef.getUsers()) {
+ if (auto storeOp = dyn_cast<fir::StoreOp>(refUser))
+ if (auto emboxOp = storeOp.getValue().getDefiningOp<fir::EmboxOp>())
+ if (auto allocmemOp =
+ emboxOp.getOperand(0).getDefiningOp<fir::AllocMemOp>())
+ allocmems.insert(storeOp);
+
+ if (auto loadOp = dyn_cast<fir::LoadOp>(refUser))
+ for (Operation *loadUser : loadOp.getResult().getUsers())
+ if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(loadUser))
+ for (Operation *boxAddrUser : boxAddrOp.getResult().getUsers())
+ if (auto freememOp = dyn_cast<fir::FreeMemOp>(boxAddrUser))
+ freemems.insert(loadOp);
+ }
+ }
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation()->getParentOfType<ModuleOp>();
+ if (!module)
+ module = dyn_cast<ModuleOp>(getOperation());
+ if (!module)
+ return;
+
+ // Build FIR builder for helper utilities.
+ fir::KindMapping kindMap = fir::getKindMapping(module);
+ fir::FirOpBuilder builder{module, std::move(kindMap)};
+
+ // Collect global variables with AUTOMAP flag.
+ llvm::DenseSet<fir::GlobalOp> automapGlobals;
+ module.walk([&](fir::GlobalOp globalOp) {
+ if (auto iface =
+ dyn_cast<omp::DeclareTargetInterface>(globalOp.getOperation()))
+ if (iface.isDeclareTarget() && iface.getDeclareTargetAutomap() &&
+ iface.getDeclareTargetDeviceType() !=
+ omp::DeclareTargetDeviceType::host)
+ automapGlobals.insert(globalOp);
+ });
+
+ auto addMapInfo = [&](auto globalOp, auto memOp) {
+ builder.setInsertionPointAfter(memOp);
+ SmallVector<Value> bounds;
+ if (needsBoundsOps(memOp.getMemref()))
+ genBoundsOps(builder, memOp.getMemref(), bounds);
+
+ omp::TargetEnterExitUpdateDataOperands clauses;
+ mlir::omp::MapInfoOp mapInfo = mlir::omp::MapInfoOp::create(
+ builder, memOp.getLoc(), memOp.getMemref().getType(),
+ memOp.getMemref(),
+ TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())),
+ builder.getIntegerAttr(
+ builder.getIntegerType(64, false),
+ static_cast<unsigned>(
+ isa<fir::StoreOp>(memOp)
+ ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
+ : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)),
+ builder.getAttr<omp::VariableCaptureKindAttr>(
+ omp::VariableCaptureKind::ByCopy),
+ /*var_ptr_ptr=*/mlir::Value{},
+ /*members=*/SmallVector<Value>{},
+ /*members_index=*/ArrayAttr{}, bounds,
+ /*mapperId=*/mlir::FlatSymbolRefAttr(), globalOp.getSymNameAttr(),
+ builder.getBoolAttr(false));
+ clauses.mapVars.push_back(mapInfo);
+ isa<fir::StoreOp>(memOp)
+ ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses)
+ : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses);
+ };
+
+ for (fir::GlobalOp globalOp : automapGlobals) {
+ if (auto uses = globalOp.getSymbolUses(module.getOperation())) {
+ llvm::DenseSet<fir::StoreOp> allocmemStores;
+ llvm::DenseSet<fir::LoadOp> freememLoads;
+ for (auto &x : *uses)
+ if (auto addrOp = dyn_cast<fir::AddrOfOp>(x.getUser()))
+ findRelatedAllocmemFreemem(addrOp, allocmemStores, freememLoads);
+
+ for (auto storeOp : allocmemStores)
+ addMapInfo(globalOp, storeOp);
+
+ for (auto loadOp : freememLoads)
+ addMapInfo(globalOp, loadOp);
+ }
+ }
+ }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e315433..e0aebd0 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -1,6 +1,7 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FlangOpenMPTransforms
+ AutomapToTargetData.cpp
DoConcurrentConversion.cpp
FunctionFiltering.cpp
GenericLoopConversion.cpp
@@ -9,6 +10,7 @@ add_flang_library(FlangOpenMPTransforms
MarkDeclareTarget.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
+ SimdOnly.cpp
DEPENDS
FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index 2b3ac16..c928b76 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -173,9 +173,11 @@ public:
DoConcurrentConversion(
mlir::MLIRContext *context, bool mapToDevice,
- llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip)
+ llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
+ mlir::SymbolTable &moduleSymbolTable)
: OpConversionPattern(context), mapToDevice(mapToDevice),
- concurrentLoopsToSkip(concurrentLoopsToSkip) {}
+ concurrentLoopsToSkip(concurrentLoopsToSkip),
+ moduleSymbolTable(moduleSymbolTable) {}
mlir::LogicalResult
matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
@@ -332,8 +334,8 @@ private:
loop.getLocalVars(),
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionLocalArgs())) {
- auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom<
- fir::LocalitySpecifierOp>(loop, sym);
+ auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
+ sym.getLeafReference());
if (localizer.getLocalitySpecifierType() ==
fir::LocalitySpecifierType::LocalInit)
TODO(localizer.getLoc(),
@@ -352,6 +354,8 @@ private:
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
privatizer.getDeallocRegion());
+ moduleSymbolTable.insert(privatizer);
+
wsloopClauseOps.privateVars.push_back(op);
wsloopClauseOps.privateSyms.push_back(
mlir::SymbolRefAttr::get(privatizer));
@@ -362,28 +366,34 @@ private:
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionReduceArgs())) {
- auto firReducer =
- mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
- loop, sym);
+ auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
+ sym.getLeafReference());
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(firReducer);
-
- auto ompReducer = mlir::omp::DeclareReductionOp::create(
- rewriter, firReducer.getLoc(),
- sym.getLeafReference().str() + ".omp",
- firReducer.getTypeAttr().getValue());
-
- cloneFIRRegionToOMP(firReducer.getAllocRegion(),
- ompReducer.getAllocRegion());
- cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
- ompReducer.getInitializerRegion());
- cloneFIRRegionToOMP(firReducer.getReductionRegion(),
- ompReducer.getReductionRegion());
- cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
- ompReducer.getAtomicReductionRegion());
- cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
- ompReducer.getCleanupRegion());
+ std::string ompReducerName = sym.getLeafReference().str() + ".omp";
+
+ auto ompReducer =
+ moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
+ rewriter.getStringAttr(ompReducerName));
+
+ if (!ompReducer) {
+ ompReducer = mlir::omp::DeclareReductionOp::create(
+ rewriter, firReducer.getLoc(), ompReducerName,
+ firReducer.getTypeAttr().getValue());
+
+ cloneFIRRegionToOMP(firReducer.getAllocRegion(),
+ ompReducer.getAllocRegion());
+ cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
+ ompReducer.getInitializerRegion());
+ cloneFIRRegionToOMP(firReducer.getReductionRegion(),
+ ompReducer.getReductionRegion());
+ cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
+ ompReducer.getAtomicReductionRegion());
+ cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
+ ompReducer.getCleanupRegion());
+ moduleSymbolTable.insert(ompReducer);
+ }
wsloopClauseOps.reductionVars.push_back(op);
wsloopClauseOps.reductionByref.push_back(byRef);
@@ -431,6 +441,7 @@ private:
bool mapToDevice;
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
+ mlir::SymbolTable &moduleSymbolTable;
};
class DoConcurrentConversionPass
@@ -444,12 +455,9 @@ public:
: DoConcurrentConversionPassBase(options) {}
void runOnOperation() override {
- mlir::func::FuncOp func = getOperation();
-
- if (func.isDeclaration())
- return;
-
+ mlir::ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
+ mlir::SymbolTable moduleSymbolTable(module);
if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host &&
mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) {
@@ -463,7 +471,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.insert<DoConcurrentConversion>(
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
- concurrentLoopsToSkip);
+ concurrentLoopsToSkip, moduleSymbolTable);
mlir::ConversionTarget target(*context);
target.addDynamicallyLegalOp<fir::DoConcurrentOp>(
[&](fir::DoConcurrentOp op) {
@@ -472,8 +480,8 @@ public:
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
- if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
- std::move(patterns)))) {
+ if (mlir::failed(
+ mlir::applyFullConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index ae5c0ec..3031bb5 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -95,8 +95,9 @@ public:
return WalkResult::skip();
}
if (declareTargetOp)
- declareTargetOp.setDeclareTarget(declareType,
- omp::DeclareTargetCaptureClause::to);
+ declareTargetOp.setDeclareTarget(
+ declareType, omp::DeclareTargetCaptureClause::to,
+ declareTargetOp.getDeclareTargetAutomap());
}
return WalkResult::advance();
});
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 970f7d7..3032857 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -53,6 +53,7 @@ class MapsForPrivatizedSymbolsPass
: public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
MapsForPrivatizedSymbolsPass> {
+ // TODO Use `createMapInfoOp` from `flang/Utils/OpenMP.h`.
omp::MapInfoOp createMapInfo(Location loc, Value var,
fir::FirOpBuilder &builder) {
// Check if a value of type `type` can be passed to the kernel by value.
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index a7ffd5f..0b0e6bd 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -33,7 +33,7 @@ class MarkDeclareTargetPass
void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
mlir::omp::DeclareTargetCaptureClause parentCapClause,
- mlir::Operation *currOp,
+ bool parentAutomap, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
@@ -57,13 +57,16 @@ class MarkDeclareTargetPass
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(
mlir::omp::DeclareTargetDeviceType::any,
- current.getDeclareTargetCaptureClause());
+ current.getDeclareTargetCaptureClause(),
+ current.getDeclareTargetAutomap());
}
} else {
- current.setDeclareTarget(parentDevTy, parentCapClause);
+ current.setDeclareTarget(parentDevTy, parentCapClause,
+ parentAutomap);
}
- markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited);
+ markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
+ currFOp, visited);
}
}
}
@@ -81,7 +84,8 @@ class MarkDeclareTargetPass
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
- functionOp, visited);
+ declareTargetOp.getDeclareTargetAutomap(), functionOp,
+ visited);
}
}
@@ -92,9 +96,10 @@ class MarkDeclareTargetPass
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
- markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost,
- mlir::omp::DeclareTargetCaptureClause::to, tarOp,
- visited);
+ markNestedFuncs(
+ /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
+ /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
+ /*parentAutomap=*/false, tarOp, visited);
});
}
};
diff --git a/flang/lib/Optimizer/OpenMP/SimdOnly.cpp b/flang/lib/Optimizer/OpenMP/SimdOnly.cpp
new file mode 100644
index 0000000..4a559d2
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/SimdOnly.cpp
@@ -0,0 +1,209 @@
+//===-- SimdOnly.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_SIMDONLYPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+namespace {
+
+#define DEBUG_TYPE "omp-simd-only-pass"
+
+/// Rewrite and remove OpenMP operations left after the parse tree rewriting for
+/// -fopenmp-simd is done. If possible, OpenMP constructs should be rewritten at
+/// the parse tree stage. This pass is supposed to only handle complexities
+/// around untangling composite simd constructs, and perform the necessary
+/// cleanup.
+class SimdOnlyConversionPattern : public mlir::RewritePattern {
+public:
+ SimdOnlyConversionPattern(mlir::MLIRContext *ctx)
+ : mlir::RewritePattern(MatchAnyOpTypeTag{}, 1, ctx) {}
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::Operation *op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (op->getDialect()->getNamespace() !=
+ mlir::omp::OpenMPDialect::getDialectNamespace())
+ return rewriter.notifyMatchFailure(op, "Not an OpenMP op");
+
+ if (auto simdOp = mlir::dyn_cast<mlir::omp::SimdOp>(op)) {
+ // Remove the composite attr given that the op will no longer be composite
+ if (simdOp.isComposite()) {
+ simdOp.setComposite(false);
+ return mlir::success();
+ }
+
+ return rewriter.notifyMatchFailure(op, "Op is a plain SimdOp");
+ }
+
+ if (op->getParentOfType<mlir::omp::SimdOp>() &&
+ (mlir::isa<mlir::omp::YieldOp>(op) ||
+ mlir::isa<mlir::omp::ScanOp>(op) ||
+ mlir::isa<mlir::omp::LoopNestOp>(op) ||
+ mlir::isa<mlir::omp::TerminatorOp>(op)))
+ return rewriter.notifyMatchFailure(op, "Op is part of a simd construct");
+
+ if (!mlir::isa<mlir::func::FuncOp>(op->getParentOp()) &&
+ (mlir::isa<mlir::omp::TerminatorOp>(op) ||
+ mlir::isa<mlir::omp::YieldOp>(op)))
+ return rewriter.notifyMatchFailure(op,
+ "Non top-level yield or terminator");
+
+ LLVM_DEBUG(llvm::dbgs() << "SimdOnlyPass matched OpenMP op:\n");
+ LLVM_DEBUG(op->dump());
+
+ auto eraseUnlessUsedBySimd = [&](mlir::Operation *ompOp,
+ mlir::StringAttr name) {
+ if (auto uses =
+ mlir::SymbolTable::getSymbolUses(name, op->getParentOp())) {
+ for (auto &use : *uses)
+ if (mlir::isa<mlir::omp::SimdOp>(use.getUser()))
+ return rewriter.notifyMatchFailure(op,
+ "Op used by a simd construct");
+ }
+ rewriter.eraseOp(ompOp);
+ return mlir::success();
+ };
+
+ if (auto ompOp = mlir::dyn_cast<mlir::omp::PrivateClauseOp>(op))
+ return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
+ if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareReductionOp>(op))
+ return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr());
+
+ // Might be left over from rewriting composite simd with target map
+ if (mlir::isa<mlir::omp::MapBoundsOp>(op)) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+ if (auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(op)) {
+ rewriter.replaceOp(mapInfoOp, {mapInfoOp.getVarPtr()});
+ return mlir::success();
+ }
+
+ // Might be leftover after parse tree rewriting
+ if (auto threadPrivateOp = mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op)) {
+ rewriter.replaceOp(threadPrivateOp, {threadPrivateOp.getSymAddr()});
+ return mlir::success();
+ }
+
+ fir::FirOpBuilder builder(rewriter, op);
+ mlir::Location loc = op->getLoc();
+
+ auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
+ if (!ompOp)
+ return false;
+
+ assert("OpenMP operation has one region" && ompOp->getNumRegions() == 1);
+
+ llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>>
+ blockArgsPairs;
+ if (auto iface =
+ mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) {
+ iface.getBlockArgsPairs(blockArgsPairs);
+ for (auto [value, argument] : blockArgsPairs)
+ rewriter.replaceAllUsesWith(argument, value);
+ }
+
+ if (ompOp->getRegion(0).getBlocks().size() == 1) {
+ auto &block = *ompOp->getRegion(0).getBlocks().begin();
+ // This block is about to be removed so any arguments should have been
+ // replaced by now.
+ block.eraseArguments(0, block.getNumArguments());
+ if (auto terminatorOp =
+ mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back())) {
+ rewriter.eraseOp(terminatorOp);
+ }
+ rewriter.inlineBlockBefore(&block, ompOp, {});
+ } else {
+ // When dealing with multi-block regions we need to fix up the control
+ // flow
+ auto *origBlock = ompOp->getBlock();
+ auto *newBlock = rewriter.splitBlock(origBlock, ompOp->getIterator());
+ auto *innerFrontBlock = &ompOp->getRegion(0).getBlocks().front();
+ builder.setInsertionPointToEnd(origBlock);
+ mlir::cf::BranchOp::create(builder, loc, innerFrontBlock);
+ // We are no longer passing any arguments to the first block in the
+ // region, so this should be safe to erase.
+ innerFrontBlock->eraseArguments(0, innerFrontBlock->getNumArguments());
+
+ for (auto &innerBlock : ompOp->getRegion(0).getBlocks()) {
+ // Remove now-unused block arguments
+ for (auto arg : innerBlock.getArguments()) {
+ if (arg.getUses().empty())
+ innerBlock.eraseArgument(arg.getArgNumber());
+ }
+ if (auto terminatorOp =
+ mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back())) {
+ builder.setInsertionPointToEnd(&innerBlock);
+ mlir::cf::BranchOp::create(builder, loc, newBlock);
+ rewriter.eraseOp(terminatorOp);
+ }
+ }
+
+ rewriter.inlineRegionBefore(ompOp->getRegion(0), newBlock);
+ }
+
+ rewriter.eraseOp(op);
+ return true;
+ };
+
+ // Remove ops that will be surrounding simd once a composite simd construct
+ // goes through the codegen stage. All of the other ones should have alredy
+ // been removed in the parse tree rewriting stage.
+ if (inlineSimpleOp(mlir::dyn_cast<mlir::omp::TeamsOp>(op)) ||
+ inlineSimpleOp(mlir::dyn_cast<mlir::omp::ParallelOp>(op)) ||
+ inlineSimpleOp(mlir::dyn_cast<mlir::omp::TargetOp>(op)) ||
+ inlineSimpleOp(mlir::dyn_cast<mlir::omp::WsloopOp>(op)) ||
+ inlineSimpleOp(mlir::dyn_cast<mlir::omp::DistributeOp>(op)))
+ return mlir::success();
+
+ op->emitOpError("left unhandled after SimdOnly pass.");
+ return mlir::failure();
+ }
+};
+
+class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> {
+
+public:
+ SimdOnlyPass() = default;
+
+ void runOnOperation() override {
+ mlir::ModuleOp module = getOperation();
+
+ mlir::MLIRContext *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ patterns.insert<SimdOnlyConversionPattern>(context);
+
+ mlir::GreedyRewriteConfig config;
+ // Prevent the pattern driver from merging blocks.
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
+
+ if (mlir::failed(
+ mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
+ mlir::emitError(module.getLoc(), "Error in SimdOnly conversion pass");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index ca8e8206..7c2777b 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -14,7 +14,7 @@
/// Force setting the no-alias attribute on fuction arguments when possible.
static llvm::cl::opt<bool> forceNoAlias("force-no-alias", llvm::cl::Hidden,
- llvm::cl::init(false));
+ llvm::cl::init(true));
namespace fir {
@@ -217,9 +217,6 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(fir::createSimplifyFIROperations(
{/*preferInlineImplementation=*/pc.OptLevel.isOptimizingForSpeed()}));
- if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
- pm.addPass(fir::createAddAliasTags());
-
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, fir::createStackReclaim);
// convert control flow to CFG form
@@ -242,7 +239,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
/// \param pm - MLIR pass manager that will hold the pipeline definition
/// \param optLevel - optimization level used for creating FIR optimization
/// passes pipeline
-void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
+void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
+ EnableOpenMP enableOpenMP,
llvm::OptimizationLevel optLevel) {
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
@@ -294,8 +292,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, hlfir::createInlineHLFIRAssign);
pm.addPass(hlfir::createConvertHLFIRtoFIR());
- if (enableOpenMP)
+ if (enableOpenMP != EnableOpenMP::None)
pm.addPass(flangomp::createLowerWorkshare());
+ if (enableOpenMP == EnableOpenMP::Simd)
+ pm.addPass(flangomp::createSimdOnlyPass());
}
/// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -316,13 +316,13 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
pm.addPass(flangomp::createDoConcurrentConversionPass(
opts.doConcurrentMappingKind == DoConcurrentMappingKind::DCMK_Device));
- // The MapsForPrivatizedSymbols pass needs to run before
- // MapInfoFinalizationPass because the former creates new
- // MapInfoOp instances, typically for descriptors.
- // MapInfoFinalizationPass adds MapInfoOp instances for the descriptors
- // underlying data which is necessary to access the data on the offload
- // target device.
+ // The MapsForPrivatizedSymbols and AutomapToTargetDataPass pass need to run
+ // before MapInfoFinalizationPass because they create new MapInfoOp
+ // instances, typically for descriptors. MapInfoFinalizationPass adds
+ // MapInfoOp instances for the descriptors underlying data which is necessary
+ // to access the data on the offload target device.
pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
+ pm.addPass(flangomp::createAutomapToTargetDataPass());
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
pm.addPass(flangomp::createGenericLoopConversionPass());
@@ -342,6 +342,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
MLIRToLLVMPassPipelineConfig config,
llvm::StringRef inputFilename) {
fir::addBoxedProcedurePass(pm);
+ if (config.OptLevel.isOptimizingForSpeed() && config.AliasAnalysis &&
+ !disableFirAliasTags && !useOldAliasTags)
+ pm.addPass(fir::createAddAliasTags());
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, fir::createAbstractResultOpt);
addPassToGPUModuleOperations<PassConstructor>(pm,
@@ -396,7 +399,12 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
MLIRToLLVMPassPipelineConfig &config,
llvm::StringRef inputFilename) {
- fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel);
+ fir::EnableOpenMP enableOpenMP = fir::EnableOpenMP::None;
+ if (config.EnableOpenMP)
+ enableOpenMP = fir::EnableOpenMP::Full;
+ if (config.EnableOpenMPSimd)
+ enableOpenMP = fir::EnableOpenMP::Simd;
+ fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP, config.OptLevel);
// Add default optimizer pass pipeline.
fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp
index 5d663e2..c71642c 100644
--- a/flang/lib/Optimizer/Support/Utils.cpp
+++ b/flang/lib/Optimizer/Support/Utils.cpp
@@ -50,3 +50,74 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault(
return componentInfo.getLowerBounds();
return std::nullopt;
}
+
+mlir::LLVM::ConstantOp
+fir::genConstantIndex(mlir::Location loc, mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::int64_t offset) {
+ auto cattr = rewriter.getI64IntegerAttr(offset);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+}
+
+mlir::Value
+fir::computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
+ mlir::Type idxTy,
+ mlir::ConversionPatternRewriter &rewriter,
+ const mlir::DataLayout &dataLayout) {
+ llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
+ unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
+ std::int64_t distance = llvm::alignTo(size, alignment);
+ return fir::genConstantIndex(loc, idxTy, rewriter, distance);
+}
+
+mlir::Value
+fir::genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy,
+ mlir::Type ity,
+ mlir::ConversionPatternRewriter &rewriter) {
+ auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
+ fir::SequenceType::Extent constSize = 1;
+ if (seqTy) {
+ int constRows = seqTy.getConstantRows();
+ const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
+ if (constRows != static_cast<int>(shape.size())) {
+ for (auto extent : shape) {
+ if (constRows-- > 0)
+ continue;
+ if (extent != fir::SequenceType::getUnknownExtent())
+ constSize *= extent;
+ }
+ }
+ }
+
+ if (constSize != 1) {
+ mlir::Value constVal{
+ fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()};
+ return constVal;
+ }
+ return nullptr;
+}
+
+mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter,
+ mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type ty, mlir::Value val, bool fold) {
+ auto valTy = val.getType();
+ // If the value was not yet lowered, lower its type so that it can
+ // be used in getPrimitiveTypeSizeInBits.
+ if (!mlir::isa<mlir::IntegerType>(valTy))
+ valTy = converter.convertType(valTy);
+ auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+ auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
+ if (fold) {
+ if (toSize < fromSize)
+ return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
+ } else {
+ if (toSize < fromSize)
+ return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+ if (toSize > fromSize)
+ return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+ }
+ return val;
+}
diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
index f1c66a5..430ef62 100644
--- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
@@ -117,10 +117,7 @@ public:
op.getValue());
return success();
}
- rewriter.startOpModification(op->getParentOp());
- op.getResult().replaceAllUsesWith(op.getValue());
- rewriter.finalizeOpModification(op->getParentOp());
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, op.getValue());
}
return success();
}
diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index b032767..061a7d2 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -25,7 +25,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Visitors.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp,
}
static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
- for (auto &bodyOp : block->getOperations()) {
+ for (auto &bodyOp : llvm::make_early_inc_range(block->getOperations())) {
if (isa<fir::LoadOp>(bodyOp))
rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
- if (isa<fir::StoreOp>(bodyOp))
+ else if (isa<fir::StoreOp>(bodyOp))
rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
}
}
@@ -476,6 +476,8 @@ public:
loop.dump(););
LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
functionAnalysis.getChildLoopAnalysis(loop);
+ if (!loopAnalysis.canPromoteToAffine())
+ return rewriter.notifyMatchFailure(loop, "cannot promote to affine");
auto &loopOps = loop.getBody()->getOperations();
auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator());
auto results = resultOp.getOperands();
@@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
public:
using OpRewritePattern::OpRewritePattern;
AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
- : OpRewritePattern(context) {}
+ : OpRewritePattern(context), functionAnalysis(afa) {}
llvm::LogicalResult
matchAndRewrite(fir::IfOp op,
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
op.dump(););
+ if (!functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine())
+ return rewriter.notifyMatchFailure(op, "cannot promote to affine");
auto &ifOps = op.getThenRegion().front().getOperations();
auto affineCondition = AffineIfCondition(op.getCondition());
if (!affineCondition.hasIntegerSet()) {
@@ -611,6 +615,8 @@ public:
rewriter.replaceOp(op, affineIf.getOperation()->getResults());
return success();
}
+
+ AffineFunctionAnalysis &functionAnalysis;
};
/// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
@@ -627,28 +633,11 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.insert<AffineIfConversion>(context, functionAnalysis);
patterns.insert<AffineLoopConversion>(context, functionAnalysis);
- mlir::ConversionTarget target = *context;
- target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect,
- mlir::scf::SCFDialect, mlir::arith::ArithDialect,
- mlir::func::FuncDialect>();
- target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
- return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
- });
- target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
- fir::DoLoopOp op) {
- return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
- });
-
LLVM_DEBUG(llvm::dbgs()
<< "AffineDialectPromotion: running promotion on: \n";
function.print(llvm::dbgs()););
// apply the patterns
- if (mlir::failed(mlir::applyPartialConversion(function, target,
- std::move(patterns)))) {
- mlir::emitError(mlir::UnknownLoc::get(context),
- "error in converting to affine dialect\n");
- signalPassFailure();
- }
+ walkAndApplyPatterns(function, std::move(patterns));
}
};
} // namespace
diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
index 247ba95..ed9a2ae 100644
--- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
+++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
@@ -1264,7 +1264,6 @@ public:
auto lhsEltRefType = toRefType(update.getMerge().getType());
auto [_, lhsLoadResult] = materializeAssignment(
loc, rewriter, update, assignElement, lhsEltRefType);
- update.replaceAllUsesWith(lhsLoadResult);
rewriter.replaceOp(update, lhsLoadResult);
return mlir::success();
}
@@ -1287,7 +1286,6 @@ public:
auto lhsEltRefType = modify.getResult(0).getType();
auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
loc, rewriter, modify, assignElement, lhsEltRefType);
- modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
return mlir::success();
}
@@ -1339,7 +1337,6 @@ public:
// This array_access is associated with an array_amend and there is a
// conflict. Make a copy to store into.
auto result = referenceToClone(loc, rewriter, access);
- access.replaceAllUsesWith(result);
rewriter.replaceOp(access, result);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp b/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp
index 5e910f7..6e04c71 100644
--- a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp
@@ -38,6 +38,15 @@ using namespace Fortran::runtime::cuda;
namespace {
+static bool isAssumedSize(mlir::ValueRange shape) {
+ if (shape.size() != 1)
+ return false;
+ std::optional<std::int64_t> val = fir::getIntIfConstant(shape[0]);
+ if (val && *val == -1)
+ return true;
+ return false;
+}
+
struct CUFComputeSharedMemoryOffsetsAndSize
: public fir::impl::CUFComputeSharedMemoryOffsetsAndSizeBase<
CUFComputeSharedMemoryOffsetsAndSize> {
@@ -82,12 +91,12 @@ struct CUFComputeSharedMemoryOffsetsAndSize
alignment = std::max(alignment, align);
uint64_t tySize = dl->getTypeSize(ty);
++nbDynamicSharedVariables;
- if (crtDynOffset) {
- sharedOp.getOffsetMutable().assign(
- builder.createConvert(loc, i32Ty, crtDynOffset));
- } else {
+ if (isAssumedSize(sharedOp.getShape()) || !crtDynOffset) {
mlir::Value zero = builder.createIntegerConstant(loc, i32Ty, 0);
sharedOp.getOffsetMutable().assign(zero);
+ } else {
+ sharedOp.getOffsetMutable().assign(
+ builder.createConvert(loc, i32Ty, crtDynOffset));
}
mlir::Value dynSize =
diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
index 5dcb54e..d038c46 100644
--- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
+++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
@@ -178,8 +178,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType(
context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr,
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy,
mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
- elements, dataLocation, rank, /*allocated=*/nullptr,
- /*associated=*/nullptr);
+ dataLocation, rank, /*allocated=*/nullptr,
+ /*associated=*/nullptr, elements);
}
addOp(llvm::dwarf::DW_OP_push_object_address, {});
@@ -255,8 +255,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType(
return mlir::LLVM::DICompositeTypeAttr::get(
context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr,
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy,
- mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements,
- dataLocation, /*rank=*/nullptr, allocated, associated);
+ mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
+ dataLocation, /*rank=*/nullptr, allocated, associated, elements);
}
std::pair<std::uint64_t, unsigned short>
@@ -389,8 +389,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
context, recId, /*isRecSelf=*/true, llvm::dwarf::DW_TAG_structure_type,
mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope,
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0,
- /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
- /*allocated=*/nullptr, /*associated=*/nullptr);
+ /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr,
+ /*allocated=*/nullptr, /*associated=*/nullptr, elements);
DerivedTypeCache::ActiveLevels nestedRecursions =
derivedTypeCache.startTranslating(Ty, placeHolder);
@@ -429,8 +429,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr,
convertType(seqTy.getEleTy(), fileAttr, scope, declOp),
mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
- arrayElements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
- /*allocated=*/nullptr, /*associated=*/nullptr);
+ /*dataLocation=*/nullptr, /*rank=*/nullptr,
+ /*allocated=*/nullptr, /*associated=*/nullptr, arrayElements);
} else
elemTy = convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr);
offset = llvm::alignTo(offset, byteAlign);
@@ -448,8 +448,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
context, recId, /*isRecSelf=*/false, llvm::dwarf::DW_TAG_structure_type,
mlir::StringAttr::get(context, sourceName.name), fileAttr, line, scope,
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
- /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
- /*allocated=*/nullptr, /*associated=*/nullptr);
+ /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr,
+ /*allocated=*/nullptr, /*associated=*/nullptr, elements);
derivedTypeCache.finalize(Ty, finalAttr, std::move(nestedRecursions));
@@ -490,8 +490,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
context, llvm::dwarf::DW_TAG_structure_type,
mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope,
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
- /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
- /*allocated=*/nullptr, /*associated=*/nullptr);
+ /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr,
+ /*allocated=*/nullptr, /*associated=*/nullptr, elements);
derivedTypeCache.finalize(Ty, typeAttr, std::move(nestedRecursions));
return typeAttr;
}
@@ -554,9 +554,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertSequenceType(
return mlir::LLVM::DICompositeTypeAttr::get(
context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr,
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy,
- mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements,
+ mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
/*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr,
- /*associated=*/nullptr);
+ /*associated=*/nullptr, elements);
}
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType(
@@ -587,9 +587,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType(
context, llvm::dwarf::DW_TAG_array_type,
mlir::StringAttr::get(context, name),
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy,
- mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0, elements,
+ mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0,
/*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr,
- /*associated=*/nullptr);
+ /*associated=*/nullptr, elements);
}
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertCharacterType(
diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
index 2fcff87..031a5ae 100644
--- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
@@ -76,12 +76,49 @@ void ExternalNameConversionPass::runOnOperation() {
auto *context = &getContext();
llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
+ mlir::SymbolTable symbolTable(op);
auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) {
auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
mlir::SymbolTable::getSymbolAttrName());
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
+ // Check if this is a private function that would conflict with a common
+ // block and get its mangled name.
+ if (auto funcOp = llvm::dyn_cast<mlir::func::FuncOp>(funcOrGlobal)) {
+ if (funcOp.isPrivate()) {
+ std::string mangledName =
+ mangleExternalName(deconstructedName, appendUnderscoreOpt);
+ auto mod = funcOp->getParentOfType<mlir::ModuleOp>();
+ bool hasConflictingCommonBlock = false;
+
+ // Check if any existing global has the same mangled name.
+ if (symbolTable.lookup<fir::GlobalOp>(mangledName))
+ hasConflictingCommonBlock = true;
+
+ // Skip externalization if the function has a conflicting common block
+ // and is not directly called (i.e. procedure pointers or type
+ // specifications)
+ if (hasConflictingCommonBlock) {
+ bool isDirectlyCalled = false;
+ std::optional<SymbolTable::UseRange> uses =
+ funcOp.getSymbolUses(mod);
+ if (uses.has_value()) {
+ for (auto use : *uses) {
+ mlir::Operation *user = use.getUser();
+ if (mlir::isa<fir::CallOp>(user) ||
+ mlir::isa<mlir::func::CallOp>(user)) {
+ isDirectlyCalled = true;
+ break;
+ }
+ }
+ }
+ if (!isDirectlyCalled)
+ return;
+ }
+ }
+ }
+
auto newName = mangleExternalName(deconstructedName, appendUnderscoreOpt);
auto newAttr = mlir::StringAttr::get(context, newName);
mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 1902757..70d6ebb 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -9,36 +9,34 @@
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace fir {
#define GEN_PASS_DEF_FIRTOSCFPASS
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
-using namespace fir;
-using namespace mlir;
-
namespace {
class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
public:
void runOnOperation() override;
};
-struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
+struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
- PatternRewriter &rewriter) const override {
- auto loc = doLoopOp.getLoc();
+ mlir::LogicalResult
+ matchAndRewrite(fir::DoLoopOp doLoopOp,
+ mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = doLoopOp.getLoc();
bool hasFinalValue = doLoopOp.getFinalValue().has_value();
// Get loop values from the DoLoopOp
- auto low = doLoopOp.getLowerBound();
- auto high = doLoopOp.getUpperBound();
+ mlir::Value low = doLoopOp.getLowerBound();
+ mlir::Value high = doLoopOp.getUpperBound();
assert(low && high && "must be a Value");
- auto step = doLoopOp.getStep();
- llvm::SmallVector<Value> iterArgs;
+ mlir::Value step = doLoopOp.getStep();
+ mlir::SmallVector<mlir::Value> iterArgs;
if (hasFinalValue)
iterArgs.push_back(low);
iterArgs.append(doLoopOp.getIterOperands().begin(),
@@ -49,31 +47,33 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
// must be a positive value.
// For easier conversion, we calculate the trip count and use a canonical
// induction variable.
- auto diff = arith::SubIOp::create(rewriter, loc, high, low);
- auto distance = arith::AddIOp::create(rewriter, loc, diff, step);
- auto tripCount = arith::DivSIOp::create(rewriter, loc, distance, step);
- auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
- auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto diff = mlir::arith::SubIOp::create(rewriter, loc, high, low);
+ auto distance = mlir::arith::AddIOp::create(rewriter, loc, diff, step);
+ auto tripCount =
+ mlir::arith::DivSIOp::create(rewriter, loc, distance, step);
+ auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1);
auto scfForOp =
- scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
+ mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
auto &loopOps = doLoopOp.getBody()->getOperations();
- auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
+ auto resultOp =
+ mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
auto results = resultOp.getOperands();
- Block *loweredBody = scfForOp.getBody();
+ mlir::Block *loweredBody = scfForOp.getBody();
loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
loopOps.begin(),
std::prev(loopOps.end()));
rewriter.setInsertionPointToStart(loweredBody);
- Value iv =
- arith::MulIOp::create(rewriter, loc, scfForOp.getInductionVar(), step);
- iv = arith::AddIOp::create(rewriter, loc, low, iv);
+ mlir::Value iv = mlir::arith::MulIOp::create(
+ rewriter, loc, scfForOp.getInductionVar(), step);
+ iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv);
if (!results.empty()) {
rewriter.setInsertionPointToEnd(loweredBody);
- scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
+ mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
}
doLoopOp.getInductionVar().replaceAllUsesWith(iv);
rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
@@ -84,34 +84,103 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
// Copy all the attributes from the old to new op.
scfForOp->setAttrs(doLoopOp->getAttrs());
rewriter.replaceOp(doLoopOp, scfForOp);
- return success();
+ return mlir::success();
+ }
+};
+
+struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
+ using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::IterWhileOp iterWhileOp,
+ mlir::PatternRewriter &rewriter) const override {
+
+ mlir::Location loc = iterWhileOp.getLoc();
+ mlir::Value lowerBound = iterWhileOp.getLowerBound();
+ mlir::Value upperBound = iterWhileOp.getUpperBound();
+ mlir::Value step = iterWhileOp.getStep();
+
+ mlir::Value okInit = iterWhileOp.getIterateIn();
+ mlir::ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+ mlir::SmallVector<mlir::Value> initVals;
+ initVals.push_back(lowerBound);
+ initVals.push_back(okInit);
+ initVals.append(iterArgs.begin(), iterArgs.end());
+
+ mlir::SmallVector<mlir::Type> loopTypes;
+ loopTypes.push_back(lowerBound.getType());
+ loopTypes.push_back(okInit.getType());
+ for (auto val : iterArgs)
+ loopTypes.push_back(val.getType());
+
+ auto scfWhileOp =
+ mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+
+ auto &beforeBlock = *rewriter.createBlock(
+ &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
+ mlir::SmallVector<mlir::Location>(loopTypes.size(), loc));
+
+ mlir::Region::BlockArgListType argsInBefore =
+ scfWhileOp.getBefore().getArguments();
+ auto ivInBefore = argsInBefore[0];
+ auto earlyExitInBefore = argsInBefore[1];
+
+ rewriter.setInsertionPointToStart(&beforeBlock);
+
+ mlir::Value inductionCmp = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound);
+ mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp,
+ earlyExitInBefore);
+
+ mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore);
+
+ rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(),
+ scfWhileOp.getAfter().begin());
+
+ auto *afterBody = scfWhileOp.getAfterBody();
+ auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator());
+ mlir::SmallVector<mlir::Value> results(resultOp->getOperands());
+ mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0];
+
+ rewriter.setInsertionPointToStart(afterBody);
+ results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step);
+
+ rewriter.setInsertionPointToEnd(afterBody);
+ rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results);
+
+ scfWhileOp->setAttrs(iterWhileOp->getAttrs());
+ rewriter.replaceOp(iterWhileOp, scfWhileOp);
+ return mlir::success();
}
};
-void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
- Block &dstBlock) {
- Operation *srcTerminator = srcBlock.getTerminator();
- auto resultOp = cast<fir::ResultOp>(srcTerminator);
+void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter,
+ mlir::Block &srcBlock, mlir::Block &dstBlock) {
+ mlir::Operation *srcTerminator = srcBlock.getTerminator();
+ auto resultOp = mlir::cast<fir::ResultOp>(srcTerminator);
dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(),
srcBlock.begin(), std::prev(srcBlock.end()));
if (!resultOp->getOperands().empty()) {
rewriter.setInsertionPointToEnd(&dstBlock);
- scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands());
+ mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(),
+ resultOp->getOperands());
}
rewriter.eraseOp(srcTerminator);
}
-struct IfConversion : public OpRewritePattern<fir::IfOp> {
+struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(fir::IfOp ifOp,
- PatternRewriter &rewriter) const override {
+ mlir::LogicalResult
+ matchAndRewrite(fir::IfOp ifOp,
+ mlir::PatternRewriter &rewriter) const override {
bool hasElse = !ifOp.getElseRegion().empty();
auto scfIfOp =
- scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(),
- ifOp.getCondition(), hasElse);
+ mlir::scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(),
+ ifOp.getCondition(), hasElse);
copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(),
scfIfOp.getThenRegion().front());
@@ -123,22 +192,18 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
scfIfOp->setAttrs(ifOp->getAttrs());
rewriter.replaceOp(ifOp, scfIfOp);
- return success();
+ return mlir::success();
}
};
} // namespace
void FIRToSCFPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
- ConversionTarget target(getContext());
- target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
- target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- if (failed(
- applyPartialConversion(getOperation(), target, std::move(patterns))))
- signalPassFailure();
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+ patterns.getContext());
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
}
-std::unique_ptr<Pass> fir::createFIRToSCFPass() {
+std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
return std::make_unique<FIRToSCFPass>();
}
diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
index 5ac4ed8..9dfe26cb 100644
--- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
+++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
@@ -95,10 +95,6 @@ void FunctionAttrPass::runOnOperation() {
func->setAttr(
mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName),
mlir::BoolAttr::get(context, true));
- if (approxFuncFPMath)
- func->setAttr(
- mlir::LLVM::LLVMFuncOp::getApproxFuncFpMathAttrName(llvmFuncOpName),
- mlir::BoolAttr::get(context, true));
if (noSignedZerosFPMath)
func->setAttr(
mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName),
diff --git a/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp b/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp
index 1688f28..68f5b5a 100644
--- a/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp
+++ b/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp
@@ -26,6 +26,8 @@ namespace fir {
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
+#define DEBUG_TYPE "optimize-array-repacking"
+
namespace {
class OptimizeArrayRepackingPass
: public fir::impl::OptimizeArrayRepackingBase<OptimizeArrayRepackingPass> {
@@ -56,8 +58,7 @@ PackingOfContiguous::matchAndRewrite(fir::PackArrayOp op,
mlir::PatternRewriter &rewriter) const {
mlir::Value box = op.getArray();
if (hlfir::isSimplyContiguous(box, !op.getInnermost())) {
- rewriter.replaceAllUsesWith(op, box);
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, box);
return mlir::success();
}
return mlir::failure();
@@ -78,13 +79,19 @@ void OptimizeArrayRepackingPass::runOnOperation() {
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config
+ .setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled)
+ // Traverse the operations top-down, so that fir.pack_array
+ // operations are optimized before their using fir.pack_array
+ // operations. This way the rewrite may converge faster.
+ .setUseTopDownTraversal();
patterns.insert<PackingOfContiguous>(context);
patterns.insert<NoopUnpacking>(context);
if (mlir::failed(
mlir::applyPatternsGreedily(funcOp, std::move(patterns), config))) {
- mlir::emitError(funcOp.getLoc(), "failure in array repacking optimization");
- signalPassFailure();
+ // Failure may happen if the rewriter does not converge soon enough.
+ // That is not an error, so just report a diagnostic under debug.
+ LLVM_DEBUG(mlir::emitError(funcOp.getLoc(),
+ "failure in array repacking optimization"));
}
}
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index c6aec96..03f97eb 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -210,19 +210,33 @@ public:
mapper.map(region.getArguments(), regionArgs);
for (mlir::Operation &op : region.front().without_terminator())
(void)rewriter.clone(op, mapper);
+
+ auto yield = mlir::cast<fir::YieldOp>(region.front().getTerminator());
+ assert(yield.getResults().size() < 2);
+
+ return yield.getResults().empty()
+ ? mlir::Value{}
+ : mapper.lookup(yield.getResults()[0]);
};
- if (!localizer.getInitRegion().empty())
- cloneLocalizerRegion(localizer.getInitRegion(), {localVar, localArg},
- rewriter.getInsertionPoint());
+ if (!localizer.getInitRegion().empty()) {
+ // Prefer the value yielded from the init region to the allocated
+ // private variable in case the region is operating on arguments
+ // by-value (e.g. Fortran character boxes).
+ localAlloc = cloneLocalizerRegion(localizer.getInitRegion(),
+ {localVar, localAlloc},
+ rewriter.getInsertionPoint());
+ assert(localAlloc);
+ }
if (localizer.getLocalitySpecifierType() ==
fir::LocalitySpecifierType::LocalInit)
- cloneLocalizerRegion(localizer.getCopyRegion(), {localVar, localArg},
+ cloneLocalizerRegion(localizer.getCopyRegion(),
+ {localVar, localAlloc},
rewriter.getInsertionPoint());
if (!localizer.getDeallocRegion().empty())
- cloneLocalizerRegion(localizer.getDeallocRegion(), {localArg},
+ cloneLocalizerRegion(localizer.getDeallocRegion(), {localAlloc},
rewriter.getInsertionBlock()->end());
rewriter.replaceAllUsesWith(localArg, localAlloc);
diff --git a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp
index 7d1f86f..0cd2858 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp
@@ -26,22 +26,16 @@ class SimplifyRegionLitePass
public:
void runOnOperation() override;
};
-
-class DummyRewriter : public mlir::PatternRewriter {
-public:
- DummyRewriter(mlir::MLIRContext *ctx) : mlir::PatternRewriter(ctx) {}
-};
-
} // namespace
void SimplifyRegionLitePass::runOnOperation() {
auto op = getOperation();
auto regions = op->getRegions();
mlir::RewritePatternSet patterns(op.getContext());
- DummyRewriter rewriter(op.getContext());
if (regions.empty())
return;
+ mlir::PatternRewriter rewriter(op.getContext());
(void)mlir::eraseUnreachableBlocks(rewriter, regions);
(void)mlir::runRegionDCE(rewriter, regions);
}
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 0d13129..80b3f68 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -600,10 +600,7 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
// replace references to heap allocation with references to stack allocation
mlir::Value newValue = convertAllocationType(
rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
- rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
-
- // remove allocmem operation
- rewriter.eraseOp(allocmem.getOperation());
+ rewriter.replaceOp(allocmem, newValue);
return mlir::success();
}
@@ -813,10 +810,10 @@ void AllocMemConversion::insertLifetimeMarkers(
mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(oldAlloc);
mlir::Value ptr = fir::factory::genLifetimeStart(
- rewriter, newAlloc.getLoc(), newAlloc, *size, &*dl);
+ rewriter, newAlloc.getLoc(), newAlloc, &*dl);
visitFreeMemOp(oldAlloc, [&](mlir::Operation *op) {
rewriter.setInsertionPoint(op);
- fir::factory::genLifetimeEnd(rewriter, op->getLoc(), ptr, *size);
+ fir::factory::genLifetimeEnd(rewriter, op->getLoc(), ptr);
});
newAlloc->setAttr(attrName, rewriter.getUnitAttr());
}
diff --git a/flang/lib/Parser/CMakeLists.txt b/flang/lib/Parser/CMakeLists.txt
index 1855b8a..20c6c2a 100644
--- a/flang/lib/Parser/CMakeLists.txt
+++ b/flang/lib/Parser/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FortranParser
message.cpp
openacc-parsers.cpp
openmp-parsers.cpp
+ openmp-utils.cpp
parse-tree.cpp
parsing.cpp
preprocessor.cpp
diff --git a/flang/lib/Parser/characters.cpp b/flang/lib/Parser/characters.cpp
index f6ac777..1a00b16 100644
--- a/flang/lib/Parser/characters.cpp
+++ b/flang/lib/Parser/characters.cpp
@@ -289,7 +289,8 @@ RESULT DecodeString(const std::string &s, bool backslashEscapes) {
DecodeCharacter<ENCODING>(p, bytes, backslashEscapes)};
if (decoded.bytes > 0) {
if (static_cast<std::size_t>(decoded.bytes) <= bytes) {
- result.append(1, decoded.codepoint);
+ result.append(
+ 1, static_cast<typename RESULT::value_type>(decoded.codepoint));
bytes -= decoded.bytes;
p += decoded.bytes;
continue;
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 84d1e81..ce46a86 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -469,6 +469,9 @@ TYPE_PARSER(sourced(construct<OmpContextSelectorSpecification>(
// --- Parsers for clause modifiers -----------------------------------
+TYPE_PARSER(construct<OmpAccessGroup>( //
+ "CGROUP" >> pure(OmpAccessGroup::Value::Cgroup)))
+
TYPE_PARSER(construct<OmpAlignment>(scalarIntExpr))
TYPE_PARSER(construct<OmpAlignModifier>( //
@@ -573,7 +576,8 @@ TYPE_PARSER(construct<OmpOrderingModifier>(
"SIMD" >> pure(OmpOrderingModifier::Value::Simd)))
TYPE_PARSER(construct<OmpPrescriptiveness>(
- "STRICT" >> pure(OmpPrescriptiveness::Value::Strict)))
+ "STRICT" >> pure(OmpPrescriptiveness::Value::Strict) ||
+ "FALLBACK" >> pure(OmpPrescriptiveness::Value::Fallback)))
TYPE_PARSER(construct<OmpPresentModifier>( //
"PRESENT" >> pure(OmpPresentModifier::Value::Present)))
@@ -636,6 +640,12 @@ TYPE_PARSER(sourced(construct<OmpDependClause::TaskDep::Modifier>(sourced(
construct<OmpDependClause::TaskDep::Modifier>(
Parser<OmpTaskDependenceType>{})))))
+TYPE_PARSER( //
+ sourced(construct<OmpDynGroupprivateClause::Modifier>(
+ Parser<OmpAccessGroup>{})) ||
+ sourced(construct<OmpDynGroupprivateClause::Modifier>(
+ Parser<OmpPrescriptiveness>{})))
+
TYPE_PARSER(
sourced(construct<OmpDeviceClause::Modifier>(Parser<OmpDeviceModifier>{})))
@@ -777,6 +787,10 @@ TYPE_PARSER(construct<OmpDefaultClause>(
Parser<OmpDefaultClause::DataSharingAttribute>{}) ||
construct<OmpDefaultClause>(indirect(Parser<OmpDirectiveSpecification>{}))))
+TYPE_PARSER(construct<OmpDynGroupprivateClause>(
+ maybe(nonemptyList(Parser<OmpDynGroupprivateClause::Modifier>{}) / ":"),
+ scalarIntExpr))
+
TYPE_PARSER(construct<OmpEnterClause>(
maybe(nonemptyList(Parser<OmpEnterClause::Modifier>{}) / ":"),
Parser<OmpObjectList>{}))
@@ -1068,6 +1082,9 @@ TYPE_PARSER( //
construct<OmpClause>(parenthesized(Parser<OmpDoacrossClause>{})) ||
"DYNAMIC_ALLOCATORS" >>
construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) ||
+ "DYN_GROUPPRIVATE" >>
+ construct<OmpClause>(construct<OmpClause::DynGroupprivate>(
+ parenthesized(Parser<OmpDynGroupprivateClause>{}))) ||
"ENTER" >> construct<OmpClause>(construct<OmpClause::Enter>(
parenthesized(Parser<OmpEnterClause>{}))) ||
"EXCLUSIVE" >> construct<OmpClause>(construct<OmpClause::Exclusive>(
@@ -1264,6 +1281,16 @@ static bool IsFortranBlockConstruct(const ExecutionPartConstruct &epc) {
}
}
+static bool IsStandaloneOrdered(const OmpDirectiveSpecification &dirSpec) {
+ // An ORDERED construct is standalone if it has DOACROSS or DEPEND clause.
+ return dirSpec.DirId() == llvm::omp::Directive::OMPD_ordered &&
+ llvm::any_of(dirSpec.Clauses().v, [](const OmpClause &clause) {
+ llvm::omp::Clause id{clause.Id()};
+ return id == llvm::omp::Clause::OMPC_depend ||
+ id == llvm::omp::Clause::OMPC_doacross;
+ });
+}
+
struct StrictlyStructuredBlockParser {
using resultType = Block;
@@ -1272,9 +1299,9 @@ struct StrictlyStructuredBlockParser {
if (lookAhead(skipStuffBeforeStatement >> "BLOCK"_tok).Parse(state)) {
if (auto epc{Parser<ExecutionPartConstruct>{}.Parse(state)}) {
if (IsFortranBlockConstruct(*epc)) {
- Block block;
- block.emplace_back(std::move(*epc));
- return std::move(block);
+ Block body;
+ body.emplace_back(std::move(*epc));
+ return std::move(body);
}
}
}
@@ -1290,22 +1317,11 @@ struct LooselyStructuredBlockParser {
if (lookAhead(skipStuffBeforeStatement >> "BLOCK"_tok).Parse(state)) {
return std::nullopt;
}
- Block body;
- if (auto epc{attempt(Parser<ExecutionPartConstruct>{}).Parse(state)}) {
- if (!IsFortranBlockConstruct(*epc)) {
- body.emplace_back(std::move(*epc));
- if (auto &&blk{attempt(block).Parse(state)}) {
- for (auto &&s : *blk) {
- body.emplace_back(std::move(s));
- }
- }
- } else {
- // Fail if the first construct is BLOCK.
- return std::nullopt;
- }
+ if (auto &&body{block.Parse(state)}) {
+ // Empty body is ok.
+ return std::move(body);
}
- // Empty body is ok.
- return std::move(body);
+ return std::nullopt;
}
};
@@ -1458,6 +1474,9 @@ struct OmpBlockConstructParser {
std::optional<resultType> Parse(ParseState &state) const {
if (auto &&begin{OmpBeginDirectiveParser(dir_).Parse(state)}) {
+ if (IsStandaloneOrdered(*begin)) {
+ return std::nullopt;
+ }
if (auto &&body{attempt(StrictlyStructuredBlockParser{}).Parse(state)}) {
// Try strictly-structured block with an optional end-directive
auto end{maybe(OmpEndDirectiveParser{dir_}).Parse(state)};
@@ -1467,11 +1486,14 @@ struct OmpBlockConstructParser {
[](auto &&s) { return OmpEndDirective(std::move(s)); })};
} else if (auto &&body{
attempt(LooselyStructuredBlockParser{}).Parse(state)}) {
- // Try loosely-structured block with a mandatory end-directive
- if (auto end{OmpEndDirectiveParser{dir_}.Parse(state)}) {
- return OmpBlockConstruct{OmpBeginDirective(std::move(*begin)),
- std::move(*body), OmpEndDirective{std::move(*end)}};
- }
+ // Try loosely-structured block with a mandatory end-directive.
+ auto end{maybe(OmpEndDirectiveParser{dir_}).Parse(state)};
+ // Delay the error for a missing end-directive until semantics so that
+ // we have better control over the output.
+ return OmpBlockConstruct{OmpBeginDirective(std::move(*begin)),
+ std::move(*body),
+ llvm::transformOptional(std::move(*end),
+ [](auto &&s) { return OmpEndDirective(std::move(s)); })};
}
}
return std::nullopt;
@@ -1622,7 +1644,6 @@ TYPE_PARSER(sourced( //
static bool IsSimpleStandalone(const OmpDirectiveName &name) {
switch (name.v) {
case llvm::omp::Directive::OMPD_barrier:
- case llvm::omp::Directive::OMPD_ordered:
case llvm::omp::Directive::OMPD_scan:
case llvm::omp::Directive::OMPD_target_enter_data:
case llvm::omp::Directive::OMPD_target_exit_data:
@@ -1638,7 +1659,9 @@ static bool IsSimpleStandalone(const OmpDirectiveName &name) {
TYPE_PARSER(sourced( //
construct<OpenMPSimpleStandaloneConstruct>(
predicated(OmpDirectiveNameParser{}, IsSimpleStandalone) >=
- Parser<OmpDirectiveSpecification>{})))
+ Parser<OmpDirectiveSpecification>{}) ||
+ construct<OpenMPSimpleStandaloneConstruct>(
+ predicated(Parser<OmpDirectiveSpecification>{}, IsStandaloneOrdered))))
TYPE_PARSER(sourced( //
construct<OpenMPFlushConstruct>(
@@ -1758,17 +1781,8 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareMapperConstruct>(
TYPE_PARSER(construct<OmpReductionCombiner>(Parser<AssignmentStmt>{}) ||
construct<OmpReductionCombiner>(Parser<FunctionReference>{}))
-// 2.13.2 OMP CRITICAL
-TYPE_PARSER(startOmpLine >>
- sourced(construct<OmpEndCriticalDirective>(
- verbatim("END CRITICAL"_tok), maybe(parenthesized(name)))) /
- endOmpLine)
-TYPE_PARSER(sourced(construct<OmpCriticalDirective>(verbatim("CRITICAL"_tok),
- maybe(parenthesized(name)), Parser<OmpClauseList>{})) /
- endOmpLine)
-
TYPE_PARSER(construct<OpenMPCriticalConstruct>(
- Parser<OmpCriticalDirective>{}, block, Parser<OmpEndCriticalDirective>{}))
+ OmpBlockConstructParser{llvm::omp::Directive::OMPD_critical}))
// 2.11.3 Executable Allocate directive
TYPE_PARSER(
@@ -1782,6 +1796,12 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareSimdConstruct>(
verbatim("DECLARE SIMD"_tok) || verbatim("DECLARE_SIMD"_tok),
maybe(parenthesized(name)), Parser<OmpClauseList>{})))
+TYPE_PARSER(sourced( //
+ construct<OpenMPGroupprivate>(
+ predicated(OmpDirectiveNameParser{},
+ IsDirective(llvm::omp::Directive::OMPD_groupprivate)) >=
+ Parser<OmpDirectiveSpecification>{})))
+
// 2.4 Requires construct
TYPE_PARSER(sourced(construct<OpenMPRequiresConstruct>(
verbatim("REQUIRES"_tok), Parser<OmpClauseList>{})))
@@ -1818,6 +1838,8 @@ TYPE_PARSER(
construct<OpenMPDeclarativeConstruct>(
Parser<OpenMPDeclarativeAllocate>{}) ||
construct<OpenMPDeclarativeConstruct>(
+ Parser<OpenMPGroupprivate>{}) ||
+ construct<OpenMPDeclarativeConstruct>(
Parser<OpenMPRequiresConstruct>{}) ||
construct<OpenMPDeclarativeConstruct>(
Parser<OpenMPThreadprivate>{}) ||
@@ -1827,20 +1849,12 @@ TYPE_PARSER(
Parser<OmpMetadirectiveDirective>{})) /
endOmpLine))
-// Assume Construct
-TYPE_PARSER(sourced(construct<OmpAssumeDirective>(
- verbatim("ASSUME"_tok), Parser<OmpClauseList>{})))
-
-TYPE_PARSER(sourced(construct<OmpEndAssumeDirective>(
- startOmpLine >> verbatim("END ASSUME"_tok))))
-
-TYPE_PARSER(sourced(
- construct<OpenMPAssumeConstruct>(Parser<OmpAssumeDirective>{} / endOmpLine,
- block, maybe(Parser<OmpEndAssumeDirective>{} / endOmpLine))))
+TYPE_PARSER(construct<OpenMPAssumeConstruct>(
+ sourced(OmpBlockConstructParser{llvm::omp::Directive::OMPD_assume})))
// Block Construct
#define MakeBlockConstruct(dir) \
- construct<OpenMPBlockConstruct>(OmpBlockConstructParser{dir})
+ construct<OmpBlockConstruct>(OmpBlockConstructParser{dir})
TYPE_PARSER( //
MakeBlockConstruct(llvm::omp::Directive::OMPD_masked) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_master) ||
@@ -1854,11 +1868,15 @@ TYPE_PARSER( //
MakeBlockConstruct(llvm::omp::Directive::OMPD_target_data) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_target_parallel) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_target_teams) ||
+ MakeBlockConstruct(
+ llvm::omp::Directive::OMPD_target_teams_workdistribute) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_target) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_task) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_taskgroup) ||
MakeBlockConstruct(llvm::omp::Directive::OMPD_teams) ||
- MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare))
+ MakeBlockConstruct(llvm::omp::Directive::OMPD_teams_workdistribute) ||
+ MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare) ||
+ MakeBlockConstruct(llvm::omp::Directive::OMPD_workdistribute))
#undef MakeBlockConstruct
// OMP SECTIONS Directive
@@ -1887,7 +1905,7 @@ TYPE_PARSER(sourced(construct<OpenMPSectionsConstruct>(
construct<OpenMPSectionConstruct>(maybe(sectionDir), block))),
many(construct<OpenMPConstruct>(
sourced(construct<OpenMPSectionConstruct>(sectionDir, block))))),
- Parser<OmpEndSectionsDirective>{} / endOmpLine)))
+ maybe(Parser<OmpEndSectionsDirective>{} / endOmpLine))))
static bool IsExecutionPart(const OmpDirectiveName &name) {
return name.IsExecutionPart();
@@ -1901,8 +1919,8 @@ TYPE_CONTEXT_PARSER("OpenMP construct"_en_US,
withMessage("expected OpenMP construct"_err_en_US,
first(construct<OpenMPConstruct>(Parser<OpenMPSectionsConstruct>{}),
construct<OpenMPConstruct>(Parser<OpenMPLoopConstruct>{}),
- construct<OpenMPConstruct>(Parser<OpenMPBlockConstruct>{}),
- // OpenMPBlockConstruct is attempted before
+ construct<OpenMPConstruct>(Parser<OmpBlockConstruct>{}),
+ // OmpBlockConstruct is attempted before
// OpenMPStandaloneConstruct to resolve !$OMP ORDERED
construct<OpenMPConstruct>(Parser<OpenMPStandaloneConstruct>{}),
construct<OpenMPConstruct>(Parser<OpenMPAtomicConstruct>{}),
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
new file mode 100644
index 0000000..ef7e4fc
--- /dev/null
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -0,0 +1,64 @@
+//===-- flang/Parser/openmp-utils.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Common OpenMP utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Parser/openmp-utils.h"
+
+#include "flang/Common/template.h"
+#include "flang/Common/visit.h"
+
+#include <tuple>
+#include <type_traits>
+#include <variant>
+
+namespace Fortran::parser::omp {
+
+const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
+ // Clauses with OmpObjectList as its data member
+ using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
+ OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
+ OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
+ OmpClause::Link, OmpClause::Private, OmpClause::Shared,
+ OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
+
+ // Clauses with OmpObjectList in the tuple
+ using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
+ OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
+ OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
+ OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
+ OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
+
+ // TODO:: Generate the tuples using TableGen.
+ return common::visit(
+ common::visitors{
+ [&](const OmpClause::Depend &x) -> const OmpObjectList * {
+ if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
+ return &std::get<OmpObjectList>(taskDep->t);
+ } else {
+ return nullptr;
+ }
+ },
+ [&](const auto &x) -> const OmpObjectList * {
+ using Ty = std::decay_t<decltype(x)>;
+ if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
+ return &x.v;
+ } else if constexpr (common::HasMember<Ty,
+ TupleObjectListClauses>) {
+ return &std::get<OmpObjectList>(x.v.t);
+ } else {
+ return nullptr;
+ }
+ },
+ },
+ clause.u);
+}
+
+} // namespace Fortran::parser::omp
diff --git a/flang/lib/Parser/parsing.cpp b/flang/lib/Parser/parsing.cpp
index ceea747..8a8c6ef 100644
--- a/flang/lib/Parser/parsing.cpp
+++ b/flang/lib/Parser/parsing.cpp
@@ -96,9 +96,6 @@ const SourceFile *Parsing::Prescan(const std::string &path, Options options) {
prescanner.AddCompilerDirectiveSentinel("$cuf");
prescanner.AddCompilerDirectiveSentinel("@cuf");
}
- if (options.features.IsEnabled(LanguageFeature::CUDA)) {
- preprocessor_.Define("_CUDA", "1");
- }
ProvenanceRange range{allSources.AddIncludedFile(
*sourceFile, ProvenanceRange{}, options.isModuleFile)};
prescanner.Prescan(range);
diff --git a/flang/lib/Parser/preprocessor.cpp b/flang/lib/Parser/preprocessor.cpp
index 0aadc41..9176b4d 100644
--- a/flang/lib/Parser/preprocessor.cpp
+++ b/flang/lib/Parser/preprocessor.cpp
@@ -414,7 +414,7 @@ std::optional<TokenSequence> Preprocessor::MacroReplacement(
const TokenSequence &input, Prescanner &prescanner,
std::optional<std::size_t> *partialFunctionLikeMacro, bool inIfExpression) {
// Do quick scan for any use of a defined name.
- if (definitions_.empty()) {
+ if (!inIfExpression && definitions_.empty()) {
return std::nullopt;
}
std::size_t tokens{input.SizeInTokens()};
@@ -742,12 +742,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
"# missing or invalid name"_err_en_US);
} else {
if (dir.IsAnythingLeft(++j)) {
- if (prescanner.features().ShouldWarn(
- common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#undef: excess tokens at end of directive"_port_en_US);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#undef: excess tokens at end of directive"_port_en_US);
} else {
definitions_.erase(nameToken);
}
@@ -760,12 +757,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
"#%s: missing name"_err_en_US, dirName);
} else {
if (dir.IsAnythingLeft(++j)) {
- if (prescanner.features().ShouldWarn(
- common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#%s: excess tokens at end of directive"_port_en_US, dirName);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#%s: excess tokens at end of directive"_port_en_US, dirName);
}
doThen = IsNameDefined(nameToken) == (dirName == "ifdef");
}
@@ -784,11 +778,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
}
} else if (dirName == "else") {
if (dir.IsAnythingLeft(j)) {
- if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#else: excess tokens at end of directive"_port_en_US);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#else: excess tokens at end of directive"_port_en_US);
}
if (ifStack_.empty()) {
prescanner.Say(dir.GetTokenProvenanceRange(dirOffset),
@@ -815,11 +807,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
}
} else if (dirName == "endif") {
if (dir.IsAnythingLeft(j)) {
- if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#endif: excess tokens at end of directive"_port_en_US);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#endif: excess tokens at end of directive"_port_en_US);
} else if (ifStack_.empty()) {
prescanner.Say(dir.GetTokenProvenanceRange(dirOffset),
"#endif: no #if, #ifdef, or #ifndef"_err_en_US);
@@ -866,12 +856,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
++k;
}
if (k >= pathTokens) {
- if (prescanner.features().ShouldWarn(
- common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#include: expected '>' at end of included file"_port_en_US);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#include: expected '>' at end of included file"_port_en_US);
}
TokenSequence braced{path, 1, k - 1};
include = braced.ToString();
@@ -897,11 +884,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) {
}
k = path.SkipBlanks(k + 1);
if (k < pathTokens && path.TokenAt(k).ToString() != "!") {
- if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) {
- prescanner.Say(common::UsageWarning::Portability,
- dir.GetIntervalProvenanceRange(j, tokens - j),
- "#include: extra stuff ignored after file name"_port_en_US);
- }
+ prescanner.Warn(common::UsageWarning::Portability,
+ dir.GetIntervalProvenanceRange(j, tokens - j),
+ "#include: extra stuff ignored after file name"_port_en_US);
}
std::string buf;
llvm::raw_string_ostream error{buf};
diff --git a/flang/lib/Parser/prescan.h b/flang/lib/Parser/prescan.h
index f650d54..c181c03 100644
--- a/flang/lib/Parser/prescan.h
+++ b/flang/lib/Parser/prescan.h
@@ -91,6 +91,15 @@ public:
return messages_.Say(std::forward<A>(a)...);
}
+ template <typename... A>
+ Message *Warn(common::UsageWarning warning, A &&...a) {
+ return messages_.Warn(false, features_, warning, std::forward<A>(a)...);
+ }
+ template <typename... A>
+ Message *Warn(common::LanguageFeature feature, A &&...a) {
+ return messages_.Warn(false, features_, feature, std::forward<A>(a)...);
+ }
+
private:
struct LineClassification {
enum class Kind {
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 46141e2..dc6d336 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2250,6 +2250,11 @@ public:
Walk(std::get<OmpObjectList>(x.t));
Walk(": ", std::get<std::optional<std::list<Modifier>>>(x.t));
}
+ void Unparse(const OmpDynGroupprivateClause &x) {
+ using Modifier = OmpDynGroupprivateClause::Modifier;
+ Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": ");
+ Walk(std::get<ScalarIntExpr>(x.t));
+ }
void Unparse(const OmpEnterClause &x) {
using Modifier = OmpEnterClause::Modifier;
Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": ");
@@ -2575,40 +2580,14 @@ public:
Put("\n");
EndOpenMP();
}
- void Unparse(const OpenMPAllocatorsConstruct &x) { //
+ void Unparse(const OpenMPAllocatorsConstruct &x) {
Unparse(static_cast<const OmpBlockConstruct &>(x));
}
- void Unparse(const OmpAssumeDirective &x) {
- BeginOpenMP();
- Word("!$OMP ASSUME");
- Walk(" ", std::get<OmpClauseList>(x.t).v);
- Put("\n");
- EndOpenMP();
- }
- void Unparse(const OmpEndAssumeDirective &x) {
- BeginOpenMP();
- Word("!$OMP END ASSUME\n");
- EndOpenMP();
- }
- void Unparse(const OmpCriticalDirective &x) {
- BeginOpenMP();
- Word("!$OMP CRITICAL");
- Walk(" (", std::get<std::optional<Name>>(x.t), ")");
- Walk(std::get<OmpClauseList>(x.t));
- Put("\n");
- EndOpenMP();
- }
- void Unparse(const OmpEndCriticalDirective &x) {
- BeginOpenMP();
- Word("!$OMP END CRITICAL");
- Walk(" (", std::get<std::optional<Name>>(x.t), ")");
- Put("\n");
- EndOpenMP();
+ void Unparse(const OpenMPAssumeConstruct &x) {
+ Unparse(static_cast<const OmpBlockConstruct &>(x));
}
void Unparse(const OpenMPCriticalConstruct &x) {
- Walk(std::get<OmpCriticalDirective>(x.t));
- Walk(std::get<Block>(x.t), "");
- Walk(std::get<OmpEndCriticalDirective>(x.t));
+ Unparse(static_cast<const OmpBlockConstruct &>(x));
}
void Unparse(const OmpDeclareTargetWithList &x) {
Put("("), Walk(x.v), Put(")");
@@ -2718,6 +2697,13 @@ public:
void Unparse(const OpenMPDispatchConstruct &x) { //
Unparse(static_cast<const OmpBlockConstruct &>(x));
}
+ void Unparse(const OpenMPGroupprivate &x) {
+ BeginOpenMP();
+ Word("!$OMP ");
+ Walk(x.v);
+ Put("\n");
+ EndOpenMP();
+ }
void Unparse(const OpenMPRequiresConstruct &y) {
BeginOpenMP();
Word("!$OMP REQUIRES ");
@@ -2778,7 +2764,7 @@ public:
Walk(std::get<std::list<OpenMPConstruct>>(x.t), "");
BeginOpenMP();
Word("!$OMP END ");
- Walk(std::get<OmpEndSectionsDirective>(x.t));
+ Walk(std::get<std::optional<OmpEndSectionsDirective>>(x.t));
Put("\n");
EndOpenMP();
}
@@ -2847,9 +2833,6 @@ public:
Put("\n");
EndOpenMP();
}
- void Unparse(const OpenMPBlockConstruct &x) {
- Unparse(static_cast<const OmpBlockConstruct &>(x));
- }
void Unparse(const OpenMPLoopConstruct &x) {
BeginOpenMP();
Word("!$OMP ");
@@ -2943,6 +2926,7 @@ public:
WALK_NESTED_ENUM(OmpTaskDependenceType, Value) // OMP task-dependence-type
WALK_NESTED_ENUM(OmpScheduleClause, Kind) // OMP schedule-kind
WALK_NESTED_ENUM(OmpSeverityClause, Severity) // OMP severity
+ WALK_NESTED_ENUM(OmpAccessGroup, Value)
WALK_NESTED_ENUM(OmpDeviceModifier, Value) // OMP device modifier
WALK_NESTED_ENUM(
OmpDeviceTypeClause, DeviceTypeDescription) // OMP device_type
diff --git a/flang/lib/Semantics/check-acc-structure.cpp b/flang/lib/Semantics/check-acc-structure.cpp
index 051abdc..6cb7e5e 100644
--- a/flang/lib/Semantics/check-acc-structure.cpp
+++ b/flang/lib/Semantics/check-acc-structure.cpp
@@ -983,24 +983,26 @@ void AccStructureChecker::Enter(const parser::AccClause::Reduction &reduction) {
[&](const parser::Designator &designator) {
if (const auto *name = getDesignatorNameIfDataRef(designator)) {
if (name->symbol) {
- const auto *type{name->symbol->GetType()};
- if (type->IsNumeric(TypeCategory::Integer) &&
- !reductionIntegerSet.test(op.v)) {
- context_.Say(GetContext().clauseSource,
- "reduction operator not supported for integer type"_err_en_US);
- } else if (type->IsNumeric(TypeCategory::Real) &&
- !reductionRealSet.test(op.v)) {
- context_.Say(GetContext().clauseSource,
- "reduction operator not supported for real type"_err_en_US);
- } else if (type->IsNumeric(TypeCategory::Complex) &&
- !reductionComplexSet.test(op.v)) {
- context_.Say(GetContext().clauseSource,
- "reduction operator not supported for complex type"_err_en_US);
- } else if (type->category() ==
- Fortran::semantics::DeclTypeSpec::Category::Logical &&
- !reductionLogicalSet.test(op.v)) {
- context_.Say(GetContext().clauseSource,
- "reduction operator not supported for logical type"_err_en_US);
+ if (const auto *type{name->symbol->GetType()}) {
+ if (type->IsNumeric(TypeCategory::Integer) &&
+ !reductionIntegerSet.test(op.v)) {
+ context_.Say(GetContext().clauseSource,
+ "reduction operator not supported for integer type"_err_en_US);
+ } else if (type->IsNumeric(TypeCategory::Real) &&
+ !reductionRealSet.test(op.v)) {
+ context_.Say(GetContext().clauseSource,
+ "reduction operator not supported for real type"_err_en_US);
+ } else if (type->IsNumeric(TypeCategory::Complex) &&
+ !reductionComplexSet.test(op.v)) {
+ context_.Say(GetContext().clauseSource,
+ "reduction operator not supported for complex type"_err_en_US);
+ } else if (type->category() ==
+ Fortran::semantics::DeclTypeSpec::Category::
+ Logical &&
+ !reductionLogicalSet.test(op.v)) {
+ context_.Say(GetContext().clauseSource,
+ "reduction operator not supported for logical type"_err_en_US);
+ }
}
// TODO: check composite type.
}
diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp
index 0805359..823aa4e 100644
--- a/flang/lib/Semantics/check-allocate.cpp
+++ b/flang/lib/Semantics/check-allocate.cpp
@@ -548,7 +548,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) {
}
}
// Shape related checks
- if (ultimate_ && evaluate::IsAssumedRank(*ultimate_)) {
+ if (ultimate_ && IsAssumedRank(*ultimate_)) {
context.Say(name_.source,
"An assumed-rank dummy argument may not appear in an ALLOCATE statement"_err_en_US);
return false;
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index 6f250328..f0078fd 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -67,7 +67,7 @@ static void CheckImplicitInterfaceArg(evaluate::ActualArgument &arg,
"Null pointer argument requires an explicit interface"_err_en_US);
} else if (auto named{evaluate::ExtractNamedEntity(*expr)}) {
const Symbol &symbol{named->GetLastSymbol()};
- if (evaluate::IsAssumedRank(symbol)) {
+ if (IsAssumedRank(symbol)) {
messages.Say(
"Assumed rank argument requires an explicit interface"_err_en_US);
}
@@ -131,7 +131,7 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual,
dummy.type.type().kind() == actualType.type().kind() &&
!dummy.attrs.test(
characteristics::DummyDataObject::Attr::DeducedFromActual)) {
- bool actualIsAssumedRank{evaluate::IsAssumedRank(actual)};
+ bool actualIsAssumedRank{IsAssumedRank(actual)};
if (actualIsAssumedRank &&
!dummy.type.attrs().test(
characteristics::TypeAndShape::Attr::AssumedRank)) {
@@ -140,7 +140,8 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual,
messages.Say(
"Assumed-rank character array may not be associated with a dummy argument that is not assumed-rank"_err_en_US);
} else {
- context.Warn(common::LanguageFeature::AssumedRankPassedToNonAssumedRank,
+ context.Warn(messages,
+ common::LanguageFeature::AssumedRankPassedToNonAssumedRank,
messages.at(),
"Assumed-rank character array should not be associated with a dummy argument that is not assumed-rank"_port_en_US);
}
@@ -187,9 +188,9 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual,
"Actual argument has fewer characters remaining in storage sequence (%jd) than %s (%jd)"_err_en_US,
static_cast<std::intmax_t>(actualChars), dummyName,
static_cast<std::intmax_t>(dummyChars));
- } else if (context.ShouldWarn(
- common::UsageWarning::ShortCharacterActual)) {
- messages.Say(common::UsageWarning::ShortCharacterActual,
+ } else {
+ context.Warn(messages,
+ common::UsageWarning::ShortCharacterActual,
"Actual argument has fewer characters remaining in storage sequence (%jd) than %s (%jd)"_warn_en_US,
static_cast<std::intmax_t>(actualChars), dummyName,
static_cast<std::intmax_t>(dummyChars));
@@ -207,9 +208,9 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual,
static_cast<std::intmax_t>(*actualSize * *actualLength),
dummyName,
static_cast<std::intmax_t>(*dummySize * *dummyLength));
- } else if (context.ShouldWarn(
- common::UsageWarning::ShortCharacterActual)) {
- messages.Say(common::UsageWarning::ShortCharacterActual,
+ } else {
+ context.Warn(messages,
+ common::UsageWarning::ShortCharacterActual,
"Actual argument array has fewer characters (%jd) than %s array (%jd)"_warn_en_US,
static_cast<std::intmax_t>(*actualSize * *actualLength),
dummyName,
@@ -229,17 +230,14 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual,
} else if (*actualLength < *dummyLength) {
CHECK(dummy.type.Rank() == 0);
bool isVariable{evaluate::IsVariable(actual)};
- if (context.ShouldWarn(
- common::UsageWarning::ShortCharacterActual)) {
- if (isVariable) {
- messages.Say(common::UsageWarning::ShortCharacterActual,
- "Actual argument variable length '%jd' is less than expected length '%jd'"_warn_en_US,
- *actualLength, *dummyLength);
- } else {
- messages.Say(common::UsageWarning::ShortCharacterActual,
- "Actual argument expression length '%jd' is less than expected length '%jd'"_warn_en_US,
- *actualLength, *dummyLength);
- }
+ if (isVariable) {
+ context.Warn(messages, common::UsageWarning::ShortCharacterActual,
+ "Actual argument variable length '%jd' is less than expected length '%jd'"_warn_en_US,
+ *actualLength, *dummyLength);
+ } else {
+ context.Warn(messages, common::UsageWarning::ShortCharacterActual,
+ "Actual argument expression length '%jd' is less than expected length '%jd'"_warn_en_US,
+ *actualLength, *dummyLength);
}
if (!isVariable) {
auto converted{
@@ -279,9 +277,8 @@ static void ConvertIntegerActual(evaluate::Expr<evaluate::SomeType> &actual,
messages.Say(
"Actual argument scalar expression of type INTEGER(%d) cannot be implicitly converted to smaller dummy argument type INTEGER(%d)"_err_en_US,
actualType.type().kind(), dummyType.type().kind());
- } else if (semanticsContext.ShouldWarn(common::LanguageFeature::
- ActualIntegerConvertedToSmallerKind)) {
- messages.Say(
+ } else {
+ semanticsContext.Warn(messages,
common::LanguageFeature::ActualIntegerConvertedToSmallerKind,
"Actual argument scalar expression of type INTEGER(%d) was converted to smaller dummy argument type INTEGER(%d)"_port_en_US,
actualType.type().kind(), dummyType.type().kind());
@@ -364,20 +361,16 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
if (const auto *constantChar{
evaluate::UnwrapConstantValue<evaluate::Ascii>(actual)};
constantChar && constantChar->wasHollerith() &&
- dummy.type.type().IsUnlimitedPolymorphic() &&
- context.ShouldWarn(common::LanguageFeature::HollerithPolymorphic)) {
- messages.Say(common::LanguageFeature::HollerithPolymorphic,
+ dummy.type.type().IsUnlimitedPolymorphic()) {
+ foldingContext.Warn(common::LanguageFeature::HollerithPolymorphic,
"passing Hollerith to unlimited polymorphic as if it were CHARACTER"_port_en_US);
}
} else if (dummyRank == 0 && allowActualArgumentConversions) {
// Extension: pass Hollerith literal to scalar as if it had been BOZ
if (auto converted{evaluate::HollerithToBOZ(
foldingContext, actual, dummy.type.type())}) {
- if (context.ShouldWarn(
- common::LanguageFeature::HollerithOrCharacterAsBOZ)) {
- messages.Say(common::LanguageFeature::HollerithOrCharacterAsBOZ,
- "passing Hollerith or character literal as if it were BOZ"_port_en_US);
- }
+ foldingContext.Warn(common::LanguageFeature::HollerithOrCharacterAsBOZ,
+ "passing Hollerith or character literal as if it were BOZ"_port_en_US);
actual = *converted;
actualType.type() = dummy.type.type();
typesCompatible = true;
@@ -387,7 +380,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
characteristics::TypeAndShape::Attr::AssumedRank)};
bool actualIsAssumedSize{actualType.attrs().test(
characteristics::TypeAndShape::Attr::AssumedSize)};
- bool actualIsAssumedRank{evaluate::IsAssumedRank(actual)};
+ bool actualIsAssumedRank{IsAssumedRank(actual)};
bool actualIsPointer{evaluate::IsObjectPointer(actual)};
bool actualIsAllocatable{evaluate::IsAllocatableDesignator(actual)};
bool actualMayBeAssumedSize{actualIsAssumedSize ||
@@ -411,7 +404,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
"%s actual argument may not be associated with INTENT(OUT) assumed-rank dummy argument requiring finalization, destruction, or initialization"_err_en_US,
actualDesc);
} else {
- context.Warn(common::UsageWarning::Portability, messages.at(),
+ foldingContext.Warn(common::UsageWarning::Portability, messages.at(),
"%s actual argument should not be associated with INTENT(OUT) assumed-rank dummy argument"_port_en_US,
actualDesc);
}
@@ -671,9 +664,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
"Actual argument has fewer elements remaining in storage sequence (%jd) than %s array (%jd)"_err_en_US,
static_cast<std::intmax_t>(*actualElements), dummyName,
static_cast<std::intmax_t>(*dummySize));
- } else if (context.ShouldWarn(
- common::UsageWarning::ShortArrayActual)) {
- messages.Say(common::UsageWarning::ShortArrayActual,
+ } else {
+ context.Warn(common::UsageWarning::ShortArrayActual,
"Actual argument has fewer elements remaining in storage sequence (%jd) than %s array (%jd)"_warn_en_US,
static_cast<std::intmax_t>(*actualElements), dummyName,
static_cast<std::intmax_t>(*dummySize));
@@ -690,9 +682,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
"Actual argument array has fewer elements (%jd) than %s array (%jd)"_err_en_US,
static_cast<std::intmax_t>(*actualSize), dummyName,
static_cast<std::intmax_t>(*dummySize));
- } else if (context.ShouldWarn(
- common::UsageWarning::ShortArrayActual)) {
- messages.Say(common::UsageWarning::ShortArrayActual,
+ } else {
+ context.Warn(common::UsageWarning::ShortArrayActual,
"Actual argument array has fewer elements (%jd) than %s array (%jd)"_warn_en_US,
static_cast<std::intmax_t>(*actualSize), dummyName,
static_cast<std::intmax_t>(*dummySize));
@@ -779,24 +770,36 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
// Cases when temporaries might be needed but must not be permitted.
bool dummyIsAssumedShape{dummy.type.attrs().test(
characteristics::TypeAndShape::Attr::AssumedShape)};
- if ((actualIsAsynchronous || actualIsVolatile) &&
- (dummyIsAsynchronous || dummyIsVolatile) && !dummyIsValue) {
- if (actualCoarrayRef) { // C1538
- messages.Say(
- "Coindexed ASYNCHRONOUS or VOLATILE actual argument may not be associated with %s with ASYNCHRONOUS or VOLATILE attributes unless VALUE"_err_en_US,
- dummyName);
- }
- if ((actualRank > 0 || actualIsAssumedRank) && !actualIsContiguous) {
- if (dummyIsContiguous ||
- !(dummyIsAssumedShape || dummyIsAssumedRank ||
- (actualIsPointer && dummyIsPointer))) { // C1539 & C1540
+ if (!dummyIsValue && (dummyIsAsynchronous || dummyIsVolatile)) {
+ if (actualIsAsynchronous || actualIsVolatile) {
+ if (actualCoarrayRef) { // F'2023 C1547
messages.Say(
- "ASYNCHRONOUS or VOLATILE actual argument that is not simply contiguous may not be associated with a contiguous ASYNCHRONOUS or VOLATILE %s"_err_en_US,
+ "Coindexed ASYNCHRONOUS or VOLATILE actual argument may not be associated with %s with ASYNCHRONOUS or VOLATILE attributes unless VALUE"_err_en_US,
dummyName);
}
+ if ((actualRank > 0 || actualIsAssumedRank) && !actualIsContiguous) {
+ if (dummyIsContiguous ||
+ !(dummyIsAssumedShape || dummyIsAssumedRank ||
+ (actualIsPointer && dummyIsPointer))) { // F'2023 C1548 & C1549
+ messages.Say(
+ "ASYNCHRONOUS or VOLATILE actual argument that is not simply contiguous may not be associated with a contiguous ASYNCHRONOUS or VOLATILE %s"_err_en_US,
+ dummyName);
+ }
+ }
+ // The vector subscript case is handled by the definability check above.
+ // The copy-in/copy-out cases are handled by the previous checks.
+ // Nag, GFortran, and NVFortran all error on this case, even though it is
+ // ok, prossibly as an over-restriction of C1548.
+ } else if (!(dummyIsAssumedShape || dummyIsAssumedRank ||
+ (actualIsPointer && dummyIsPointer)) &&
+ evaluate::IsArraySection(actual) &&
+ !evaluate::HasVectorSubscript(actual)) {
+ context.Warn(common::UsageWarning::Portability, messages.at(),
+ "The array section '%s' should not be associated with %s with %s attribute, unless the dummy is assumed-shape or assumed-rank"_port_en_US,
+ actual.AsFortran(), dummyName,
+ dummyIsAsynchronous ? "ASYNCHRONOUS" : "VOLATILE");
}
}
-
// 15.5.2.6 -- dummy is ALLOCATABLE
bool dummyIsOptional{
dummy.attrs.test(characteristics::DummyDataObject::Attr::Optional)};
@@ -821,10 +824,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
messages.Say(
"A null pointer should not be associated with allocatable %s without INTENT(IN)"_warn_en_US,
dummyName);
- } else if (dummy.intent == common::Intent::In &&
- context.ShouldWarn(
- common::LanguageFeature::NullActualForAllocatable)) {
- messages.Say(common::LanguageFeature::NullActualForAllocatable,
+ } else if (dummy.intent == common::Intent::In) {
+ foldingContext.Warn(common::LanguageFeature::NullActualForAllocatable,
"Allocatable %s is associated with a null pointer"_port_en_US,
dummyName);
}
@@ -878,11 +879,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
checkTypeCompatibility = false;
if (dummyIsUnlimited && dummy.intent == common::Intent::In &&
context.IsEnabled(common::LanguageFeature::RelaxedIntentInChecking)) {
- if (context.ShouldWarn(
- common::LanguageFeature::RelaxedIntentInChecking)) {
- messages.Say(common::LanguageFeature::RelaxedIntentInChecking,
- "If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both should be so"_port_en_US);
- }
+ foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking,
+ "If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both should be so"_port_en_US);
} else {
messages.Say(
"If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both must be so"_err_en_US);
@@ -890,21 +888,15 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
} else if (dummyIsPolymorphic != actualIsPolymorphic) {
if (dummyIsPolymorphic && dummy.intent == common::Intent::In &&
context.IsEnabled(common::LanguageFeature::RelaxedIntentInChecking)) {
- if (context.ShouldWarn(
- common::LanguageFeature::RelaxedIntentInChecking)) {
- messages.Say(common::LanguageFeature::RelaxedIntentInChecking,
- "If a POINTER or ALLOCATABLE dummy or actual argument is polymorphic, both should be so"_port_en_US);
- }
+ foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking,
+ "If a POINTER or ALLOCATABLE dummy or actual argument is polymorphic, both should be so"_port_en_US);
} else if (actualIsPolymorphic &&
context.IsEnabled(common::LanguageFeature::
PolymorphicActualAllocatableOrPointerToMonomorphicDummy)) {
- if (context.ShouldWarn(common::LanguageFeature::
- PolymorphicActualAllocatableOrPointerToMonomorphicDummy)) {
- messages.Say(
- common::LanguageFeature::
- PolymorphicActualAllocatableOrPointerToMonomorphicDummy,
- "If a POINTER or ALLOCATABLE actual argument is polymorphic, the corresponding dummy argument should also be so"_port_en_US);
- }
+ foldingContext.Warn(
+ common::LanguageFeature::
+ PolymorphicActualAllocatableOrPointerToMonomorphicDummy,
+ "If a POINTER or ALLOCATABLE actual argument is polymorphic, the corresponding dummy argument should also be so"_port_en_US);
} else {
checkTypeCompatibility = false;
messages.Say(
@@ -916,11 +908,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
if (dummy.intent == common::Intent::In &&
context.IsEnabled(
common::LanguageFeature::RelaxedIntentInChecking)) {
- if (context.ShouldWarn(
- common::LanguageFeature::RelaxedIntentInChecking)) {
- messages.Say(common::LanguageFeature::RelaxedIntentInChecking,
- "POINTER or ALLOCATABLE dummy and actual arguments should have the same declared type and kind"_port_en_US);
- }
+ foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking,
+ "POINTER or ALLOCATABLE dummy and actual arguments should have the same declared type and kind"_port_en_US);
} else {
messages.Say(
"POINTER or ALLOCATABLE dummy and actual arguments must have the same declared type and kind"_err_en_US);
@@ -991,13 +980,13 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
bool actualIsTemp{
!actualIsVariable || HasVectorSubscript(actual) || actualCoarrayRef};
if (actualIsTemp) {
- messages.Say(common::UsageWarning::NonTargetPassedToTarget,
+ foldingContext.Warn(common::UsageWarning::NonTargetPassedToTarget,
"Any pointer associated with TARGET %s during this call will not be associated with the value of '%s' afterwards"_warn_en_US,
dummyName, actual.AsFortran());
} else {
auto actualSymbolVector{GetSymbolVector(actual)};
if (!evaluate::GetLastTarget(actualSymbolVector)) {
- messages.Say(common::UsageWarning::NonTargetPassedToTarget,
+ foldingContext.Warn(common::UsageWarning::NonTargetPassedToTarget,
"Any pointer associated with TARGET %s during this call must not be used afterwards, as '%s' is not a target"_warn_en_US,
dummyName, actual.AsFortran());
}
@@ -1058,12 +1047,11 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
dummyName);
}
}
- std::optional<std::string> warning;
bool isHostDeviceProc{procedure.cudaSubprogramAttrs &&
*procedure.cudaSubprogramAttrs ==
common::CUDASubprogramAttrs::HostDevice};
if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
- dummy.ignoreTKR, &warning, /*allowUnifiedMatchingRule=*/true,
+ dummy.ignoreTKR, /*allowUnifiedMatchingRule=*/true,
isHostDeviceProc, &context.languageFeatures())) {
auto toStr{[](std::optional<common::CUDADataAttr> x) {
return x ? "ATTRIBUTES("s +
@@ -1074,10 +1062,6 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
"%s has %s but its associated actual argument has %s"_err_en_US,
dummyName, toStr(dummyDataAttr), toStr(actualDataAttr));
}
- if (warning && context.ShouldWarn(common::UsageWarning::CUDAUsage)) {
- messages.Say(common::UsageWarning::CUDAUsage, "%s"_warn_en_US,
- std::move(*warning));
- }
}
// Warning for breaking F'2023 change with character allocatables
@@ -1131,9 +1115,8 @@ static void CheckProcedureArg(evaluate::ActualArgument &arg,
evaluate::SayWithDeclaration(messages, *argProcSymbol,
"Procedure binding '%s' passed as an actual argument"_err_en_US,
argProcSymbol->name());
- } else if (context.ShouldWarn(
- common::LanguageFeature::BindingAsProcedure)) {
- evaluate::SayWithDeclaration(messages, *argProcSymbol,
+ } else {
+ evaluate::WarnWithDeclaration(foldingContext, *argProcSymbol,
common::LanguageFeature::BindingAsProcedure,
"Procedure binding '%s' passed as an actual argument"_port_en_US,
argProcSymbol->name());
@@ -1185,15 +1168,14 @@ static void CheckProcedureArg(evaluate::ActualArgument &arg,
messages.Say(
"Actual procedure argument for %s of a PURE procedure must have an explicit interface"_err_en_US,
dummyName);
- } else if (context.ShouldWarn(
- common::UsageWarning::ImplicitInterfaceActual)) {
- messages.Say(common::UsageWarning::ImplicitInterfaceActual,
+ } else {
+ foldingContext.Warn(
+ common::UsageWarning::ImplicitInterfaceActual,
"Actual procedure argument has an implicit interface which is not known to be compatible with %s which has an explicit interface"_warn_en_US,
dummyName);
}
- } else if (warning &&
- context.ShouldWarn(common::UsageWarning::ProcDummyArgShapes)) {
- messages.Say(common::UsageWarning::ProcDummyArgShapes,
+ } else if (warning) {
+ foldingContext.Warn(common::UsageWarning::ProcDummyArgShapes,
"Actual procedure argument has possible interface incompatibility with %s: %s"_warn_en_US,
dummyName, std::move(*warning));
}
@@ -1368,16 +1350,14 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg,
messages.Say(
"NULL() actual argument '%s' may not be associated with allocatable dummy argument %s that is INTENT(OUT) or INTENT(IN OUT)"_err_en_US,
expr->AsFortran(), dummyName);
- } else if (object.intent == common::Intent::Default &&
- context.ShouldWarn(common::UsageWarning::
- NullActualForDefaultIntentAllocatable)) {
- messages.Say(common::UsageWarning::
- NullActualForDefaultIntentAllocatable,
+ } else if (object.intent == common::Intent::Default) {
+ foldingContext.Warn(
+ common::UsageWarning::
+ NullActualForDefaultIntentAllocatable,
"NULL() actual argument '%s' should not be associated with allocatable dummy argument %s without INTENT(IN)"_warn_en_US,
expr->AsFortran(), dummyName);
- } else if (context.ShouldWarn(common::LanguageFeature::
- NullActualForAllocatable)) {
- messages.Say(
+ } else {
+ foldingContext.Warn(
common::LanguageFeature::NullActualForAllocatable,
"Allocatable %s is associated with %s"_port_en_US,
dummyName, expr->AsFortran());
@@ -1395,8 +1375,7 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg,
assumed.name(), dummyName);
} else if (object.type.attrs().test(characteristics::
TypeAndShape::Attr::AssumedRank) &&
- !IsAssumedShape(assumed) &&
- !evaluate::IsAssumedRank(assumed)) {
+ !IsAssumedShape(assumed) && !IsAssumedRank(assumed)) {
messages.Say( // C711
"Assumed-type '%s' must be either assumed shape or assumed rank to be associated with assumed rank %s"_err_en_US,
assumed.name(), dummyName);
@@ -1567,7 +1546,7 @@ static void CheckAssociated(evaluate::ActualArguments &arguments,
if (semanticsContext.ShouldWarn(common::UsageWarning::Portability)) {
if (!evaluate::ExtractDataRef(*pointerExpr) &&
!evaluate::IsProcedurePointer(*pointerExpr)) {
- messages.Say(common::UsageWarning::Portability,
+ foldingContext.Warn(common::UsageWarning::Portability,
pointerArg->sourceLocation(),
"POINTER= argument of ASSOCIATED() is required by some other compilers to be a pointer"_port_en_US);
} else if (scope && !evaluate::UnwrapProcedureRef(*pointerExpr)) {
@@ -1578,7 +1557,8 @@ static void CheckAssociated(evaluate::ActualArguments &arguments,
DefinabilityFlag::DoNotNoteDefinition},
*pointerExpr)}) {
if (whyNot->IsFatal()) {
- if (auto *msg{messages.Say(common::UsageWarning::Portability,
+ if (auto *msg{foldingContext.Warn(
+ common::UsageWarning::Portability,
pointerArg->sourceLocation(),
"POINTER= argument of ASSOCIATED() is required by some other compilers to be a valid left-hand side of a pointer assignment statement"_port_en_US)}) {
msg->Attach(std::move(
@@ -2005,8 +1985,9 @@ static void CheckReduce(
}
}
}
- const auto *result{
- procChars ? procChars->functionResult->GetTypeAndShape() : nullptr};
+ const auto *result{procChars && procChars->functionResult
+ ? procChars->functionResult->GetTypeAndShape()
+ : nullptr};
if (!procChars || !procChars->IsPure() ||
procChars->dummyArguments.size() != 2 || !procChars->functionResult) {
messages.Say(
@@ -2092,10 +2073,8 @@ static void CheckReduce(
// TRANSFER (16.9.193)
static void CheckTransferOperandType(SemanticsContext &context,
const evaluate::DynamicType &type, const char *which) {
- if (type.IsPolymorphic() &&
- context.ShouldWarn(common::UsageWarning::PolymorphicTransferArg)) {
- context.foldingContext().messages().Say(
- common::UsageWarning::PolymorphicTransferArg,
+ if (type.IsPolymorphic()) {
+ context.foldingContext().Warn(common::UsageWarning::PolymorphicTransferArg,
"%s of TRANSFER is polymorphic"_warn_en_US, which);
} else if (!type.IsUnlimitedPolymorphic() &&
type.category() == TypeCategory::Derived &&
@@ -2103,7 +2082,7 @@ static void CheckTransferOperandType(SemanticsContext &context,
DirectComponentIterator directs{type.GetDerivedTypeSpec()};
if (auto bad{std::find_if(directs.begin(), directs.end(), IsDescriptor)};
bad != directs.end()) {
- evaluate::SayWithDeclaration(context.foldingContext().messages(), *bad,
+ evaluate::WarnWithDeclaration(context.foldingContext(), *bad,
common::UsageWarning::PointerComponentTransferArg,
"%s of TRANSFER contains allocatable or pointer component %s"_warn_en_US,
which, bad.BuildResultDesignatorName());
@@ -2133,8 +2112,8 @@ static void CheckTransfer(evaluate::ActualArguments &arguments,
messages.Say(
"Element size of MOLD= array may not be zero when SOURCE= is not empty"_err_en_US);
}
- } else if (context.ShouldWarn(common::UsageWarning::VoidMold)) {
- messages.Say(common::UsageWarning::VoidMold,
+ } else {
+ foldingContext.Warn(common::UsageWarning::VoidMold,
"Element size of MOLD= array may not be zero unless SOURCE= is empty"_warn_en_US);
}
}
@@ -2150,7 +2129,7 @@ static void CheckTransfer(evaluate::ActualArguments &arguments,
} else if (context.ShouldWarn(
common::UsageWarning::TransferSizePresence) &&
IsAllocatableOrObjectPointer(whole)) {
- messages.Say(common::UsageWarning::TransferSizePresence,
+ foldingContext.Warn(common::UsageWarning::TransferSizePresence,
"SIZE= argument that is allocatable or pointer must be present at execution; parenthesize to silence this warning"_warn_en_US);
}
}
@@ -2373,13 +2352,10 @@ bool CheckArguments(const characteristics::Procedure &proc,
/*extentErrors=*/true, ignoreImplicitVsExplicit)};
if (!buffer.empty()) {
if (treatingExternalAsImplicit) {
- if (context.ShouldWarn(
- common::UsageWarning::KnownBadImplicitInterface)) {
- if (auto *msg{messages.Say(
- common::UsageWarning::KnownBadImplicitInterface,
- "If the procedure's interface were explicit, this reference would be in error"_warn_en_US)}) {
- buffer.AttachTo(*msg, parser::Severity::Because);
- }
+ if (auto *msg{foldingContext.Warn(
+ common::UsageWarning::KnownBadImplicitInterface,
+ "If the procedure's interface were explicit, this reference would be in error"_warn_en_US)}) {
+ buffer.AttachTo(*msg, parser::Severity::Because);
} else {
buffer.clear();
}
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index d769f22..84edceb 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -130,21 +130,14 @@ private:
}
template <typename FeatureOrUsageWarning, typename... A>
parser::Message *Warn(FeatureOrUsageWarning warning, A &&...x) {
- if (!context_.ShouldWarn(warning) || InModuleFile()) {
- return nullptr;
- } else {
- return messages_.Say(warning, std::forward<A>(x)...);
- }
+ return messages_.Warn(InModuleFile(), context_.languageFeatures(), warning,
+ std::forward<A>(x)...);
}
template <typename FeatureOrUsageWarning, typename... A>
parser::Message *Warn(
FeatureOrUsageWarning warning, parser::CharBlock source, A &&...x) {
- if (!context_.ShouldWarn(warning) ||
- FindModuleFileContaining(context_.FindScope(source))) {
- return nullptr;
- } else {
- return messages_.Say(warning, source, std::forward<A>(x)...);
- }
+ return messages_.Warn(FindModuleFileContaining(context_.FindScope(source)),
+ context_.languageFeatures(), warning, source, std::forward<A>(x)...);
}
bool IsResultOkToDiffer(const FunctionResult &);
void CheckGlobalName(const Symbol &);
@@ -326,7 +319,7 @@ void CheckHelper::Check(const Symbol &symbol) {
!IsDummy(symbol)) {
if (context_.IsEnabled(
common::LanguageFeature::IgnoreIrrelevantAttributes)) {
- context_.Warn(common::LanguageFeature::IgnoreIrrelevantAttributes,
+ Warn(common::LanguageFeature::IgnoreIrrelevantAttributes,
"Only a dummy argument should have an INTENT, VALUE, or OPTIONAL attribute"_warn_en_US);
} else {
messages_.Say(
@@ -633,7 +626,7 @@ void CheckHelper::CheckValue(
"VALUE attribute may not apply to a type with a coarray ultimate component"_err_en_US);
}
}
- if (evaluate::IsAssumedRank(symbol)) {
+ if (IsAssumedRank(symbol)) {
messages_.Say(
"VALUE attribute may not apply to an assumed-rank array"_err_en_US);
}
@@ -743,7 +736,7 @@ void CheckHelper::CheckObjectEntity(
"Coarray '%s' may not have type TEAM_TYPE, C_PTR, or C_FUNPTR"_err_en_US,
symbol.name());
}
- if (evaluate::IsAssumedRank(symbol)) {
+ if (IsAssumedRank(symbol)) {
messages_.Say("Coarray '%s' may not be an assumed-rank array"_err_en_US,
symbol.name());
}
@@ -889,7 +882,7 @@ void CheckHelper::CheckObjectEntity(
"!DIR$ IGNORE_TKR may not apply to an allocatable or pointer"_err_en_US);
}
} else if (ignoreTKR.test(common::IgnoreTKR::Rank)) {
- if (ignoreTKR.count() == 1 && evaluate::IsAssumedRank(symbol)) {
+ if (ignoreTKR.count() == 1 && IsAssumedRank(symbol)) {
Warn(common::UsageWarning::IgnoreTKRUsage,
"!DIR$ IGNORE_TKR(R) is not meaningful for an assumed-rank array"_warn_en_US);
} else if (inExplicitExternalInterface) {
@@ -1214,7 +1207,7 @@ void CheckHelper::CheckObjectEntity(
SayWithDeclaration(symbol,
"Deferred-shape entity of %s type is not supported"_err_en_US,
typeName);
- } else if (evaluate::IsAssumedRank(symbol)) {
+ } else if (IsAssumedRank(symbol)) {
SayWithDeclaration(symbol,
"Assumed rank entity of %s type is not supported"_err_en_US,
typeName);
@@ -2428,7 +2421,7 @@ void CheckHelper::CheckVolatile(const Symbol &symbol,
void CheckHelper::CheckContiguous(const Symbol &symbol) {
if (evaluate::IsVariable(symbol) &&
((IsPointer(symbol) && symbol.Rank() > 0) || IsAssumedShape(symbol) ||
- evaluate::IsAssumedRank(symbol))) {
+ IsAssumedRank(symbol))) {
} else {
parser::MessageFixedText msg{symbol.owner().IsDerivedType()
? "CONTIGUOUS component '%s' should be an array with the POINTER attribute"_port_en_US
@@ -2957,7 +2950,7 @@ static bool IsSubprogramDefinition(const Symbol &symbol) {
static bool IsExternalProcedureDefinition(const Symbol &symbol) {
return IsBlockData(symbol) ||
- (IsSubprogramDefinition(symbol) &&
+ ((IsSubprogramDefinition(symbol) || IsAlternateEntry(&symbol)) &&
(IsExternal(symbol) || symbol.GetBindName()));
}
@@ -3141,16 +3134,14 @@ parser::Messages CheckHelper::WhyNotInteroperableDerivedType(
*dyType, &context_.languageFeatures())
.value_or(false)) {
if (type->category() == DeclTypeSpec::Logical) {
- if (context_.ShouldWarn(common::UsageWarning::LogicalVsCBool)) {
- msgs.Say(common::UsageWarning::LogicalVsCBool, component.name(),
- "A LOGICAL component of an interoperable type should have the interoperable KIND=C_BOOL"_port_en_US);
- }
+ context().Warn(msgs, common::UsageWarning::LogicalVsCBool,
+ component.name(),
+ "A LOGICAL component of an interoperable type should have the interoperable KIND=C_BOOL"_port_en_US);
} else if (type->category() == DeclTypeSpec::Character && dyType &&
dyType->kind() == 1) {
- if (context_.ShouldWarn(common::UsageWarning::BindCCharLength)) {
- msgs.Say(common::UsageWarning::BindCCharLength, component.name(),
- "A CHARACTER component of an interoperable type should have length 1"_port_en_US);
- }
+ context().Warn(msgs, common::UsageWarning::BindCCharLength,
+ component.name(),
+ "A CHARACTER component of an interoperable type should have length 1"_port_en_US);
} else {
msgs.Say(component.name(),
"Each component of an interoperable derived type must have an interoperable type"_err_en_US);
@@ -3165,10 +3156,9 @@ parser::Messages CheckHelper::WhyNotInteroperableDerivedType(
}
}
if (derived->componentNames().empty()) { // F'2023 C1805
- if (context_.ShouldWarn(common::LanguageFeature::EmptyBindCDerivedType)) {
- msgs.Say(common::LanguageFeature::EmptyBindCDerivedType, symbol.name(),
- "A derived type with the BIND attribute should not be empty"_warn_en_US);
- }
+ context().Warn(msgs, common::LanguageFeature::EmptyBindCDerivedType,
+ symbol.name(),
+ "A derived type with the BIND attribute should not be empty"_warn_en_US);
}
}
if (msgs.AnyFatalError()) {
@@ -3218,7 +3208,7 @@ parser::Messages CheckHelper::WhyNotInteroperableObject(
if (derived && !derived->typeSymbol().attrs().test(Attr::BIND_C)) {
if (allowNonInteroperableType) { // portability warning only
evaluate::AttachDeclaration(
- context_.Warn(common::UsageWarning::Portability, symbol.name(),
+ Warn(common::UsageWarning::Portability, symbol.name(),
"The derived type of this interoperable object should be BIND(C)"_port_en_US),
derived->typeSymbol());
} else if (!context_.IsEnabled(
@@ -3260,10 +3250,10 @@ parser::Messages CheckHelper::WhyNotInteroperableObject(
} else if (type->category() == DeclTypeSpec::Logical) {
if (context_.ShouldWarn(common::UsageWarning::LogicalVsCBool)) {
if (IsDummy(symbol)) {
- msgs.Say(common::UsageWarning::LogicalVsCBool, symbol.name(),
+ Warn(common::UsageWarning::LogicalVsCBool, symbol.name(),
"A BIND(C) LOGICAL dummy argument should have the interoperable KIND=C_BOOL"_port_en_US);
} else {
- msgs.Say(common::UsageWarning::LogicalVsCBool, symbol.name(),
+ Warn(common::UsageWarning::LogicalVsCBool, symbol.name(),
"A BIND(C) LOGICAL object should have the interoperable KIND=C_BOOL"_port_en_US);
}
}
@@ -3459,7 +3449,7 @@ void CheckHelper::CheckBindC(const Symbol &symbol) {
bool CheckHelper::CheckDioDummyIsData(
const Symbol &subp, const Symbol *arg, std::size_t position) {
if (arg && arg->detailsIf<ObjectEntityDetails>()) {
- if (evaluate::IsAssumedRank(*arg)) {
+ if (IsAssumedRank(*arg)) {
messages_.Say(arg->name(),
"Dummy argument '%s' may not be assumed-rank"_err_en_US, arg->name());
return false;
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index a5fdabf..f25497e 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -11,13 +11,16 @@
//===----------------------------------------------------------------------===//
#include "check-omp-structure.h"
-#include "openmp-utils.h"
#include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
#include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
+#include "flang/Evaluate/rewrite.h"
#include "flang/Evaluate/tools.h"
#include "flang/Parser/char-block.h"
#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
@@ -42,11 +45,167 @@ using namespace Fortran::semantics::omp;
namespace operation = Fortran::evaluate::operation;
+static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr);
+
template <typename T, typename U>
static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) {
return !(e == f);
}
+namespace {
+template <typename...> struct IsIntegral {
+ static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsIntegral<evaluate::Type<C, K>> {
+ static constexpr bool value{//
+ C == common::TypeCategory::Integer ||
+ C == common::TypeCategory::Unsigned ||
+ C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
+
+template <typename...> struct IsFloatingPoint {
+ static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsFloatingPoint<evaluate::Type<C, K>> {
+ static constexpr bool value{//
+ C == common::TypeCategory::Real || C == common::TypeCategory::Complex};
+};
+
+template <typename T>
+constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
+
+template <typename T>
+constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};
+
+template <typename T, typename Op0, typename Op1>
+using ReassocOpBase = evaluate::match::AnyOfPattern< //
+ evaluate::match::Add<T, Op0, Op1>, //
+ evaluate::match::Mul<T, Op0, Op1>>;
+
+template <typename T, typename Op0, typename Op1>
+struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
+ using Base = ReassocOpBase<T, Op0, Op1>;
+ using Base::Base;
+};
+
+template <typename T, typename Op0, typename Op1>
+ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) {
+ return ReassocOp<T, Op0, Op1>(op0, op1);
+}
+} // namespace
+
+struct ReassocRewriter : public evaluate::rewrite::Identity {
+ using Id = evaluate::rewrite::Identity;
+ struct NonIntegralTag {};
+
+ ReassocRewriter(const SomeExpr &atom, const SemanticsContext &context)
+ : atom_(atom), context_(context) {}
+
+ // Try to find cases where the input expression is of the form
+ // (1) (a . b) . c, or
+ // (2) a . (b . c),
+ // where . denotes an associative operation (currently + or *), and a, b, c
+ // are some subexpresions.
+ // If one of the operands in the nested operation is the atomic variable
+ // (with some possible type conversions applied to it), bring it to the
+ // top-level operation, and move the top-level operand into the nested
+ // operation.
+ // For example, assuming x is the atomic variable:
+ // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
+ template <typename T, typename U,
+ typename = std::enable_if_t<is_numeric_v<T>>>
+ evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
+ if constexpr (is_floating_point_v<T>) {
+ if (!context_.langOptions().AssociativeMath) {
+ return Id::operator()(std::move(x), u);
+ }
+ }
+ // As per the above comment, there are 3 subexpressions involved in this
+ // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is
+ // same as U, plus it will store a pointer (ref) to the matched expression.
+ // When the match is successful, the sub[i].ref will point to a, b, x (in
+ // some order) from the example above.
+ evaluate::match::Expr<T> sub[3];
+ auto inner{reassocOp<T>(sub[0], sub[1])};
+ auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
+ auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+#if !defined(__clang__) && !defined(_MSC_VER) && \
+ (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
+ // If GCC version < 8.5, use this definition. For the other definition
+ // (which is equivalent), GCC 7.5 emits a somewhat cryptic error:
+ // use of ‘outer1’ before deduction of ‘auto’
+ // inside of the visitor function in common::visit.
+ // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming
+ // that this is some kind of a GCC issue.
+ using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>;
+#else
+ using MatchTypes = typename decltype(outer1)::MatchTypes;
+#endif
+ // There is no way to ensure that the outer operation is the same as
+ // the inner one. They are matched independently, so we need to compare
+ // the index in the member variant that represents the matched type.
+ if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) ||
+ (match(outer2, x) && outer2.ref.index() == inner.ref.index())) {
+ size_t atomIdx{[&]() { // sub[atomIdx] will be the atom.
+ size_t idx;
+ for (idx = 0; idx != 3; ++idx) {
+ if (IsAtom(*sub[idx].ref)) {
+ break;
+ }
+ }
+ return idx;
+ }()};
+
+ if (atomIdx > 2) {
+ return Id::operator()(std::move(x), u);
+ }
+ return common::visit(
+ [&](auto &&s) {
+ using Expr = evaluate::Expr<T>;
+ using TypeS = llvm::remove_cvref_t<decltype(s)>;
+ // This visitor has to be semantically correct for all possible
+ // types of s even though at runtime s will only be one of the
+ // matched types.
+ // Limit the construction to the operation types that we tried
+ // to match (otherwise TypeS(op1, op2) would fail for non-binary
+ // operations).
+ if constexpr (common::HasMember<TypeS, MatchTypes>) {
+ Expr atom{*sub[atomIdx].ref};
+ Expr op1{*sub[(atomIdx + 1) % 3].ref};
+ Expr op2{*sub[(atomIdx + 2) % 3].ref};
+ return Expr(
+ TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
+ } else {
+ return Expr(TypeS(s));
+ }
+ },
+ evaluate::match::deparen(x).u);
+ }
+ return Id::operator()(std::move(x), u);
+ }
+
+ template <typename T, typename U,
+ typename = std::enable_if_t<!is_numeric_v<T>>>
+ evaluate::Expr<T> operator()(
+ evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
+ return Id::operator()(std::move(x), u);
+ }
+
+private:
+ template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
+ return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
+ }
+
+ const SomeExpr &atom_;
+ const SemanticsContext &context_;
+};
+
struct AnalyzedCondStmt {
SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
parser::CharBlock source;
@@ -196,6 +355,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
llvm_unreachable("Could not find assignment operator");
}
+static std::vector<SomeExpr> GetNonAtomExpressions(
+ const SomeExpr &atom, const std::vector<SomeExpr> &exprs) {
+ std::vector<SomeExpr> nonAtom;
+ for (const SomeExpr &e : exprs) {
+ if (!IsSameOrConvertOf(e, atom)) {
+ nonAtom.push_back(e);
+ }
+ }
+ return nonAtom;
+}
+
+static std::vector<SomeExpr> GetNonAtomArguments(
+ const SomeExpr &atom, const SomeExpr &expr) {
+ if (auto &&maybe{GetConvertInput(expr)}) {
+ return GetNonAtomExpressions(
+ atom, GetTopLevelOperationIgnoreResizing(*maybe).second);
+ }
+ return {};
+}
+
static bool IsCheckForAssociated(const SomeExpr &cond) {
return GetTopLevelOperationIgnoreResizing(cond).first ==
operation::Operator::Associated;
@@ -222,47 +401,85 @@ static void SetAssignment(parser::AssignmentStmt::TypedAssignment &assign,
}
}
-static parser::OpenMPAtomicConstruct::Analysis::Op MakeAtomicAnalysisOp(
- int what,
- const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) {
- parser::OpenMPAtomicConstruct::Analysis::Op operation;
- operation.what = what;
- SetAssignment(operation.assign, maybeAssign);
- return operation;
-}
+namespace {
+struct AtomicAnalysis {
+ AtomicAnalysis(const SomeExpr &atom, const MaybeExpr &cond = std::nullopt)
+ : atom_(atom), cond_(cond) {}
-static parser::OpenMPAtomicConstruct::Analysis MakeAtomicAnalysis(
- const SomeExpr &atom, const MaybeExpr &cond,
- parser::OpenMPAtomicConstruct::Analysis::Op &&op0,
- parser::OpenMPAtomicConstruct::Analysis::Op &&op1) {
- // Defined in flang/include/flang/Parser/parse-tree.h
- //
- // struct Analysis {
- // struct Kind {
- // static constexpr int None = 0;
- // static constexpr int Read = 1;
- // static constexpr int Write = 2;
- // static constexpr int Update = Read | Write;
- // static constexpr int Action = 3; // Bits containing N, R, W, U
- // static constexpr int IfTrue = 4;
- // static constexpr int IfFalse = 8;
- // static constexpr int Condition = 12; // Bits containing IfTrue, IfFalse
- // };
- // struct Op {
- // int what;
- // TypedAssignment assign;
- // };
- // TypedExpr atom, cond;
- // Op op0, op1;
- // };
-
- parser::OpenMPAtomicConstruct::Analysis an;
- SetExpr(an.atom, atom);
- SetExpr(an.cond, cond);
- an.op0 = std::move(op0);
- an.op1 = std::move(op1);
- return an;
-}
+ AtomicAnalysis &addOp0(int what,
+ const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) {
+ return addOp(op0_, what, maybeAssign);
+ }
+ AtomicAnalysis &addOp1(int what,
+ const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) {
+ return addOp(op1_, what, maybeAssign);
+ }
+
+ operator parser::OpenMPAtomicConstruct::Analysis() const {
+ // Defined in flang/include/flang/Parser/parse-tree.h
+ //
+ // struct Analysis {
+ // struct Kind {
+ // static constexpr int None = 0;
+ // static constexpr int Read = 1;
+ // static constexpr int Write = 2;
+ // static constexpr int Update = Read | Write;
+ // static constexpr int Action = 3; // Bits containing None, Read,
+ // // Write, Update
+ // static constexpr int IfTrue = 4;
+ // static constexpr int IfFalse = 8;
+ // static constexpr int Condition = 12; // Bits containing IfTrue,
+ // // IfFalse
+ // };
+ // struct Op {
+ // int what;
+ // TypedAssignment assign;
+ // };
+ // TypedExpr atom, cond;
+ // Op op0, op1;
+ // };
+
+ parser::OpenMPAtomicConstruct::Analysis an;
+ SetExpr(an.atom, atom_);
+ SetExpr(an.cond, cond_);
+ an.op0 = std::move(op0_);
+ an.op1 = std::move(op1_);
+ return an;
+ }
+
+private:
+ struct Op {
+ operator parser::OpenMPAtomicConstruct::Analysis::Op() const {
+ parser::OpenMPAtomicConstruct::Analysis::Op op;
+ op.what = what;
+ SetAssignment(op.assign, assign);
+ return op;
+ }
+
+ int what;
+ std::optional<evaluate::Assignment> assign;
+ };
+
+ AtomicAnalysis &addOp(Op &op, int what,
+ const std::optional<evaluate::Assignment> &maybeAssign) {
+ op.what = what;
+ if (maybeAssign) {
+ if (MaybeExpr rewritten{PostSemaRewrite(atom_, maybeAssign->rhs)}) {
+ op.assign = evaluate::Assignment(
+ AsRvalue(maybeAssign->lhs), std::move(*rewritten));
+ op.assign->u = std::move(maybeAssign->u);
+ } else {
+ op.assign = *maybeAssign;
+ }
+ }
+ return *this;
+ }
+
+ const SomeExpr &atom_;
+ const MaybeExpr &cond_;
+ Op op0_, op1_;
+};
+} // namespace
/// Check if `expr` satisfies the following conditions for x and v:
///
@@ -535,6 +752,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
const evaluate::Assignment &capture, const SomeExpr &atom,
parser::CharBlock source) {
auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+ (void)lsrc;
const SomeExpr &cap{capture.lhs};
if (!IsVarOrFunctionRef(atom)) {
@@ -551,6 +769,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment(
void OmpStructureChecker::CheckAtomicReadAssignment(
const evaluate::Assignment &read, parser::CharBlock source) {
auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+ (void)lsrc;
if (auto maybe{GetConvertInput(read.rhs)}) {
const SomeExpr &atom{*maybe};
@@ -584,7 +803,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment(
}
}
-void OmpStructureChecker::CheckAtomicUpdateAssignment(
+std::optional<evaluate::Assignment>
+OmpStructureChecker::CheckAtomicUpdateAssignment(
const evaluate::Assignment &update, parser::CharBlock source) {
// [6.0:191:1-7]
// An update structured block is update-statement, an update statement
@@ -600,14 +820,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
if (!IsVarOrFunctionRef(atom)) {
ErrorShouldBeVariable(atom, rsrc);
// Skip other checks.
- return;
+ return std::nullopt;
}
CheckAtomicVariable(atom, lsrc);
+ auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs(
+ atom, update.rhs, source, /*suppressDiagnostics=*/true)};
+
+ if (!hasErrors) {
+ CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source);
+ return std::nullopt;
+ } else if (tryReassoc) {
+ ReassocRewriter ra(atom, context_);
+ SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)};
+
+ std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs(
+ atom, raRhs, source, /*suppressDiagnostics=*/true);
+ if (!hasErrors) {
+ CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source);
+
+ evaluate::Assignment raAssign(update);
+ raAssign.rhs = raRhs;
+ return raAssign;
+ }
+ }
+
+ // This is guaranteed to report errors.
+ CheckAtomicUpdateAssignmentRhs(
+ atom, update.rhs, source, /*suppressDiagnostics=*/false);
+ return std::nullopt;
+}
+
+std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs(
+ const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source,
+ bool suppressDiagnostics) {
+ auto [lsrc, rsrc]{SplitAssignmentSource(source)};
+ (void)lsrc;
+
std::pair<operation::Operator, std::vector<SomeExpr>> top{
operation::Operator::Unknown, {}};
- if (auto &&maybeInput{GetConvertInput(update.rhs)}) {
+ if (auto &&maybeInput{GetConvertInput(rhs)}) {
top = GetTopLevelOperationIgnoreResizing(*maybeInput);
}
switch (top.first) {
@@ -624,29 +877,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
case operation::Operator::Identity:
break;
case operation::Operator::Call:
- context_.Say(source,
- "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
- return;
+ if (!suppressDiagnostics) {
+ context_.Say(source,
+ "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US);
+ }
+ return std::make_pair(true, false);
case operation::Operator::Convert:
- context_.Say(source,
- "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
- return;
+ if (!suppressDiagnostics) {
+ context_.Say(source,
+ "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US);
+ }
+ return std::make_pair(true, false);
case operation::Operator::Intrinsic:
- context_.Say(source,
- "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
- return;
+ if (!suppressDiagnostics) {
+ context_.Say(source,
+ "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US);
+ }
+ return std::make_pair(true, false);
case operation::Operator::Constant:
case operation::Operator::Unknown:
- context_.Say(
- source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
- return;
+ if (!suppressDiagnostics) {
+ context_.Say(
+ source, "This is not a valid ATOMIC UPDATE operation"_err_en_US);
+ }
+ return std::make_pair(true, false);
default:
assert(
top.first != operation::Operator::Identity && "Handle this separately");
- context_.Say(source,
- "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
- operation::ToString(top.first));
- return;
+ if (!suppressDiagnostics) {
+ context_.Say(source,
+ "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US,
+ operation::ToString(top.first));
+ }
+ return std::make_pair(true, false);
}
// Check how many times `atom` occurs as an argument, if it's a subexpression
// of an argument, and collect the non-atom arguments.
@@ -667,39 +930,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment(
return count;
}()};
- bool hasError{false};
+ bool hasError{false}, tryReassoc{false};
if (subExpr) {
- context_.Say(rsrc,
- "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
- atom.AsFortran(), subExpr->AsFortran());
+ if (!suppressDiagnostics) {
+ context_.Say(rsrc,
+ "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US,
+ atom.AsFortran(), subExpr->AsFortran());
+ }
hasError = true;
}
if (top.first == operation::Operator::Identity) {
// This is "x = y".
assert((atomCount == 0 || atomCount == 1) && "Unexpected count");
if (atomCount == 0) {
- context_.Say(rsrc,
- "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
- atom.AsFortran());
+ if (!suppressDiagnostics) {
+ context_.Say(rsrc,
+ "The atomic variable %s should appear as an argument in the update operation"_err_en_US,
+ atom.AsFortran());
+ }
hasError = true;
}
} else {
if (atomCount == 0) {
- context_.Say(rsrc,
- "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
- atom.AsFortran(), operation::ToString(top.first));
+ if (!suppressDiagnostics) {
+ context_.Say(rsrc,
+ "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US,
+ atom.AsFortran(), operation::ToString(top.first));
+ }
+ // If `atom` is a proper subexpression, and it not present as an
+ // argument on its own, reassociation may be able to help.
+ tryReassoc = subExpr.has_value();
hasError = true;
} else if (atomCount > 1) {
- context_.Say(rsrc,
- "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
- atom.AsFortran(), operation::ToString(top.first));
+ if (!suppressDiagnostics) {
+ context_.Say(rsrc,
+ "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US,
+ atom.AsFortran(), operation::ToString(top.first));
+ }
hasError = true;
}
}
- if (!hasError) {
- CheckStorageOverlap(atom, nonAtom, source);
- }
+ return std::make_pair(hasError, tryReassoc);
}
void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment(
@@ -802,12 +1074,14 @@ void OmpStructureChecker::CheckAtomicUpdateOnly(
SourcedActionStmt action{GetActionStmt(&body.front())};
if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) {
const SomeExpr &atom{maybeUpdate->lhs};
- CheckAtomicUpdateAssignment(*maybeUpdate, action.source);
+ auto maybeAssign{
+ CheckAtomicUpdateAssignment(*maybeUpdate, action.source)};
+ auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate};
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
- x.analysis = MakeAtomicAnalysis(atom, std::nullopt,
- MakeAtomicAnalysisOp(Analysis::Update, maybeUpdate),
- MakeAtomicAnalysisOp(Analysis::None));
+ x.analysis = AtomicAnalysis(atom)
+ .addOp0(Analysis::Update, updateAssign)
+ .addOp1(Analysis::None);
} else if (!IsAssignment(action.stmt)) {
context_.Say(
source, "ATOMIC UPDATE operation should be an assignment"_err_en_US);
@@ -889,9 +1163,11 @@ void OmpStructureChecker::CheckAtomicConditionalUpdate(
}
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
- x.analysis = MakeAtomicAnalysis(assign.lhs, update.cond,
- MakeAtomicAnalysisOp(Analysis::Update | Analysis::IfTrue, assign),
- MakeAtomicAnalysisOp(Analysis::None));
+ const SomeExpr &atom{assign.lhs};
+
+ x.analysis = AtomicAnalysis(atom, update.cond)
+ .addOp0(Analysis::Update | Analysis::IfTrue, assign)
+ .addOp1(Analysis::None);
}
void OmpStructureChecker::CheckAtomicUpdateCapture(
@@ -920,29 +1196,32 @@ void OmpStructureChecker::CheckAtomicUpdateCapture(
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
int action;
+ std::optional<evaluate::Assignment> updateAssign{update};
if (IsMaybeAtomicWrite(update)) {
action = Analysis::Write;
CheckAtomicWriteAssignment(update, uact.source);
} else {
action = Analysis::Update;
- CheckAtomicUpdateAssignment(update, uact.source);
+ if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) {
+ updateAssign = maybe;
+ }
}
CheckAtomicCaptureAssignment(capture, atom, cact.source);
- if (IsPointerAssignment(update) != IsPointerAssignment(capture)) {
+ if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) {
context_.Say(cact.source,
"The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US);
return;
}
if (GetActionStmt(&body.front()).stmt == uact.stmt) {
- x.analysis = MakeAtomicAnalysis(atom, std::nullopt,
- MakeAtomicAnalysisOp(action, update),
- MakeAtomicAnalysisOp(Analysis::Read, capture));
+ x.analysis = AtomicAnalysis(atom)
+ .addOp0(action, updateAssign)
+ .addOp1(Analysis::Read, capture);
} else {
- x.analysis = MakeAtomicAnalysis(atom, std::nullopt,
- MakeAtomicAnalysisOp(Analysis::Read, capture),
- MakeAtomicAnalysisOp(action, update));
+ x.analysis = AtomicAnalysis(atom)
+ .addOp0(Analysis::Read, capture)
+ .addOp1(action, updateAssign);
}
}
@@ -1087,15 +1366,16 @@ void OmpStructureChecker::CheckAtomicConditionalUpdateCapture(
evaluate::Assignment updAssign{*GetEvaluateAssignment(update.ift.stmt)};
evaluate::Assignment capAssign{*GetEvaluateAssignment(capture.stmt)};
+ const SomeExpr &atom{updAssign.lhs};
if (captureFirst) {
- x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond,
- MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign),
- MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign));
+ x.analysis = AtomicAnalysis(atom, update.cond)
+ .addOp0(Analysis::Read | captureWhen, capAssign)
+ .addOp1(Analysis::Write | updateWhen, updAssign);
} else {
- x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond,
- MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign),
- MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign));
+ x.analysis = AtomicAnalysis(atom, update.cond)
+ .addOp0(Analysis::Write | updateWhen, updAssign)
+ .addOp1(Analysis::Read | captureWhen, capAssign);
}
}
@@ -1125,9 +1405,9 @@ void OmpStructureChecker::CheckAtomicRead(
if (auto maybe{GetConvertInput(maybeRead->rhs)}) {
const SomeExpr &atom{*maybe};
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
- x.analysis = MakeAtomicAnalysis(atom, std::nullopt,
- MakeAtomicAnalysisOp(Analysis::Read, maybeRead),
- MakeAtomicAnalysisOp(Analysis::None));
+ x.analysis = AtomicAnalysis(atom)
+ .addOp0(Analysis::Read, maybeRead)
+ .addOp1(Analysis::None);
}
} else if (!IsAssignment(action.stmt)) {
context_.Say(
@@ -1159,9 +1439,9 @@ void OmpStructureChecker::CheckAtomicWrite(
CheckAtomicWriteAssignment(*maybeWrite, action.source);
using Analysis = parser::OpenMPAtomicConstruct::Analysis;
- x.analysis = MakeAtomicAnalysis(atom, std::nullopt,
- MakeAtomicAnalysisOp(Analysis::Write, maybeWrite),
- MakeAtomicAnalysisOp(Analysis::None));
+ x.analysis = AtomicAnalysis(atom)
+ .addOp0(Analysis::Write, maybeWrite)
+ .addOp1(Analysis::None);
} else if (!IsAssignment(action.stmt)) {
context_.Say(
x.source, "ATOMIC WRITE operation should be an assignment"_err_en_US);
@@ -1260,4 +1540,118 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) {
dirContext_.pop_back();
}
+// Rewrite min/max:
+// Min and max intrinsics in Fortran take an arbitrary number of arguments
+// (two or more). The first two are mandatory, the rest is optional. That
+// means that arguments beyond the first two may be optional dummy argument
+// from the caller. In that case, a reference to such an argument will
+// cause presence test to be emitted, which cannot go inside of the atomic
+// operation. Since the atom operand must be present, rewrite the min/max
+// operation in a way that avoid the presence tests in the atomic code.
+// For example, in
+// subroutine f(atom, x, y, z)
+// integer :: atom, x
+// integer, optional :: y, z
+// !$omp atomic update
+// atom = min(atom, x, y, z)
+// end
+// the min operation will become
+// atom = min(atom, min(x, y, z))
+// and in the final code
+// // Presence check is fine here.
+// tmp = min(x, y, z)
+// atomic update {
+// // Both operands are mandatory, no presence check needed.
+// atom = min(atom, tmp)
+// }
+struct MinMaxRewriter : public evaluate::rewrite::Identity {
+ using Id = evaluate::rewrite::Identity;
+ using Id::operator();
+
+ MinMaxRewriter(const SomeExpr &atom) : atom_(atom) {}
+
+ static bool IsMinMax(const evaluate::ProcedureDesignator &p) {
+ if (auto *intrin{p.GetSpecificIntrinsic()}) {
+ return intrin->name == "min" || intrin->name == "max";
+ }
+ return false;
+ }
+
+ // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
+ // One of the a_i's, say a_t, must be the atom.
+ // Generate
+ // min/max(a_t, min/max(a0, a1, ... [except a_t]))
+ template <typename T>
+ evaluate::Expr<T> operator()(
+ evaluate::Expr<T> &&x, const evaluate::FunctionRef<T> &f) {
+ const evaluate::ProcedureDesignator &proc = f.proc();
+ if (!IsMinMax(proc) || f.arguments().size() <= 2) {
+ return Id::operator()(std::move(x), f);
+ }
+
+ // Collect arguments as SomeExpr's and find out which argument
+ // corresponds to atom.
+ const SomeExpr *atomArg{nullptr};
+ std::vector<const SomeExpr *> args;
+ for (const std::optional<evaluate::ActualArgument> &a : f.arguments()) {
+ if (!a) {
+ continue;
+ }
+ if (const SomeExpr *e{a->UnwrapExpr()}) {
+ if (evaluate::IsSameOrConvertOf(*e, atom_)) {
+ atomArg = e;
+ }
+ args.push_back(e);
+ }
+ }
+ if (!atomArg) {
+ return Id::operator()(std::move(x), f);
+ }
+
+ evaluate::ActualArguments nonAtoms;
+
+ auto AsActual = [](const SomeExpr &z) {
+ SomeExpr copy = z;
+ return evaluate::ActualArgument(std::move(copy));
+ };
+ // Semantic checks guarantee that the "atom" shows exactly once in the
+ // argument list (with potential conversions around it).
+ // For the first two (non-optional) arguments, if "atom" is among them,
+ // replace it with another occurrence of the other non-optional argument.
+ if (atomArg == args[0]) {
+ // (atom, x, y...) -> (x, x, y...)
+ nonAtoms.push_back(AsActual(*args[1]));
+ nonAtoms.push_back(AsActual(*args[1]));
+ } else if (atomArg == args[1]) {
+ // (x, atom, y...) -> (x, x, y...)
+ nonAtoms.push_back(AsActual(*args[0]));
+ nonAtoms.push_back(AsActual(*args[0]));
+ } else {
+ // (x, y, z...) -> unchanged
+ nonAtoms.push_back(AsActual(*args[0]));
+ nonAtoms.push_back(AsActual(*args[1]));
+ }
+
+ // The rest of arguments are optional, so we can just skip "atom".
+ for (size_t i = 2, e = args.size(); i != e; ++i) {
+ if (atomArg != args[i])
+ nonAtoms.push_back(AsActual(*args[i]));
+ }
+
+ SomeExpr tmp = evaluate::AsGenericExpr(
+ evaluate::FunctionRef<T>(AsRvalue(proc), AsRvalue(nonAtoms)));
+
+ return evaluate::Expr<T>(evaluate::FunctionRef<T>(
+ AsRvalue(proc), {AsActual(*atomArg), AsActual(tmp)}));
+ }
+
+private:
+ const SomeExpr &atom_;
+};
+
+static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr) {
+ MinMaxRewriter rewriter(atom);
+ return evaluate::rewrite::Mutator(rewriter)(expr);
+}
+
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 59d57a2..9384e03 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -13,7 +13,6 @@
#include "check-omp-structure.h"
#include "check-directive-structure.h"
-#include "openmp-utils.h"
#include "flang/Common/idioms.h"
#include "flang/Common/visit.h"
@@ -23,6 +22,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/openmp-modifiers.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
@@ -196,7 +196,7 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) {
common::visit(
common::visitors{
// Allow `!$OMP ORDERED SIMD`
- [&](const parser::OpenMPBlockConstruct &c) {
+ [&](const parser::OmpBlockConstruct &c) {
const parser::OmpDirectiveSpecification &beginSpec{c.BeginDir()};
if (beginSpec.DirId() == llvm::omp::Directive::OMPD_ordered) {
for (const auto &clause : beginSpec.Clauses().v) {
diff --git a/flang/lib/Semantics/check-omp-metadirective.cpp b/flang/lib/Semantics/check-omp-metadirective.cpp
index 03487da..cf5ea90 100644
--- a/flang/lib/Semantics/check-omp-metadirective.cpp
+++ b/flang/lib/Semantics/check-omp-metadirective.cpp
@@ -12,8 +12,6 @@
#include "check-omp-structure.h"
-#include "openmp-utils.h"
-
#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
#include "flang/Common/visit.h"
@@ -21,6 +19,7 @@
#include "flang/Parser/message.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-modifiers.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/tools.h"
#include "llvm/Frontend/OpenMP/OMP.h"
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index a9c56c3..85d79a00 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -10,7 +10,6 @@
#include "check-directive-structure.h"
#include "definable.h"
-#include "openmp-utils.h"
#include "resolve-names-utils.h"
#include "flang/Common/idioms.h"
@@ -21,12 +20,14 @@
#include "flang/Parser/char-block.h"
#include "flang/Parser/characters.h"
#include "flang/Parser/message.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/openmp-modifiers.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
@@ -57,6 +58,7 @@
namespace Fortran::semantics {
using namespace Fortran::semantics::omp;
+using namespace Fortran::parser::omp;
// Use when clause falls under 'struct OmpClause' in 'parse-tree.h'.
#define CHECK_SIMPLE_CLAUSE(X, Y) \
@@ -141,6 +143,64 @@ private:
parser::CharBlock source_;
};
+// 'OmpWorkdistributeBlockChecker' is used to check the validity of the
+// assignment statements and the expressions enclosed in an OpenMP
+// WORKDISTRIBUTE construct
+class OmpWorkdistributeBlockChecker {
+public:
+ OmpWorkdistributeBlockChecker(
+ SemanticsContext &context, parser::CharBlock source)
+ : context_{context}, source_{source} {}
+
+ template <typename T> bool Pre(const T &) { return true; }
+ template <typename T> void Post(const T &) {}
+
+ bool Pre(const parser::AssignmentStmt &assignment) {
+ const auto &var{std::get<parser::Variable>(assignment.t)};
+ const auto &expr{std::get<parser::Expr>(assignment.t)};
+ const auto *lhs{GetExpr(context_, var)};
+ const auto *rhs{GetExpr(context_, expr)};
+ if (lhs && rhs) {
+ Tristate isDefined{semantics::IsDefinedAssignment(
+ lhs->GetType(), lhs->Rank(), rhs->GetType(), rhs->Rank())};
+ if (isDefined == Tristate::Yes) {
+ context_.Say(expr.source,
+ "Defined assignment statement is not allowed in a WORKDISTRIBUTE construct"_err_en_US);
+ }
+ }
+ return true;
+ }
+
+ bool Pre(const parser::Expr &expr) {
+ if (const auto *e{GetExpr(context_, expr)}) {
+ if (!e)
+ return false;
+ for (const Symbol &symbol : evaluate::CollectSymbols(*e)) {
+ const Symbol &root{GetAssociationRoot(symbol)};
+ if (IsFunction(root)) {
+ std::vector<std::string> attrs;
+ if (!IsElementalProcedure(root)) {
+ attrs.push_back("non-ELEMENTAL");
+ }
+ if (root.attrs().test(Attr::IMPURE)) {
+ attrs.push_back("IMPURE");
+ }
+ std::string attrsStr =
+ attrs.empty() ? "" : " " + llvm::join(attrs, ", ");
+ context_.Say(expr.source,
+ "User defined%s function '%s' is not allowed in a WORKDISTRIBUTE construct"_err_en_US,
+ attrsStr, root.name());
+ }
+ }
+ }
+ return false;
+ }
+
+private:
+ SemanticsContext &context_;
+ parser::CharBlock source_;
+};
+
// `OmpUnitedTaskDesignatorChecker` is used to check if the designator
// can appear within the TASK construct
class OmpUnitedTaskDesignatorChecker {
@@ -208,6 +268,41 @@ bool OmpStructureChecker::CheckAllowedClause(llvmOmpClause clause) {
return CheckAllowed(clause);
}
+void OmpStructureChecker::AnalyzeObject(const parser::OmpObject &object) {
+ if (std::holds_alternative<parser::Name>(object.u)) {
+ // Do not analyze common block names. The analyzer will flag an error
+ // on those.
+ return;
+ }
+ if (auto *symbol{GetObjectSymbol(object)}) {
+ // Eliminate certain kinds of symbols before running the analyzer to
+ // avoid confusing error messages. The analyzer assumes that the context
+ // of the object use is an expression, and some diagnostics are tailored
+ // to that.
+ if (symbol->has<DerivedTypeDetails>() || symbol->has<MiscDetails>()) {
+ // Type names, construct names, etc.
+ return;
+ }
+ if (auto *typeSpec{symbol->GetType()}) {
+ if (typeSpec->category() == DeclTypeSpec::Category::Character) {
+ // Don't pass character objects to the analyzer, it can emit somewhat
+ // cryptic errors (e.g. "'obj' is not an array"). Substrings are
+ // checked elsewhere in OmpStructureChecker.
+ return;
+ }
+ }
+ }
+ evaluate::ExpressionAnalyzer ea{context_};
+ auto restore{ea.AllowWholeAssumedSizeArray(true)};
+ common::visit([&](auto &&s) { ea.Analyze(s); }, object.u);
+}
+
+void OmpStructureChecker::AnalyzeObjects(const parser::OmpObjectList &objects) {
+ for (const parser::OmpObject &object : objects.v) {
+ AnalyzeObject(object);
+ }
+}
+
bool OmpStructureChecker::IsCloselyNestedRegion(const OmpDirectiveSet &set) {
// Definition of close nesting:
//
@@ -529,22 +624,6 @@ template <typename Checker> struct DirectiveSpellingVisitor {
checker_(GetDirName(x.t).source, Directive::OMPD_allocators);
return false;
}
- bool Pre(const parser::OmpAssumeDirective &x) {
- checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_assume);
- return false;
- }
- bool Pre(const parser::OmpEndAssumeDirective &x) {
- checker_(x.v.source, Directive::OMPD_assume);
- return false;
- }
- bool Pre(const parser::OmpCriticalDirective &x) {
- checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_critical);
- return false;
- }
- bool Pre(const parser::OmpEndCriticalDirective &x) {
- checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_critical);
- return false;
- }
bool Pre(const parser::OmpMetadirectiveDirective &x) {
checker_(
std::get<parser::Verbatim>(x.t).source, Directive::OMPD_metadirective);
@@ -579,6 +658,10 @@ template <typename Checker> struct DirectiveSpellingVisitor {
Directive::OMPD_declare_variant);
return false;
}
+ bool Pre(const parser::OpenMPGroupprivate &x) {
+ checker_(x.v.DirName().source, Directive::OMPD_groupprivate);
+ return false;
+ }
bool Pre(const parser::OpenMPThreadprivate &x) {
checker_(
std::get<parser::Verbatim>(x.t).source, Directive::OMPD_threadprivate);
@@ -731,7 +814,7 @@ void OmpStructureChecker::CheckTargetNest(const parser::OpenMPConstruct &c) {
parser::CharBlock source;
common::visit(
common::visitors{
- [&](const parser::OpenMPBlockConstruct &c) {
+ [&](const parser::OmpBlockConstruct &c) {
const parser::OmpDirectiveSpecification &beginSpec{c.BeginDir()};
source = beginSpec.DirName().source;
if (beginSpec.DirId() == llvm::omp::Directive::OMPD_target_data) {
@@ -781,12 +864,44 @@ void OmpStructureChecker::CheckTargetNest(const parser::OpenMPConstruct &c) {
}
}
-void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) {
+void OmpStructureChecker::Enter(const parser::OmpBlockConstruct &x) {
const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
const std::optional<parser::OmpEndDirective> &endSpec{x.EndDir()};
const parser::Block &block{std::get<parser::Block>(x.t)};
PushContextAndClauseSets(beginSpec.DirName().source, beginSpec.DirId());
+
+ // Missing mandatory end block: this is checked in semantics because that
+ // makes it easier to control the error messages.
+ // The end block is mandatory when the construct is not applied to a strictly
+ // structured block (aka it is applied to a loosely structured block). In
+ // other words, the body doesn't contain exactly one parser::BlockConstruct.
+ auto isStrictlyStructuredBlock{[](const parser::Block &block) -> bool {
+ if (block.size() != 1) {
+ return false;
+ }
+ const parser::ExecutionPartConstruct &contents{block.front()};
+ auto *executableConstruct{
+ std::get_if<parser::ExecutableConstruct>(&contents.u)};
+ if (!executableConstruct) {
+ return false;
+ }
+ return std::holds_alternative<common::Indirection<parser::BlockConstruct>>(
+ executableConstruct->u);
+ }};
+ if (!endSpec && !isStrictlyStructuredBlock(block)) {
+ llvm::omp::Directive dirId{beginSpec.DirId()};
+ auto &msg{context_.Say(beginSpec.source,
+ "Expected OpenMP END %s directive"_err_en_US,
+ parser::ToUpperCaseLetters(getDirectiveName(dirId)))};
+ // ORDERED has two variants, so be explicit about which variant we think
+ // this is.
+ if (dirId == llvm::omp::Directive::OMPD_ordered) {
+ msg.Attach(
+ beginSpec.source, "The ORDERED directive is block-associated"_en_US);
+ }
+ }
+
if (llvm::omp::allTargetSet.test(GetContext().directive)) {
EnterDirectiveNest(TargetNest);
}
@@ -817,6 +932,12 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) {
"TARGET construct with nested TEAMS region contains statements or "
"directives outside of the TEAMS construct"_err_en_US);
}
+ if (GetContext().directive == llvm::omp::Directive::OMPD_workdistribute &&
+ GetContextParent().directive != llvm::omp::Directive::OMPD_teams) {
+ context_.Say(x.BeginDir().DirName().source,
+ "%s region can only be strictly nested within TEAMS region"_err_en_US,
+ ContextDirectiveAsFortran());
+ }
}
CheckNoBranching(block, beginSpec.DirId(), beginSpec.source);
@@ -900,6 +1021,17 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) {
HasInvalidWorksharingNesting(
beginSpec.source, llvm::omp::nestedWorkshareErrSet);
break;
+ case llvm::omp::OMPD_workdistribute:
+ if (!CurrentDirectiveIsNested()) {
+ context_.Say(beginSpec.source,
+ "A WORKDISTRIBUTE region must be nested inside TEAMS region only."_err_en_US);
+ }
+ CheckWorkdistributeBlockStmts(block, beginSpec.source);
+ break;
+ case llvm::omp::OMPD_teams_workdistribute:
+ case llvm::omp::OMPD_target_teams_workdistribute:
+ CheckWorkdistributeBlockStmts(block, beginSpec.source);
+ break;
case llvm::omp::Directive::OMPD_scope:
case llvm::omp::Directive::OMPD_single:
// TODO: This check needs to be extended while implementing nesting of
@@ -921,7 +1053,7 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) {
}
void OmpStructureChecker::CheckMasterNesting(
- const parser::OpenMPBlockConstruct &x) {
+ const parser::OmpBlockConstruct &x) {
// A MASTER region may not be `closely nested` inside a worksharing, loop,
// task, taskloop, or atomic region.
// TODO: Expand the check to include `LOOP` construct as well when it is
@@ -950,7 +1082,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAssumes &) {
dirContext_.pop_back();
}
-void OmpStructureChecker::Leave(const parser::OpenMPBlockConstruct &) {
+void OmpStructureChecker::Leave(const parser::OmpBlockConstruct &) {
if (GetDirectiveNest(TargetBlockOnlyTeams)) {
ExitDirectiveNest(TargetBlockOnlyTeams);
}
@@ -1041,14 +1173,23 @@ void OmpStructureChecker::Leave(const parser::OmpBeginDirective &) {
void OmpStructureChecker::Enter(const parser::OpenMPSectionsConstruct &x) {
const auto &beginSectionsDir{
std::get<parser::OmpBeginSectionsDirective>(x.t)};
- const auto &endSectionsDir{std::get<parser::OmpEndSectionsDirective>(x.t)};
+ const auto &endSectionsDir{
+ std::get<std::optional<parser::OmpEndSectionsDirective>>(x.t)};
const auto &beginDir{
std::get<parser::OmpSectionsDirective>(beginSectionsDir.t)};
- const auto &endDir{std::get<parser::OmpSectionsDirective>(endSectionsDir.t)};
+ PushContextAndClauseSets(beginDir.source, beginDir.v);
+
+ if (!endSectionsDir) {
+ context_.Say(beginSectionsDir.source,
+ "Expected OpenMP END SECTIONS directive"_err_en_US);
+ // Following code assumes the option is present.
+ return;
+ }
+
+ const auto &endDir{std::get<parser::OmpSectionsDirective>(endSectionsDir->t)};
CheckMatching<parser::OmpSectionsDirective>(beginDir, endDir);
- PushContextAndClauseSets(beginDir.source, beginDir.v);
- AddEndDirectiveClauses(std::get<parser::OmpClauseList>(endSectionsDir.t));
+ AddEndDirectiveClauses(std::get<parser::OmpClauseList>(endSectionsDir->t));
const auto &sectionBlocks{std::get<std::list<parser::OpenMPConstruct>>(x.t)};
for (const parser::OpenMPConstruct &construct : sectionBlocks) {
@@ -1090,113 +1231,155 @@ void OmpStructureChecker::Leave(const parser::OmpEndSectionsDirective &x) {
}
void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar(
+ const parser::Designator &designator) {
+ auto *name{parser::Unwrap<parser::Name>(designator)};
+ // If the symbol is null, return early, CheckSymbolNames
+ // should have already reported the missing symbol as a
+ // diagnostic error
+ if (!name || !name->symbol) {
+ return;
+ }
+
+ llvm::omp::Directive directive{GetContext().directive};
+
+ if (name->symbol->GetUltimate().IsSubprogram()) {
+ if (directive == llvm::omp::Directive::OMPD_threadprivate)
+ context_.Say(name->source,
+ "The procedure name cannot be in a %s directive"_err_en_US,
+ ContextDirectiveAsFortran());
+ // TODO: Check for procedure name in declare target directive.
+ } else if (name->symbol->attrs().test(Attr::PARAMETER)) {
+ if (directive == llvm::omp::Directive::OMPD_threadprivate)
+ context_.Say(name->source,
+ "The entity with PARAMETER attribute cannot be in a %s directive"_err_en_US,
+ ContextDirectiveAsFortran());
+ else if (directive == llvm::omp::Directive::OMPD_declare_target)
+ context_.Warn(common::UsageWarning::OpenMPUsage, name->source,
+ "The entity with PARAMETER attribute is used in a %s directive"_warn_en_US,
+ ContextDirectiveAsFortran());
+ } else if (FindCommonBlockContaining(*name->symbol)) {
+ context_.Say(name->source,
+ "A variable in a %s directive cannot be an element of a common block"_err_en_US,
+ ContextDirectiveAsFortran());
+ } else if (FindEquivalenceSet(*name->symbol)) {
+ context_.Say(name->source,
+ "A variable in a %s directive cannot appear in an EQUIVALENCE statement"_err_en_US,
+ ContextDirectiveAsFortran());
+ } else if (name->symbol->test(Symbol::Flag::OmpThreadprivate) &&
+ directive == llvm::omp::Directive::OMPD_declare_target) {
+ context_.Say(name->source,
+ "A THREADPRIVATE variable cannot appear in a %s directive"_err_en_US,
+ ContextDirectiveAsFortran());
+ } else {
+ const semantics::Scope &useScope{
+ context_.FindScope(GetContext().directiveSource)};
+ const semantics::Scope &curScope = name->symbol->GetUltimate().owner();
+ if (!curScope.IsTopLevel()) {
+ const semantics::Scope &declScope =
+ GetProgramUnitOrBlockConstructContaining(curScope);
+ const semantics::Symbol *sym{
+ declScope.parent().FindSymbol(name->symbol->name())};
+ if (sym &&
+ (sym->has<MainProgramDetails>() || sym->has<ModuleDetails>())) {
+ context_.Say(name->source,
+ "The module name cannot be in a %s directive"_err_en_US,
+ ContextDirectiveAsFortran());
+ } else if (!IsSaved(*name->symbol) &&
+ declScope.kind() != Scope::Kind::MainProgram &&
+ declScope.kind() != Scope::Kind::Module) {
+ context_.Say(name->source,
+ "A variable that appears in a %s directive must be declared in the scope of a module or have the SAVE attribute, either explicitly or implicitly"_err_en_US,
+ ContextDirectiveAsFortran());
+ } else if (useScope != declScope) {
+ context_.Say(name->source,
+ "The %s directive and the common block or variable in it must appear in the same declaration section of a scoping unit"_err_en_US,
+ ContextDirectiveAsFortran());
+ }
+ }
+ }
+}
+
+void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar(
+ const parser::Name &name) {
+ if (!name.symbol) {
+ return;
+ }
+
+ if (auto *cb{name.symbol->detailsIf<CommonBlockDetails>()}) {
+ for (const auto &obj : cb->objects()) {
+ if (FindEquivalenceSet(*obj)) {
+ context_.Say(name.source,
+ "A variable in a %s directive cannot appear in an EQUIVALENCE statement (variable '%s' from common block '/%s/')"_err_en_US,
+ ContextDirectiveAsFortran(), obj->name(), name.symbol->name());
+ }
+ }
+ }
+}
+
+void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar(
const parser::OmpObjectList &objList) {
for (const auto &ompObject : objList.v) {
- common::visit(
- common::visitors{
- [&](const parser::Designator &) {
- if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
- // The symbol is null, return early, CheckSymbolNames
- // should have already reported the missing symbol as a
- // diagnostic error
- if (!name->symbol) {
- return;
- }
-
- if (name->symbol->GetUltimate().IsSubprogram()) {
- if (GetContext().directive ==
- llvm::omp::Directive::OMPD_threadprivate)
- context_.Say(name->source,
- "The procedure name cannot be in a %s "
- "directive"_err_en_US,
- ContextDirectiveAsFortran());
- // TODO: Check for procedure name in declare target directive.
- } else if (name->symbol->attrs().test(Attr::PARAMETER)) {
- if (GetContext().directive ==
- llvm::omp::Directive::OMPD_threadprivate)
- context_.Say(name->source,
- "The entity with PARAMETER attribute cannot be in a %s "
- "directive"_err_en_US,
- ContextDirectiveAsFortran());
- else if (GetContext().directive ==
- llvm::omp::Directive::OMPD_declare_target)
- context_.Warn(common::UsageWarning::OpenMPUsage,
- name->source,
- "The entity with PARAMETER attribute is used in a %s directive"_warn_en_US,
- ContextDirectiveAsFortran());
- } else if (FindCommonBlockContaining(*name->symbol)) {
- context_.Say(name->source,
- "A variable in a %s directive cannot be an element of a "
- "common block"_err_en_US,
- ContextDirectiveAsFortran());
- } else if (FindEquivalenceSet(*name->symbol)) {
- context_.Say(name->source,
- "A variable in a %s directive cannot appear in an "
- "EQUIVALENCE statement"_err_en_US,
- ContextDirectiveAsFortran());
- } else if (name->symbol->test(Symbol::Flag::OmpThreadprivate) &&
- GetContext().directive ==
- llvm::omp::Directive::OMPD_declare_target) {
- context_.Say(name->source,
- "A THREADPRIVATE variable cannot appear in a %s "
- "directive"_err_en_US,
- ContextDirectiveAsFortran());
- } else {
- const semantics::Scope &useScope{
- context_.FindScope(GetContext().directiveSource)};
- const semantics::Scope &curScope =
- name->symbol->GetUltimate().owner();
- if (!curScope.IsTopLevel()) {
- const semantics::Scope &declScope =
- GetProgramUnitOrBlockConstructContaining(curScope);
- const semantics::Symbol *sym{
- declScope.parent().FindSymbol(name->symbol->name())};
- if (sym &&
- (sym->has<MainProgramDetails>() ||
- sym->has<ModuleDetails>())) {
- context_.Say(name->source,
- "The module name cannot be in a %s "
- "directive"_err_en_US,
- ContextDirectiveAsFortran());
- } else if (!IsSaved(*name->symbol) &&
- declScope.kind() != Scope::Kind::MainProgram &&
- declScope.kind() != Scope::Kind::Module) {
- context_.Say(name->source,
- "A variable that appears in a %s directive must be "
- "declared in the scope of a module or have the SAVE "
- "attribute, either explicitly or "
- "implicitly"_err_en_US,
- ContextDirectiveAsFortran());
- } else if (useScope != declScope) {
- context_.Say(name->source,
- "The %s directive and the common block or variable "
- "in it must appear in the same declaration section "
- "of a scoping unit"_err_en_US,
- ContextDirectiveAsFortran());
- }
- }
- }
- }
- },
- [&](const parser::Name &name) {
- if (name.symbol) {
- if (auto *cb{name.symbol->detailsIf<CommonBlockDetails>()}) {
- for (const auto &obj : cb->objects()) {
- if (FindEquivalenceSet(*obj)) {
- context_.Say(name.source,
- "A variable in a %s directive cannot appear in an EQUIVALENCE statement (variable '%s' from common block '/%s/')"_err_en_US,
- ContextDirectiveAsFortran(), obj->name(),
- name.symbol->name());
- }
- }
- }
- }
- },
- },
+ common::visit([&](auto &&s) { CheckThreadprivateOrDeclareTargetVar(s); },
ompObject.u);
}
}
+void OmpStructureChecker::Enter(const parser::OpenMPGroupprivate &x) {
+ PushContextAndClauseSets(
+ x.v.DirName().source, llvm::omp::Directive::OMPD_groupprivate);
+
+ for (const parser::OmpArgument &arg : x.v.Arguments().v) {
+ auto *locator{std::get_if<parser::OmpLocator>(&arg.u)};
+ const Symbol *sym{GetArgumentSymbol(arg)};
+
+ if (!locator || !sym ||
+ (!IsVariableListItem(*sym) && !IsCommonBlock(*sym))) {
+ context_.Say(arg.source,
+ "GROUPPRIVATE argument should be a variable or a named common block"_err_en_US);
+ continue;
+ }
+
+ if (sym->has<AssocEntityDetails>()) {
+ context_.SayWithDecl(*sym, arg.source,
+ "GROUPPRIVATE argument cannot be an ASSOCIATE name"_err_en_US);
+ continue;
+ }
+ if (auto *obj{sym->detailsIf<ObjectEntityDetails>()}) {
+ if (obj->IsCoarray()) {
+ context_.Say(
+ arg.source, "GROUPPRIVATE argument cannot be a coarray"_err_en_US);
+ continue;
+ }
+ if (obj->init()) {
+ context_.SayWithDecl(*sym, arg.source,
+ "GROUPPRIVATE argument cannot be declared with an initializer"_err_en_US);
+ continue;
+ }
+ }
+ if (sym->test(Symbol::Flag::InCommonBlock)) {
+ context_.Say(arg.source,
+ "GROUPPRIVATE argument cannot be a member of a common block"_err_en_US);
+ continue;
+ }
+ if (!IsCommonBlock(*sym)) {
+ const Scope &thisScope{context_.FindScope(x.v.source)};
+ if (thisScope != sym->owner()) {
+ context_.SayWithDecl(*sym, arg.source,
+ "GROUPPRIVATE argument variable must be declared in the same scope as the construct on which it appears"_err_en_US);
+ continue;
+ } else if (!thisScope.IsModule() && !sym->attrs().test(Attr::SAVE)) {
+ context_.SayWithDecl(*sym, arg.source,
+ "GROUPPRIVATE argument variable must be declared in the module scope or have SAVE attribute"_err_en_US);
+ continue;
+ }
+ }
+ }
+}
+
+void OmpStructureChecker::Leave(const parser::OpenMPGroupprivate &x) {
+ dirContext_.pop_back();
+}
+
void OmpStructureChecker::Enter(const parser::OpenMPThreadprivate &c) {
const auto &dir{std::get<parser::Verbatim>(c.t)};
PushContextAndClauseSets(
@@ -2034,41 +2217,87 @@ void OmpStructureChecker::Leave(const parser::OpenMPCancelConstruct &) {
}
void OmpStructureChecker::Enter(const parser::OpenMPCriticalConstruct &x) {
- const auto &dir{std::get<parser::OmpCriticalDirective>(x.t)};
- const auto &dirSource{std::get<parser::Verbatim>(dir.t).source};
- const auto &endDir{std::get<parser::OmpEndCriticalDirective>(x.t)};
- PushContextAndClauseSets(dirSource, llvm::omp::Directive::OMPD_critical);
+ const parser::OmpBeginDirective &beginSpec{x.BeginDir()};
+ const std::optional<parser::OmpEndDirective> &endSpec{x.EndDir()};
+ PushContextAndClauseSets(beginSpec.DirName().source, beginSpec.DirName().v);
+
const auto &block{std::get<parser::Block>(x.t)};
- CheckNoBranching(block, llvm::omp::Directive::OMPD_critical, dir.source);
- const auto &dirName{std::get<std::optional<parser::Name>>(dir.t)};
- const auto &endDirName{std::get<std::optional<parser::Name>>(endDir.t)};
- const auto &ompClause{std::get<parser::OmpClauseList>(dir.t)};
- if (dirName && endDirName &&
- dirName->ToString().compare(endDirName->ToString())) {
- context_
- .Say(endDirName->source,
- parser::MessageFormattedText{
- "CRITICAL directive names do not match"_err_en_US})
- .Attach(dirName->source, "should be "_en_US);
- } else if (dirName && !endDirName) {
- context_
- .Say(dirName->source,
- parser::MessageFormattedText{
- "CRITICAL directive names do not match"_err_en_US})
- .Attach(dirName->source, "should be NULL"_en_US);
- } else if (!dirName && endDirName) {
- context_
- .Say(endDirName->source,
- parser::MessageFormattedText{
- "CRITICAL directive names do not match"_err_en_US})
- .Attach(endDirName->source, "should be NULL"_en_US);
- }
- if (!dirName && !ompClause.source.empty() &&
- ompClause.source.NULTerminatedToString() != "hint(omp_sync_hint_none)") {
- context_.Say(dir.source,
- parser::MessageFormattedText{
- "Hint clause other than omp_sync_hint_none cannot be specified for "
- "an unnamed CRITICAL directive"_err_en_US});
+ CheckNoBranching(
+ block, llvm::omp::Directive::OMPD_critical, beginSpec.DirName().source);
+
+ auto getNameFromArg{[](const parser::OmpArgument &arg) {
+ if (auto *object{parser::Unwrap<parser::OmpObject>(arg.u)}) {
+ if (auto *designator{omp::GetDesignatorFromObj(*object)}) {
+ return getDesignatorNameIfDataRef(*designator);
+ }
+ }
+ return static_cast<const parser::Name *>(nullptr);
+ }};
+
+ auto checkArgumentList{[&](const parser::OmpArgumentList &args) {
+ if (args.v.size() > 1) {
+ context_.Say(args.source,
+ "Only a single argument is allowed in CRITICAL directive"_err_en_US);
+ } else if (!args.v.empty()) {
+ if (!getNameFromArg(args.v.front())) {
+ context_.Say(args.v.front().source,
+ "CRITICAL argument should be a name"_err_en_US);
+ }
+ }
+ }};
+
+ const parser::Name *beginName{nullptr};
+ const parser::Name *endName{nullptr};
+
+ auto &beginArgs{beginSpec.Arguments()};
+ checkArgumentList(beginArgs);
+
+ if (!beginArgs.v.empty()) {
+ beginName = getNameFromArg(beginArgs.v.front());
+ }
+
+ if (endSpec) {
+ auto &endArgs{endSpec->Arguments()};
+ checkArgumentList(endArgs);
+
+ if (beginArgs.v.empty() != endArgs.v.empty()) {
+ parser::CharBlock source{
+ beginArgs.v.empty() ? endArgs.source : beginArgs.source};
+ context_.Say(source,
+ "Either both CRITICAL and END CRITICAL should have an argument, or none of them should"_err_en_US);
+ } else if (!beginArgs.v.empty()) {
+ endName = getNameFromArg(endArgs.v.front());
+ if (beginName && endName) {
+ if (beginName->ToString() != endName->ToString()) {
+ context_.Say(endName->source,
+ "The names on CRITICAL and END CRITICAL must match"_err_en_US);
+ }
+ }
+ }
+ }
+
+ for (auto &clause : beginSpec.Clauses().v) {
+ auto *hint{std::get_if<parser::OmpClause::Hint>(&clause.u)};
+ if (!hint) {
+ continue;
+ }
+ const int64_t OmpSyncHintNone = 0; // omp_sync_hint_none
+ std::optional<int64_t> hintValue{GetIntValue(hint->v.v)};
+ if (hintValue && *hintValue != OmpSyncHintNone) {
+ // Emit a diagnostic if the name is missing, and point to the directive
+ // with a missing name.
+ parser::CharBlock source;
+ if (!beginName) {
+ source = beginSpec.DirName().source;
+ } else if (endSpec && !endName) {
+ source = endSpec->DirName().source;
+ }
+
+ if (!source.empty()) {
+ context_.Say(source,
+ "When HINT other than 'omp_sync_hint_none' is present, CRITICAL directive should have a name"_err_en_US);
+ }
+ }
}
}
@@ -2511,8 +2740,9 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
void OmpStructureChecker::Enter(const parser::OmpClause &x) {
SetContextClause(x);
+ llvm::omp::Clause id{x.Id()};
// The visitors for these clauses do their own checks.
- switch (x.Id()) {
+ switch (id) {
case llvm::omp::Clause::OMPC_copyprivate:
case llvm::omp::Clause::OMPC_enter:
case llvm::omp::Clause::OMPC_lastprivate:
@@ -2523,11 +2753,25 @@ void OmpStructureChecker::Enter(const parser::OmpClause &x) {
break;
}
+ // Named constants are OK to be used within 'shared' and 'firstprivate'
+ // clauses. The check for this happens a few lines below.
+ bool SharedOrFirstprivate = false;
+ switch (id) {
+ case llvm::omp::Clause::OMPC_shared:
+ case llvm::omp::Clause::OMPC_firstprivate:
+ SharedOrFirstprivate = true;
+ break;
+ default:
+ break;
+ }
+
if (const parser::OmpObjectList *objList{GetOmpObjectList(x)}) {
+ AnalyzeObjects(*objList);
SymbolSourceMap symbols;
GetSymbolsInObjectList(*objList, symbols);
for (const auto &[symbol, source] : symbols) {
- if (!IsVariableListItem(*symbol)) {
+ if (!IsVariableListItem(*symbol) &&
+ !(IsNamedConstant(*symbol) && SharedOrFirstprivate)) {
deferredNonVariables_.insert({symbol, source});
}
}
@@ -2543,6 +2787,7 @@ CHECK_SIMPLE_CLAUSE(Default, OMPC_default)
CHECK_SIMPLE_CLAUSE(Depobj, OMPC_depobj)
CHECK_SIMPLE_CLAUSE(DeviceType, OMPC_device_type)
CHECK_SIMPLE_CLAUSE(DistSchedule, OMPC_dist_schedule)
+CHECK_SIMPLE_CLAUSE(DynGroupprivate, OMPC_dyn_groupprivate)
CHECK_SIMPLE_CLAUSE(Exclusive, OMPC_exclusive)
CHECK_SIMPLE_CLAUSE(Final, OMPC_final)
CHECK_SIMPLE_CLAUSE(Flush, OMPC_flush)
@@ -2853,7 +3098,8 @@ static bool CheckSymbolSupportsType(const Scope &scope,
static bool IsReductionAllowedForType(
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
- const Scope &scope, SemanticsContext &context) {
+ bool cannotBeBuiltinReduction, const Scope &scope,
+ SemanticsContext &context) {
auto isLogical{[](const DeclTypeSpec &type) -> bool {
return type.category() == DeclTypeSpec::Logical;
}};
@@ -2864,6 +3110,10 @@ static bool IsReductionAllowedForType(
auto checkOperator{[&](const parser::DefinedOperator &dOpr) {
if (const auto *intrinsicOp{
std::get_if<parser::DefinedOperator::IntrinsicOperator>(&dOpr.u)}) {
+ if (cannotBeBuiltinReduction) {
+ return false;
+ }
+
// OMP5.2: The type [...] of a list item that appears in a
// reduction clause must be valid for the combiner expression
// See F2023: Table 10.2
@@ -2915,7 +3165,8 @@ static bool IsReductionAllowedForType(
// IAND: arguments must be integers: F2023 16.9.100
// IEOR: arguments must be integers: F2023 16.9.106
// IOR: arguments must be integers: F2023 16.9.111
- if (type.IsNumeric(TypeCategory::Integer)) {
+ if (type.IsNumeric(TypeCategory::Integer) &&
+ !cannotBeBuiltinReduction) {
return true;
}
} else if (realName == "max" || realName == "min") {
@@ -2923,8 +3174,9 @@ static bool IsReductionAllowedForType(
// F2023 16.9.135
// MIN: arguments must be integer, real, or character:
// F2023 16.9.141
- if (type.IsNumeric(TypeCategory::Integer) ||
- type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
+ if ((type.IsNumeric(TypeCategory::Integer) ||
+ type.IsNumeric(TypeCategory::Real) || isCharacter(type)) &&
+ !cannotBeBuiltinReduction) {
return true;
}
}
@@ -2957,9 +3209,16 @@ void OmpStructureChecker::CheckReductionObjectTypes(
GetSymbolsInObjectList(objects, symbols);
for (auto &[symbol, source] : symbols) {
+ // Built in reductions require types which can be used in their initializer
+ // and combiner expressions. For example, for +:
+ // r = 0; r = r + r2
+ // But it might be valid to use these with DECLARE REDUCTION.
+ // Assumed size is already caught elsewhere.
+ bool cannotBeBuiltinReduction{IsAssumedRank(*symbol)};
if (auto *type{symbol->GetType()}) {
const auto &scope{context_.FindScope(symbol->name())};
- if (!IsReductionAllowedForType(ident, *type, scope, context_)) {
+ if (!IsReductionAllowedForType(
+ ident, *type, cannotBeBuiltinReduction, scope, context_)) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
@@ -3238,9 +3497,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Aligned &x) {
x.v, llvm::omp::OMPC_aligned, GetContext().clauseSource, context_)) {
auto &modifiers{OmpGetModifiers(x.v)};
if (auto *align{OmpGetUniqueModifier<parser::OmpAlignment>(modifiers)}) {
- if (const auto &v{GetIntValue(align->v)}; !v || *v <= 0) {
+ const auto &v{GetIntValue(align->v)};
+ if (!v || *v <= 0) {
context_.Say(OmpGetModifierSource(modifiers, align),
"The alignment value should be a constant positive integer"_err_en_US);
+ } else if (((*v) & (*v - 1)) != 0) {
+ context_.Warn(common::UsageWarning::OpenMPUsage,
+ OmpGetModifierSource(modifiers, align),
+ "Alignment is not a power of 2, Aligned clause will be ignored"_warn_en_US);
}
}
}
@@ -4349,7 +4613,7 @@ bool OmpStructureChecker::CheckTargetBlockOnlyTeams(
if (const auto *ompConstruct{
parser::Unwrap<parser::OpenMPConstruct>(*it)}) {
if (const auto *ompBlockConstruct{
- std::get_if<parser::OpenMPBlockConstruct>(&ompConstruct->u)}) {
+ std::get_if<parser::OmpBlockConstruct>(&ompConstruct->u)}) {
llvm::omp::Directive dirId{ompBlockConstruct->BeginDir().DirId()};
if (dirId == llvm::omp::Directive::OMPD_teams) {
nestedTeams = true;
@@ -4396,7 +4660,7 @@ void OmpStructureChecker::CheckWorkshareBlockStmts(
// 'Parallel' constructs
auto currentDir{llvm::omp::Directive::OMPD_unknown};
if (const auto *ompBlockConstruct{
- std::get_if<parser::OpenMPBlockConstruct>(&ompConstruct->u)}) {
+ std::get_if<parser::OmpBlockConstruct>(&ompConstruct->u)}) {
currentDir = ompBlockConstruct->BeginDir().DirId();
} else if (const auto *ompLoopConstruct{
std::get_if<parser::OpenMPLoopConstruct>(
@@ -4432,6 +4696,27 @@ void OmpStructureChecker::CheckWorkshareBlockStmts(
}
}
+void OmpStructureChecker::CheckWorkdistributeBlockStmts(
+ const parser::Block &block, parser::CharBlock source) {
+ unsigned version{context_.langOptions().OpenMPVersion};
+ unsigned since{60};
+ if (version < since)
+ context_.Say(source,
+ "WORKDISTRIBUTE construct is not allowed in %s, %s"_err_en_US,
+ ThisVersion(version), TryVersion(since));
+
+ OmpWorkdistributeBlockChecker ompWorkdistributeBlockChecker{context_, source};
+
+ for (auto it{block.begin()}; it != block.end(); ++it) {
+ if (parser::Unwrap<parser::AssignmentStmt>(*it)) {
+ parser::Walk(*it, ompWorkdistributeBlockChecker);
+ } else {
+ context_.Say(source,
+ "The structured block in a WORKDISTRIBUTE construct may consist of only SCALAR or ARRAY assignments"_err_en_US);
+ }
+ }
+}
+
void OmpStructureChecker::CheckIfContiguous(const parser::OmpObject &object) {
if (auto contig{IsContiguous(context_, object)}; contig && !*contig) {
const parser::Name *name{GetObjectName(object)};
@@ -4475,42 +4760,6 @@ const parser::Name *OmpStructureChecker::GetObjectName(
return NameHelper::Visit(object);
}
-const parser::OmpObjectList *OmpStructureChecker::GetOmpObjectList(
- const parser::OmpClause &clause) {
-
- // Clauses with OmpObjectList as its data member
- using MemberObjectListClauses =
- std::tuple<parser::OmpClause::Copyprivate, parser::OmpClause::Copyin,
- parser::OmpClause::Firstprivate, parser::OmpClause::Link,
- parser::OmpClause::Private, parser::OmpClause::Shared,
- parser::OmpClause::UseDevicePtr, parser::OmpClause::UseDeviceAddr>;
-
- // Clauses with OmpObjectList in the tuple
- using TupleObjectListClauses =
- std::tuple<parser::OmpClause::Aligned, parser::OmpClause::Allocate,
- parser::OmpClause::From, parser::OmpClause::Lastprivate,
- parser::OmpClause::Map, parser::OmpClause::Reduction,
- parser::OmpClause::To, parser::OmpClause::Enter>;
-
- // TODO:: Generate the tuples using TableGen.
- // Handle other constructs with OmpObjectList such as OpenMPThreadprivate.
- return common::visit(
- common::visitors{
- [&](const auto &x) -> const parser::OmpObjectList * {
- using Ty = std::decay_t<decltype(x)>;
- if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
- return &x.v;
- } else if constexpr (common::HasMember<Ty,
- TupleObjectListClauses>) {
- return &(std::get<parser::OmpObjectList>(x.v.t));
- } else {
- return nullptr;
- }
- },
- },
- clause.u);
-}
-
void OmpStructureChecker::Enter(
const parser::OmpClause::AtomicDefaultMemOrder &x) {
CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_atomic_default_mem_order);
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 6b33ca6..ce074f5 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -88,8 +88,8 @@ public:
void Leave(const parser::OpenMPAssumeConstruct &);
void Enter(const parser::OpenMPDeclarativeAssumes &);
void Leave(const parser::OpenMPDeclarativeAssumes &);
- void Enter(const parser::OpenMPBlockConstruct &);
- void Leave(const parser::OpenMPBlockConstruct &);
+ void Enter(const parser::OmpBlockConstruct &);
+ void Leave(const parser::OmpBlockConstruct &);
void Leave(const parser::OmpBeginDirective &);
void Enter(const parser::OmpEndDirective &);
void Leave(const parser::OmpEndDirective &);
@@ -126,6 +126,8 @@ public:
void Leave(const parser::OpenMPAllocatorsConstruct &);
void Enter(const parser::OpenMPRequiresConstruct &);
void Leave(const parser::OpenMPRequiresConstruct &);
+ void Enter(const parser::OpenMPGroupprivate &);
+ void Leave(const parser::OpenMPGroupprivate &);
void Enter(const parser::OpenMPThreadprivate &);
void Leave(const parser::OpenMPThreadprivate &);
@@ -165,6 +167,8 @@ private:
void CheckVariableListItem(const SymbolSourceMap &symbols);
void CheckDirectiveSpelling(
parser::CharBlock spelling, llvm::omp::Directive id);
+ void AnalyzeObject(const parser::OmpObject &object);
+ void AnalyzeObjects(const parser::OmpObjectList &objects);
void CheckMultipleOccurrence(semantics::UnorderedSymbolSet &listVars,
const std::list<parser::Name> &nameList, const parser::CharBlock &item,
const std::string &clauseName);
@@ -222,8 +226,9 @@ private:
const parser::OmpObject &obj, llvm::StringRef clause = "");
void CheckVarIsNotPartOfAnotherVar(const parser::CharBlock &source,
const parser::OmpObjectList &objList, llvm::StringRef clause = "");
- void CheckThreadprivateOrDeclareTargetVar(
- const parser::OmpObjectList &objList);
+ void CheckThreadprivateOrDeclareTargetVar(const parser::Designator &);
+ void CheckThreadprivateOrDeclareTargetVar(const parser::Name &);
+ void CheckThreadprivateOrDeclareTargetVar(const parser::OmpObjectList &);
void CheckSymbolNames(
const parser::CharBlock &source, const parser::OmpObjectList &objList);
void CheckIntentInPointer(SymbolSourceMap &, const llvm::omp::Clause);
@@ -242,6 +247,7 @@ private:
llvmOmpClause clause, const parser::OmpObjectList &ompObjectList);
bool CheckTargetBlockOnlyTeams(const parser::Block &);
void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock);
+ void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock);
void CheckIteratorRange(const parser::OmpIteratorSpecifier &x);
void CheckIteratorModifier(const parser::OmpIterator &x);
@@ -267,8 +273,10 @@ private:
const evaluate::Assignment &read, parser::CharBlock source);
void CheckAtomicWriteAssignment(
const evaluate::Assignment &write, parser::CharBlock source);
- void CheckAtomicUpdateAssignment(
+ std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment(
const evaluate::Assignment &update, parser::CharBlock source);
+ std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom,
+ const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics);
void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond,
parser::CharBlock condSource, const evaluate::Assignment &assign,
parser::CharBlock assignSource);
@@ -307,7 +315,7 @@ private:
const parser::OmpReductionIdentifier &ident);
void CheckReductionModifier(const parser::OmpReductionModifier &);
void CheckLastprivateModifier(const parser::OmpLastprivateModifier &);
- void CheckMasterNesting(const parser::OpenMPBlockConstruct &x);
+ void CheckMasterNesting(const parser::OmpBlockConstruct &x);
void ChecksOnOrderedAsBlock();
void CheckBarrierNesting(const parser::OpenMPSimpleStandaloneConstruct &x);
void CheckScan(const parser::OpenMPSimpleStandaloneConstruct &x);
@@ -321,7 +329,6 @@ private:
const parser::OmpObjectList &ompObjectList);
void CheckIfContiguous(const parser::OmpObject &object);
const parser::Name *GetObjectName(const parser::OmpObject &object);
- const parser::OmpObjectList *GetOmpObjectList(const parser::OmpClause &);
void CheckPredefinedAllocatorRestriction(const parser::CharBlock &source,
const parser::OmpObjectList &ompObjectList);
void CheckPredefinedAllocatorRestriction(
diff --git a/flang/lib/Semantics/check-select-rank.cpp b/flang/lib/Semantics/check-select-rank.cpp
index b227bba..5dade2c 100644
--- a/flang/lib/Semantics/check-select-rank.cpp
+++ b/flang/lib/Semantics/check-select-rank.cpp
@@ -32,7 +32,7 @@ void SelectRankConstructChecker::Leave(
const Symbol *saveSelSymbol{nullptr};
if (const auto selExpr{GetExprFromSelector(selectRankStmtSel)}) {
if (const Symbol * sel{evaluate::UnwrapWholeSymbolDataRef(*selExpr)}) {
- if (!evaluate::IsAssumedRank(*sel)) { // C1150
+ if (!semantics::IsAssumedRank(*sel)) { // C1150
context_.Say(parser::FindSourceLocation(selectRankStmtSel),
"Selector '%s' is not an assumed-rank array variable"_err_en_US,
sel->name().ToString());
diff --git a/flang/lib/Semantics/check-select-type.cpp b/flang/lib/Semantics/check-select-type.cpp
index 94d16a7..b1b22c3 100644
--- a/flang/lib/Semantics/check-select-type.cpp
+++ b/flang/lib/Semantics/check-select-type.cpp
@@ -252,7 +252,7 @@ void SelectTypeChecker::Enter(const parser::SelectTypeConstruct &construct) {
if (IsProcedure(*selector)) {
context_.Say(
selectTypeStmt.source, "Selector may not be a procedure"_err_en_US);
- } else if (evaluate::IsAssumedRank(*selector)) {
+ } else if (IsAssumedRank(*selector)) {
context_.Say(selectTypeStmt.source,
"Assumed-rank variable may only be used as actual argument"_err_en_US);
} else if (auto exprType{selector->GetType()}) {
diff --git a/flang/lib/Semantics/compute-offsets.cpp b/flang/lib/Semantics/compute-offsets.cpp
index 6d4fce2..1c48d33 100644
--- a/flang/lib/Semantics/compute-offsets.cpp
+++ b/flang/lib/Semantics/compute-offsets.cpp
@@ -239,7 +239,9 @@ void ComputeOffsetsHelper::DoCommonBlock(Symbol &commonBlock) {
std::size_t minAlignment{0};
UnorderedSymbolSet previous;
for (auto object : details.objects()) {
- Symbol &symbol{*object};
+ // Allow for host association when the common block is
+ // OpenMP firstprivate.
+ Symbol &symbol{object->GetUltimate()};
auto errorSite{
commonBlock.name().empty() ? symbol.name() : commonBlock.name()};
if (std::size_t padding{DoSymbol(symbol.GetUltimate())}) {
diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp
index b4c83ba..1c45438 100644
--- a/flang/lib/Semantics/data-to-inits.cpp
+++ b/flang/lib/Semantics/data-to-inits.cpp
@@ -285,21 +285,22 @@ template <typename DSV>
std::optional<std::pair<SomeExpr, bool>>
DataInitializationCompiler<DSV>::ConvertElement(
const SomeExpr &expr, const evaluate::DynamicType &type) {
+ evaluate::FoldingContext &foldingContext{exprAnalyzer_.GetFoldingContext()};
+ evaluate::CheckRealWidening(expr, type, foldingContext);
if (auto converted{evaluate::ConvertToType(type, SomeExpr{expr})}) {
return {std::make_pair(std::move(*converted), false)};
}
// Allow DATA initialization with Hollerith and kind=1 CHARACTER like
// (most) other Fortran compilers do.
- if (auto converted{evaluate::HollerithToBOZ(
- exprAnalyzer_.GetFoldingContext(), expr, type)}) {
+ if (auto converted{evaluate::HollerithToBOZ(foldingContext, expr, type)}) {
return {std::make_pair(std::move(*converted), true)};
}
SemanticsContext &context{exprAnalyzer_.context()};
if (context.IsEnabled(common::LanguageFeature::LogicalIntegerAssignment)) {
if (MaybeExpr converted{evaluate::DataConstantConversionExtension(
- exprAnalyzer_.GetFoldingContext(), type, expr)}) {
+ foldingContext, type, expr)}) {
context.Warn(common::LanguageFeature::LogicalIntegerAssignment,
- exprAnalyzer_.GetFoldingContext().messages().at(),
+ foldingContext.messages().at(),
"nonstandard usage: initialization of %s with %s"_port_en_US,
type.AsFortran(), expr.GetType().value().AsFortran());
return {std::make_pair(std::move(*converted), false)};
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 92dbe0e..ccccf60 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -828,7 +828,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(
template <typename TYPE>
Constant<TYPE> ReadRealLiteral(
- parser::CharBlock source, FoldingContext &context) {
+ parser::CharBlock source, FoldingContext &context, bool isDefaultKind) {
const char *p{source.begin()};
auto valWithFlags{
Scalar<TYPE>::Read(p, context.targetCharacteristics().roundingMode())};
@@ -838,19 +838,24 @@ Constant<TYPE> ReadRealLiteral(
if (context.targetCharacteristics().areSubnormalsFlushedToZero()) {
value = value.FlushSubnormalToZero();
}
- return {value};
+ typename Constant<TYPE>::Result resultInfo;
+ resultInfo.set_isFromInexactLiteralConversion(
+ isDefaultKind && valWithFlags.flags.test(RealFlag::Inexact));
+ return {value, resultInfo};
}
struct RealTypeVisitor {
using Result = std::optional<Expr<SomeReal>>;
using Types = RealTypes;
- RealTypeVisitor(int k, parser::CharBlock lit, FoldingContext &ctx)
- : kind{k}, literal{lit}, context{ctx} {}
+ RealTypeVisitor(
+ int k, parser::CharBlock lit, FoldingContext &ctx, bool isDeftKind)
+ : kind{k}, literal{lit}, context{ctx}, isDefaultKind{isDeftKind} {}
template <typename T> Result Test() {
if (kind == T::kind) {
- return {AsCategoryExpr(ReadRealLiteral<T>(literal, context))};
+ return {
+ AsCategoryExpr(ReadRealLiteral<T>(literal, context, isDefaultKind))};
}
return std::nullopt;
}
@@ -858,6 +863,7 @@ struct RealTypeVisitor {
int kind;
parser::CharBlock literal;
FoldingContext &context;
+ bool isDefaultKind;
};
// Reads a real literal constant and encodes it with the right kind.
@@ -909,8 +915,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::RealLiteralConstant &x) {
"Explicit kind parameter together with non-'E' exponent letter is not standard"_port_en_US);
}
}
- auto result{common::SearchTypes(
- RealTypeVisitor{kind, x.real.source, GetFoldingContext()})};
+ bool isDefaultKind{!x.kind && letterKind.value_or('e') == 'e'};
+ auto result{common::SearchTypes(RealTypeVisitor{
+ kind, x.real.source, GetFoldingContext(), isDefaultKind})};
if (!result) { // C717
Say("Unsupported REAL(KIND=%d)"_err_en_US, kind);
}
@@ -1841,8 +1848,7 @@ void ArrayConstructorContext::Push(MaybeExpr &&x) {
if (*thisLen != *constantLength_ && !(messageDisplayedSet_ & 1)) {
exprAnalyzer_.Warn(
common::LanguageFeature::DistinctArrayConstructorLengths,
- "Character literal in array constructor without explicit "
- "type has different length than earlier elements"_port_en_US);
+ "Character literal in array constructor without explicit type has different length than earlier elements"_port_en_US);
messageDisplayedSet_ |= 1;
}
if (*thisLen > *constantLength_) {
@@ -1862,17 +1868,17 @@ void ArrayConstructorContext::Push(MaybeExpr &&x) {
} else {
if (!(messageDisplayedSet_ & 2)) {
exprAnalyzer_.Say(
- "Values in array constructor must have the same declared type "
- "when no explicit type appears"_err_en_US); // C7110
+ "Values in array constructor must have the same declared type when no explicit type appears"_err_en_US); // C7110
messageDisplayedSet_ |= 2;
}
}
} else {
+ CheckRealWidening(*x, *type_, exprAnalyzer_.GetFoldingContext());
if (auto cast{ConvertToType(*type_, std::move(*x))}) {
values_.Push(std::move(*cast));
} else if (!(messageDisplayedSet_ & 4)) {
- exprAnalyzer_.Say("Value in array constructor of type '%s' could not "
- "be converted to the type of the array '%s'"_err_en_US,
+ exprAnalyzer_.Say(
+ "Value in array constructor of type '%s' could not be converted to the type of the array '%s'"_err_en_US,
x->GetType()->AsFortran(), type_->AsFortran()); // C7111, C7112
messageDisplayedSet_ |= 4;
}
@@ -2065,8 +2071,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ArrayConstructor &array) {
// Check if implicit conversion of expr to the symbol type is legal (if needed),
// and make it explicit if requested.
-static MaybeExpr ImplicitConvertTo(const semantics::Symbol &sym,
- Expr<SomeType> &&expr, bool keepConvertImplicit) {
+static MaybeExpr ImplicitConvertTo(const Symbol &sym, Expr<SomeType> &&expr,
+ bool keepConvertImplicit, FoldingContext &foldingContext) {
+ CheckRealWidening(expr, DynamicType::From(sym), foldingContext);
if (!keepConvertImplicit) {
return ConvertToType(sym, std::move(expr));
} else {
@@ -2191,7 +2198,8 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor(
}
if (symbol) {
const semantics::Scope &innermost{context_.FindScope(exprSource)};
- if (auto msg{CheckAccessibleSymbol(innermost, *symbol)}) {
+ if (auto msg{CheckAccessibleSymbol(
+ innermost, *symbol, /*inStructureConstructor=*/true)}) {
Say(exprSource, std::move(*msg));
}
if (checkConflicts) {
@@ -2293,10 +2301,12 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor(
// convert would cause a segfault. Lowering will deal with
// conditionally converting and preserving the lower bounds in this
// case.
- if (MaybeExpr converted{ImplicitConvertTo(
- *symbol, std::move(value), IsAllocatable(*symbol))}) {
- if (auto componentShape{GetShape(GetFoldingContext(), *symbol)}) {
- if (auto valueShape{GetShape(GetFoldingContext(), *converted)}) {
+ FoldingContext &foldingContext{GetFoldingContext()};
+ if (MaybeExpr converted{ImplicitConvertTo(*symbol, std::move(value),
+ /*keepConvertImplicit=*/IsAllocatable(*symbol),
+ foldingContext)}) {
+ if (auto componentShape{GetShape(foldingContext, *symbol)}) {
+ if (auto valueShape{GetShape(foldingContext, *converted)}) {
if (GetRank(*componentShape) == 0 && GetRank(*valueShape) > 0) {
AttachDeclaration(
Say(exprSource,
@@ -2310,7 +2320,7 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor(
if (checked && *checked && GetRank(*componentShape) > 0 &&
GetRank(*valueShape) == 0 &&
(IsDeferredShape(*symbol) ||
- !IsExpandableScalar(*converted, GetFoldingContext(),
+ !IsExpandableScalar(*converted, foldingContext,
*componentShape, true /*admit PURE call*/))) {
AttachDeclaration(
Say(exprSource,
@@ -3774,10 +3784,9 @@ MaybeExpr NumericBinaryHelper(
analyzer.CheckForNullPointer();
analyzer.CheckForAssumedRank();
analyzer.CheckConformance();
- constexpr bool canBeUnsigned{opr != NumericOperator::Power};
- return NumericOperation<OPR, canBeUnsigned>(
- context.GetContextualMessages(), analyzer.MoveExpr(0),
- analyzer.MoveExpr(1), context.GetDefaultKind(TypeCategory::Real));
+ return NumericOperation<OPR>(context.GetContextualMessages(),
+ analyzer.MoveExpr(0), analyzer.MoveExpr(1),
+ context.GetDefaultKind(TypeCategory::Real));
} else {
return analyzer.TryDefinedOp(AsFortran(opr),
"Operands of %s must be numeric; have %s and %s"_err_en_US);
@@ -4623,7 +4632,7 @@ bool ArgumentAnalyzer::CheckForNullPointer(const char *where) {
bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
for (const std::optional<ActualArgument> &arg : actuals_) {
- if (arg && IsAssumedRank(arg->UnwrapExpr())) {
+ if (arg && semantics::IsAssumedRank(arg->UnwrapExpr())) {
context_.Say(source_,
"An assumed-rank dummy argument is not allowed %s"_err_en_US, where);
fatalErrors_ = true;
@@ -4827,6 +4836,11 @@ std::optional<ProcedureRef> ArgumentAnalyzer::TryDefinedAssignment() {
// conversion in this case.
if (lhsType) {
if (rhsType) {
+ FoldingContext &foldingContext{context_.GetFoldingContext()};
+ auto restorer{foldingContext.messages().SetLocation(
+ actuals_.at(1).value().sourceLocation().value_or(
+ foldingContext.messages().at()))};
+ CheckRealWidening(rhs, lhsType, foldingContext);
if (!IsAllocatableDesignator(lhs) || context_.inWhereBody()) {
AddAssignmentConversion(*lhsType, *rhsType);
}
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 7a492a4..e8df346c 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
-#include "openmp-utils.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Common/indirection.h"
#include "flang/Common/reference.h"
diff --git a/flang/lib/Semantics/openmp-utils.h b/flang/lib/Semantics/openmp-utils.h
deleted file mode 100644
index b8ad9ed..0000000
--- a/flang/lib/Semantics/openmp-utils.h
+++ /dev/null
@@ -1,81 +0,0 @@
-//===-- lib/Semantics/openmp-utils.h --------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Common utilities used in OpenMP semantic checks.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef FORTRAN_SEMANTICS_OPENMP_UTILS_H
-#define FORTRAN_SEMANTICS_OPENMP_UTILS_H
-
-#include "flang/Evaluate/type.h"
-#include "flang/Parser/char-block.h"
-#include "flang/Parser/parse-tree.h"
-#include "flang/Semantics/tools.h"
-
-#include "llvm/ADT/ArrayRef.h"
-
-#include <optional>
-#include <string>
-
-namespace Fortran::semantics {
-class SemanticsContext;
-class Symbol;
-
-// Add this namespace to avoid potential conflicts
-namespace omp {
-// There is no consistent way to get the source of an ActionStmt, but there
-// is "source" in Statement<T>. This structure keeps the ActionStmt with the
-// extracted source for further use.
-struct SourcedActionStmt {
- const parser::ActionStmt *stmt{nullptr};
- parser::CharBlock source;
-
- operator bool() const { return stmt != nullptr; }
-};
-
-SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x);
-SourcedActionStmt GetActionStmt(const parser::Block &block);
-
-std::string ThisVersion(unsigned version);
-std::string TryVersion(unsigned version);
-
-const parser::Designator *GetDesignatorFromObj(const parser::OmpObject &object);
-const parser::DataRef *GetDataRefFromObj(const parser::OmpObject &object);
-const parser::ArrayElement *GetArrayElementFromObj(
- const parser::OmpObject &object);
-const Symbol *GetObjectSymbol(const parser::OmpObject &object);
-const Symbol *GetArgumentSymbol(const parser::OmpArgument &argument);
-std::optional<parser::CharBlock> GetObjectSource(
- const parser::OmpObject &object);
-
-bool IsCommonBlock(const Symbol &sym);
-bool IsExtendedListItem(const Symbol &sym);
-bool IsVariableListItem(const Symbol &sym);
-bool IsVarOrFunctionRef(const MaybeExpr &expr);
-
-bool IsMapEnteringType(parser::OmpMapType::Value type);
-bool IsMapExitingType(parser::OmpMapType::Value type);
-
-std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr);
-std::optional<evaluate::DynamicType> GetDynamicType(
- const parser::Expr &parserExpr);
-
-std::optional<bool> IsContiguous(
- SemanticsContext &semaCtx, const parser::OmpObject &object);
-
-std::vector<SomeExpr> GetAllDesignators(const SomeExpr &expr);
-const SomeExpr *HasStorageOverlap(
- const SomeExpr &base, llvm::ArrayRef<SomeExpr> exprs);
-bool IsAssignment(const parser::ActionStmt *x);
-bool IsPointerAssignment(const evaluate::Assignment &x);
-const parser::Block &GetInnermostExecPart(const parser::Block &block);
-} // namespace omp
-} // namespace Fortran::semantics
-
-#endif // FORTRAN_SEMANTICS_OPENMP_UTILS_H
diff --git a/flang/lib/Semantics/pointer-assignment.cpp b/flang/lib/Semantics/pointer-assignment.cpp
index e767bf8..5508ba8 100644
--- a/flang/lib/Semantics/pointer-assignment.cpp
+++ b/flang/lib/Semantics/pointer-assignment.cpp
@@ -159,7 +159,7 @@ bool PointerAssignmentChecker::CheckLeftHandSide(const SomeExpr &lhs) {
msg->Attach(std::move(whyNot->set_severity(parser::Severity::Because)));
}
return false;
- } else if (evaluate::IsAssumedRank(lhs)) {
+ } else if (IsAssumedRank(lhs)) {
Say("The left-hand side of a pointer assignment must not be an assumed-rank dummy argument"_err_en_US);
return false;
} else if (evaluate::ExtractCoarrayRef(lhs)) { // F'2023 C1027
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 0557b08..a08e764 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -10,7 +10,6 @@
#include "check-acc-structure.h"
#include "check-omp-structure.h"
-#include "openmp-utils.h"
#include "resolve-names-utils.h"
#include "flang/Common/idioms.h"
#include "flang/Evaluate/fold.h"
@@ -22,6 +21,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/openmp-dsa.h"
#include "flang/Semantics/openmp-modifiers.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Support/Flags.h"
@@ -29,7 +29,6 @@
#include "llvm/Support/Debug.h"
#include <list>
#include <map>
-#include <sstream>
template <typename T>
static Fortran::semantics::Scope *GetScope(
@@ -61,6 +60,13 @@ protected:
parser::OmpDefaultmapClause::ImplicitBehavior>
defaultMap;
+ std::optional<Symbol::Flag> FindSymbolWithDSA(const Symbol &symbol) {
+ if (auto it{objectWithDSA.find(&symbol)}; it != objectWithDSA.end()) {
+ return it->second;
+ }
+ return std::nullopt;
+ }
+
bool withinConstruct{false};
std::int64_t associatedLoopLevel{0};
};
@@ -75,10 +81,19 @@ protected:
: std::make_optional<DirContext>(dirContext_.back());
}
void PushContext(const parser::CharBlock &source, T dir, Scope &scope) {
- dirContext_.emplace_back(source, dir, scope);
+ if constexpr (std::is_same_v<T, llvm::acc::Directive>) {
+ dirContext_.emplace_back(source, dir, scope);
+ if (std::size_t size{dirContext_.size()}; size > 1) {
+ std::size_t lastIndex{size - 1};
+ dirContext_[lastIndex].defaultDSA =
+ dirContext_[lastIndex - 1].defaultDSA;
+ }
+ } else {
+ dirContext_.emplace_back(source, dir, scope);
+ }
}
void PushContext(const parser::CharBlock &source, T dir) {
- dirContext_.emplace_back(source, dir, context_.FindScope(source));
+ PushContext(source, dir, context_.FindScope(source));
}
void PopContext() { dirContext_.pop_back(); }
void SetContextDirectiveSource(parser::CharBlock &dir) {
@@ -100,9 +115,21 @@ protected:
AddToContextObjectWithDSA(symbol, flag, GetContext());
}
bool IsObjectWithDSA(const Symbol &symbol) {
- auto it{GetContext().objectWithDSA.find(&symbol)};
- return it != GetContext().objectWithDSA.end();
+ return GetContext().FindSymbolWithDSA(symbol).has_value();
}
+ bool IsObjectWithVisibleDSA(const Symbol &symbol) {
+ for (std::size_t i{dirContext_.size()}; i != 0; i--) {
+ if (dirContext_[i - 1].FindSymbolWithDSA(symbol).has_value()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool WithinConstruct() {
+ return !dirContext_.empty() && GetContext().withinConstruct;
+ }
+
void SetContextAssociatedLoopLevel(std::int64_t level) {
GetContext().associatedLoopLevel = level;
}
@@ -384,13 +411,16 @@ public:
}
void Post(const parser::OmpMetadirectiveDirective &) { PopContext(); }
- bool Pre(const parser::OpenMPBlockConstruct &);
- void Post(const parser::OpenMPBlockConstruct &);
+ bool Pre(const parser::OmpBlockConstruct &);
+ void Post(const parser::OmpBlockConstruct &);
void Post(const parser::OmpBeginDirective &x) {
GetContext().withinConstruct = true;
}
+ bool Pre(const parser::OpenMPGroupprivate &);
+ void Post(const parser::OpenMPGroupprivate &) { PopContext(); }
+
bool Pre(const parser::OpenMPStandaloneConstruct &x) {
common::visit(
[&](auto &&s) {
@@ -528,6 +558,9 @@ public:
bool Pre(const parser::OpenMPDeclarativeAllocate &);
void Post(const parser::OpenMPDeclarativeAllocate &) { PopContext(); }
+ bool Pre(const parser::OpenMPAssumeConstruct &);
+ void Post(const parser::OpenMPAssumeConstruct &) { PopContext(); }
+
bool Pre(const parser::OpenMPAtomicConstruct &);
void Post(const parser::OpenMPAtomicConstruct &) { PopContext(); }
@@ -793,7 +826,8 @@ public:
if (name->symbol) {
name->symbol->set(
ompFlag.value_or(Symbol::Flag::OmpMapStorage));
- AddToContextObjectWithDSA(*name->symbol, *ompFlag);
+ AddToContextObjectWithDSA(*name->symbol,
+ ompFlag.value_or(Symbol::Flag::OmpMapStorage));
if (semantics::IsAssumedSizeArray(*name->symbol)) {
context_.Say(designator.source,
"Assumed-size whole arrays may not appear on the %s "
@@ -841,7 +875,8 @@ private:
Symbol::Flags ompFlagsRequireMark{Symbol::Flag::OmpThreadprivate,
Symbol::Flag::OmpDeclareTarget, Symbol::Flag::OmpExclusiveScan,
- Symbol::Flag::OmpInclusiveScan, Symbol::Flag::OmpInScanReduction};
+ Symbol::Flag::OmpInclusiveScan, Symbol::Flag::OmpInScanReduction,
+ Symbol::Flag::OmpGroupPrivate};
Symbol::Flags dataCopyingAttributeFlags{
Symbol::Flag::OmpCopyIn, Symbol::Flag::OmpCopyPrivate};
@@ -876,6 +911,9 @@ private:
bool IsNestedInDirective(llvm::omp::Directive directive);
void ResolveOmpObjectList(const parser::OmpObjectList &, Symbol::Flag);
+ void ResolveOmpDesignator(
+ const parser::Designator &designator, Symbol::Flag ompFlag);
+ void ResolveOmpCommonBlock(const parser::Name &name, Symbol::Flag ompFlag);
void ResolveOmpObject(const parser::OmpObject &, Symbol::Flag);
Symbol *ResolveOmp(const parser::Name &, Symbol::Flag, Scope &);
Symbol *ResolveOmp(Symbol &, Symbol::Flag, Scope &);
@@ -1562,10 +1600,10 @@ void AccAttributeVisitor::Post(const parser::AccDefaultClause &x) {
// and adjust the symbol for each Name if necessary
void AccAttributeVisitor::Post(const parser::Name &name) {
auto *symbol{name.symbol};
- if (symbol && !dirContext_.empty() && GetContext().withinConstruct) {
+ if (symbol && WithinConstruct()) {
symbol = &symbol->GetUltimate();
if (!symbol->owner().IsDerivedType() && !symbol->has<ProcEntityDetails>() &&
- !symbol->has<SubprogramDetails>() && !IsObjectWithDSA(*symbol)) {
+ !symbol->has<SubprogramDetails>() && !IsObjectWithVisibleDSA(*symbol)) {
if (Symbol * found{currScope().FindSymbol(name.source)}) {
if (symbol != found) {
name.symbol = found; // adjust the symbol within region
@@ -1715,7 +1753,7 @@ static std::string ScopeSourcePos(const Fortran::semantics::Scope &scope);
#endif
-bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
+bool OmpAttributeVisitor::Pre(const parser::OmpBlockConstruct &x) {
const parser::OmpDirectiveSpecification &dirSpec{x.BeginDir()};
llvm::omp::Directive dirId{dirSpec.DirId()};
switch (dirId) {
@@ -1732,10 +1770,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
case llvm::omp::Directive::OMPD_task:
case llvm::omp::Directive::OMPD_taskgroup:
case llvm::omp::Directive::OMPD_teams:
+ case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_workshare:
case llvm::omp::Directive::OMPD_parallel_workshare:
case llvm::omp::Directive::OMPD_target_teams:
+ case llvm::omp::Directive::OMPD_target_teams_workdistribute:
case llvm::omp::Directive::OMPD_target_parallel:
+ case llvm::omp::Directive::OMPD_teams_workdistribute:
PushContext(dirSpec.source, dirId);
break;
default:
@@ -1751,7 +1792,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
return true;
}
-void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
+void OmpAttributeVisitor::Post(const parser::OmpBlockConstruct &x) {
const parser::OmpDirectiveSpecification &dirSpec{x.BeginDir()};
llvm::omp::Directive dirId{dirSpec.DirId()};
switch (dirId) {
@@ -1765,9 +1806,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
case llvm::omp::Directive::OMPD_target:
case llvm::omp::Directive::OMPD_task:
case llvm::omp::Directive::OMPD_teams:
+ case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_parallel_workshare:
case llvm::omp::Directive::OMPD_target_teams:
- case llvm::omp::Directive::OMPD_target_parallel: {
+ case llvm::omp::Directive::OMPD_target_parallel:
+ case llvm::omp::Directive::OMPD_target_teams_workdistribute:
+ case llvm::omp::Directive::OMPD_teams_workdistribute: {
bool hasPrivate;
for (const auto *allocName : allocateNames_) {
hasPrivate = false;
@@ -1942,7 +1986,7 @@ void OmpAttributeVisitor::ResolveSeqLoopIndexInParallelOrTaskConstruct(
// till OpenMP-5.0 standard.
// In above both cases we skip the privatization of iteration variables.
bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) {
- if (!dirContext_.empty() && GetContext().withinConstruct) {
+ if (WithinConstruct()) {
llvm::SmallVector<const parser::Name *> ivs;
if (x.IsDoNormal()) {
const parser::Name *iv{GetLoopIndex(x)};
@@ -2114,6 +2158,18 @@ void OmpAttributeVisitor::CheckAssocLoopLevel(
}
}
+bool OmpAttributeVisitor::Pre(const parser::OpenMPGroupprivate &x) {
+ PushContext(x.source, llvm::omp::Directive::OMPD_groupprivate);
+ for (const parser::OmpArgument &arg : x.v.Arguments().v) {
+ if (auto *locator{std::get_if<parser::OmpLocator>(&arg.u)}) {
+ if (auto *object{std::get_if<parser::OmpObject>(&locator->u)}) {
+ ResolveOmpObject(*object, Symbol::Flag::OmpGroupPrivate);
+ }
+ }
+ }
+ return true;
+}
+
bool OmpAttributeVisitor::Pre(const parser::OpenMPSectionsConstruct &x) {
const auto &beginSectionsDir{
std::get<parser::OmpBeginSectionsDirective>(x.t)};
@@ -2139,8 +2195,8 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPSectionConstruct &x) {
}
bool OmpAttributeVisitor::Pre(const parser::OpenMPCriticalConstruct &x) {
- const auto &beginCriticalDir{std::get<parser::OmpCriticalDirective>(x.t)};
- PushContext(beginCriticalDir.source, llvm::omp::Directive::OMPD_critical);
+ const parser::OmpBeginDirective &beginSpec{x.BeginDir()};
+ PushContext(beginSpec.DirName().source, beginSpec.DirName().v);
GetContext().withinConstruct = true;
return true;
}
@@ -2194,6 +2250,11 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPDeclarativeAllocate &x) {
return false;
}
+bool OmpAttributeVisitor::Pre(const parser::OpenMPAssumeConstruct &x) {
+ PushContext(x.source, llvm::omp::Directive::OMPD_assume);
+ return true;
+}
+
bool OmpAttributeVisitor::Pre(const parser::OpenMPAtomicConstruct &x) {
PushContext(x.source, llvm::omp::Directive::OMPD_atomic);
return true;
@@ -2435,7 +2496,7 @@ static bool IsTargetCaptureImplicitlyFirstprivatizeable(const Symbol &symbol,
// investigate the flags we can intermix with.
if (!(dsa & (dataSharingAttributeFlags | dataMappingAttributeFlags))
.none() ||
- !checkSym.flags().none() || semantics::IsAssumedShape(checkSym) ||
+ !checkSym.flags().none() || IsAssumedShape(checkSym) ||
semantics::IsAllocatableOrPointer(checkSym)) {
return false;
}
@@ -2651,7 +2712,7 @@ void OmpAttributeVisitor::CreateImplicitSymbols(const Symbol *symbol) {
void OmpAttributeVisitor::Post(const parser::Name &name) {
auto *symbol{name.symbol};
- if (symbol && !dirContext_.empty() && GetContext().withinConstruct) {
+ if (symbol && WithinConstruct()) {
if (IsPrivatizable(symbol) && !IsObjectWithDSA(*symbol)) {
// TODO: create a separate function to go through the rules for
// predetermined, explicitly determined, and implicitly
@@ -2786,196 +2847,182 @@ static bool SymbolOrEquivalentIsInNamelist(const Symbol &symbol) {
});
}
-void OmpAttributeVisitor::ResolveOmpObject(
- const parser::OmpObject &ompObject, Symbol::Flag ompFlag) {
+void OmpAttributeVisitor::ResolveOmpDesignator(
+ const parser::Designator &designator, Symbol::Flag ompFlag) {
unsigned version{context_.langOptions().OpenMPVersion};
- common::visit(
- common::visitors{
- [&](const parser::Designator &designator) {
- if (const auto *name{
- semantics::getDesignatorNameIfDataRef(designator)}) {
- if (auto *symbol{ResolveOmp(*name, ompFlag, currScope())}) {
- auto checkExclusivelists =
- [&](const Symbol *symbol1, Symbol::Flag firstOmpFlag,
- const Symbol *symbol2, Symbol::Flag secondOmpFlag) {
- if ((symbol1->test(firstOmpFlag) &&
- symbol2->test(secondOmpFlag)) ||
- (symbol1->test(secondOmpFlag) &&
- symbol2->test(firstOmpFlag))) {
- context_.Say(designator.source,
- "Variable '%s' may not "
- "appear on both %s and %s "
- "clauses on a %s construct"_err_en_US,
- symbol2->name(),
- Symbol::OmpFlagToClauseName(firstOmpFlag),
- Symbol::OmpFlagToClauseName(secondOmpFlag),
- parser::ToUpperCaseLetters(
- llvm::omp::getOpenMPDirectiveName(
- GetContext().directive, version)
- .str()));
- }
- };
- if (dataCopyingAttributeFlags.test(ompFlag)) {
- CheckDataCopyingClause(*name, *symbol, ompFlag);
- } else {
- AddToContextObjectWithExplicitDSA(*symbol, ompFlag);
- if (dataSharingAttributeFlags.test(ompFlag)) {
- CheckMultipleAppearances(*name, *symbol, ompFlag);
- }
- if (privateDataSharingAttributeFlags.test(ompFlag)) {
- CheckObjectIsPrivatizable(*name, *symbol, ompFlag);
- }
+ llvm::omp::Directive directive{GetContext().directive};
- if (ompFlag == Symbol::Flag::OmpAllocate) {
- AddAllocateName(name);
- }
- }
- if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective &&
- IsAllocatable(*symbol) &&
- !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) {
- context_.Say(designator.source,
- "List items specified in the ALLOCATE directive must not "
- "have the ALLOCATABLE attribute unless the directive is "
- "associated with an ALLOCATE statement"_err_en_US);
- }
- if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective ||
- ompFlag ==
- Symbol::Flag::OmpExecutableAllocateDirective) &&
- ResolveOmpObjectScope(name) == nullptr) {
- context_.Say(designator.source, // 2.15.3
- "List items must be declared in the same scoping unit "
- "in which the %s directive appears"_err_en_US,
- parser::ToUpperCaseLetters(
- llvm::omp::getOpenMPDirectiveName(
- GetContext().directive, version)
- .str()));
- }
- if (ompFlag == Symbol::Flag::OmpReduction) {
- // Using variables inside of a namelist in OpenMP reductions
- // is allowed by the standard, but is not allowed for
- // privatisation. This looks like an oversight. If the
- // namelist is hoisted to a global, we cannot apply the
- // mapping for the reduction variable: resulting in incorrect
- // results. Disabling this hoisting could make some real
- // production code go slower. See discussion in #109303
- if (SymbolOrEquivalentIsInNamelist(*symbol)) {
- context_.Say(name->source,
- "Variable '%s' in NAMELIST cannot be in a REDUCTION clause"_err_en_US,
- name->ToString());
- }
- }
- if (ompFlag == Symbol::Flag::OmpInclusiveScan ||
- ompFlag == Symbol::Flag::OmpExclusiveScan) {
- if (!symbol->test(Symbol::Flag::OmpInScanReduction)) {
- context_.Say(name->source,
- "List item %s must appear in REDUCTION clause "
- "with the INSCAN modifier of the parent "
- "directive"_err_en_US,
- name->ToString());
- }
- }
- if (ompFlag == Symbol::Flag::OmpDeclareTarget) {
- if (symbol->IsFuncResult()) {
- if (Symbol * func{currScope().symbol()}) {
- CHECK(func->IsSubprogram());
- func->set(ompFlag);
- name->symbol = func;
- }
- }
- }
- if (GetContext().directive ==
- llvm::omp::Directive::OMPD_target_data) {
- checkExclusivelists(symbol, Symbol::Flag::OmpUseDevicePtr,
- symbol, Symbol::Flag::OmpUseDeviceAddr);
- }
- if (llvm::omp::allDistributeSet.test(GetContext().directive)) {
- checkExclusivelists(symbol, Symbol::Flag::OmpFirstPrivate,
- symbol, Symbol::Flag::OmpLastPrivate);
- }
- if (llvm::omp::allTargetSet.test(GetContext().directive)) {
- checkExclusivelists(symbol, Symbol::Flag::OmpIsDevicePtr,
- symbol, Symbol::Flag::OmpHasDeviceAddr);
- const auto *hostAssocSym{symbol};
- if (!(symbol->test(Symbol::Flag::OmpIsDevicePtr) ||
- symbol->test(Symbol::Flag::OmpHasDeviceAddr))) {
- if (const auto *details{
- symbol->detailsIf<HostAssocDetails>()}) {
- hostAssocSym = &details->symbol();
- }
- }
- Symbol::Flag dataMappingAttributeFlags[] = {
- Symbol::Flag::OmpMapTo, Symbol::Flag::OmpMapFrom,
- Symbol::Flag::OmpMapToFrom, Symbol::Flag::OmpMapStorage,
- Symbol::Flag::OmpMapDelete, Symbol::Flag::OmpIsDevicePtr,
- Symbol::Flag::OmpHasDeviceAddr};
-
- Symbol::Flag dataSharingAttributeFlags[] = {
- Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate,
- Symbol::Flag::OmpLastPrivate, Symbol::Flag::OmpShared,
- Symbol::Flag::OmpLinear};
-
- // For OMP TARGET TEAMS directive some sharing attribute
- // flags and mapping attribute flags can co-exist.
- if (!(llvm::omp::allTeamsSet.test(GetContext().directive) ||
- llvm::omp::allParallelSet.test(
- GetContext().directive))) {
- for (Symbol::Flag ompFlag1 : dataMappingAttributeFlags) {
- for (Symbol::Flag ompFlag2 : dataSharingAttributeFlags) {
- if ((hostAssocSym->test(ompFlag2) &&
- hostAssocSym->test(
- Symbol::Flag::OmpExplicit)) ||
- (symbol->test(ompFlag2) &&
- symbol->test(Symbol::Flag::OmpExplicit))) {
- checkExclusivelists(
- hostAssocSym, ompFlag1, symbol, ompFlag2);
- }
- }
- }
- }
- }
- }
- } else {
- // Array sections to be changed to substrings as needed
- if (AnalyzeExpr(context_, designator)) {
- if (std::holds_alternative<parser::Substring>(designator.u)) {
- context_.Say(designator.source,
- "Substrings are not allowed on OpenMP "
- "directives or clauses"_err_en_US);
- }
- }
- // other checks, more TBD
- }
- },
- [&](const parser::Name &name) { // common block
- if (auto *symbol{ResolveOmpCommonBlockName(&name)}) {
- if (!dataCopyingAttributeFlags.test(ompFlag)) {
- CheckMultipleAppearances(
- name, *symbol, Symbol::Flag::OmpCommonBlock);
- }
- // 2.15.3 When a named common block appears in a list, it has the
- // same meaning as if every explicit member of the common block
- // appeared in the list
- auto &details{symbol->get<CommonBlockDetails>()};
- unsigned index{0};
- for (auto &object : details.objects()) {
- if (auto *resolvedObject{
- ResolveOmp(*object, ompFlag, currScope())}) {
- if (dataCopyingAttributeFlags.test(ompFlag)) {
- CheckDataCopyingClause(name, *resolvedObject, ompFlag);
- } else {
- AddToContextObjectWithExplicitDSA(*resolvedObject, ompFlag);
- }
- details.replace_object(*resolvedObject, index);
- }
- index++;
- }
- } else {
- context_.Say(name.source, // 2.15.3
- "COMMON block must be declared in the same scoping unit "
- "in which the OpenMP directive or clause appears"_err_en_US);
+ const auto *name{semantics::getDesignatorNameIfDataRef(designator)};
+ if (!name) {
+ // Array sections to be changed to substrings as needed
+ if (AnalyzeExpr(context_, designator)) {
+ if (std::holds_alternative<parser::Substring>(designator.u)) {
+ context_.Say(designator.source,
+ "Substrings are not allowed on OpenMP directives or clauses"_err_en_US);
+ }
+ }
+ // other checks, more TBD
+ return;
+ }
+
+ if (auto *symbol{ResolveOmp(*name, ompFlag, currScope())}) {
+ auto checkExclusivelists{//
+ [&](const Symbol *symbol1, Symbol::Flag firstOmpFlag,
+ const Symbol *symbol2, Symbol::Flag secondOmpFlag) {
+ if ((symbol1->test(firstOmpFlag) && symbol2->test(secondOmpFlag)) ||
+ (symbol1->test(secondOmpFlag) && symbol2->test(firstOmpFlag))) {
+ context_.Say(designator.source,
+ "Variable '%s' may not appear on both %s and %s clauses on a %s construct"_err_en_US,
+ symbol2->name(), Symbol::OmpFlagToClauseName(firstOmpFlag),
+ Symbol::OmpFlagToClauseName(secondOmpFlag),
+ parser::ToUpperCaseLetters(
+ llvm::omp::getOpenMPDirectiveName(directive, version)));
+ }
+ }};
+ if (dataCopyingAttributeFlags.test(ompFlag)) {
+ CheckDataCopyingClause(*name, *symbol, ompFlag);
+ } else {
+ AddToContextObjectWithExplicitDSA(*symbol, ompFlag);
+ if (dataSharingAttributeFlags.test(ompFlag)) {
+ CheckMultipleAppearances(*name, *symbol, ompFlag);
+ }
+ if (privateDataSharingAttributeFlags.test(ompFlag)) {
+ CheckObjectIsPrivatizable(*name, *symbol, ompFlag);
+ }
+
+ if (ompFlag == Symbol::Flag::OmpAllocate) {
+ AddAllocateName(name);
+ }
+ }
+ if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective &&
+ IsAllocatable(*symbol) &&
+ !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) {
+ context_.Say(designator.source,
+ "List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US);
+ }
+ if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective ||
+ ompFlag == Symbol::Flag::OmpExecutableAllocateDirective) &&
+ ResolveOmpObjectScope(name) == nullptr) {
+ context_.Say(designator.source, // 2.15.3
+ "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US,
+ parser::ToUpperCaseLetters(
+ llvm::omp::getOpenMPDirectiveName(directive, version)));
+ }
+ if (ompFlag == Symbol::Flag::OmpReduction) {
+ // Using variables inside of a namelist in OpenMP reductions
+ // is allowed by the standard, but is not allowed for
+ // privatisation. This looks like an oversight. If the
+ // namelist is hoisted to a global, we cannot apply the
+ // mapping for the reduction variable: resulting in incorrect
+ // results. Disabling this hoisting could make some real
+ // production code go slower. See discussion in #109303
+ if (SymbolOrEquivalentIsInNamelist(*symbol)) {
+ context_.Say(name->source,
+ "Variable '%s' in NAMELIST cannot be in a REDUCTION clause"_err_en_US,
+ name->ToString());
+ }
+ }
+ if (ompFlag == Symbol::Flag::OmpInclusiveScan ||
+ ompFlag == Symbol::Flag::OmpExclusiveScan) {
+ if (!symbol->test(Symbol::Flag::OmpInScanReduction)) {
+ context_.Say(name->source,
+ "List item %s must appear in REDUCTION clause with the INSCAN modifier of the parent directive"_err_en_US,
+ name->ToString());
+ }
+ }
+ if (ompFlag == Symbol::Flag::OmpDeclareTarget) {
+ if (symbol->IsFuncResult()) {
+ if (Symbol * func{currScope().symbol()}) {
+ CHECK(func->IsSubprogram());
+ func->set(ompFlag);
+ name->symbol = func;
+ }
+ }
+ }
+ if (directive == llvm::omp::Directive::OMPD_target_data) {
+ checkExclusivelists(symbol, Symbol::Flag::OmpUseDevicePtr, symbol,
+ Symbol::Flag::OmpUseDeviceAddr);
+ }
+ if (llvm::omp::allDistributeSet.test(directive)) {
+ checkExclusivelists(symbol, Symbol::Flag::OmpFirstPrivate, symbol,
+ Symbol::Flag::OmpLastPrivate);
+ }
+ if (llvm::omp::allTargetSet.test(directive)) {
+ checkExclusivelists(symbol, Symbol::Flag::OmpIsDevicePtr, symbol,
+ Symbol::Flag::OmpHasDeviceAddr);
+ const auto *hostAssocSym{symbol};
+ if (!symbol->test(Symbol::Flag::OmpIsDevicePtr) &&
+ !symbol->test(Symbol::Flag::OmpHasDeviceAddr)) {
+ if (const auto *details{symbol->detailsIf<HostAssocDetails>()}) {
+ hostAssocSym = &details->symbol();
+ }
+ }
+ static Symbol::Flag dataMappingAttributeFlags[] = {//
+ Symbol::Flag::OmpMapTo, Symbol::Flag::OmpMapFrom,
+ Symbol::Flag::OmpMapToFrom, Symbol::Flag::OmpMapStorage,
+ Symbol::Flag::OmpMapDelete, Symbol::Flag::OmpIsDevicePtr,
+ Symbol::Flag::OmpHasDeviceAddr};
+
+ static Symbol::Flag dataSharingAttributeFlags[] = {//
+ Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate,
+ Symbol::Flag::OmpLastPrivate, Symbol::Flag::OmpShared,
+ Symbol::Flag::OmpLinear};
+
+ // For OMP TARGET TEAMS directive some sharing attribute
+ // flags and mapping attribute flags can co-exist.
+ if (!llvm::omp::allTeamsSet.test(directive) &&
+ !llvm::omp::allParallelSet.test(directive)) {
+ for (Symbol::Flag ompFlag1 : dataMappingAttributeFlags) {
+ for (Symbol::Flag ompFlag2 : dataSharingAttributeFlags) {
+ if ((hostAssocSym->test(ompFlag2) &&
+ hostAssocSym->test(Symbol::Flag::OmpExplicit)) ||
+ (symbol->test(ompFlag2) &&
+ symbol->test(Symbol::Flag::OmpExplicit))) {
+ checkExclusivelists(hostAssocSym, ompFlag1, symbol, ompFlag2);
}
- },
- },
+ }
+ }
+ }
+ }
+ }
+}
+
+void OmpAttributeVisitor::ResolveOmpCommonBlock(
+ const parser::Name &name, Symbol::Flag ompFlag) {
+ if (auto *symbol{ResolveOmpCommonBlockName(&name)}) {
+ if (!dataCopyingAttributeFlags.test(ompFlag)) {
+ CheckMultipleAppearances(name, *symbol, Symbol::Flag::OmpCommonBlock);
+ }
+ // 2.15.3 When a named common block appears in a list, it has the
+ // same meaning as if every explicit member of the common block
+ // appeared in the list
+ auto &details{symbol->get<CommonBlockDetails>()};
+ for (auto [index, object] : llvm::enumerate(details.objects())) {
+ if (auto *resolvedObject{ResolveOmp(*object, ompFlag, currScope())}) {
+ if (dataCopyingAttributeFlags.test(ompFlag)) {
+ CheckDataCopyingClause(name, *resolvedObject, ompFlag);
+ } else {
+ AddToContextObjectWithExplicitDSA(*resolvedObject, ompFlag);
+ }
+ details.replace_object(*resolvedObject, index);
+ }
+ }
+ } else {
+ context_.Say(name.source, // 2.15.3
+ "COMMON block must be declared in the same scoping unit in which the OpenMP directive or clause appears"_err_en_US);
+ }
+}
+
+void OmpAttributeVisitor::ResolveOmpObject(
+ const parser::OmpObject &ompObject, Symbol::Flag ompFlag) {
+ common::visit(common::visitors{
+ [&](const parser::Designator &designator) {
+ ResolveOmpDesignator(designator, ompFlag);
+ },
+ [&](const parser::Name &name) { // common block
+ ResolveOmpCommonBlock(name, ompFlag);
+ },
+ },
ompObject.u);
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 66a45dd..4720932 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -30,6 +30,7 @@
#include "flang/Semantics/attr.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/openmp-modifiers.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/program-tree.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
@@ -487,6 +488,10 @@ public:
// Result symbol
Symbol *resultSymbol{nullptr};
bool inFunctionStmt{false}; // true between Pre/Post of FunctionStmt
+ // Functions with previous implicitly-typed references get those types
+ // checked against their later definitions.
+ const DeclTypeSpec *previousImplicitType{nullptr};
+ SourceName previousName;
};
// Completes the definition of the top function's result.
@@ -942,7 +947,7 @@ private:
// Edits an existing symbol created for earlier calls to a subprogram or ENTRY
// so that it can be replaced by a later definition.
bool HandlePreviousCalls(const parser::Name &, Symbol &, Symbol::Flag);
- void CheckExtantProc(const parser::Name &, Symbol::Flag);
+ const Symbol *CheckExtantProc(const parser::Name &, Symbol::Flag);
// Create a subprogram symbol in the current scope and push a new scope.
Symbol &PushSubprogramScope(const parser::Name &, Symbol::Flag,
const parser::LanguageBindingSpec * = nullptr,
@@ -1465,7 +1470,7 @@ class OmpVisitor : public virtual DeclarationVisitor {
public:
void AddOmpSourceRange(const parser::CharBlock &);
- static bool NeedsScope(const parser::OpenMPBlockConstruct &);
+ static bool NeedsScope(const parser::OmpBlockConstruct &);
static bool NeedsScope(const parser::OmpClause &);
bool Pre(const parser::OmpMetadirectiveDirective &x) { //
@@ -1482,10 +1487,20 @@ public:
AddOmpSourceRange(x.source);
return true;
}
- bool Pre(const parser::OpenMPBlockConstruct &);
- void Post(const parser::OpenMPBlockConstruct &);
+ bool Pre(const parser::OmpBlockConstruct &);
+ void Post(const parser::OmpBlockConstruct &);
bool Pre(const parser::OmpBeginDirective &x) {
AddOmpSourceRange(x.source);
+ // Manually resolve names in CRITICAL directives. This is because these
+ // names do not denote Fortran objects, and the CRITICAL directive causes
+ // them to be "auto-declared", i.e. inserted into the global scope.
+ // More specifically, they are not expected to have explicit declarations,
+ // and if they do the behavior is unspeficied.
+ if (x.DirName().v == llvm::omp::Directive::OMPD_critical) {
+ for (const parser::OmpArgument &arg : x.Arguments().v) {
+ ResolveCriticalName(arg);
+ }
+ }
return true;
}
void Post(const parser::OmpBeginDirective &) {
@@ -1493,6 +1508,12 @@ public:
}
bool Pre(const parser::OmpEndDirective &x) {
AddOmpSourceRange(x.source);
+ // Manually resolve names in CRITICAL directives.
+ if (x.DirName().v == llvm::omp::Directive::OMPD_critical) {
+ for (const parser::OmpArgument &arg : x.Arguments().v) {
+ ResolveCriticalName(arg);
+ }
+ }
return true;
}
void Post(const parser::OmpEndDirective &) {
@@ -1591,32 +1612,6 @@ public:
void Post(const parser::OmpEndSectionsDirective &) {
messageHandler().set_currStmtSource(std::nullopt);
}
- bool Pre(const parser::OmpCriticalDirective &x) {
- AddOmpSourceRange(x.source);
- // Manually resolve names in CRITICAL directives. This is because these
- // names do not denote Fortran objects, and the CRITICAL directive causes
- // them to be "auto-declared", i.e. inserted into the global scope.
- // More specifically, they are not expected to have explicit declarations,
- // and if they do the behavior is unspeficied.
- if (auto &maybeName{std::get<std::optional<parser::Name>>(x.t)}) {
- ResolveCriticalName(*maybeName);
- }
- return true;
- }
- void Post(const parser::OmpCriticalDirective &) {
- messageHandler().set_currStmtSource(std::nullopt);
- }
- bool Pre(const parser::OmpEndCriticalDirective &x) {
- AddOmpSourceRange(x.source);
- // Manually resolve names in CRITICAL directives.
- if (auto &maybeName{std::get<std::optional<parser::Name>>(x.t)}) {
- ResolveCriticalName(*maybeName);
- }
- return true;
- }
- void Post(const parser::OmpEndCriticalDirective &) {
- messageHandler().set_currStmtSource(std::nullopt);
- }
bool Pre(const parser::OpenMPThreadprivate &) {
SkipImplicitTyping(true);
return true;
@@ -1732,13 +1727,13 @@ private:
const std::optional<parser::OmpClauseList> &clauses,
const T &wholeConstruct);
- void ResolveCriticalName(const parser::Name &name);
+ void ResolveCriticalName(const parser::OmpArgument &arg);
int metaLevel_{0};
const parser::OmpMetadirectiveDirective *metaDirective_{nullptr};
};
-bool OmpVisitor::NeedsScope(const parser::OpenMPBlockConstruct &x) {
+bool OmpVisitor::NeedsScope(const parser::OmpBlockConstruct &x) {
switch (x.BeginDir().DirId()) {
case llvm::omp::Directive::OMPD_master:
case llvm::omp::Directive::OMPD_ordered:
@@ -1759,14 +1754,14 @@ void OmpVisitor::AddOmpSourceRange(const parser::CharBlock &source) {
currScope().AddSourceRange(source);
}
-bool OmpVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
+bool OmpVisitor::Pre(const parser::OmpBlockConstruct &x) {
if (NeedsScope(x)) {
PushScope(Scope::Kind::OtherConstruct, nullptr);
}
return true;
}
-void OmpVisitor::Post(const parser::OpenMPBlockConstruct &x) {
+void OmpVisitor::Post(const parser::OmpBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
}
@@ -1961,7 +1956,7 @@ void OmpVisitor::ProcessReductionSpecifier(
}
}
-void OmpVisitor::ResolveCriticalName(const parser::Name &name) {
+void OmpVisitor::ResolveCriticalName(const parser::OmpArgument &arg) {
auto &globalScope{[&]() -> Scope & {
for (Scope *s{&currScope()};; s = &s->parent()) {
if (s->IsTopLevel()) {
@@ -1971,15 +1966,21 @@ void OmpVisitor::ResolveCriticalName(const parser::Name &name) {
llvm_unreachable("Cannot find global scope");
}()};
- if (auto *symbol{FindInScope(globalScope, name)}) {
- if (!symbol->test(Symbol::Flag::OmpCriticalLock)) {
- SayWithDecl(name, *symbol,
- "CRITICAL construct name '%s' conflicts with a previous declaration"_warn_en_US,
- name.ToString());
+ if (auto *object{parser::Unwrap<parser::OmpObject>(arg.u)}) {
+ if (auto *desg{omp::GetDesignatorFromObj(*object)}) {
+ if (auto *name{getDesignatorNameIfDataRef(*desg)}) {
+ if (auto *symbol{FindInScope(globalScope, *name)}) {
+ if (!symbol->test(Symbol::Flag::OmpCriticalLock)) {
+ SayWithDecl(*name, *symbol,
+ "CRITICAL construct name '%s' conflicts with a previous declaration"_warn_en_US,
+ name->ToString());
+ }
+ } else {
+ name->symbol = &MakeSymbol(globalScope, name->source, Attrs{});
+ name->symbol->set(Symbol::Flag::OmpCriticalLock);
+ }
+ }
}
- } else {
- name.symbol = &MakeSymbol(globalScope, name.source, Attrs{});
- name.symbol->set(Symbol::Flag::OmpCriticalLock);
}
}
@@ -2694,11 +2695,24 @@ void ArraySpecVisitor::PostAttrSpec() {
FuncResultStack::~FuncResultStack() { CHECK(stack_.empty()); }
+// True when either type is absent, or if they are both present and are
+// equivalent for interface compatibility purposes.
+static bool TypesMismatchIfNonNull(
+ const DeclTypeSpec *type1, const DeclTypeSpec *type2) {
+ if (auto t1{evaluate::DynamicType::From(type1)}) {
+ if (auto t2{evaluate::DynamicType::From(type2)}) {
+ return !t1->IsEquivalentTo(*t2);
+ }
+ }
+ return false;
+}
+
void FuncResultStack::CompleteFunctionResultType() {
// If the function has a type in the prefix, process it now.
FuncInfo *info{Top()};
- if (info && &info->scope == &scopeHandler_.currScope()) {
- if (info->parsedType && info->resultSymbol) {
+ if (info && &info->scope == &scopeHandler_.currScope() &&
+ info->resultSymbol) {
+ if (info->parsedType) {
scopeHandler_.messageHandler().set_currStmtSource(info->source);
if (const auto *type{
scopeHandler_.ProcessTypeSpec(*info->parsedType, true)}) {
@@ -2715,6 +2729,16 @@ void FuncResultStack::CompleteFunctionResultType() {
}
info->parsedType = nullptr;
}
+ if (TypesMismatchIfNonNull(
+ info->resultSymbol->GetType(), info->previousImplicitType)) {
+ scopeHandler_
+ .Say(info->resultSymbol->name(),
+ "Function '%s' has a result type that differs from the implicit type it obtained in a previous reference"_err_en_US,
+ info->previousName)
+ .Attach(info->previousName,
+ "Previous reference implicitly typed as %s\n"_en_US,
+ info->previousImplicitType->AsFortran());
+ }
}
}
@@ -4764,9 +4788,7 @@ void SubprogramVisitor::Post(const parser::FunctionStmt &stmt) {
if (info.resultName && !distinctResultName) {
context().Warn(common::UsageWarning::HomonymousResult,
info.resultName->source,
- "The function name should not appear in RESULT; references to '%s' "
- "inside the function will be considered as references to the "
- "result only"_warn_en_US,
+ "The function name should not appear in RESULT; references to '%s' inside the function will be considered as references to the result only"_warn_en_US,
name.source);
// RESULT name was ignored above, the only side effect from doing so will be
// the inability to make recursive calls. The related parser::Name is still
@@ -5077,8 +5099,7 @@ bool SubprogramVisitor::BeginSubprogram(const parser::Name &name,
if (hasModulePrefix && !currScope().IsModule() &&
!currScope().IsSubmodule()) { // C1547
Say(name,
- "'%s' is a MODULE procedure which must be declared within a "
- "MODULE or SUBMODULE"_err_en_US);
+ "'%s' is a MODULE procedure which must be declared within a MODULE or SUBMODULE"_err_en_US);
// Don't return here because it can be useful to have the scope set for
// other semantic checks run before we print the errors
isValid = false;
@@ -5199,9 +5220,10 @@ bool SubprogramVisitor::HandlePreviousCalls(
}
}
-void SubprogramVisitor::CheckExtantProc(
+const Symbol *SubprogramVisitor::CheckExtantProc(
const parser::Name &name, Symbol::Flag subpFlag) {
- if (auto *prev{FindSymbol(name)}) {
+ Symbol *prev{FindSymbol(name)};
+ if (prev) {
if (IsDummy(*prev)) {
} else if (auto *entity{prev->detailsIf<EntityDetails>()};
IsPointer(*prev) && entity && !entity->type()) {
@@ -5213,12 +5235,15 @@ void SubprogramVisitor::CheckExtantProc(
SayAlreadyDeclared(name, *prev);
}
}
+ return prev;
}
Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name,
Symbol::Flag subpFlag, const parser::LanguageBindingSpec *bindingSpec,
bool hasModulePrefix) {
Symbol *symbol{GetSpecificFromGeneric(name)};
+ const DeclTypeSpec *previousImplicitType{nullptr};
+ SourceName previousName;
if (!symbol) {
if (bindingSpec && currScope().IsGlobal() &&
std::get<std::optional<parser::ScalarDefaultCharConstantExpr>>(
@@ -5231,14 +5256,25 @@ Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name,
&MakeSymbol(context().GetTempName(currScope()), Attrs{},
MiscDetails{MiscDetails::Kind::ScopeName}));
}
- CheckExtantProc(name, subpFlag);
+ if (const Symbol *previous{CheckExtantProc(name, subpFlag)}) {
+ if (previous->test(Symbol::Flag::Function) &&
+ previous->test(Symbol::Flag::Implicit)) {
+ // Function was implicitly typed in previous compilation unit.
+ previousImplicitType = previous->GetType();
+ previousName = previous->name();
+ }
+ }
symbol = &MakeSymbol(name, SubprogramDetails{});
}
symbol->ReplaceName(name.source);
symbol->set(subpFlag);
PushScope(Scope::Kind::Subprogram, symbol);
if (subpFlag == Symbol::Flag::Function) {
- funcResultStack().Push(currScope(), name.source);
+ auto &funcResultTop{funcResultStack().Push(currScope(), name.source)};
+ funcResultTop.previousImplicitType = previousImplicitType;
+ ;
+ funcResultTop.previousName = previousName;
+ ;
}
if (inInterfaceBlock()) {
auto &details{symbol->get<SubprogramDetails>()};
@@ -7916,7 +7952,7 @@ void ConstructVisitor::Post(const parser::AssociateStmt &x) {
if (ExtractCoarrayRef(expr)) { // C1103
Say("Selector must not be a coindexed object"_err_en_US);
}
- if (evaluate::IsAssumedRank(expr)) {
+ if (IsAssumedRank(expr)) {
Say("Selector must not be assumed-rank"_err_en_US);
}
SetTypeFromAssociation(*symbol);
@@ -8672,11 +8708,6 @@ const parser::Name *DeclarationVisitor::ResolveDataRef(
x.u);
}
-static bool TypesMismatchIfNonNull(
- const DeclTypeSpec *type1, const DeclTypeSpec *type2) {
- return type1 && type2 && *type1 != *type2;
-}
-
// If implicit types are allowed, ensure name is in the symbol table.
// Otherwise, report an error if it hasn't been declared.
const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) {
diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 4eeb1b9..eae22dc 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -12,6 +12,7 @@
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Parser/tools.h"
+#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
@@ -41,11 +42,23 @@ public:
void Post(parser::Name &);
bool Pre(parser::MainProgram &);
+ bool Pre(parser::Module &);
bool Pre(parser::FunctionSubprogram &);
bool Pre(parser::SubroutineSubprogram &);
bool Pre(parser::SeparateModuleSubprogram &);
bool Pre(parser::BlockConstruct &);
+ bool Pre(parser::Block &);
+ bool Pre(parser::DoConstruct &);
+ bool Pre(parser::IfConstruct &);
bool Pre(parser::ActionStmt &);
+ void Post(parser::MainProgram &);
+ void Post(parser::FunctionSubprogram &);
+ void Post(parser::SubroutineSubprogram &);
+ void Post(parser::SeparateModuleSubprogram &);
+ void Post(parser::BlockConstruct &);
+ void Post(parser::Block &);
+ void Post(parser::DoConstruct &);
+ void Post(parser::IfConstruct &);
void Post(parser::ReadStmt &);
void Post(parser::WriteStmt &);
@@ -67,8 +80,15 @@ public:
bool Pre(parser::EndSubroutineStmt &) { return false; }
bool Pre(parser::EndTypeStmt &) { return false; }
+ bool Pre(parser::OmpBlockConstruct &);
+ bool Pre(parser::OpenMPLoopConstruct &);
+ void Post(parser::OmpBlockConstruct &);
+ void Post(parser::OpenMPLoopConstruct &);
+
private:
void FixMisparsedStmtFuncs(parser::SpecificationPart &, parser::Block &);
+ void OpenMPSimdOnly(parser::Block &, bool);
+ void OpenMPSimdOnly(parser::SpecificationPart &);
SemanticsContext &context_;
bool errorOnUnresolvedName_{true};
@@ -96,6 +116,132 @@ static bool ReturnsDataPointer(const Symbol &symbol) {
return false;
}
+static bool LoopConstructIsSIMD(parser::OpenMPLoopConstruct *ompLoop) {
+ auto &begin = std::get<parser::OmpBeginLoopDirective>(ompLoop->t);
+ auto directive = std::get<parser::OmpLoopDirective>(begin.t).v;
+ return llvm::omp::allSimdSet.test(directive);
+}
+
+// Remove non-SIMD OpenMPConstructs once they are parsed.
+// This massively simplifies the logic inside the SimdOnlyPass for
+// -fopenmp-simd.
+void RewriteMutator::OpenMPSimdOnly(parser::SpecificationPart &specPart) {
+ auto &list{std::get<std::list<parser::DeclarationConstruct>>(specPart.t)};
+ for (auto it{list.begin()}; it != list.end();) {
+ if (auto *specConstr{std::get_if<parser::SpecificationConstruct>(&it->u)}) {
+ if (auto *ompDecl{std::get_if<
+ common::Indirection<parser::OpenMPDeclarativeConstruct>>(
+ &specConstr->u)}) {
+ if (std::holds_alternative<parser::OpenMPThreadprivate>(
+ ompDecl->value().u) ||
+ std::holds_alternative<parser::OpenMPDeclareMapperConstruct>(
+ ompDecl->value().u)) {
+ it = list.erase(it);
+ continue;
+ }
+ }
+ }
+ ++it;
+ }
+}
+
+// Remove non-SIMD OpenMPConstructs once they are parsed.
+// This massively simplifies the logic inside the SimdOnlyPass for
+// -fopenmp-simd. `isNonSimdLoopBody` should be set to true if `block` is the
+// body of a non-simd OpenMP loop. This is to indicate that scan constructs
+// should be removed from the body, where they would be kept if it were a simd
+// loop.
+void RewriteMutator::OpenMPSimdOnly(
+ parser::Block &block, bool isNonSimdLoopBody = false) {
+ auto replaceInlineBlock =
+ [&](std::list<parser::ExecutionPartConstruct> &innerBlock,
+ auto it) -> auto {
+ auto insertPos = std::next(it);
+ block.splice(insertPos, innerBlock);
+ block.erase(it);
+ return insertPos;
+ };
+
+ for (auto it{block.begin()}; it != block.end();) {
+ if (auto *stmt{std::get_if<parser::ExecutableConstruct>(&it->u)}) {
+ if (auto *omp{std::get_if<common::Indirection<parser::OpenMPConstruct>>(
+ &stmt->u)}) {
+ if (auto *ompStandalone{std::get_if<parser::OpenMPStandaloneConstruct>(
+ &omp->value().u)}) {
+ if (std::holds_alternative<parser::OpenMPCancelConstruct>(
+ ompStandalone->u) ||
+ std::holds_alternative<parser::OpenMPFlushConstruct>(
+ ompStandalone->u) ||
+ std::holds_alternative<parser::OpenMPCancellationPointConstruct>(
+ ompStandalone->u)) {
+ it = block.erase(it);
+ continue;
+ }
+ if (auto *constr{std::get_if<parser::OpenMPSimpleStandaloneConstruct>(
+ &ompStandalone->u)}) {
+ auto directive = constr->v.DirId();
+ // Scan should only be removed from non-simd loops
+ if (llvm::omp::simpleStandaloneNonSimdOnlySet.test(directive) ||
+ (isNonSimdLoopBody && directive == llvm::omp::OMPD_scan)) {
+ it = block.erase(it);
+ continue;
+ }
+ }
+ } else if (auto *ompBlock{std::get_if<parser::OmpBlockConstruct>(
+ &omp->value().u)}) {
+ it = replaceInlineBlock(std::get<parser::Block>(ompBlock->t), it);
+ continue;
+ } else if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(
+ &omp->value().u)}) {
+ if (LoopConstructIsSIMD(ompLoop)) {
+ ++it;
+ continue;
+ }
+ auto &nest =
+ std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
+
+ if (auto *doConstruct =
+ std::get_if<parser::DoConstruct>(&nest.value())) {
+ auto &loopBody = std::get<parser::Block>(doConstruct->t);
+ // We can only remove some constructs from a loop when it's _not_ a
+ // OpenMP simd loop
+ OpenMPSimdOnly(loopBody, /*isNonSimdLoopBody=*/true);
+ auto newDoConstruct = std::move(*doConstruct);
+ auto newLoop = parser::ExecutionPartConstruct{
+ parser::ExecutableConstruct{std::move(newDoConstruct)}};
+ it = block.erase(it);
+ block.insert(it, std::move(newLoop));
+ continue;
+ }
+ } else if (auto *ompCon{std::get_if<parser::OpenMPSectionsConstruct>(
+ &omp->value().u)}) {
+ auto &sections =
+ std::get<std::list<parser::OpenMPConstruct>>(ompCon->t);
+ auto insertPos = std::next(it);
+ for (auto &sectionCon : sections) {
+ auto &section =
+ std::get<parser::OpenMPSectionConstruct>(sectionCon.u);
+ auto &innerBlock = std::get<parser::Block>(section.t);
+ block.splice(insertPos, innerBlock);
+ }
+ block.erase(it);
+ it = insertPos;
+ continue;
+ } else if (auto *atomic{std::get_if<parser::OpenMPAtomicConstruct>(
+ &omp->value().u)}) {
+ it = replaceInlineBlock(std::get<parser::Block>(atomic->t), it);
+ continue;
+ } else if (auto *critical{std::get_if<parser::OpenMPCriticalConstruct>(
+ &omp->value().u)}) {
+ it = replaceInlineBlock(std::get<parser::Block>(critical->t), it);
+ continue;
+ }
+ }
+ }
+ ++it;
+ }
+}
+
// Finds misparsed statement functions in a specification part, rewrites
// them into array element assignment statements, and moves them into the
// beginning of the corresponding (execution part's) block.
@@ -133,33 +279,155 @@ void RewriteMutator::FixMisparsedStmtFuncs(
bool RewriteMutator::Pre(parser::MainProgram &program) {
FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(program.t),
std::get<parser::ExecutionPart>(program.t).v);
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(program.t).v);
+ OpenMPSimdOnly(std::get<parser::SpecificationPart>(program.t));
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::MainProgram &program) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(program.t).v);
+ }
+}
+
+bool RewriteMutator::Pre(parser::Module &module) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::SpecificationPart>(module.t));
+ }
return true;
}
bool RewriteMutator::Pre(parser::FunctionSubprogram &func) {
FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(func.t),
std::get<parser::ExecutionPart>(func.t).v);
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(func.t).v);
+ }
return true;
}
+void RewriteMutator::Post(parser::FunctionSubprogram &func) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(func.t).v);
+ }
+}
+
bool RewriteMutator::Pre(parser::SubroutineSubprogram &subr) {
FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(subr.t),
std::get<parser::ExecutionPart>(subr.t).v);
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(subr.t).v);
+ }
return true;
}
+void RewriteMutator::Post(parser::SubroutineSubprogram &subr) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(subr.t).v);
+ }
+}
+
bool RewriteMutator::Pre(parser::SeparateModuleSubprogram &subp) {
FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(subp.t),
std::get<parser::ExecutionPart>(subp.t).v);
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(subp.t).v);
+ }
return true;
}
+void RewriteMutator::Post(parser::SeparateModuleSubprogram &subp) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::ExecutionPart>(subp.t).v);
+ }
+}
+
bool RewriteMutator::Pre(parser::BlockConstruct &block) {
FixMisparsedStmtFuncs(std::get<parser::BlockSpecificationPart>(block.t).v,
std::get<parser::Block>(block.t));
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::Block>(block.t));
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::BlockConstruct &block) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(std::get<parser::Block>(block.t));
+ }
+}
+
+bool RewriteMutator::Pre(parser::Block &block) {
+ if (context_.langOptions().OpenMPSimd) {
+ OpenMPSimdOnly(block);
+ }
return true;
}
+void RewriteMutator::Post(parser::Block &block) { this->Pre(block); }
+
+bool RewriteMutator::Pre(parser::OmpBlockConstruct &block) {
+ if (context_.langOptions().OpenMPSimd) {
+ auto &innerBlock = std::get<parser::Block>(block.t);
+ OpenMPSimdOnly(innerBlock);
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::OmpBlockConstruct &block) {
+ this->Pre(block);
+}
+
+bool RewriteMutator::Pre(parser::OpenMPLoopConstruct &ompLoop) {
+ if (context_.langOptions().OpenMPSimd) {
+ if (LoopConstructIsSIMD(&ompLoop)) {
+ return true;
+ }
+ // If we're looking at a non-simd OpenMP loop, we need to explicitly
+ // call OpenMPSimdOnly on the nested loop block while indicating where
+ // the block comes from.
+ auto &nest = std::get<std::optional<parser::NestedConstruct>>(ompLoop.t);
+ if (!nest.has_value()) {
+ return true;
+ }
+ if (auto *doConstruct = std::get_if<parser::DoConstruct>(&*nest)) {
+ auto &innerBlock = std::get<parser::Block>(doConstruct->t);
+ OpenMPSimdOnly(innerBlock, /*isNonSimdLoopBody=*/true);
+ }
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::OpenMPLoopConstruct &ompLoop) {
+ this->Pre(ompLoop);
+}
+
+bool RewriteMutator::Pre(parser::DoConstruct &doConstruct) {
+ if (context_.langOptions().OpenMPSimd) {
+ auto &innerBlock = std::get<parser::Block>(doConstruct.t);
+ OpenMPSimdOnly(innerBlock);
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::DoConstruct &doConstruct) {
+ this->Pre(doConstruct);
+}
+
+bool RewriteMutator::Pre(parser::IfConstruct &ifConstruct) {
+ if (context_.langOptions().OpenMPSimd) {
+ auto &innerBlock = std::get<parser::Block>(ifConstruct.t);
+ OpenMPSimdOnly(innerBlock);
+ }
+ return true;
+}
+
+void RewriteMutator::Post(parser::IfConstruct &ifConstruct) {
+ this->Pre(ifConstruct);
+}
+
// Rewrite PRINT NML -> WRITE(*,NML=NML)
bool RewriteMutator::Pre(parser::ActionStmt &x) {
if (auto *print{std::get_if<common::Indirection<parser::PrintStmt>>(&x.u)};
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index 2259cfc..a6b402c 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -611,7 +611,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) {
sep = ',';
}
},
- [](const HostAssocDetails &) {},
+ [&os](const HostAssocDetails &x) { os << " => " << x.symbol(); },
[&](const ProcBindingDetails &x) {
os << " => " << x.symbol().name();
DumpOptional(os, "passName", x.passName());
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index 913bf08..28829d3 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -705,7 +705,7 @@ SymbolVector FinalsForDerivedTypeInstantiation(const DerivedTypeSpec &spec) {
const Symbol *IsFinalizable(const Symbol &symbol,
std::set<const DerivedTypeSpec *> *inProgress, bool withImpureFinalizer) {
- if (IsPointer(symbol) || evaluate::IsAssumedRank(symbol)) {
+ if (IsPointer(symbol) || IsAssumedRank(symbol)) {
return nullptr;
}
if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()}) {
@@ -741,7 +741,7 @@ const Symbol *IsFinalizable(const DerivedTypeSpec &derived,
if (const SubprogramDetails *
subp{symbol->detailsIf<SubprogramDetails>()}) {
if (const auto &args{subp->dummyArgs()}; !args.empty() &&
- args.at(0) && !evaluate::IsAssumedRank(*args.at(0)) &&
+ args.at(0) && !IsAssumedRank(*args.at(0)) &&
args.at(0)->Rank() != *rank) {
continue; // not a finalizer for this rank
}
@@ -790,7 +790,7 @@ const Symbol *HasImpureFinal(const Symbol &original, std::optional<int> rank) {
if (symbol.has<ObjectEntityDetails>()) {
if (const DeclTypeSpec * symType{symbol.GetType()}) {
if (const DerivedTypeSpec * derived{symType->AsDerived()}) {
- if (evaluate::IsAssumedRank(symbol)) {
+ if (IsAssumedRank(symbol)) {
// finalizable assumed-rank not allowed (C839)
return nullptr;
} else {
@@ -1170,7 +1170,7 @@ bool IsAccessible(const Symbol &original, const Scope &scope) {
}
std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
- const Scope &scope, const Symbol &symbol) {
+ const Scope &scope, const Symbol &symbol, bool inStructureConstructor) {
if (IsAccessible(symbol, scope)) {
return std::nullopt;
} else if (FindModuleFileContaining(scope)) {
@@ -1179,10 +1179,20 @@ std::optional<parser::MessageFormattedText> CheckAccessibleSymbol(
// whose structure constructors reference private components.
return std::nullopt;
} else {
+ const Scope &module{DEREF(FindModuleContaining(symbol.owner()))};
+ // Subtlety: Sometimes we want to be able to convert a generated
+ // module file back into Fortran, perhaps to convert it into a
+ // hermetic module file. Don't emit a fatal error for things like
+ // "__builtin_c_ptr(__address=0)" that came from expansions of
+ // "cptr_null()"; specifically, just warn about structure constructor
+ // component names from intrinsic modules when in a module.
+ parser::MessageFixedText text{FindModuleContaining(scope) &&
+ module.parent().IsIntrinsicModules() &&
+ inStructureConstructor && symbol.owner().IsDerivedType()
+ ? "PRIVATE name '%s' is accessible only within module '%s'"_warn_en_US
+ : "PRIVATE name '%s' is accessible only within module '%s'"_err_en_US};
return parser::MessageFormattedText{
- "PRIVATE name '%s' is accessible only within module '%s'"_err_en_US,
- symbol.name(),
- DEREF(FindModuleContaining(symbol.owner())).GetName().value()};
+ std::move(text), symbol.name(), module.GetName().value()};
}
}
diff --git a/flang/lib/Semantics/unparse-with-symbols.cpp b/flang/lib/Semantics/unparse-with-symbols.cpp
index 3093e39..b199481 100644
--- a/flang/lib/Semantics/unparse-with-symbols.cpp
+++ b/flang/lib/Semantics/unparse-with-symbols.cpp
@@ -47,6 +47,11 @@ public:
return true;
}
void Post(const parser::OmpClause &) { currStmt_ = std::nullopt; }
+ bool Pre(const parser::OpenMPGroupprivate &dir) {
+ currStmt_ = dir.source;
+ return true;
+ }
+ void Post(const parser::OpenMPGroupprivate &) { currStmt_ = std::nullopt; }
bool Pre(const parser::OpenMPThreadprivate &dir) {
currStmt_ = dir.source;
return true;
@@ -70,20 +75,6 @@ public:
currStmt_ = std::nullopt;
}
- bool Pre(const parser::OmpCriticalDirective &x) {
- currStmt_ = x.source;
- return true;
- }
- void Post(const parser::OmpCriticalDirective &) { currStmt_ = std::nullopt; }
-
- bool Pre(const parser::OmpEndCriticalDirective &x) {
- currStmt_ = x.source;
- return true;
- }
- void Post(const parser::OmpEndCriticalDirective &) {
- currStmt_ = std::nullopt;
- }
-
// Directive arguments can be objects with symbols.
bool Pre(const parser::OmpBeginDirective &x) {
currStmt_ = x.source;
diff --git a/flang/lib/Support/Fortran-features.cpp b/flang/lib/Support/Fortran-features.cpp
index df51b3c..4a6fb8d 100644
--- a/flang/lib/Support/Fortran-features.cpp
+++ b/flang/lib/Support/Fortran-features.cpp
@@ -90,6 +90,7 @@ LanguageFeatureControl::LanguageFeatureControl() {
disable_.set(LanguageFeature::OldStyleParameter);
// Possibly an accidental "feature" of nvfortran.
disable_.set(LanguageFeature::AssumedRankPassedToNonAssumedRank);
+ disable_.set(LanguageFeature::Coarray);
// These warnings are enabled by default, but only because they used
// to be unconditional. TODO: prune this list
warnLanguage_.set(LanguageFeature::ExponentMatchingKindParam);
@@ -147,6 +148,7 @@ LanguageFeatureControl::LanguageFeatureControl() {
warnUsage_.set(UsageWarning::UseAssociationIntoSameNameSubprogram);
warnUsage_.set(UsageWarning::HostAssociatedIntentOutInSpecExpr);
warnUsage_.set(UsageWarning::NonVolatilePointerToVolatile);
+ warnUsage_.set(UsageWarning::RealConstantWidening);
// New warnings, on by default
warnLanguage_.set(LanguageFeature::SavedLocalInSpecExpr);
warnLanguage_.set(LanguageFeature::NullActualForAllocatable);
diff --git a/flang/lib/Support/Fortran.cpp b/flang/lib/Support/Fortran.cpp
index 8e286be..3a8ebbb 100644
--- a/flang/lib/Support/Fortran.cpp
+++ b/flang/lib/Support/Fortran.cpp
@@ -103,8 +103,8 @@ std::string AsFortran(IgnoreTKRSet tkr) {
/// dummy argument attribute while `y` represents the actual argument attribute.
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
- std::optional<std::string> *warning, bool allowUnifiedMatchingRule,
- bool isHostDeviceProcedure, const LanguageFeatureControl *features) {
+ bool allowUnifiedMatchingRule, bool isHostDeviceProcedure,
+ const LanguageFeatureControl *features) {
bool isCudaManaged{features
? features->IsEnabled(common::LanguageFeature::CudaManaged)
: false};
@@ -145,9 +145,6 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
*y == CUDADataAttr::Shared ||
*y == CUDADataAttr::Constant)) ||
(!y && (isCudaUnified || isCudaManaged))) {
- if (y && *y == CUDADataAttr::Shared && warning) {
- *warning = "SHARED attribute ignored"s;
- }
return true;
}
} else if (*x == CUDADataAttr::Managed) {
diff --git a/flang/lib/Utils/CMakeLists.txt b/flang/lib/Utils/CMakeLists.txt
new file mode 100644
index 0000000..2119b0e
--- /dev/null
+++ b/flang/lib/Utils/CMakeLists.txt
@@ -0,0 +1,20 @@
+#===-- lib/Utils/CMakeLists.txt --------------------------------------------===#
+#
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+#===------------------------------------------------------------------------===#
+
+add_flang_library(FortranUtils
+ OpenMP.cpp
+
+ DEPENDS
+ FIRDialect
+
+ LINK_LIBS
+ FIRDialect
+
+ MLIR_LIBS
+ MLIROpenMPDialect
+)
diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp
new file mode 100644
index 0000000..e1681e9
--- /dev/null
+++ b/flang/lib/Utils/OpenMP.cpp
@@ -0,0 +1,47 @@
+//===-- lib/Utisl/OpenMP.cpp ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Utils/OpenMP.h"
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace Fortran::utils::openmp {
+mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder,
+ mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr,
+ llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds,
+ llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
+ uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
+ mlir::Type retTy, bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
+
+ if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
+ baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr);
+ retTy = baseAddr.getType();
+ }
+
+ mlir::TypeAttr varType = mlir::TypeAttr::get(
+ llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
+
+ // For types with unknown extents such as <2x?xi32> we discard the incomplete
+ // type info and only retain the base type. The correct dimensions are later
+ // recovered through the bounds info.
+ if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue()))
+ if (seqType.hasDynamicExtents())
+ varType = mlir::TypeAttr::get(seqType.getEleTy());
+
+ mlir::omp::MapInfoOp op =
+ mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varType,
+ builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
+ varPtrPtr, members, membersIndex, bounds, mapperId,
+ builder.getStringAttr(name), builder.getBoolAttr(partialMap));
+ return op;
+}
+} // namespace Fortran::utils::openmp