diff options
Diffstat (limited to 'flang/lib')
126 files changed, 5718 insertions, 2774 deletions
diff --git a/flang/lib/CMakeLists.txt b/flang/lib/CMakeLists.txt index 8b201d9..528e7b5 100644 --- a/flang/lib/CMakeLists.txt +++ b/flang/lib/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(Semantics) add_subdirectory(Support) add_subdirectory(Frontend) add_subdirectory(FrontendTool) +add_subdirectory(Utils) add_subdirectory(Optimizer) diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp index 8954773..37c62c9 100644 --- a/flang/lib/Evaluate/characteristics.cpp +++ b/flang/lib/Evaluate/characteristics.cpp @@ -400,7 +400,7 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual, } if (!attrs.test(Attr::Value) && !common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr, - ignoreTKR, warning, + ignoreTKR, /*allowUnifiedMatchingRule=*/false, /*=isHostDeviceProcedure*/ false)) { if (whyNot) { @@ -1816,7 +1816,7 @@ bool DistinguishUtils::Distinguishable( x.intent != common::Intent::In) { return true; } else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr, - x.ignoreTKR | y.ignoreTKR, nullptr, + x.ignoreTKR | y.ignoreTKR, /*allowUnifiedMatchingRule=*/false, /*=isHostDeviceProcedure*/ false)) { return true; diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp index 3d7f01d..394a033 100644 --- a/flang/lib/Evaluate/check-expression.cpp +++ b/flang/lib/Evaluate/check-expression.cpp @@ -405,6 +405,88 @@ bool IsInitialProcedureTarget(const Expr<SomeType> &expr) { } } +class SuspiciousRealLiteralFinder + : public AnyTraverse<SuspiciousRealLiteralFinder> { +public: + using Base = AnyTraverse<SuspiciousRealLiteralFinder>; + SuspiciousRealLiteralFinder(int kind, FoldingContext &c) + : Base{*this}, kind_{kind}, context_{c} {} + using Base::operator(); + template <int KIND> + bool operator()(const Constant<Type<TypeCategory::Real, KIND>> &x) const { + if (kind_ > KIND && x.result().isFromInexactLiteralConversion()) { + context_.Warn(common::UsageWarning::RealConstantWidening, + "Default real literal in REAL(%d) context might need a kind suffix, as its rounded value %s is inexact"_warn_en_US, + kind_, x.AsFortran()); + return true; + } else { + return false; + } + } + template <int KIND> + bool operator()(const Constant<Type<TypeCategory::Complex, KIND>> &x) const { + if (kind_ > KIND && x.result().isFromInexactLiteralConversion()) { + context_.Warn(common::UsageWarning::RealConstantWidening, + "Default real literal in COMPLEX(%d) context might need a kind suffix, as its rounded value %s is inexact"_warn_en_US, + kind_, x.AsFortran()); + return true; + } else { + return false; + } + } + template <TypeCategory TOCAT, int TOKIND, TypeCategory FROMCAT> + bool operator()(const Convert<Type<TOCAT, TOKIND>, FROMCAT> &x) const { + if constexpr ((TOCAT == TypeCategory::Real || + TOCAT == TypeCategory::Complex) && + (FROMCAT == TypeCategory::Real || FROMCAT == TypeCategory::Complex)) { + auto fromType{x.left().GetType()}; + if (!fromType || fromType->kind() < TOKIND) { + return false; + } + } + return (*this)(x.left()); + } + +private: + int kind_; + FoldingContext &context_; +}; + +void CheckRealWidening(const Expr<SomeType> &expr, const DynamicType &toType, + FoldingContext &context) { + if (toType.category() == TypeCategory::Real || + toType.category() == TypeCategory::Complex) { + if (auto fromType{expr.GetType()}) { + if ((fromType->category() == TypeCategory::Real || + fromType->category() == TypeCategory::Complex) && + toType.kind() > fromType->kind()) { + SuspiciousRealLiteralFinder{toType.kind(), context}(expr); + } + } + } +} + +void CheckRealWidening(const Expr<SomeType> &expr, + const std::optional<DynamicType> &toType, FoldingContext &context) { + if (toType) { + CheckRealWidening(expr, *toType, context); + } +} + +class InexactLiteralConversionFlagClearer + : public AnyTraverse<InexactLiteralConversionFlagClearer> { +public: + using Base = AnyTraverse<InexactLiteralConversionFlagClearer>; + InexactLiteralConversionFlagClearer() : Base(*this) {} + using Base::operator(); + template <int KIND> + bool operator()(const Constant<Type<TypeCategory::Real, KIND>> &x) const { + auto &mut{const_cast<Type<TypeCategory::Real, KIND> &>(x.result())}; + mut.set_isFromInexactLiteralConversion(false); + return false; + } +}; + // Converts, folds, and then checks type, rank, and shape of an // initialization expression for a named constant, a non-pointer // variable static initialization, a component default initializer, @@ -416,16 +498,14 @@ std::optional<Expr<SomeType>> NonPointerInitializationExpr(const Symbol &symbol, if (auto symTS{ characteristics::TypeAndShape::Characterize(symbol, context)}) { auto xType{x.GetType()}; + CheckRealWidening(x, symTS->type(), context); auto converted{ConvertToType(symTS->type(), Expr<SomeType>{x})}; if (!converted && symbol.owner().context().IsEnabled( common::LanguageFeature::LogicalIntegerAssignment)) { converted = DataConstantConversionExtension(context, symTS->type(), x); - if (converted && - symbol.owner().context().ShouldWarn( - common::LanguageFeature::LogicalIntegerAssignment)) { - context.messages().Say( - common::LanguageFeature::LogicalIntegerAssignment, + if (converted) { + context.Warn(common::LanguageFeature::LogicalIntegerAssignment, "nonstandard usage: initialization of %s with %s"_port_en_US, symTS->type().AsFortran(), x.GetType().value().AsFortran()); } @@ -433,6 +513,7 @@ std::optional<Expr<SomeType>> NonPointerInitializationExpr(const Symbol &symbol, if (converted) { auto folded{Fold(context, std::move(*converted))}; if (IsActuallyConstant(folded)) { + InexactLiteralConversionFlagClearer{}(folded); int symRank{symTS->Rank()}; if (IsImpliedShape(symbol)) { if (folded.Rank() == symRank) { @@ -579,10 +660,8 @@ public: // host-associated dummy argument, and that doesn't seem like a // good idea. if (!inInquiry_ && hasHostAssociation && - ultimate.attrs().test(semantics::Attr::INTENT_OUT) && - context_.languageFeatures().ShouldWarn( - common::UsageWarning::HostAssociatedIntentOutInSpecExpr)) { - context_.messages().Say( + ultimate.attrs().test(semantics::Attr::INTENT_OUT)) { + context_.Warn(common::UsageWarning::HostAssociatedIntentOutInSpecExpr, "specification expression refers to host-associated INTENT(OUT) dummy argument '%s'"_port_en_US, ultimate.name()); } @@ -593,13 +672,9 @@ public: } else if (isInitialized && context_.languageFeatures().IsEnabled( common::LanguageFeature::SavedLocalInSpecExpr)) { - if (!scope_.IsModuleFile() && - context_.languageFeatures().ShouldWarn( - common::LanguageFeature::SavedLocalInSpecExpr)) { - context_.messages().Say(common::LanguageFeature::SavedLocalInSpecExpr, - "specification expression refers to local object '%s' (initialized and saved)"_port_en_US, - ultimate.name()); - } + context_.Warn(common::LanguageFeature::SavedLocalInSpecExpr, + "specification expression refers to local object '%s' (initialized and saved)"_port_en_US, + ultimate.name()); return std::nullopt; } else if (const auto *object{ ultimate.detailsIf<semantics::ObjectEntityDetails>()}) { @@ -917,8 +992,8 @@ public: } else { return Base::operator()(ultimate); // use expr } - } else if (semantics::IsPointer(ultimate) || - semantics::IsAssumedShape(ultimate) || IsAssumedRank(ultimate)) { + } else if (semantics::IsPointer(ultimate) || IsAssumedShape(ultimate) || + IsAssumedRank(ultimate)) { return std::nullopt; } else if (ultimate.has<semantics::ObjectEntityDetails>()) { return true; @@ -1198,9 +1273,21 @@ std::optional<bool> IsContiguous(const A &x, FoldingContext &context, } } +std::optional<bool> IsContiguous(const ActualArgument &actual, + FoldingContext &fc, bool namedConstantSectionsAreContiguous, + bool firstDimensionStride1) { + auto *expr{actual.UnwrapExpr()}; + return expr && + IsContiguous( + *expr, fc, namedConstantSectionsAreContiguous, firstDimensionStride1); +} + template std::optional<bool> IsContiguous(const Expr<SomeType> &, FoldingContext &, bool namedConstantSectionsAreContiguous, bool firstDimensionStride1); +template std::optional<bool> IsContiguous(const ActualArgument &, + FoldingContext &, bool namedConstantSectionsAreContiguous, + bool firstDimensionStride1); template std::optional<bool> IsContiguous(const ArrayRef &, FoldingContext &, bool namedConstantSectionsAreContiguous, bool firstDimensionStride1); template std::optional<bool> IsContiguous(const Substring &, FoldingContext &, @@ -1350,4 +1437,177 @@ std::optional<parser::Message> CheckStatementFunction( return StmtFunctionChecker{sf, context}(expr); } +// Helper class for checking differences between actual and dummy arguments +class CopyInOutExplicitInterface { +public: + explicit CopyInOutExplicitInterface(FoldingContext &fc, + const ActualArgument &actual, + const characteristics::DummyDataObject &dummyObj) + : fc_{fc}, actual_{actual}, dummyObj_{dummyObj} {} + + // Returns true, if actual and dummy have different contiguity requirements + bool HaveContiguityDifferences() const { + // Check actual contiguity, unless dummy doesn't care + bool dummyTreatAsArray{dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank)}; + bool actualTreatAsContiguous{ + dummyObj_.ignoreTKR.test(common::IgnoreTKR::Contiguous) || + IsSimplyContiguous(actual_, fc_)}; + bool dummyIsExplicitShape{dummyObj_.type.IsExplicitShape()}; + bool dummyIsAssumedSize{dummyObj_.type.attrs().test( + characteristics::TypeAndShape::Attr::AssumedSize)}; + bool dummyIsPolymorphic{dummyObj_.type.type().IsPolymorphic()}; + // type(*) with IGNORE_TKR(tkr) is often used to interface with C "void*". + // Since the other languages don't know about Fortran's discontiguity + // handling, such cases should require contiguity. + bool dummyIsVoidStar{dummyObj_.type.type().IsAssumedType() && + dummyObj_.ignoreTKR.test(common::IgnoreTKR::Type) && + dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank) && + dummyObj_.ignoreTKR.test(common::IgnoreTKR::Kind)}; + // Explicit shape and assumed size arrays must be contiguous + bool dummyNeedsContiguity{dummyIsExplicitShape || dummyIsAssumedSize || + (dummyTreatAsArray && !dummyIsPolymorphic) || dummyIsVoidStar || + dummyObj_.attrs.test( + characteristics::DummyDataObject::Attr::Contiguous)}; + return !actualTreatAsContiguous && dummyNeedsContiguity; + } + + // Returns true, if actual and dummy have polymorphic differences + bool HavePolymorphicDifferences() const { + bool dummyIsAssumedRank{dummyObj_.type.attrs().test( + characteristics::TypeAndShape::Attr::AssumedRank)}; + bool actualIsAssumedRank{semantics::IsAssumedRank(actual_)}; + bool dummyIsAssumedShape{dummyObj_.type.attrs().test( + characteristics::TypeAndShape::Attr::AssumedShape)}; + bool actualIsAssumedShape{semantics::IsAssumedShape(actual_)}; + if ((actualIsAssumedRank && dummyIsAssumedRank) || + (actualIsAssumedShape && dummyIsAssumedShape)) { + // Assumed-rank and assumed-shape arrays are represented by descriptors, + // so don't need to do polymorphic check. + } else if (!dummyObj_.ignoreTKR.test(common::IgnoreTKR::Type)) { + // flang supports limited cases of passing polymorphic to non-polimorphic. + // These cases require temporary of non-polymorphic type. (For example, + // the actual argument could be polymorphic array of child type, + // while the dummy argument could be non-polymorphic array of parent + // type.) + bool dummyIsPolymorphic{dummyObj_.type.type().IsPolymorphic()}; + auto actualType{ + characteristics::TypeAndShape::Characterize(actual_, fc_)}; + bool actualIsPolymorphic{ + actualType && actualType->type().IsPolymorphic()}; + if (actualIsPolymorphic && !dummyIsPolymorphic) { + return true; + } + } + return false; + } + + bool HaveArrayOrAssumedRankArgs() const { + bool dummyTreatAsArray{dummyObj_.ignoreTKR.test(common::IgnoreTKR::Rank)}; + return IsArrayOrAssumedRank(actual_) && + (IsArrayOrAssumedRank(dummyObj_) || dummyTreatAsArray); + } + + bool PassByValue() const { + return dummyObj_.attrs.test(characteristics::DummyDataObject::Attr::Value); + } + + bool HaveCoarrayDifferences() const { + return ExtractCoarrayRef(actual_) && dummyObj_.type.corank() == 0; + } + + bool HasIntentOut() const { return dummyObj_.intent == common::Intent::Out; } + + bool HasIntentIn() const { return dummyObj_.intent == common::Intent::In; } + + static bool IsArrayOrAssumedRank(const ActualArgument &actual) { + return semantics::IsAssumedRank(actual) || actual.Rank() > 0; + } + + static bool IsArrayOrAssumedRank( + const characteristics::DummyDataObject &dummy) { + return dummy.type.attrs().test( + characteristics::TypeAndShape::Attr::AssumedRank) || + dummy.type.Rank() > 0; + } + +private: + FoldingContext &fc_; + const ActualArgument &actual_; + const characteristics::DummyDataObject &dummyObj_; +}; + +// If forCopyOut is false, returns if a particular actual/dummy argument +// combination may need a temporary creation with copy-in operation. If +// forCopyOut is true, returns the same for copy-out operation. For +// procedures with explicit interface, it's expected that "dummy" is not null. +// For procedures with implicit interface dummy may be null. +// +// Note that these copy-in and copy-out checks are done from the caller's +// perspective, meaning that for copy-in the caller need to do the copy +// before calling the callee. Similarly, for copy-out the caller is expected +// to do the copy after the callee returns. +bool MayNeedCopy(const ActualArgument *actual, + const characteristics::DummyArgument *dummy, FoldingContext &fc, + bool forCopyOut) { + if (!actual) { + return false; + } + if (actual->isAlternateReturn()) { + return false; + } + const auto *dummyObj{dummy + ? std::get_if<characteristics::DummyDataObject>(&dummy->u) + : nullptr}; + const bool forCopyIn = !forCopyOut; + if (!evaluate::IsVariable(*actual)) { + // Actual argument expressions that aren’t variables are copy-in, but + // not copy-out. + return forCopyIn; + } + if (dummyObj) { // Explict interface + CopyInOutExplicitInterface check{fc, *actual, *dummyObj}; + if (forCopyOut && check.HasIntentIn()) { + // INTENT(IN) dummy args never need copy-out + return false; + } + if (forCopyIn && check.HasIntentOut()) { + // INTENT(OUT) dummy args never need copy-in + return false; + } + if (check.PassByValue()) { + // Pass by value, always copy-in, never copy-out + return forCopyIn; + } + if (check.HaveCoarrayDifferences()) { + return true; + } + // Note: contiguity and polymorphic checks deal with array or assumed rank + // arguments + if (!check.HaveArrayOrAssumedRankArgs()) { + return false; + } + if (check.HaveContiguityDifferences()) { + return true; + } + if (check.HavePolymorphicDifferences()) { + return true; + } + } else { // Implicit interface + if (ExtractCoarrayRef(*actual)) { + // Coindexed actual args may need copy-in and copy-out with implicit + // interface + return true; + } + if (!IsSimplyContiguous(*actual, fc)) { + // Copy-in: actual arguments that are variables are copy-in when + // non-contiguous. + // Copy-out: vector subscripts could refer to duplicate elements, can't + // copy out. + return !(forCopyOut && HasVectorSubscript(*actual)); + } + } + // For everything else, no copy-in or copy-out + return false; +} + } // namespace Fortran::evaluate diff --git a/flang/lib/Evaluate/common.cpp b/flang/lib/Evaluate/common.cpp index 6a960d4..46c75a5 100644 --- a/flang/lib/Evaluate/common.cpp +++ b/flang/lib/Evaluate/common.cpp @@ -16,26 +16,22 @@ namespace Fortran::evaluate { void RealFlagWarnings( FoldingContext &context, const RealFlags &flags, const char *operation) { static constexpr auto warning{common::UsageWarning::FoldingException}; - if (context.languageFeatures().ShouldWarn(warning)) { - if (flags.test(RealFlag::Overflow)) { - context.messages().Say(warning, "overflow on %s"_warn_en_US, operation); - } - if (flags.test(RealFlag::DivideByZero)) { - if (std::strcmp(operation, "division") == 0) { - context.messages().Say(warning, "division by zero"_warn_en_US); - } else { - context.messages().Say( - warning, "division by zero on %s"_warn_en_US, operation); - } - } - if (flags.test(RealFlag::InvalidArgument)) { - context.messages().Say( - warning, "invalid argument on %s"_warn_en_US, operation); - } - if (flags.test(RealFlag::Underflow)) { - context.messages().Say(warning, "underflow on %s"_warn_en_US, operation); + if (flags.test(RealFlag::Overflow)) { + context.Warn(warning, "overflow on %s"_warn_en_US, operation); + } + if (flags.test(RealFlag::DivideByZero)) { + if (std::strcmp(operation, "division") == 0) { + context.Warn(warning, "division by zero"_warn_en_US); + } else { + context.Warn(warning, "division by zero on %s"_warn_en_US, operation); } } + if (flags.test(RealFlag::InvalidArgument)) { + context.Warn(warning, "invalid argument on %s"_warn_en_US, operation); + } + if (flags.test(RealFlag::Underflow)) { + context.Warn(warning, "underflow on %s"_warn_en_US, operation); + } } ConstantSubscript &FoldingContext::StartImpliedDo( diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp index 76ac497..a43742a 100644 --- a/flang/lib/Evaluate/fold-character.cpp +++ b/flang/lib/Evaluate/fold-character.cpp @@ -58,13 +58,10 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction( return FoldElementalIntrinsic<T, IntT>(context, std::move(funcRef), ScalarFunc<T, IntT>([&](const Scalar<IntT> &i) { if (i.IsNegative() || i.BGE(Scalar<IntT>{0}.IBSET(8 * KIND))) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { - context.messages().Say(common::UsageWarning::FoldingValueChecks, - "%s(I=%jd) is out of range for CHARACTER(KIND=%d)"_warn_en_US, - parser::ToUpperCaseLetters(name), - static_cast<std::intmax_t>(i.ToInt64()), KIND); - } + context.Warn(common::UsageWarning::FoldingValueChecks, + "%s(I=%jd) is out of range for CHARACTER(KIND=%d)"_warn_en_US, + parser::ToUpperCaseLetters(name), + static_cast<std::intmax_t>(i.ToInt64()), KIND); } return CharacterUtils<KIND>::CHAR(i.ToUInt64()); })); @@ -106,12 +103,9 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction( static_cast<std::intmax_t>(n)); } else if (static_cast<double>(n) * str.size() > (1 << 20)) { // sanity limit of 1MiB - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingLimit)) { - context.messages().Say(common::UsageWarning::FoldingLimit, - "Result of REPEAT() is too large to compute at compilation time (%g characters)"_port_en_US, - static_cast<double>(n) * str.size()); - } + context.Warn(common::UsageWarning::FoldingLimit, + "Result of REPEAT() is too large to compute at compilation time (%g characters)"_port_en_US, + static_cast<double>(n) * str.size()); } else { return Expr<T>{Constant<T>{CharacterUtils<KIND>::REPEAT(str, n)}}; } diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp index 3eb8e1f..84066ee 100644 --- a/flang/lib/Evaluate/fold-complex.cpp +++ b/flang/lib/Evaluate/fold-complex.cpp @@ -29,9 +29,8 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction( if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) { return FoldElementalIntrinsic<T, T>( context, std::move(funcRef), *callable); - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "%s(complex(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND); } @@ -83,12 +82,21 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation( if (auto array{ApplyElementwise(context, x)}) { return *array; } - using Result = Type<TypeCategory::Complex, KIND>; + using ComplexType = Type<TypeCategory::Complex, KIND>; if (auto folded{OperandsAreConstants(x)}) { - return Expr<Result>{ - Constant<Result>{Scalar<Result>{folded->first, folded->second}}}; + using RealType = typename ComplexType::Part; + Constant<ComplexType> result{ + Scalar<ComplexType>{folded->first, folded->second}}; + if (const auto *re{UnwrapConstantValue<RealType>(x.left())}; + re && re->result().isFromInexactLiteralConversion()) { + result.result().set_isFromInexactLiteralConversion(); + } else if (const auto *im{UnwrapConstantValue<RealType>(x.right())}; + im && im->result().isFromInexactLiteralConversion()) { + result.result().set_isFromInexactLiteralConversion(); + } + return Expr<ComplexType>{std::move(result)}; } - return Expr<Result>{std::move(x)}; + return Expr<ComplexType>{std::move(x)}; } #ifdef _MSC_VER // disable bogus warning about missing definitions diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 52e954d..3fdf3a6 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -1321,8 +1321,8 @@ public: *charLength_, std::move(elements_), ConstantSubscripts{n}}}; } } else { - return Expr<T>{ - Constant<T>{std::move(elements_), ConstantSubscripts{n}}}; + return Expr<T>{Constant<T>{ + std::move(elements_), ConstantSubscripts{n}, resultInfo_}}; } } return Expr<T>{std::move(array)}; @@ -1343,6 +1343,11 @@ private: if (!knownCharLength_) { charLength_ = std::max(c->LEN(), charLength_.value_or(-1)); } + } else if constexpr (T::category == TypeCategory::Real || + T::category == TypeCategory::Complex) { + if (c->result().isFromInexactLiteralConversion()) { + resultInfo_.set_isFromInexactLiteralConversion(); + } } return true; } else { @@ -1395,6 +1400,7 @@ private: std::vector<Scalar<T>> elements_; std::optional<ConstantSubscript> charLength_; bool knownCharLength_{false}; + typename Constant<T>::Result resultInfo_; }; template <typename T> @@ -1779,7 +1785,7 @@ common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) { if (static_cast<std::uint64_t>(*iter) > 127) { return std::nullopt; } - str.push_back(*iter); + str.push_back(static_cast<typename TO::value_type>(*iter)); } return std::make_optional<TO>(std::move(str)); } @@ -1808,10 +1814,8 @@ Expr<TO> FoldOperation( if constexpr (TO::category == TypeCategory::Integer) { if constexpr (FromCat == TypeCategory::Integer) { auto converted{Scalar<TO>::ConvertSigned(*value)}; - if (converted.overflow && - msvcWorkaround.context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - ctx.messages().Say(common::UsageWarning::FoldingException, + if (converted.overflow) { + ctx.Warn(common::UsageWarning::FoldingException, "conversion of %s_%d to INTEGER(%d) overflowed; result is %s"_warn_en_US, value->SignedDecimal(), Operand::kind, TO::kind, converted.value.SignedDecimal()); @@ -1819,10 +1823,8 @@ Expr<TO> FoldOperation( return ScalarConstantToExpr(std::move(converted.value)); } else if constexpr (FromCat == TypeCategory::Unsigned) { auto converted{Scalar<TO>::ConvertUnsigned(*value)}; - if ((converted.overflow || converted.value.IsNegative()) && - msvcWorkaround.context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - ctx.messages().Say(common::UsageWarning::FoldingException, + if ((converted.overflow || converted.value.IsNegative())) { + ctx.Warn(common::UsageWarning::FoldingException, "conversion of %s_U%d to INTEGER(%d) overflowed; result is %s"_warn_en_US, value->UnsignedDecimal(), Operand::kind, TO::kind, converted.value.SignedDecimal()); @@ -1830,17 +1832,14 @@ Expr<TO> FoldOperation( return ScalarConstantToExpr(std::move(converted.value)); } else if constexpr (FromCat == TypeCategory::Real) { auto converted{value->template ToInteger<Scalar<TO>>()}; - if (msvcWorkaround.context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - if (converted.flags.test(RealFlag::InvalidArgument)) { - ctx.messages().Say(common::UsageWarning::FoldingException, - "REAL(%d) to INTEGER(%d) conversion: invalid argument"_warn_en_US, - Operand::kind, TO::kind); - } else if (converted.flags.test(RealFlag::Overflow)) { - ctx.messages().Say( - "REAL(%d) to INTEGER(%d) conversion overflowed"_warn_en_US, - Operand::kind, TO::kind); - } + if (converted.flags.test(RealFlag::InvalidArgument)) { + ctx.Warn(common::UsageWarning::FoldingException, + "REAL(%d) to INTEGER(%d) conversion: invalid argument"_warn_en_US, + Operand::kind, TO::kind); + } else if (converted.flags.test(RealFlag::Overflow)) { + ctx.Warn(common::UsageWarning::FoldingException, + "REAL(%d) to INTEGER(%d) conversion overflowed"_warn_en_US, + Operand::kind, TO::kind); } return ScalarConstantToExpr(std::move(converted.value)); } @@ -1960,10 +1959,8 @@ Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) { } else if (auto value{GetScalarConstantValue<T>(operand)}) { if constexpr (T::category == TypeCategory::Integer) { auto negated{value->Negate()}; - if (negated.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (negated.overflow) { + context.Warn(common::UsageWarning::FoldingException, "INTEGER(%d) negation overflowed"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{std::move(negated.value)}}; @@ -2004,10 +2001,8 @@ Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) { if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto sum{folded->first.AddSigned(folded->second)}; - if (sum.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (sum.overflow) { + context.Warn(common::UsageWarning::FoldingException, "INTEGER(%d) addition overflowed"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{sum.value}}; @@ -2035,10 +2030,8 @@ Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) { if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto difference{folded->first.SubtractSigned(folded->second)}; - if (difference.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (difference.overflow) { + context.Warn(common::UsageWarning::FoldingException, "INTEGER(%d) subtraction overflowed"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{difference.value}}; @@ -2066,10 +2059,8 @@ Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) { if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto product{folded->first.MultiplySigned(folded->second)}; - if (product.SignedMultiplicationOverflowed() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (product.SignedMultiplicationOverflowed()) { + context.Warn(common::UsageWarning::FoldingException, "INTEGER(%d) multiplication overflowed"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{product.lower}}; @@ -2116,28 +2107,20 @@ Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) { if constexpr (T::category == TypeCategory::Integer) { auto quotAndRem{folded->first.DivideSigned(folded->second)}; if (quotAndRem.divisionByZero) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, - "INTEGER(%d) division by zero"_warn_en_US, T::kind); - } + context.Warn(common::UsageWarning::FoldingException, + "INTEGER(%d) division by zero"_warn_en_US, T::kind); return Expr<T>{std::move(x)}; } - if (quotAndRem.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (quotAndRem.overflow) { + context.Warn(common::UsageWarning::FoldingException, "INTEGER(%d) division overflowed"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{quotAndRem.quotient}}; } else if constexpr (T::category == TypeCategory::Unsigned) { auto quotAndRem{folded->first.DivideUnsigned(folded->second)}; if (quotAndRem.divisionByZero) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, - "UNSIGNED(%d) division by zero"_warn_en_US, T::kind); - } + context.Warn(common::UsageWarning::FoldingException, + "UNSIGNED(%d) division by zero"_warn_en_US, T::kind); return Expr<T>{std::move(x)}; } return Expr<T>{Constant<T>{quotAndRem.quotient}}; @@ -2177,24 +2160,21 @@ Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) { if (auto folded{OperandsAreConstants(x)}) { if constexpr (T::category == TypeCategory::Integer) { auto power{folded->first.Power(folded->second)}; - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - if (power.divisionByZero) { - context.messages().Say(common::UsageWarning::FoldingException, - "INTEGER(%d) zero to negative power"_warn_en_US, T::kind); - } else if (power.overflow) { - context.messages().Say(common::UsageWarning::FoldingException, - "INTEGER(%d) power overflowed"_warn_en_US, T::kind); - } else if (power.zeroToZero) { - context.messages().Say(common::UsageWarning::FoldingException, - "INTEGER(%d) 0**0 is not defined"_warn_en_US, T::kind); - } + if (power.divisionByZero) { + context.Warn(common::UsageWarning::FoldingException, + "INTEGER(%d) zero to negative power"_warn_en_US, T::kind); + } else if (power.overflow) { + context.Warn(common::UsageWarning::FoldingException, + "INTEGER(%d) power overflowed"_warn_en_US, T::kind); + } else if (power.zeroToZero) { + context.Warn(common::UsageWarning::FoldingException, + "INTEGER(%d) 0**0 is not defined"_warn_en_US, T::kind); } return Expr<T>{Constant<T>{power.power}}; } else { if (folded->first.IsZero()) { if (folded->second.IsZero()) { - context.messages().Say(common::UsageWarning::FoldingException, + context.Warn(common::UsageWarning::FoldingException, "REAL/COMPLEX 0**0 is not defined"_warn_en_US); } else { return Expr<T>(Constant<T>{folded->first}); // 0. ** nonzero -> 0. @@ -2202,9 +2182,8 @@ Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) { } else if (auto callable{GetHostRuntimeWrapper<T, T, T>("pow")}) { return Expr<T>{ Constant<T>{(*callable)(context, folded->first, folded->second)}}; - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "Power for %s cannot be folded on host"_warn_en_US, T{}.AsFortran()); } @@ -2291,10 +2270,8 @@ Expr<Type<TypeCategory::Real, KIND>> ToReal( CHECK(constant); Scalar<Result> real{constant->GetScalarValue().value()}; From converted{From::ConvertUnsigned(real.RawBits()).value}; - if (original != converted && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { // C1601 - context.messages().Say(common::UsageWarning::FoldingValueChecks, + if (original != converted) { // C1601 + context.Warn(common::UsageWarning::FoldingValueChecks, "Nonzero bits truncated from BOZ literal constant in REAL intrinsic"_warn_en_US); } } else if constexpr (IsNumericCategoryExpr<From>()) { diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index 352dec4..3628497 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -38,13 +38,13 @@ static bool CheckDimArg(const std::optional<ActualArgument> &dimArg, const Expr<SomeType> &array, parser::ContextualMessages &messages, bool isLBound, std::optional<int> &dimVal) { dimVal.reset(); - if (int rank{array.Rank()}; rank > 0 || IsAssumedRank(array)) { + if (int rank{array.Rank()}; rank > 0 || semantics::IsAssumedRank(array)) { auto named{ExtractNamedEntity(array)}; if (auto dim64{ToInt64(dimArg)}) { if (*dim64 < 1) { messages.Say("DIM=%jd dimension must be positive"_err_en_US, *dim64); return false; - } else if (!IsAssumedRank(array) && *dim64 > rank) { + } else if (!semantics::IsAssumedRank(array) && *dim64 > rank) { messages.Say( "DIM=%jd dimension is out of range for rank-%d array"_err_en_US, *dim64, rank); @@ -56,7 +56,7 @@ static bool CheckDimArg(const std::optional<ActualArgument> &dimArg, "DIM=%jd dimension is out of range for rank-%d assumed-size array"_err_en_US, *dim64, rank); return false; - } else if (IsAssumedRank(array)) { + } else if (semantics::IsAssumedRank(array)) { if (*dim64 > common::maxRank) { messages.Say( "DIM=%jd dimension is too large for any array (maximum rank %d)"_err_en_US, @@ -189,7 +189,7 @@ Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context, return Expr<T>{std::move(funcRef)}; } } - if (IsAssumedRank(*array)) { + if (semantics::IsAssumedRank(*array)) { // Would like to return 1 if DIM=.. is present, but that would be // hiding a runtime error if the DIM= were too large (including // the case of an assumed-rank argument that's scalar). @@ -240,7 +240,7 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context, return Expr<T>{std::move(funcRef)}; } } - if (IsAssumedRank(*array)) { + if (semantics::IsAssumedRank(*array)) { } else if (int rank{array->Rank()}; rank > 0) { bool takeBoundsFromShape{true}; if (auto named{ExtractNamedEntity(*array)}) { @@ -350,10 +350,8 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) { CountAccumulator<T, maskKind> accumulator{arrayAndMask->array}; Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask, dim, Scalar<T>{}, accumulator)}; - if (accumulator.overflow() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (accumulator.overflow()) { + context.Warn(common::UsageWarning::FoldingException, "Result of intrinsic function COUNT overflows its result type"_warn_en_US); } return Expr<T>{std::move(result)}; @@ -965,10 +963,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( auto FromInt64{[&name, &context](std::int64_t n) { Scalar<T> result{n}; - if (result.ToInt64() != n && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (result.ToInt64() != n) { + context.Warn(common::UsageWarning::FoldingException, "Result of intrinsic function '%s' (%jd) overflows its result type"_warn_en_US, name, std::intmax_t{n}); } @@ -979,10 +975,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> { typename Scalar<T>::ValueWithOverflow j{i.ABS()}; - if (j.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (j.overflow) { + context.Warn(common::UsageWarning::FoldingException, "abs(integer(kind=%d)) folding overflowed"_warn_en_US, KIND); } return j.value; @@ -999,11 +993,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef), ScalarFunc<T, TR>([&](const Scalar<TR> &x) { auto y{x.template ToInteger<Scalar<T>>(mode)}; - if (y.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say( - common::UsageWarning::FoldingException, + if (y.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "%s intrinsic folding overflow"_warn_en_US, name); } return y.value; @@ -1029,10 +1020,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>( [&context](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto result{x.DIM(y)}; - if (result.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (result.overflow) { + context.Warn(common::UsageWarning::FoldingException, "DIM intrinsic folding overflow"_warn_en_US); } return result.value; @@ -1061,14 +1050,13 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( context.messages().Say( "Character in intrinsic function %s must have length one"_err_en_US, name); - } else if (len.value() > 1 && - context.languageFeatures().ShouldWarn( - common::UsageWarning::Portability)) { - // Do not die, this was not checked before - context.messages().Say(common::UsageWarning::Portability, - "Character in intrinsic function %s should have length one"_port_en_US, - name); } else { + // Do not die, this was not checked before + if (len.value() > 1) { + context.Warn(common::UsageWarning::Portability, + "Character in intrinsic function %s should have length one"_port_en_US, + name); + } return common::visit( [&funcRef, &context, &FromInt64](const auto &str) -> Expr<T> { using Char = typename std::decay_t<decltype(str)>::Result; @@ -1256,11 +1244,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( bool badPConst{false}; if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { *pExpr = Fold(context, std::move(*pExpr)); - if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && - pConst->IsZero() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; + pConst && pConst->IsZero()) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "MOD: P argument is zero"_warn_en_US); badPConst = true; } @@ -1270,17 +1256,12 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( [badPConst](FoldingContext &context, const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto quotRem{x.DivideSigned(y)}; - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - if (!badPConst && quotRem.divisionByZero) { - context.messages().Say( - common::UsageWarning::FoldingAvoidsRuntimeCrash, - "mod() by zero"_warn_en_US); - } else if (quotRem.overflow) { - context.messages().Say( - common::UsageWarning::FoldingAvoidsRuntimeCrash, - "mod() folding overflowed"_warn_en_US); - } + if (!badPConst && quotRem.divisionByZero) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, + "mod() by zero"_warn_en_US); + } else if (quotRem.overflow) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, + "mod() folding overflowed"_warn_en_US); } return quotRem.remainder; })); @@ -1288,11 +1269,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( bool badPConst{false}; if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { *pExpr = Fold(context, std::move(*pExpr)); - if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && - pConst->IsZero() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; + pConst && pConst->IsZero()) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "MODULO: P argument is zero"_warn_en_US); badPConst = true; } @@ -1302,10 +1281,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto result{x.MODULO(y)}; - if (!badPConst && result.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (!badPConst && result.overflow) { + context.Warn(common::UsageWarning::FoldingException, "modulo() folding overflowed"_warn_en_US); } return result.value; @@ -1405,10 +1382,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>([&context](const Scalar<T> &j, const Scalar<T> &k) -> Scalar<T> { typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)}; - if (result.overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (result.overflow) { + context.Warn(common::UsageWarning::FoldingException, "sign(integer(kind=%d)) folding overflowed"_warn_en_US, KIND); } return result.value; @@ -1465,11 +1440,11 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( auto realBytes{ context.targetCharacteristics().GetByteSize(TypeCategory::Real, context.defaults().GetDefaultKind(TypeCategory::Real))}; - if (intBytes != realBytes && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { - context.messages().Say(common::UsageWarning::FoldingValueChecks, - *context.moduleFileName(), + if (intBytes != realBytes) { + // Using the low-level API to bypass the module file check in this case. + context.messages().Warn( + /*isInModuleFile=*/false, context.languageFeatures(), + common::UsageWarning::FoldingValueChecks, *context.moduleFileName(), "NUMERIC_STORAGE_SIZE from ISO_FORTRAN_ENV is not well-defined when default INTEGER and REAL are not consistent due to compiler options"_warn_en_US); } return Expr<T>{8 * std::min(intBytes, realBytes)}; @@ -1496,11 +1471,9 @@ Expr<Type<TypeCategory::Unsigned, KIND>> FoldIntrinsicFunction( bool badPConst{false}; if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { *pExpr = Fold(context, std::move(*pExpr)); - if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && - pConst->IsZero() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; + pConst && pConst->IsZero()) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "%s: P argument is zero"_warn_en_US, name); badPConst = true; } @@ -1510,13 +1483,9 @@ Expr<Type<TypeCategory::Unsigned, KIND>> FoldIntrinsicFunction( [badPConst, &name](FoldingContext &context, const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto quotRem{x.DivideUnsigned(y)}; - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - if (!badPConst && quotRem.divisionByZero) { - context.messages().Say( - common::UsageWarning::FoldingAvoidsRuntimeCrash, - "%s() by zero"_warn_en_US, name); - } + if (!badPConst && quotRem.divisionByZero) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, + "%s() by zero"_warn_en_US, name); } return quotRem.remainder; })); diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp index 6950caf..c64f79e 100644 --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -530,13 +530,11 @@ static Expr<Type<TypeCategory::Logical, KIND>> RewriteOutOfRange( if (args.size() >= 3) { // Bounds depend on round= value if (auto *round{UnwrapExpr<Expr<SomeType>>(args[2])}) { - if (const Symbol * whole{UnwrapWholeSymbolDataRef(*round)}; - whole && semantics::IsOptional(whole->GetUltimate()) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::OptionalMustBePresent)) { + if (const Symbol *whole{UnwrapWholeSymbolDataRef(*round)}; + whole && semantics::IsOptional(whole->GetUltimate())) { if (auto source{args[2]->sourceLocation()}) { - context.messages().Say( - common::UsageWarning::OptionalMustBePresent, *source, + context.Warn(common::UsageWarning::OptionalMustBePresent, + *source, "ROUND= argument to OUT_OF_RANGE() is an optional dummy argument that must be present at execution"_warn_en_US); } } diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h index 9237d6e..ae9221f 100644 --- a/flang/lib/Evaluate/fold-matmul.h +++ b/flang/lib/Evaluate/fold-matmul.h @@ -92,10 +92,8 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { elements.push_back(sum); } } - if (overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (overflow) { + context.Warn(common::UsageWarning::FoldingException, "MATMUL of %s data overflowed during computation"_warn_en_US, T::AsFortran()); } diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 6fb5249..225e340 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -35,9 +35,8 @@ static Expr<T> FoldTransformationalBessel( } return Expr<T>{Constant<T>{ std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}}; - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US, name, T::kind); } @@ -131,10 +130,8 @@ static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context, context.targetCharacteristics().roundingMode()}; Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask, dim, identity, norm2Accumulator)}; - if (norm2Accumulator.overflow() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (norm2Accumulator.overflow()) { + context.Warn(common::UsageWarning::FoldingException, "NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND); } return Expr<T>{std::move(result)}; @@ -165,9 +162,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) { return FoldElementalIntrinsic<T, T>( context, std::move(funcRef), *callable); - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "%s(real(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND); } } else if (name == "amax0" || name == "amin0" || name == "amin1" || @@ -179,9 +175,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) { return FoldElementalIntrinsic<T, T, T>( context, std::move(funcRef), *callable); - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_warn_en_US, name, KIND, KIND); } @@ -191,9 +186,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( if (auto callable{GetHostRuntimeWrapper<T, Int4, T>(name)}) { return FoldElementalIntrinsic<T, Int4, T>( context, std::move(funcRef), *callable); - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, + } else { + context.Warn(common::UsageWarning::FoldingFailure, "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND); } @@ -210,10 +204,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, ComplexT>([&name, &context]( const Scalar<ComplexT> &z) -> Scalar<T> { ValueWithRealFlags<Scalar<T>> y{z.ABS()}; - if (y.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (y.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "complex ABS intrinsic folding overflow"_warn_en_US, name); } return y.value; @@ -234,10 +226,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T>( [&name, &context, mode](const Scalar<T> &x) -> Scalar<T> { ValueWithRealFlags<Scalar<T>> y{x.ToWholeNumber(mode)}; - if (y.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (y.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "%s intrinsic folding overflow"_warn_en_US, name); } return y.value; @@ -247,10 +237,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>([&context](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { ValueWithRealFlags<Scalar<T>> result{x.DIM(y)}; - if (result.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (result.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "DIM intrinsic folding overflow"_warn_en_US); } return result.value; @@ -282,10 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>( [&](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { ValueWithRealFlags<Scalar<T>> result{x.HYPOT(y)}; - if (result.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (result.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "HYPOT intrinsic folding overflow"_warn_en_US); } return result.value; @@ -307,11 +293,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( bool badPConst{false}; if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { *pExpr = Fold(context, std::move(*pExpr)); - if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && - pConst->IsZero() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; + pConst && pConst->IsZero()) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "MOD: P argument is zero"_warn_en_US); badPConst = true; } @@ -320,11 +304,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto result{x.MOD(y)}; - if (!badPConst && result.flags.test(RealFlag::DivideByZero) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say( - common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (!badPConst && result.flags.test(RealFlag::DivideByZero)) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "second argument to MOD must not be zero"_warn_en_US); } return result.value; @@ -334,11 +315,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( bool badPConst{false}; if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { *pExpr = Fold(context, std::move(*pExpr)); - if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && - pConst->IsZero() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; + pConst && pConst->IsZero()) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "MODULO: P argument is zero"_warn_en_US); badPConst = true; } @@ -347,11 +326,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { auto result{x.MODULO(y)}; - if (!badPConst && result.flags.test(RealFlag::DivideByZero) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingAvoidsRuntimeCrash)) { - context.messages().Say( - common::UsageWarning::FoldingAvoidsRuntimeCrash, + if (!badPConst && result.flags.test(RealFlag::DivideByZero)) { + context.Warn(common::UsageWarning::FoldingAvoidsRuntimeCrash, "second argument to MODULO must not be zero"_warn_en_US); } return result.value; @@ -363,11 +339,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( [&](const auto &sVal) { using TS = ResultType<decltype(sVal)>; bool badSConst{false}; - if (auto sConst{GetScalarConstantValue<TS>(sVal)}; sConst && - (sConst->IsZero() || sConst->IsNotANumber()) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { - context.messages().Say(common::UsageWarning::FoldingValueChecks, + if (auto sConst{GetScalarConstantValue<TS>(sVal)}; + sConst && (sConst->IsZero() || sConst->IsNotANumber())) { + context.Warn(common::UsageWarning::FoldingValueChecks, "NEAREST: S argument is %s"_warn_en_US, sConst->IsZero() ? "zero" : "NaN"); badSConst = true; @@ -375,22 +349,15 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( return FoldElementalIntrinsic<T, T, TS>(context, std::move(funcRef), ScalarFunc<T, T, TS>([&](const Scalar<T> &x, const Scalar<TS> &s) -> Scalar<T> { - if (!badSConst && (s.IsZero() || s.IsNotANumber()) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { - context.messages().Say( - common::UsageWarning::FoldingValueChecks, + if (!badSConst && (s.IsZero() || s.IsNotANumber())) { + context.Warn(common::UsageWarning::FoldingValueChecks, "NEAREST: S argument is %s"_warn_en_US, s.IsZero() ? "zero" : "NaN"); } auto result{x.NEAREST(!s.IsNegative())}; - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - if (result.flags.test(RealFlag::InvalidArgument)) { - context.messages().Say( - common::UsageWarning::FoldingException, - "NEAREST intrinsic folding: bad argument"_warn_en_US); - } + if (result.flags.test(RealFlag::InvalidArgument)) { + context.Warn(common::UsageWarning::FoldingException, + "NEAREST intrinsic folding: bad argument"_warn_en_US); } return result.value; })); @@ -427,11 +394,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( template #endif SCALE<Scalar<TBY>>(y)}; - if (result.flags.test(RealFlag::Overflow) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say( - common::UsageWarning::FoldingException, + if (result.flags.test(RealFlag::Overflow)) { + context.Warn(common::UsageWarning::FoldingException, "SCALE/IEEE_SCALB intrinsic folding overflow"_warn_en_US); } return result.value; @@ -481,12 +445,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( auto yBig{Scalar<LargestReal>::Convert(y).value}; switch (xBig.Compare(yBig)) { case Relation::Unordered: - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingValueChecks)) { - context.messages().Say( - common::UsageWarning::FoldingValueChecks, - "IEEE_NEXT_AFTER intrinsic folding: arguments are unordered"_warn_en_US); - } + context.Warn(common::UsageWarning::FoldingValueChecks, + "IEEE_NEXT_AFTER intrinsic folding: arguments are unordered"_warn_en_US); return x.NotANumber(); case Relation::Equal: break; @@ -507,12 +467,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), ScalarFunc<T, T>([&](const Scalar<T> &x) -> Scalar<T> { auto result{x.NEAREST(upward)}; - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - if (result.flags.test(RealFlag::InvalidArgument)) { - context.messages().Say(common::UsageWarning::FoldingException, - "%s intrinsic folding: argument is NaN"_warn_en_US, iName); - } + if (result.flags.test(RealFlag::InvalidArgument)) { + context.Warn(common::UsageWarning::FoldingException, + "%s intrinsic folding: argument is NaN"_warn_en_US, iName); } return result.value; })); diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index b6f2d21..fe89739 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -112,10 +112,8 @@ static Expr<T> FoldDotProduct( } } } - if (overflow && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (overflow) { + context.Warn(common::UsageWarning::FoldingException, "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US, T::AsFortran()); } @@ -334,10 +332,8 @@ static Expr<T> FoldProduct( ProductAccumulator accumulator{arrayAndMask->array}; auto result{Expr<T>{DoReduction<T>( arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; - if (accumulator.overflow() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (accumulator.overflow()) { + context.Warn(common::UsageWarning::FoldingException, "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran()); } return result; @@ -406,10 +402,8 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) { arrayAndMask->array, context.targetCharacteristics().roundingMode()}; auto result{Expr<T>{DoReduction<T>( arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}}; - if (accumulator.overflow() && - context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingException)) { - context.messages().Say(common::UsageWarning::FoldingException, + if (accumulator.overflow()) { + context.Warn(common::UsageWarning::FoldingException, "SUM() of %s data overflowed"_warn_en_US, T::AsFortran()); } return result; diff --git a/flang/lib/Evaluate/fold.cpp b/flang/lib/Evaluate/fold.cpp index 71ead1b..1fbbbba 100644 --- a/flang/lib/Evaluate/fold.cpp +++ b/flang/lib/Evaluate/fold.cpp @@ -290,11 +290,8 @@ std::optional<Expr<SomeType>> FoldTransfer( } else if (source && moldType) { if (const auto *boz{std::get_if<BOZLiteralConstant>(&source->u)}) { // TRANSFER(BOZ, MOLD=integer or real) extension - if (context.languageFeatures().ShouldWarn( - common::LanguageFeature::TransferBOZ)) { - context.messages().Say(common::LanguageFeature::TransferBOZ, - "TRANSFER(BOZ literal) is not standard"_port_en_US); - } + context.Warn(common::LanguageFeature::TransferBOZ, + "TRANSFER(BOZ literal) is not standard"_port_en_US); return Fold(context, ConvertToType(*moldType, Expr<SomeType>{*boz})); } } diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp index 121afc6..ec5dc0b 100644 --- a/flang/lib/Evaluate/formatting.cpp +++ b/flang/lib/Evaluate/formatting.cpp @@ -98,6 +98,14 @@ llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran( return o; } +template <typename RESULT, typename VALUE> +std::string ConstantBase<RESULT, VALUE>::AsFortran() const { + std::string result; + llvm::raw_string_ostream sstream(result); + AsFortran(sstream); + return result; +} + template <int KIND> llvm::raw_ostream &Constant<Type<TypeCategory::Character, KIND>>::AsFortran( llvm::raw_ostream &o) const { @@ -126,6 +134,14 @@ llvm::raw_ostream &Constant<Type<TypeCategory::Character, KIND>>::AsFortran( return o; } +template <int KIND> +std::string Constant<Type<TypeCategory::Character, KIND>>::AsFortran() const { + std::string result; + llvm::raw_string_ostream sstream(result); + AsFortran(sstream); + return result; +} + llvm::raw_ostream &EmitVar(llvm::raw_ostream &o, const Symbol &symbol, std::optional<parser::CharBlock> name = std::nullopt) { const auto &renamings{symbol.owner().context().moduleFileOutputRenamings()}; diff --git a/flang/lib/Evaluate/host.cpp b/flang/lib/Evaluate/host.cpp index 187bb2f..25409ac 100644 --- a/flang/lib/Evaluate/host.cpp +++ b/flang/lib/Evaluate/host.cpp @@ -100,13 +100,8 @@ void HostFloatingPointEnvironment::SetUpHostFloatingPointEnvironment( break; case common::RoundingMode::TiesAwayFromZero: fesetround(FE_TONEAREST); - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::FoldingFailure)) { - context.messages().Say(common::UsageWarning::FoldingFailure, - "TiesAwayFromZero rounding mode is not available when folding " - "constants" - " with host runtime; using TiesToEven instead"_warn_en_US); - } + context.Warn(common::UsageWarning::FoldingFailure, + "TiesAwayFromZero rounding mode is not available when folding constants with host runtime; using TiesToEven instead"_warn_en_US); break; } flags_.clear(); diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp index c37a7f90..abe53c3 100644 --- a/flang/lib/Evaluate/intrinsics.cpp +++ b/flang/lib/Evaluate/intrinsics.cpp @@ -666,7 +666,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {ArgFlag::canBeMoldNull, ArgFlag::onlyConstantInquiry}}}, DefaultInt, Rank::elemental, IntrinsicClass::inquiryFunction}, {"lbound", - {{"array", AnyData, Rank::anyOrAssumedRank}, RequiredDIM, + {{"array", AnyData, Rank::arrayOrAssumedRank}, RequiredDIM, SizeDefaultKIND}, KINDInt, Rank::scalar, IntrinsicClass::inquiryFunction}, {"lbound", {{"array", AnyData, Rank::arrayOrAssumedRank}, SizeDefaultKIND}, @@ -921,6 +921,10 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"back", AnyLogical, Rank::elemental, Optionality::optional}, DefaultingKIND}, KINDInt}, + {"secnds", + {{"refTime", TypePattern{RealType, KindCode::exactKind, 4}, + Rank::scalar}}, + TypePattern{RealType, KindCode::exactKind, 4}, Rank::scalar}, {"second", {}, DefaultReal, Rank::scalar}, {"selected_char_kind", {{"name", DefaultChar, Rank::scalar}}, DefaultInt, Rank::scalar, IntrinsicClass::transformationalFunction}, @@ -1034,7 +1038,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"trim", {{"string", SameCharNoLen, Rank::scalar}}, SameCharNoLen, Rank::scalar, IntrinsicClass::transformationalFunction}, {"ubound", - {{"array", AnyData, Rank::anyOrAssumedRank}, RequiredDIM, + {{"array", AnyData, Rank::arrayOrAssumedRank}, RequiredDIM, SizeDefaultKIND}, KINDInt, Rank::scalar, IntrinsicClass::inquiryFunction}, {"ubound", {{"array", AnyData, Rank::arrayOrAssumedRank}, SizeDefaultKIND}, @@ -2256,7 +2260,7 @@ std::optional<SpecificCall> IntrinsicInterface::Match( for (std::size_t j{0}; j < dummies; ++j) { const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]}; if (const ActualArgument *arg{actualForDummy[j]}) { - bool isAssumedRank{IsAssumedRank(*arg)}; + bool isAssumedRank{semantics::IsAssumedRank(*arg)}; if (isAssumedRank && d.rank != Rank::anyOrAssumedRank && d.rank != Rank::arrayOrAssumedRank) { messages.Say(arg->sourceLocation(), @@ -2617,15 +2621,12 @@ std::optional<SpecificCall> IntrinsicInterface::Match( if (const Symbol *whole{ UnwrapWholeSymbolOrComponentDataRef(actualForDummy[*dimArg])}) { if (IsOptional(*whole) || IsAllocatableOrObjectPointer(whole)) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::OptionalMustBePresent)) { - if (rank == Rank::scalarIfDim || arrayRank.value_or(-1) == 1) { - messages.Say(common::UsageWarning::OptionalMustBePresent, - "The actual argument for DIM= is optional, pointer, or allocatable, and it is assumed to be present and equal to 1 at execution time"_warn_en_US); - } else { - messages.Say(common::UsageWarning::OptionalMustBePresent, - "The actual argument for DIM= is optional, pointer, or allocatable, and may not be absent during execution; parenthesize to silence this warning"_warn_en_US); - } + if (rank == Rank::scalarIfDim || arrayRank.value_or(-1) == 1) { + context.Warn(common::UsageWarning::OptionalMustBePresent, + "The actual argument for DIM= is optional, pointer, or allocatable, and it is assumed to be present and equal to 1 at execution time"_warn_en_US); + } else { + context.Warn(common::UsageWarning::OptionalMustBePresent, + "The actual argument for DIM= is optional, pointer, or allocatable, and may not be absent during execution; parenthesize to silence this warning"_warn_en_US); } } } @@ -3002,7 +3003,7 @@ SpecificCall IntrinsicProcTable::Implementation::HandleNull( mold = nullptr; } if (mold) { - if (IsAssumedRank(*arguments[0])) { + if (semantics::IsAssumedRank(*arguments[0])) { context.messages().Say(arguments[0]->sourceLocation(), "MOLD= argument to NULL() must not be assumed-rank"_err_en_US); } @@ -3109,16 +3110,12 @@ IntrinsicProcTable::Implementation::HandleC_F_Pointer( context.messages().Say(at, "FPTR= argument to C_F_POINTER() may not have a deferred type parameter"_err_en_US); } else if (type->category() == TypeCategory::Derived) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::Interoperability) && - type->IsUnlimitedPolymorphic()) { - context.messages().Say(common::UsageWarning::Interoperability, at, + if (type->IsUnlimitedPolymorphic()) { + context.Warn(common::UsageWarning::Interoperability, at, "FPTR= argument to C_F_POINTER() should not be unlimited polymorphic"_warn_en_US); } else if (!type->GetDerivedTypeSpec().typeSymbol().attrs().test( - semantics::Attr::BIND_C) && - context.languageFeatures().ShouldWarn( - common::UsageWarning::Portability)) { - context.messages().Say(common::UsageWarning::Portability, at, + semantics::Attr::BIND_C)) { + context.Warn(common::UsageWarning::Portability, at, "FPTR= argument to C_F_POINTER() should not have a derived type that is not BIND(C)"_port_en_US); } } else if (!IsInteroperableIntrinsicType( @@ -3126,16 +3123,11 @@ IntrinsicProcTable::Implementation::HandleC_F_Pointer( .value_or(true)) { if (type->category() == TypeCategory::Character && type->kind() == 1) { - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::CharacterInteroperability)) { - context.messages().Say( - common::UsageWarning::CharacterInteroperability, at, - "FPTR= argument to C_F_POINTER() should not have the non-interoperable character length %s"_warn_en_US, - type->AsFortran()); - } - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::Interoperability)) { - context.messages().Say(common::UsageWarning::Interoperability, at, + context.Warn(common::UsageWarning::CharacterInteroperability, at, + "FPTR= argument to C_F_POINTER() should not have the non-interoperable character length %s"_warn_en_US, + type->AsFortran()); + } else { + context.Warn(common::UsageWarning::Interoperability, at, "FPTR= argument to C_F_POINTER() should not have the non-interoperable intrinsic type or kind %s"_warn_en_US, type->AsFortran()); } @@ -3274,16 +3266,11 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::HandleC_Loc( if (typeAndShape->type().category() == TypeCategory::Character && typeAndShape->type().kind() == 1) { // Default character kind, but length is not known to be 1 - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::CharacterInteroperability)) { - context.messages().Say( - common::UsageWarning::CharacterInteroperability, - arguments[0]->sourceLocation(), - "C_LOC() argument has non-interoperable character length"_warn_en_US); - } - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::Interoperability)) { - context.messages().Say(common::UsageWarning::Interoperability, + context.Warn(common::UsageWarning::CharacterInteroperability, + arguments[0]->sourceLocation(), + "C_LOC() argument has non-interoperable character length"_warn_en_US); + } else { + context.Warn(common::UsageWarning::Interoperability, arguments[0]->sourceLocation(), "C_LOC() argument has non-interoperable intrinsic type or kind"_warn_en_US); } @@ -3341,16 +3328,11 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::HandleC_Devloc( if (typeAndShape->type().category() == TypeCategory::Character && typeAndShape->type().kind() == 1) { // Default character kind, but length is not known to be 1 - if (context.languageFeatures().ShouldWarn( - common::UsageWarning::CharacterInteroperability)) { - context.messages().Say( - common::UsageWarning::CharacterInteroperability, - arguments[0]->sourceLocation(), - "C_DEVLOC() argument has non-interoperable character length"_warn_en_US); - } - } else if (context.languageFeatures().ShouldWarn( - common::UsageWarning::Interoperability)) { - context.messages().Say(common::UsageWarning::Interoperability, + context.Warn(common::UsageWarning::CharacterInteroperability, + arguments[0]->sourceLocation(), + "C_DEVLOC() argument has non-interoperable character length"_warn_en_US); + } else { + context.Warn(common::UsageWarning::Interoperability, arguments[0]->sourceLocation(), "C_DEVLOC() argument has non-interoperable intrinsic type or kind"_warn_en_US); } @@ -3673,15 +3655,10 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe( genericType.category() == TypeCategory::Real) && (newType.category() == TypeCategory::Integer || newType.category() == TypeCategory::Real))) { - if (context.languageFeatures().ShouldWarn( - common::LanguageFeature:: - UseGenericIntrinsicWhenSpecificDoesntMatch)) { - context.messages().Say( - common::LanguageFeature:: - UseGenericIntrinsicWhenSpecificDoesntMatch, - "Argument types do not match specific intrinsic '%s' requirements; using '%s' generic instead and converting the result to %s if needed"_port_en_US, - call.name, genericName, newType.AsFortran()); - } + context.Warn(common::LanguageFeature:: + UseGenericIntrinsicWhenSpecificDoesntMatch, + "Argument types do not match specific intrinsic '%s' requirements; using '%s' generic instead and converting the result to %s if needed"_port_en_US, + call.name, genericName, newType.AsFortran()); specificCall->specificIntrinsic.name = call.name; specificCall->specificIntrinsic.characteristics.value() .functionResult.value() diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp index 2c0f283..6e6b9f3 100644 --- a/flang/lib/Evaluate/real.cpp +++ b/flang/lib/Evaluate/real.cpp @@ -750,6 +750,14 @@ llvm::raw_ostream &Real<W, P>::AsFortran( return o; } +template <typename W, int P> +std::string Real<W, P>::AsFortran(int kind, bool minimal) const { + std::string result; + llvm::raw_string_ostream sstream(result); + AsFortran(sstream, kind, minimal); + return result; +} + // 16.9.180 template <typename W, int P> Real<W, P> Real<W, P>::RRSPACING() const { if (IsNotANumber()) { diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp index 776866d..07bff10 100644 --- a/flang/lib/Evaluate/shape.cpp +++ b/flang/lib/Evaluate/shape.cpp @@ -623,7 +623,7 @@ MaybeExtentExpr GetRawUpperBound( } else if (semantics::IsAssumedSizeArray(symbol) && dimension + 1 == symbol.Rank()) { return std::nullopt; - } else { + } else if (IsSafelyCopyable(base, /*admitPureCall=*/true)) { return ComputeUpperBound( GetRawLowerBound(base, dimension), GetExtent(base, dimension)); } @@ -678,9 +678,11 @@ static MaybeExtentExpr GetUBOUND(FoldingContext *context, } else if (semantics::IsAssumedSizeArray(symbol) && dimension + 1 == symbol.Rank()) { return std::nullopt; // UBOUND() folding replaces with -1 - } else if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) { - return ComputeUpperBound( - std::move(*lb), GetExtent(base, dimension, invariantOnly)); + } else if (IsSafelyCopyable(base, /*admitPureCall=*/true)) { + if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) { + return ComputeUpperBound( + std::move(*lb), GetExtent(base, dimension, invariantOnly)); + } } } } else if (const auto *assoc{ @@ -947,7 +949,7 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result { intrinsic->name == "ubound") { // For LBOUND/UBOUND, these are the array-valued cases (no DIM=) if (!call.arguments().empty() && call.arguments().front()) { - if (IsAssumedRank(*call.arguments().front())) { + if (semantics::IsAssumedRank(*call.arguments().front())) { return Shape{MaybeExtentExpr{}}; } else { return Shape{ diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index 9c059b0..1f3cbbf 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -495,7 +495,7 @@ Expr<SomeComplex> PromoteMixedComplexReal( // N.B. When a "typeless" BOZ literal constant appears as one (not both!) of // the operands to a dyadic operation where one is permitted, it assumes the // type and kind of the other operand. -template <template <typename> class OPR, bool CAN_BE_UNSIGNED> +template <template <typename> class OPR> std::optional<Expr<SomeType>> NumericOperation( parser::ContextualMessages &messages, Expr<SomeType> &&x, Expr<SomeType> &&y, int defaultRealKind) { @@ -510,13 +510,8 @@ std::optional<Expr<SomeType>> NumericOperation( std::move(rx), std::move(ry))); }, [&](Expr<SomeUnsigned> &&ix, Expr<SomeUnsigned> &&iy) { - if constexpr (CAN_BE_UNSIGNED) { - return Package(PromoteAndCombine<OPR, TypeCategory::Unsigned>( - std::move(ix), std::move(iy))); - } else { - messages.Say("Operands must not be UNSIGNED"_err_en_US); - return NoExpr(); - } + return Package(PromoteAndCombine<OPR, TypeCategory::Unsigned>( + std::move(ix), std::move(iy))); }, // Mixed REAL/INTEGER operations [](Expr<SomeReal> &&rx, Expr<SomeInteger> &&iy) { @@ -575,34 +570,31 @@ std::optional<Expr<SomeType>> NumericOperation( }, // Operations with one typeless operand [&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, + return NumericOperation<OPR>(messages, AsGenericExpr(ConvertTo(iy, std::move(bx))), std::move(y), defaultRealKind); }, [&](BOZLiteralConstant &&bx, Expr<SomeUnsigned> &&iy) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, + return NumericOperation<OPR>(messages, AsGenericExpr(ConvertTo(iy, std::move(bx))), std::move(y), defaultRealKind); }, [&](BOZLiteralConstant &&bx, Expr<SomeReal> &&ry) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, + return NumericOperation<OPR>(messages, AsGenericExpr(ConvertTo(ry, std::move(bx))), std::move(y), defaultRealKind); }, [&](Expr<SomeInteger> &&ix, BOZLiteralConstant &&by) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, - std::move(x), AsGenericExpr(ConvertTo(ix, std::move(by))), - defaultRealKind); + return NumericOperation<OPR>(messages, std::move(x), + AsGenericExpr(ConvertTo(ix, std::move(by))), defaultRealKind); }, [&](Expr<SomeUnsigned> &&ix, BOZLiteralConstant &&by) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, - std::move(x), AsGenericExpr(ConvertTo(ix, std::move(by))), - defaultRealKind); + return NumericOperation<OPR>(messages, std::move(x), + AsGenericExpr(ConvertTo(ix, std::move(by))), defaultRealKind); }, [&](Expr<SomeReal> &&rx, BOZLiteralConstant &&by) { - return NumericOperation<OPR, CAN_BE_UNSIGNED>(messages, - std::move(x), AsGenericExpr(ConvertTo(rx, std::move(by))), - defaultRealKind); + return NumericOperation<OPR>(messages, std::move(x), + AsGenericExpr(ConvertTo(rx, std::move(by))), defaultRealKind); }, // Error cases [&](Expr<SomeUnsigned> &&, auto &&) { @@ -621,7 +613,7 @@ std::optional<Expr<SomeType>> NumericOperation( std::move(x.u), std::move(y.u)); } -template std::optional<Expr<SomeType>> NumericOperation<Power, false>( +template std::optional<Expr<SomeType>> NumericOperation<Power>( parser::ContextualMessages &, Expr<SomeType> &&, Expr<SomeType> &&, int defaultRealKind); template std::optional<Expr<SomeType>> NumericOperation<Multiply>( @@ -890,29 +882,6 @@ std::optional<Expr<SomeType>> ConvertToType( } } -bool IsAssumedRank(const Symbol &original) { - if (const auto *assoc{original.detailsIf<semantics::AssocEntityDetails>()}) { - if (assoc->rank()) { - return false; // in RANK(n) or RANK(*) - } else if (assoc->IsAssumedRank()) { - return true; // RANK DEFAULT - } - } - const Symbol &symbol{semantics::ResolveAssociations(original)}; - const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()}; - return object && object->IsAssumedRank(); -} - -bool IsAssumedRank(const ActualArgument &arg) { - if (const auto *expr{arg.UnwrapExpr()}) { - return IsAssumedRank(*expr); - } else { - const Symbol *assumedTypeDummy{arg.GetAssumedTypeDummy()}; - CHECK(assumedTypeDummy); - return IsAssumedRank(*assumedTypeDummy); - } -} - int GetCorank(const ActualArgument &arg) { const auto *expr{arg.UnwrapExpr()}; return GetCorank(*expr); @@ -1129,7 +1098,7 @@ struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper, CollectCudaSymbolsHelper() : Base{*this} {} using Base::operator(); semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const { - return {symbol}; + return {symbol.GetUltimate()}; } // Overload some of the operator() to filter out the symbols that are not // of interest for CUDA data transfer logic. @@ -1203,6 +1172,15 @@ bool HasVectorSubscript(const Expr<SomeType> &expr) { return HasVectorSubscriptHelper{}(expr); } +bool HasVectorSubscript(const ActualArgument &actual) { + auto expr{actual.UnwrapExpr()}; + return expr && HasVectorSubscript(*expr); +} + +bool IsArraySection(const Expr<SomeType> &expr) { + return expr.Rank() > 0 && IsVariable(expr) && !UnwrapWholeSymbolDataRef(expr); +} + // HasConstant() struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool, /*TraverseAssocEntityDetails=*/false> { @@ -2312,9 +2290,22 @@ bool IsDummy(const Symbol &symbol) { ResolveAssociations(symbol).details()); } +bool IsAssumedRank(const Symbol &original) { + if (const auto *assoc{original.detailsIf<semantics::AssocEntityDetails>()}) { + if (assoc->rank()) { + return false; // in RANK(n) or RANK(*) + } else if (assoc->IsAssumedRank()) { + return true; // RANK DEFAULT + } + } + const Symbol &symbol{semantics::ResolveAssociations(original)}; + const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()}; + return object && object->IsAssumedRank(); +} + bool IsAssumedShape(const Symbol &symbol) { const Symbol &ultimate{ResolveAssociations(symbol)}; - const auto *object{ultimate.detailsIf<ObjectEntityDetails>()}; + const auto *object{ultimate.detailsIf<semantics::ObjectEntityDetails>()}; return object && object->IsAssumedShape() && !semantics::IsAllocatableOrObjectPointer(&ultimate); } diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp index d1bff03..b9b34d4 100644 --- a/flang/lib/Evaluate/variable.cpp +++ b/flang/lib/Evaluate/variable.cpp @@ -212,21 +212,17 @@ std::optional<Expr<SomeCharacter>> Substring::Fold(FoldingContext &context) { } if (!result) { // error cases if (*lbi < 1) { - if (context.languageFeatures().ShouldWarn(common::UsageWarning::Bounds)) { - context.messages().Say(common::UsageWarning::Bounds, - "Lower bound (%jd) on substring is less than one"_warn_en_US, - static_cast<std::intmax_t>(*lbi)); - } + context.Warn(common::UsageWarning::Bounds, + "Lower bound (%jd) on substring is less than one"_warn_en_US, + static_cast<std::intmax_t>(*lbi)); *lbi = 1; lower_ = AsExpr(Constant<SubscriptInteger>{1}); } if (length && *ubi > *length) { - if (context.languageFeatures().ShouldWarn(common::UsageWarning::Bounds)) { - context.messages().Say(common::UsageWarning::Bounds, - "Upper bound (%jd) on substring is greater than character length (%jd)"_warn_en_US, - static_cast<std::intmax_t>(*ubi), - static_cast<std::intmax_t>(*length)); - } + context.Warn(common::UsageWarning::Bounds, + "Upper bound (%jd) on substring is greater than character length (%jd)"_warn_en_US, + static_cast<std::intmax_t>(*ubi), + static_cast<std::intmax_t>(*length)); *ubi = *length; upper_ = AsExpr(Constant<SubscriptInteger>{*ubi}); } diff --git a/flang/lib/Frontend/CompilerInstance.cpp b/flang/lib/Frontend/CompilerInstance.cpp index cd8ddda..d97b4b8 100644 --- a/flang/lib/Frontend/CompilerInstance.cpp +++ b/flang/lib/Frontend/CompilerInstance.cpp @@ -253,18 +253,15 @@ getExplicitAndImplicitAMDGPUTargetFeatures(clang::DiagnosticsEngine &diags, const TargetOptions &targetOpts, const llvm::Triple triple) { llvm::StringRef cpu = targetOpts.cpu; - llvm::StringMap<bool> implicitFeaturesMap; - // Get the set of implicit target features - llvm::AMDGPU::fillAMDGPUFeatureMap(cpu, triple, implicitFeaturesMap); + llvm::StringMap<bool> FeaturesMap; // Add target features specified by the user for (auto &userFeature : targetOpts.featuresAsWritten) { std::string userKeyString = userFeature.substr(1); - implicitFeaturesMap[userKeyString] = (userFeature[0] == '+'); + FeaturesMap[userKeyString] = (userFeature[0] == '+'); } - auto HasError = - llvm::AMDGPU::insertWaveSizeFeature(cpu, triple, implicitFeaturesMap); + auto HasError = llvm::AMDGPU::fillAMDGPUFeatureMap(cpu, triple, FeaturesMap); if (HasError.first) { unsigned diagID = diags.getCustomDiagID(clang::DiagnosticsEngine::Error, "Unsupported feature ID: %0"); @@ -273,9 +270,9 @@ getExplicitAndImplicitAMDGPUTargetFeatures(clang::DiagnosticsEngine &diags, } llvm::SmallVector<std::string> featuresVec; - for (auto &implicitFeatureItem : implicitFeaturesMap) { - featuresVec.push_back((llvm::Twine(implicitFeatureItem.second ? "+" : "-") + - implicitFeatureItem.first().str()) + for (auto &FeatureItem : FeaturesMap) { + featuresVec.push_back((llvm::Twine(FeatureItem.second ? "+" : "-") + + FeatureItem.first().str()) .str()); } llvm::sort(featuresVec); diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp index 111c5aa4..6295a58 100644 --- a/flang/lib/Frontend/CompilerInvocation.cpp +++ b/flang/lib/Frontend/CompilerInvocation.cpp @@ -22,6 +22,7 @@ #include "flang/Tools/TargetSetup.h" #include "flang/Version.inc" #include "clang/Basic/DiagnosticDriver.h" +#include "clang/Basic/DiagnosticFrontend.h" #include "clang/Basic/DiagnosticOptions.h" #include "clang/Driver/CommonArgs.h" #include "clang/Driver/Driver.h" @@ -1152,6 +1153,17 @@ static bool parseDialectArgs(CompilerInvocation &res, llvm::opt::ArgList &args, diags.Report(diagID); } } + // -fcoarray + if (args.hasArg(clang::driver::options::OPT_fcoarray)) { + res.getFrontendOpts().features.Enable( + Fortran::common::LanguageFeature::Coarray); + const unsigned diagID = + diags.getCustomDiagID(clang::DiagnosticsEngine::Warning, + "Support for multi image Fortran features is " + "still experimental and in development."); + diags.Report(diagID); + } + return diags.getNumErrors() == numErrorsBefore; } @@ -1162,13 +1174,21 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args, clang::DiagnosticsEngine &diags) { llvm::opt::Arg *arg = args.getLastArg(clang::driver::options::OPT_fopenmp, clang::driver::options::OPT_fno_openmp); - if (!arg || arg->getOption().matches(clang::driver::options::OPT_fno_openmp)) - return true; + if (!arg || + arg->getOption().matches(clang::driver::options::OPT_fno_openmp)) { + bool isSimdSpecified = args.hasFlag( + clang::driver::options::OPT_fopenmp_simd, + clang::driver::options::OPT_fno_openmp_simd, /*Default=*/false); + if (!isSimdSpecified) + return true; + res.getLangOpts().OpenMPSimd = 1; + } unsigned numErrorsBefore = diags.getNumErrors(); llvm::Triple t(res.getTargetOpts().triple); constexpr unsigned newestFullySupported = 31; + constexpr unsigned latestFinalized = 60; // By default OpenMP is set to the most recent fully supported version res.getLangOpts().OpenMPVersion = newestFullySupported; res.getFrontendOpts().features.Enable( @@ -1191,12 +1211,26 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args, diags.Report(diagID) << value << arg->getAsString(args) << versions.str(); }; + auto reportFutureVersion = [&](llvm::StringRef value) { + const unsigned diagID = diags.getCustomDiagID( + clang::DiagnosticsEngine::Warning, + "The specification for OpenMP version %0 is still under development; " + "the syntax and semantics of new features may be subject to change"); + std::string buffer; + llvm::raw_string_ostream versions(buffer); + llvm::interleaveComma(ompVersions, versions); + + diags.Report(diagID) << value; + }; + llvm::StringRef value = arg->getValue(); if (!value.getAsInteger(/*radix=*/10, version)) { if (llvm::is_contained(ompVersions, version)) { res.getLangOpts().OpenMPVersion = version; - if (version > newestFullySupported) + if (version > latestFinalized) + reportFutureVersion(value); + else if (version > newestFullySupported) diags.Report(clang::diag::warn_openmp_incomplete) << version; } else if (llvm::is_contained(oldVersions, version)) { const unsigned diagID = @@ -1225,7 +1259,7 @@ static bool parseOpenMPArgs(CompilerInvocation &res, llvm::opt::ArgList &args, clang::driver::options::OPT_fopenmp_host_ir_file_path)) { res.getLangOpts().OMPHostIRFile = arg->getValue(); if (!llvm::sys::fs::exists(res.getLangOpts().OMPHostIRFile)) - diags.Report(clang::diag::err_drv_omp_host_ir_file_not_found) + diags.Report(clang::diag::err_omp_host_ir_file_not_found) << res.getLangOpts().OMPHostIRFile; } @@ -1696,6 +1730,20 @@ void CompilerInvocation::setDefaultPredefinitions() { fortranOptions.predefinitions.emplace_back("__flang_patchlevel__", FLANG_VERSION_PATCHLEVEL_STRING); + // Add predefinitions based on the relocation model + if (unsigned PICLevel = getCodeGenOpts().PICLevel) { + fortranOptions.predefinitions.emplace_back("__PIC__", + std::to_string(PICLevel)); + fortranOptions.predefinitions.emplace_back("__pic__", + std::to_string(PICLevel)); + if (getCodeGenOpts().IsPIE) { + fortranOptions.predefinitions.emplace_back("__PIE__", + std::to_string(PICLevel)); + fortranOptions.predefinitions.emplace_back("__pie__", + std::to_string(PICLevel)); + } + } + // Add predefinitions based on extensions enabled if (frontendOptions.features.IsEnabled( Fortran::common::LanguageFeature::OpenACC)) { @@ -1707,6 +1755,11 @@ void CompilerInvocation::setDefaultPredefinitions() { fortranOptions.predefinitions); } + if (frontendOptions.features.IsEnabled( + Fortran::common::LanguageFeature::CUDA)) { + fortranOptions.predefinitions.emplace_back("_CUDA", "1"); + } + llvm::Triple targetTriple{llvm::Triple(this->targetOpts.triple)}; if (targetTriple.isOSLinux()) { fortranOptions.predefinitions.emplace_back("__linux__", "1"); diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 5c66ecf..3bef6b1 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -298,6 +298,7 @@ bool CodeGenAction::beginSourceFileAction() { bool isOpenMPEnabled = ci.getInvocation().getFrontendOpts().features.IsEnabled( Fortran::common::LanguageFeature::OpenMP); + bool isOpenMPSimd = ci.getInvocation().getLangOpts().OpenMPSimd; fir::OpenMPFIRPassPipelineOpts opts; @@ -329,12 +330,13 @@ bool CodeGenAction::beginSourceFileAction() { if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>( mlirModule->getOperation())) opts.isTargetDevice = offloadMod.getIsTargetDevice(); + } - // WARNING: This pipeline must be run immediately after the lowering to - // ensure that the FIR is correct with respect to OpenMP operations/ - // attributes. + // WARNING: This pipeline must be run immediately after the lowering to + // ensure that the FIR is correct with respect to OpenMP operations/ + // attributes. + if (isOpenMPEnabled || isOpenMPSimd) fir::createOpenMPFIRPassPipeline(pm, opts); - } pm.enableVerifier(/*verifyPasses=*/true); pm.addPass(std::make_unique<Fortran::lower::VerifierPass>()); @@ -617,12 +619,14 @@ void CodeGenAction::lowerHLFIRToFIR() { pm.addPass(std::make_unique<Fortran::lower::VerifierPass>()); pm.enableVerifier(/*verifyPasses=*/true); + fir::EnableOpenMP enableOpenMP = fir::EnableOpenMP::None; + if (ci.getInvocation().getFrontendOpts().features.IsEnabled( + Fortran::common::LanguageFeature::OpenMP)) + enableOpenMP = fir::EnableOpenMP::Full; + if (ci.getInvocation().getLangOpts().OpenMPSimd) + enableOpenMP = fir::EnableOpenMP::Simd; // Create the pass pipeline - fir::createHLFIRToFIRPassPipeline( - pm, - ci.getInvocation().getFrontendOpts().features.IsEnabled( - Fortran::common::LanguageFeature::OpenMP), - level); + fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP, level); (void)mlir::applyPassManagerCLOptions(pm); mlir::TimingScope timingScopeMLIRPasses = timingScopeRoot.nest( @@ -748,6 +752,9 @@ void CodeGenAction::generateLLVMIR() { Fortran::common::LanguageFeature::OpenMP)) config.EnableOpenMP = true; + if (ci.getInvocation().getLangOpts().OpenMPSimd) + config.EnableOpenMPSimd = true; + if (ci.getInvocation().getLoweringOpts().getIntegerWrapAround()) config.NSWOnLoopVarInc = false; diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp index 219f920..444b5b6 100644 --- a/flang/lib/Lower/Allocatable.cpp +++ b/flang/lib/Lower/Allocatable.cpp @@ -13,9 +13,9 @@ #include "flang/Lower/Allocatable.h" #include "flang/Evaluate/tools.h" #include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/CUDA.h" #include "flang/Lower/ConvertType.h" #include "flang/Lower/ConvertVariable.h" -#include "flang/Lower/Cuda.h" #include "flang/Lower/IterationSpace.h" #include "flang/Lower/Mangler.h" #include "flang/Lower/OpenACC.h" @@ -445,10 +445,14 @@ private: /*mustBeHeap=*/true); } - void postAllocationAction(const Allocation &alloc) { + void postAllocationAction(const Allocation &alloc, + const fir::MutableBoxValue &box) { if (alloc.getSymbol().test(Fortran::semantics::Symbol::Flag::AccDeclare)) Fortran::lower::attachDeclarePostAllocAction(converter, builder, alloc.getSymbol()); + if (Fortran::semantics::HasCUDAComponent(alloc.getSymbol())) + Fortran::lower::initializeDeviceComponentAllocator( + converter, alloc.getSymbol(), box); } void setPinnedToFalse() { @@ -481,11 +485,21 @@ private: // Pointers must use PointerAllocate so that their deallocations // can be validated. genInlinedAllocation(alloc, box); - postAllocationAction(alloc); + postAllocationAction(alloc, box); setPinnedToFalse(); return; } + // Preserve characters' dynamic length. + if (lenParams.empty() && box.isCharacter() && + !box.hasNonDeferredLenParams()) { + auto charTy = mlir::dyn_cast<fir::CharacterType>(box.getEleTy()); + if (charTy && charTy.hasDynamicLen()) { + fir::ExtendedValue exv{box}; + lenParams.push_back(fir::factory::readCharLen(builder, loc, exv)); + } + } + // Generate a sequence of runtime calls. errorManager.genStatCheck(builder, loc); genAllocateObjectInit(box, allocatorIdx); @@ -504,7 +518,7 @@ private: genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol()); } fir::factory::syncMutableBoxFromIRBox(builder, loc, box); - postAllocationAction(alloc); + postAllocationAction(alloc, box); errorManager.assignStat(builder, loc, stat); } @@ -647,7 +661,7 @@ private: setPinnedToFalse(); } fir::factory::syncMutableBoxFromIRBox(builder, loc, box); - postAllocationAction(alloc); + postAllocationAction(alloc, box); errorManager.assignStat(builder, loc, stat); } diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 6b7efe6..c003a5b 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -13,6 +13,7 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/Allocatable.h" +#include "flang/Lower/CUDA.h" #include "flang/Lower/CallInterface.h" #include "flang/Lower/Coarray.h" #include "flang/Lower/ConvertCall.h" @@ -20,7 +21,6 @@ #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertType.h" #include "flang/Lower/ConvertVariable.h" -#include "flang/Lower/Cuda.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/HostAssociations.h" #include "flang/Lower/IO.h" @@ -475,7 +475,9 @@ public: fir::runtime::genMain(*builder, toLocation(), bridge.getEnvironmentDefaults(), getFoldingContext().languageFeatures().IsEnabled( - Fortran::common::LanguageFeature::CUDA)); + Fortran::common::LanguageFeature::CUDA), + getFoldingContext().languageFeatures().IsEnabled( + Fortran::common::LanguageFeature::Coarray)); }); finalizeOpenMPLowering(globalOmpRequiresSymbol); @@ -1400,21 +1402,23 @@ private: mlir::Value genLoopVariableAddress(mlir::Location loc, const Fortran::semantics::Symbol &sym, bool isUnordered) { - if (isUnordered || sym.has<Fortran::semantics::HostAssocDetails>() || - sym.has<Fortran::semantics::UseDetails>()) { - if (!shallowLookupSymbol(sym) && - !GetSymbolDSA(sym).test( - Fortran::semantics::Symbol::Flag::OmpShared)) { - // Do concurrent loop variables are not mapped yet since they are local - // to the Do concurrent scope (same for OpenMP loops). - mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint(); - builder->setInsertionPointToStart(builder->getAllocaBlock()); - mlir::Type tempTy = genType(sym); - mlir::Value temp = - builder->createTemporaryAlloc(loc, tempTy, toStringRef(sym.name())); - bindIfNewSymbol(sym, temp); - builder->restoreInsertionPoint(insPt); - } + if (!shallowLookupSymbol(sym) && + (isUnordered || + GetSymbolDSA(sym).test(Fortran::semantics::Symbol::Flag::OmpPrivate) || + GetSymbolDSA(sym).test( + Fortran::semantics::Symbol::Flag::OmpFirstPrivate) || + GetSymbolDSA(sym).test( + Fortran::semantics::Symbol::Flag::OmpLastPrivate) || + GetSymbolDSA(sym).test(Fortran::semantics::Symbol::Flag::OmpLinear))) { + // Do concurrent loop variables are not mapped yet since they are + // local to the Do concurrent scope (same for OpenMP loops). + mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint(); + builder->setInsertionPointToStart(builder->getAllocaBlock()); + mlir::Type tempTy = genType(sym); + mlir::Value temp = + builder->createTemporaryAlloc(loc, tempTy, toStringRef(sym.name())); + bindIfNewSymbol(sym, temp); + builder->restoreInsertionPoint(insPt); } auto entry = lookupSymbol(sym); (void)entry; @@ -2060,10 +2064,10 @@ private: // TODO Promote to using `enableDelayedPrivatization` (which is enabled by // default unlike the staging flag) once the implementation of this is more // complete. - bool useDelayedPriv = - enableDelayedPrivatizationStaging && doConcurrentLoopOp; + bool useDelayedPriv = enableDelayedPrivatization && doConcurrentLoopOp; llvm::SetVector<const Fortran::semantics::Symbol *> allPrivatizedSymbols; - llvm::SmallSet<const Fortran::semantics::Symbol *, 16> mightHaveReadHostSym; + llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 16> + mightHaveReadHostSym; for (const Fortran::semantics::Symbol *symToPrivatize : info.localSymList) { if (useDelayedPriv) { @@ -2122,6 +2126,9 @@ private: } } + if (!doConcurrentLoopOp) + return; + llvm::SmallVector<bool> reduceVarByRef; llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; llvm::SmallVector<mlir::Attribute> nestReduceAttrs; @@ -4824,7 +4831,9 @@ private: void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc, const Fortran::evaluate::Assignment &assign, - hlfir::Entity &lhs, hlfir::Entity &rhs) { + hlfir::Entity &lhs, hlfir::Entity &rhs, + bool isWholeAllocatableAssignment, + bool keepLhsLengthInAllocatableAssignment) { bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs); bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs); @@ -4889,6 +4898,28 @@ private: // host = device if (!lhsIsDevice && rhsIsDevice) { + if (Fortran::lower::isTransferWithConversion(rhs)) { + mlir::OpBuilder::InsertionGuard insertionGuard(builder); + auto elementalOp = + mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()); + assert(elementalOp && "expect elemental op"); + auto designateOp = + *elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin(); + builder.setInsertionPoint(elementalOp); + // Create a temp to transfer the rhs before applying the conversion. + hlfir::Entity entity{designateOp.getMemref()}; + auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, entity); + auto transferKindAttr = cuf::DataTransferKindAttr::get( + builder.getContext(), cuf::DataTransferKind::DeviceHost); + cuf::DataTransferOp::create(builder, loc, designateOp.getMemref(), temp, + /*shape=*/mlir::Value{}, transferKindAttr); + designateOp.getMemrefMutable().assign(temp); + builder.setInsertionPointAfter(elementalOp); + hlfir::AssignOp::create(builder, loc, elementalOp, lhs, + isWholeAllocatableAssignment, + keepLhsLengthInAllocatableAssignment); + return; + } auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceHost); cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape, @@ -4898,7 +4929,6 @@ private: // device = device if (lhsIsDevice && rhsIsDevice) { - assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal"); auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceDevice); cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape, @@ -5037,7 +5067,9 @@ private: hlfir::Entity rhs = evaluateRhs(localStmtCtx); hlfir::Entity lhs = evaluateLhs(localStmtCtx); if (isCUDATransfer && !hasCUDAImplicitTransfer) - genCUDADataTransfer(builder, loc, assign, lhs, rhs); + genCUDADataTransfer(builder, loc, assign, lhs, rhs, + isWholeAllocatableAssignment, + keepLhsLengthInAllocatableAssignment); else hlfir::AssignOp::create(builder, loc, rhs, lhs, isWholeAllocatableAssignment, diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt index 8e20abf..eb4d57d 100644 --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -15,6 +15,7 @@ add_flang_library(FortranLower ConvertProcedureDesignator.cpp ConvertType.cpp ConvertVariable.cpp + CUDA.cpp CustomIntrinsicCall.cpp HlfirIntrinsics.cpp HostAssociations.cpp @@ -59,6 +60,7 @@ add_flang_library(FortranLower FortranParser FortranEvaluate FortranSemantics + FortranUtils LINK_COMPONENTS Support diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp new file mode 100644 index 0000000..1293d2c --- /dev/null +++ b/flang/lib/Lower/CUDA.cpp @@ -0,0 +1,167 @@ +//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Lower/CUDA.h" +#include "flang/Lower/AbstractConverter.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" + +#define DEBUG_TYPE "flang-lower-cuda" + +void Fortran::lower::initializeDeviceComponentAllocator( + Fortran::lower::AbstractConverter &converter, + const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) { + if (const auto *details{ + sym.GetUltimate() + .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) { + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived() + : nullptr}; + if (derived) { + if (!FindCUDADeviceAllocatableUltimateComponent(*derived)) + return; // No device components. + + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mlir::Location loc = converter.getCurrentLocation(); + + mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType()); + + // Only pointer and allocatable needs post allocation initialization + // of components descriptors. + if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy)) + return; + + // Extract the derived type. + mlir::Type ty = fir::getDerivedType(baseTy); + auto recTy = mlir::dyn_cast<fir::RecordType>(ty); + assert(recTy && "expected fir::RecordType"); + + if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy)) + baseTy = boxTy.getEleTy(); + baseTy = fir::unwrapRefType(baseTy); + + Fortran::semantics::UltimateComponentIterator components{*derived}; + mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr()); + mlir::Value addr; + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) { + mlir::Type idxTy = builder.getIndexType(); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); + llvm::SmallVector<fir::DoLoopOp> loops; + llvm::SmallVector<mlir::Value> indices; + llvm::SmallVector<mlir::Value> extents; + for (unsigned i = 0; i < seqTy.getDimension(); ++i) { + mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); + auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, + idxTy, loadedBox, dim); + mlir::Value lbub = mlir::arith::AddIOp::create( + builder, loc, dimInfo.getResult(0), dimInfo.getResult(1)); + mlir::Value ext = + mlir::arith::SubIOp::create(builder, loc, lbub, one); + mlir::Value cmp = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero); + ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero); + extents.push_back(ext); + + auto loop = fir::DoLoopOp::create( + builder, loc, dimInfo.getResult(0), dimInfo.getResult(1), + dimInfo.getResult(2), /*isUnordered=*/true, + /*finalCount=*/false, mlir::ValueRange{}); + loops.push_back(loop); + indices.push_back(loop.getInductionVar()); + builder.setInsertionPointToStart(loop.getBody()); + } + mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox); + auto shape = fir::ShapeOp::create(builder, loc, extents); + addr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape, + /*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{}); + } else { + addr = fir::BoxAddrOp::create(builder, loc, loadedBox); + } + for (const auto &compSym : components) { + if (Fortran::semantics::IsDeviceAllocatable(compSym)) { + llvm::SmallVector<mlir::Value> coord; + mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType( + builder, loc, compSym, recTy, coord); + assert(coord.size() == 1 && "expect one coordinate"); + mlir::Value comp = fir::CoordinateOp::create( + builder, loc, builder.getRefType(fieldTy), addr, coord[0]); + cuf::DataAttributeAttr dataAttr = + Fortran::lower::translateSymbolCUFDataAttribute( + builder.getContext(), compSym); + cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr); + } + } + } + } +} + +mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType( + fir::FirOpBuilder &builder, mlir::Location loc, + const Fortran::semantics::Symbol &sym, fir::RecordType recTy, + llvm::SmallVector<mlir::Value> &coordinates) { + unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); + mlir::Type fieldTy; + if (fieldIdx != std::numeric_limits<unsigned>::max()) { + // Field found in the base record type. + auto fieldName = recTy.getTypeList()[fieldIdx].first; + fieldTy = recTy.getTypeList()[fieldIdx].second; + mlir::Value fieldIndex = fir::FieldIndexOp::create( + builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName, + recTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(fieldIndex); + } else { + // Field not found in base record type, search in potential + // record type components. + for (auto component : recTy.getTypeList()) { + if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) { + fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); + if (fieldIdx != std::numeric_limits<unsigned>::max()) { + mlir::Value parentFieldIndex = fir::FieldIndexOp::create( + builder, loc, fir::FieldType::get(childRecTy.getContext()), + component.first, recTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(parentFieldIndex); + auto fieldName = childRecTy.getTypeList()[fieldIdx].first; + fieldTy = childRecTy.getTypeList()[fieldIdx].second; + mlir::Value childFieldIndex = fir::FieldIndexOp::create( + builder, loc, fir::FieldType::get(fieldTy.getContext()), + fieldName, childRecTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(childFieldIndex); + break; + } + } + } + } + if (coordinates.empty()) + TODO(loc, "device resident component in complex derived-type hierarchy"); + return fieldTy; +} + +cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( + mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) { + std::optional<Fortran::common::CUDADataAttr> cudaAttr = + Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate()); + return cuf::getDataAttribute(mlirContext, cudaAttr); +} + +bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) { + if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp())) + if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1) + return true; + return false; +} diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index bf713f5..04dcc92 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -880,9 +880,10 @@ struct CallContext { std::optional<mlir::Type> resultType, mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, - Fortran::lower::StatementContext &stmtCtx) + Fortran::lower::StatementContext &stmtCtx, bool doCopyIn = true) : procRef{procRef}, converter{converter}, symMap{symMap}, - stmtCtx{stmtCtx}, resultType{resultType}, loc{loc} {} + stmtCtx{stmtCtx}, resultType{resultType}, loc{loc}, doCopyIn{doCopyIn} { + } fir::FirOpBuilder &getBuilder() { return converter.getFirOpBuilder(); } @@ -924,6 +925,7 @@ struct CallContext { Fortran::lower::StatementContext &stmtCtx; std::optional<mlir::Type> resultType; mlir::Location loc; + bool doCopyIn; }; using ExvAndCleanup = @@ -1161,18 +1163,6 @@ mlir::Value static getZeroLowerBounds(mlir::Location loc, return builder.genShift(loc, lowerBounds); } -static bool -isSimplyContiguous(const Fortran::evaluate::ActualArgument &arg, - Fortran::evaluate::FoldingContext &foldingContext) { - if (const auto *expr = arg.UnwrapExpr()) - return Fortran::evaluate::IsSimplyContiguous(*expr, foldingContext); - const Fortran::semantics::Symbol *sym = arg.GetAssumedTypeDummy(); - assert(sym && - "expect ActualArguments to be expression or assumed-type symbols"); - return sym->Rank() == 0 || - Fortran::evaluate::IsSimplyContiguous(*sym, foldingContext); -} - static bool isParameterObjectOrSubObject(hlfir::Entity entity) { mlir::Value base = entity; bool foundParameter = false; @@ -1204,6 +1194,10 @@ static bool isParameterObjectOrSubObject(hlfir::Entity entity) { /// fir.box_char...). /// This function should only be called with an actual that is present. /// The optional aspects must be handled by this function user. +/// +/// Note: while Fortran::lower::CallerInterface::PassedEntity (the type of arg) +/// is technically a template type, in the prepare*ActualArgument() calls +/// it resolves to Fortran::evaluate::ActualArgument * static PreparedDummyArgument preparePresentUserCallActualArgument( mlir::Location loc, fir::FirOpBuilder &builder, const Fortran::lower::PreparedActualArgument &preparedActual, @@ -1211,9 +1205,6 @@ static PreparedDummyArgument preparePresentUserCallActualArgument( const Fortran::lower::CallerInterface::PassedEntity &arg, CallContext &callContext) { - Fortran::evaluate::FoldingContext &foldingContext = - callContext.converter.getFoldingContext(); - // Step 1: get the actual argument, which includes addressing the // element if this is an array in an elemental call. hlfir::Entity actual = preparedActual.getActual(loc, builder); @@ -1254,13 +1245,20 @@ static PreparedDummyArgument preparePresentUserCallActualArgument( passingPolymorphicToNonPolymorphic && (actual.isArray() || mlir::isa<fir::BaseBoxType>(dummyType)); - // The simple contiguity of the actual is "lost" when passing a polymorphic - // to a non polymorphic entity because the dummy dynamic type matters for - // the contiguity. - const bool mustDoCopyInOut = - actual.isArray() && arg.mustBeMadeContiguous() && - (passingPolymorphicToNonPolymorphic || - !isSimplyContiguous(*arg.entity, foldingContext)); + bool mustDoCopyIn{false}; + bool mustDoCopyOut{false}; + + if (callContext.doCopyIn) { + Fortran::evaluate::FoldingContext &foldingContext{ + callContext.converter.getFoldingContext()}; + + bool suggestCopyIn = Fortran::evaluate::MayNeedCopy( + arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/false); + bool suggestCopyOut = Fortran::evaluate::MayNeedCopy( + arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/true); + mustDoCopyIn = actual.isArray() && suggestCopyIn; + mustDoCopyOut = actual.isArray() && suggestCopyOut; + } const bool actualIsAssumedRank = actual.isAssumedRank(); // Create dummy type with actual argument rank when the dummy is an assumed @@ -1370,8 +1368,14 @@ static PreparedDummyArgument preparePresentUserCallActualArgument( entity = hlfir::Entity{associate.getBase()}; // Register the temporary destruction after the call. preparedDummy.pushExprAssociateCleanUp(associate); - } else if (mustDoCopyInOut) { + } else if (mustDoCopyIn || mustDoCopyOut) { // Copy-in non contiguous variables. + // + // TODO: copy-in and copy-out are now determined separately, in order + // to allow more fine grained copying. While currently both copy-in + // and copy-out are must be done together, these copy operations could + // be separated in the future. (This is related to TODO comment below.) + // // TODO: for non-finalizable monomorphic derived type actual // arguments associated with INTENT(OUT) dummy arguments // we may avoid doing the copy and only allocate the temporary. @@ -1379,7 +1383,7 @@ static PreparedDummyArgument preparePresentUserCallActualArgument( // allocation for the temp in this case. We can communicate // this to the codegen via some CopyInOp flag. // This is a performance concern. - entity = genCopyIn(entity, arg.mayBeModifiedByCall()); + entity = genCopyIn(entity, mustDoCopyOut); } } else { const Fortran::lower::SomeExpr *expr = arg.entity->UnwrapExpr(); @@ -2966,8 +2970,11 @@ void Fortran::lower::convertUserDefinedAssignmentToHLFIR( const evaluate::ProcedureRef &procRef, hlfir::Entity lhs, hlfir::Entity rhs, Fortran::lower::SymMap &symMap) { Fortran::lower::StatementContext definedAssignmentContext; + // For defined assignment, don't use regular copy-in/copy-out mechanism: + // defined assignment generates hlfir.region_assign construct, and this + // construct automatically handles any copy-in. CallContext callContext(procRef, /*resultType=*/std::nullopt, loc, converter, - symMap, definedAssignmentContext); + symMap, definedAssignmentContext, /*doCopyIn=*/false); Fortran::lower::CallerInterface caller(procRef, converter); mlir::FunctionType callSiteType = caller.genFunctionType(); PreparedActualArgument preparedLhs{lhs, /*isPresent=*/std::nullopt}; diff --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp index 768a237..376ec12 100644 --- a/flang/lib/Lower/ConvertConstant.cpp +++ b/flang/lib/Lower/ConvertConstant.cpp @@ -145,6 +145,9 @@ private: fir::FirOpBuilder &builder, const Fortran::evaluate::Constant<Fortran::evaluate::Type<TC, KIND>> &constant) { + using Element = + Fortran::evaluate::Scalar<Fortran::evaluate::Type<TC, KIND>>; + static_assert(TC != Fortran::common::TypeCategory::Character, "must be numerical or logical"); auto attrTc = TC == Fortran::common::TypeCategory::Logical @@ -152,7 +155,24 @@ private: : TC; attributeElementType = Fortran::lower::getFIRType(builder.getContext(), attrTc, KIND, {}); - for (auto element : constant.values()) + + const std::vector<Element> &values = constant.values(); + auto sameElements = [&]() -> bool { + if (values.empty()) + return false; + + return std::all_of(values.begin(), values.end(), + [&](const auto &v) { return v == values.front(); }); + }; + + if (sameElements()) { + auto attr = convertToAttribute<TC, KIND>(builder, values.front(), + attributeElementType); + attributes.assign(values.size(), attr); + return; + } + + for (auto element : values) attributes.push_back( convertToAttribute<TC, KIND>(builder, element, attributeElementType)); } diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index 5588f62..d7f94e1 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -2750,7 +2750,7 @@ public: fir::unwrapSequenceType(fir::unwrapPassByRefType(argTy)))) TODO(loc, "passing to an OPTIONAL CONTIGUOUS derived type argument " "with length parameters"); - if (Fortran::evaluate::IsAssumedRank(*expr)) + if (Fortran::semantics::IsAssumedRank(*expr)) TODO(loc, "passing an assumed rank entity to an OPTIONAL " "CONTIGUOUS argument"); // Assumed shape VALUE are currently TODO in the call interface diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index 9930dd6..81e09a1 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -26,7 +26,6 @@ #include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/MutableBox.h" -#include "flang/Optimizer/Builder/Runtime/Character.h" #include "flang/Optimizer/Builder/Runtime/Derived.h" #include "flang/Optimizer/Builder/Runtime/Pointer.h" #include "flang/Optimizer/Builder/Todo.h" @@ -1286,16 +1285,8 @@ struct BinaryOp<Fortran::evaluate::Relational< fir::FirOpBuilder &builder, const Op &op, hlfir::Entity lhs, hlfir::Entity rhs) { - auto [lhsExv, lhsCleanUp] = - hlfir::translateToExtendedValue(loc, builder, lhs); - auto [rhsExv, rhsCleanUp] = - hlfir::translateToExtendedValue(loc, builder, rhs); - auto cmp = fir::runtime::genCharCompare( - builder, loc, translateSignedRelational(op.opr), lhsExv, rhsExv); - if (lhsCleanUp) - (*lhsCleanUp)(); - if (rhsCleanUp) - (*rhsCleanUp)(); + auto cmp = hlfir::CmpCharOp::create( + builder, loc, translateSignedRelational(op.opr), lhs, rhs); return hlfir::EntityWithAttributes{cmp}; } }; diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index a4a8a69..80af7f4 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -14,12 +14,12 @@ #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/Allocatable.h" #include "flang/Lower/BoxAnalyzer.h" +#include "flang/Lower/CUDA.h" #include "flang/Lower/CallInterface.h" #include "flang/Lower/ConvertConstant.h" #include "flang/Lower/ConvertExpr.h" #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertProcedureDesignator.h" -#include "flang/Lower/Cuda.h" #include "flang/Lower/Mangler.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" @@ -814,81 +814,24 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, baseTy = boxTy.getEleTy(); baseTy = fir::unwrapRefType(baseTy); - if (mlir::isa<fir::SequenceType>(baseTy) && - (fir::isAllocatableType(fir::getBase(exv).getType()) || - fir::isPointerType(fir::getBase(exv).getType()))) + if (fir::isAllocatableType(fir::getBase(exv).getType()) || + fir::isPointerType(fir::getBase(exv).getType())) return; // Allocator index need to be set after allocation. auto recTy = mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy)); assert(recTy && "expected fir::RecordType"); - llvm::SmallVector<mlir::Value> coordinates; Fortran::semantics::UltimateComponentIterator components{*derived}; for (const auto &sym : components) { if (Fortran::semantics::IsDeviceAllocatable(sym)) { - unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); - mlir::Type fieldTy; - llvm::SmallVector<mlir::Value> coordinates; - - if (fieldIdx != std::numeric_limits<unsigned>::max()) { - // Field found in the base record type. - auto fieldName = recTy.getTypeList()[fieldIdx].first; - fieldTy = recTy.getTypeList()[fieldIdx].second; - mlir::Value fieldIndex = fir::FieldIndexOp::create( - builder, loc, fir::FieldType::get(fieldTy.getContext()), - fieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(fieldIndex); - } else { - // Field not found in base record type, search in potential - // record type components. - for (auto component : recTy.getTypeList()) { - if (auto childRecTy = - mlir::dyn_cast<fir::RecordType>(component.second)) { - fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); - if (fieldIdx != std::numeric_limits<unsigned>::max()) { - mlir::Value parentFieldIndex = fir::FieldIndexOp::create( - builder, loc, - fir::FieldType::get(childRecTy.getContext()), - component.first, recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(parentFieldIndex); - auto fieldName = childRecTy.getTypeList()[fieldIdx].first; - fieldTy = childRecTy.getTypeList()[fieldIdx].second; - mlir::Value childFieldIndex = fir::FieldIndexOp::create( - builder, loc, fir::FieldType::get(fieldTy.getContext()), - fieldName, childRecTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(childFieldIndex); - break; - } - } - } - } - - if (coordinates.empty()) - TODO(loc, "device resident component in complex derived-type " - "hierarchy"); - + llvm::SmallVector<mlir::Value> coord; + mlir::Type fieldTy = + Fortran::lower::gatherDeviceComponentCoordinatesAndType( + builder, loc, sym, recTy, coord); mlir::Value base = fir::getBase(exv); - mlir::Value comp; - if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) { - mlir::Value box = fir::LoadOp::create(builder, loc, base); - mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box); - llvm::SmallVector<mlir::Value> lenParams; - assert(coordinates.size() == 1 && "expect one coordinate"); - auto field = mlir::dyn_cast<fir::FieldIndexOp>( - coordinates[0].getDefiningOp()); - comp = hlfir::DesignateOp::create( - builder, loc, builder.getRefType(fieldTy), addr, - /*component=*/field.getFieldName(), - /*componentShape=*/mlir::Value{}, - hlfir::DesignateOp::Subscripts{}); - } else { - comp = fir::CoordinateOp::create( - builder, loc, builder.getRefType(fieldTy), base, coordinates); - } + mlir::Value comp = fir::CoordinateOp::create( + builder, loc, builder.getRefType(fieldTy), base, coord); cuf::DataAttributeAttr dataAttr = Fortran::lower::translateSymbolCUFDataAttribute( builder.getContext(), sym); @@ -1777,7 +1720,7 @@ static bool lowerToBoxValue(const Fortran::semantics::Symbol &sym, return true; // Assumed rank and optional fir.box cannot yet be read while lowering the // specifications. - if (Fortran::evaluate::IsAssumedRank(sym) || + if (Fortran::semantics::IsAssumedRank(sym) || Fortran::semantics::IsOptional(sym)) return true; // Polymorphic entity should be tracked through a fir.box that has the @@ -1950,13 +1893,6 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes( return fir::FortranVariableFlagsAttr::get(mlirContext, flags); } -cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( - mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) { - std::optional<Fortran::common::CUDADataAttr> cudaAttr = - Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate()); - return cuf::getDataAttribute(mlirContext, cudaAttr); -} - static bool isCapturedInInternalProcedure(Fortran::lower::AbstractConverter &converter, const Fortran::semantics::Symbol &sym) { @@ -2236,7 +2172,7 @@ void Fortran::lower::mapSymbolAttributes( return; } - const bool isAssumedRank = Fortran::evaluate::IsAssumedRank(sym); + const bool isAssumedRank = Fortran::semantics::IsAssumedRank(sym); if (isAssumedRank && !allowAssumedRank) TODO(loc, "assumed-rank variable in procedure implemented in Fortran"); diff --git a/flang/lib/Lower/HlfirIntrinsics.cpp b/flang/lib/Lower/HlfirIntrinsics.cpp index 6e1d06a..b9731e9 100644 --- a/flang/lib/Lower/HlfirIntrinsics.cpp +++ b/flang/lib/Lower/HlfirIntrinsics.cpp @@ -159,6 +159,18 @@ protected: hlfir::CharExtremumPredicate pred; }; +class HlfirCharTrimLowering : public HlfirTransformationalIntrinsic { +public: + HlfirCharTrimLowering(fir::FirOpBuilder &builder, mlir::Location loc) + : HlfirTransformationalIntrinsic(builder, loc) {} + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; + class HlfirCShiftLowering : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; @@ -170,6 +182,17 @@ protected: mlir::Type stmtResultType) override; }; +class HlfirEOShiftLowering : public HlfirTransformationalIntrinsic { +public: + using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; + class HlfirReshapeLowering : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; @@ -410,6 +433,15 @@ mlir::Value HlfirCharExtremumLowering::lowerImpl( return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands}); } +mlir::Value HlfirCharTrimLowering::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + assert(operands.size() == 1); + return createOp<hlfir::CharTrimOp>(operands[0]); +} + mlir::Value HlfirCShiftLowering::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, @@ -430,6 +462,46 @@ mlir::Value HlfirCShiftLowering::lowerImpl( return createOp<hlfir::CShiftOp>(resultType, operands); } +mlir::Value HlfirEOShiftLowering::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + assert(operands.size() == 4); + mlir::Value array = operands[0]; + mlir::Value shift = operands[1]; + mlir::Value boundary = operands[2]; + mlir::Value dim = operands[3]; + // If DIM is present, then dereference it if it is a ref. + if (dim) + dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); + + mlir::Type resultType = computeResultType(array, stmtResultType); + + if (boundary && fir::isa_trivial(boundary.getType())) { + mlir::Type elementType = hlfir::getFortranElementType(resultType); + if (auto logicalTy = mlir::dyn_cast<fir::LogicalType>(elementType)) { + // Scalar logical constant boundary might be represented using i1, i2, ... + // type. We need to cast it to fir.logical type of the ARRAY/result. + if (boundary.getType() != logicalTy) + boundary = builder.createConvert(loc, logicalTy, boundary); + } else { + // When the boundary is a constant like '1u', the lowering converts + // it into a signless arith.constant value (which is a requirement + // of the Arith dialect). If the ARRAY/RESULT is also UNSIGNED, + // we have to cast the boundary to the same unsigned type. + auto resultIntTy = mlir::dyn_cast<mlir::IntegerType>(elementType); + auto boundaryIntTy = + mlir::dyn_cast<mlir::IntegerType>(boundary.getType()); + if (resultIntTy && boundaryIntTy && + resultIntTy.getSignedness() != boundaryIntTy.getSignedness()) + boundary = builder.createConvert(loc, resultIntTy, boundary); + } + } + + return createOp<hlfir::EOShiftOp>(resultType, array, shift, boundary, dim); +} + mlir::Value HlfirReshapeLowering::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, @@ -489,6 +561,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic( if (name == "cshift") return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); + if (name == "eoshift") + return HlfirEOShiftLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); if (name == "reshape") return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); @@ -501,6 +576,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic( return HlfirCharExtremumLowering{builder, loc, hlfir::CharExtremumPredicate::max} .lower(loweredActuals, argLowering, stmtResultType); + if (name == "trim") + return HlfirCharTrimLowering{builder, loc}.lower( + loweredActuals, argLowering, stmtResultType); } return std::nullopt; } diff --git a/flang/lib/Lower/HostAssociations.cpp b/flang/lib/Lower/HostAssociations.cpp index 2a330cc..ad6aba1 100644 --- a/flang/lib/Lower/HostAssociations.cpp +++ b/flang/lib/Lower/HostAssociations.cpp @@ -431,7 +431,7 @@ public: mlir::Value box = args.valueInTuple; mlir::IndexType idxTy = builder.getIndexType(); llvm::SmallVector<mlir::Value> lbounds; - if (!ba.lboundIsAllOnes() && !Fortran::evaluate::IsAssumedRank(sym)) { + if (!ba.lboundIsAllOnes() && !Fortran::semantics::IsAssumedRank(sym)) { if (ba.isStaticArray()) { for (std::int64_t lb : ba.staticLBound()) lbounds.emplace_back(builder.createIntegerConstant(loc, idxTy, lb)); @@ -490,7 +490,7 @@ private: bool isPolymorphic = type && type->IsPolymorphic(); return isScalarOrContiguous && !isPolymorphic && !isDerivedWithLenParameters(sym) && - !Fortran::evaluate::IsAssumedRank(sym); + !Fortran::semantics::IsAssumedRank(sym); } }; } // namespace diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 35edcb0..7a84b21 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1575,7 +1575,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, if (bounds.empty()) { llvm::SmallVector<mlir::Value> extents; mlir::Type idxTy = builder.getIndexType(); - for (auto extent : seqTy.getShape()) { + for (auto extent : llvm::reverse(seqTy.getShape())) { mlir::Value lb = mlir::arith::ConstantOp::create( builder, loc, idxTy, builder.getIntegerAttr(idxTy, 0)); mlir::Value ub = mlir::arith::ConstantOp::create( @@ -1607,12 +1607,11 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, } } else { // Lowerbound, upperbound and step are passed as block arguments. - [[maybe_unused]] unsigned nbRangeArgs = + unsigned nbRangeArgs = recipe.getCombinerRegion().getArguments().size() - 2; assert((nbRangeArgs / 3 == seqTy.getDimension()) && "Expect 3 block arguments per dimension"); - for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size(); - i += 3) { + for (int i = nbRangeArgs - 1; i >= 2; i -= 3) { mlir::Value lb = recipe.getCombinerRegion().getArgument(i); mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1); mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2); @@ -1623,8 +1622,11 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, ivs.push_back(loop.getInductionVar()); } } - auto addr1 = fir::CoordinateOp::create(builder, loc, refTy, value1, ivs); - auto addr2 = fir::CoordinateOp::create(builder, loc, refTy, value2, ivs); + llvm::SmallVector<mlir::Value> reversedIvs(ivs.rbegin(), ivs.rend()); + auto addr1 = + fir::CoordinateOp::create(builder, loc, refTy, value1, reversedIvs); + auto addr2 = + fir::CoordinateOp::create(builder, loc, refTy, value2, reversedIvs); auto load1 = fir::LoadOp::create(builder, loc, addr1); auto load2 = fir::LoadOp::create(builder, loc, addr2); mlir::Value res = diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ed0bff0..ff82a36 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -43,179 +43,6 @@ namespace omp { using namespace Fortran::lower::omp; } -namespace { -// An example of a type that can be used to get the return value from -// the visitor: -// visitor(type_identity<Xyz>) -> result_type -using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>; - -struct GetProc - : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *, - false> { - using Result = const evaluate::ProcedureDesignator *; - using Base = evaluate::Traverse<GetProc, Result, false>; - GetProc() : Base(*this) {} - - using Base::operator(); - - static Result Default() { return nullptr; } - - Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; } - static Result Combine(Result a, Result b) { return a != nullptr ? a : b; } -}; - -struct WithType { - WithType(const evaluate::DynamicType &t) : type(t) { - assert(type.category() != common::TypeCategory::Derived && - "Type cannot be a derived type"); - } - - template <typename VisitorTy> // - auto visit(VisitorTy &&visitor) const - -> std::invoke_result_t<VisitorTy, SomeArgType> { - switch (type.category()) { - case common::TypeCategory::Integer: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{}); - } - break; - case common::TypeCategory::Unsigned: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{}); - } - break; - case common::TypeCategory::Real: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{}); - case 3: - return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{}); - case 10: - return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{}); - } - break; - case common::TypeCategory::Complex: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{}); - case 3: - return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{}); - case 10: - return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{}); - } - break; - case common::TypeCategory::Logical: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{}); - } - break; - case common::TypeCategory::Character: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{}); - } - break; - case common::TypeCategory::Derived: - (void)Derived; - break; - } - llvm_unreachable("Unhandled type"); - } - - const evaluate::DynamicType &type; - -private: - // Shorter names. - static constexpr auto Character = common::TypeCategory::Character; - static constexpr auto Complex = common::TypeCategory::Complex; - static constexpr auto Derived = common::TypeCategory::Derived; - static constexpr auto Integer = common::TypeCategory::Integer; - static constexpr auto Logical = common::TypeCategory::Logical; - static constexpr auto Real = common::TypeCategory::Real; - static constexpr auto Unsigned = common::TypeCategory::Unsigned; -}; - -template <typename T, typename U = std::remove_const_t<T>> -U AsRvalue(T &t) { - U copy{t}; - return std::move(copy); -} - -template <typename T> -T &&AsRvalue(T &&t) { - return std::move(t); -} - -struct ArgumentReplacer - : public evaluate::Traverse<ArgumentReplacer, bool, false> { - using Base = evaluate::Traverse<ArgumentReplacer, bool, false>; - using Result = bool; - - Result Default() const { return false; } - - ArgumentReplacer(evaluate::ActualArguments &&newArgs) - : Base(*this), args_(std::move(newArgs)) {} - - using Base::operator(); - - template <typename T> - Result operator()(const evaluate::FunctionRef<T> &x) { - assert(!done_); - auto &mut = const_cast<evaluate::FunctionRef<T> &>(x); - mut.arguments() = args_; - done_ = true; - return true; - } - - Result Combine(Result &&a, Result &&b) { return a || b; } - -private: - bool done_{false}; - evaluate::ActualArguments &&args_; -}; -} // namespace - [[maybe_unused]] static void dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) { auto whatStr = [](int k) { @@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter, return nullptr; } -static bool replaceArgs(semantics::SomeExpr &expr, - evaluate::ActualArguments &&newArgs) { - return ArgumentReplacer(std::move(newArgs))(expr); -} - -static semantics::SomeExpr makeCall(const evaluate::DynamicType &type, - const evaluate::ProcedureDesignator &proc, - const evaluate::ActualArguments &args) { - return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr { - using Type = typename llvm::remove_cvref_t<decltype(s)>::type; - return evaluate::AsGenericExpr( - evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args))); - }); -} - -static const evaluate::ProcedureDesignator & -getProcedureDesignator(const semantics::SomeExpr &call) { - const evaluate::ProcedureDesignator *proc = GetProc{}(call); - assert(proc && "Call has no procedure designator"); - return *proc; -} - -static semantics::SomeExpr // -genReducedMinMax(const semantics::SomeExpr &orig, - const semantics::SomeExpr *atomArg, - const std::vector<semantics::SomeExpr> &args) { - // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] - // One of the a_i's, say a_t, must be atomArg. - // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate - // call = min/max(a_t, tmp). - // Return "call". - - // The min/max intrinsics have 2 mandatory arguments, the rest is optional. - // Make sure that the "tmp = min/max(...)" doesn't promote an optional - // argument to a non-optional position. This could happen if a_t is at - // position 0 or 1. - if (args.size() <= 2) - return orig; - - evaluate::ActualArguments nonAtoms; - - auto AsActual = [](const semantics::SomeExpr &x) { - semantics::SomeExpr copy = x; - return evaluate::ActualArgument(std::move(copy)); - }; - // Semantic checks guarantee that the "atom" shows exactly once in the - // argument list (with potential conversions around it). - // For the first two (non-optional) arguments, if "atom" is among them, - // replace it with another occurrence of the other non-optional argument. - if (atomArg == &args[0]) { - // (atom, x, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[1])); - nonAtoms.push_back(AsActual(args[1])); - } else if (atomArg == &args[1]) { - // (x, atom, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[0])); - } else { - // (x, y, z...) -> unchanged - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[1])); - } - - // The rest of arguments are optional, so we can just skip "atom". - for (size_t i = 2, e = args.size(); i != e; ++i) { - if (atomArg != &args[i]) - nonAtoms.push_back(AsActual(args[i])); - } - - // The type of the intermediate min/max is the same as the type of its - // arguments, which may be different from the type of the original - // expression. The original expression may have additional coverts. - auto tmp = - makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms); - semantics::SomeExpr call = orig; - replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)}); - return call; -} - static mlir::Operation * // genAtomicRead(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, mlir::Location loc, @@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter, auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input); assert(!args.empty() && "Update operation without arguments"); - // Pass args as an argument to avoid capturing a structured binding. - const semantics::SomeExpr *atomArg = [&](auto &args) { - for (const semantics::SomeExpr &e : args) { - if (evaluate::IsSameOrConvertOf(e, atom)) - return &e; - } - llvm_unreachable("Atomic variable not in argument list"); - }(args); - - if (opcode == evaluate::operation::Operator::Min || - opcode == evaluate::operation::Operator::Max) { - // Min and max operations are expanded inline, so reduce them to - // operations with exactly two (non-optional) arguments. - rhs = genReducedMinMax(rhs, atomArg, args); - input = *evaluate::GetConvertInput(rhs); - std::tie(opcode, args) = - evaluate::GetTopLevelOperationIgnoreResizing(input); - atomArg = nullptr; // No longer valid. - } for (auto &arg : args) { if (!evaluate::IsSameOrConvertOf(arg, atom)) { mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc)); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b98ad3c..6b9bd66 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -19,6 +19,7 @@ #include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Parser/tools.h" #include "flang/Semantics/tools.h" +#include "flang/Utils/OpenMP.h" #include "llvm/Frontend/OpenMP/OMP.h.inc" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -647,10 +648,8 @@ addAlignedClause(lower::AbstractConverter &converter, // The default alignment for some targets is equal to 0. // Do not generate alignment assumption if alignment is less than or equal to - // 0. - if (alignment > 0) { - // alignment value must be power of 2 - assert((alignment & (alignment - 1)) == 0 && "alignment is not power of 2"); + // 0 or not a power of two + if (alignment > 0 && ((alignment & (alignment - 1)) == 0)) { auto &objects = std::get<omp::ObjectList>(clause.t); if (!objects.empty()) genObjectList(objects, converter, alignedVars); @@ -1179,12 +1178,13 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { } bool ClauseProcessor::processLink( - llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const { return findRepeatableClause<omp::clause::Link>( [&](const omp::clause::Link &clause, const parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( - clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); + clause.v, mlir::omp::DeclareTargetCaptureClause::link, result, + /*automap=*/false); }); } @@ -1280,7 +1280,7 @@ void ClauseProcessor::processMapObjects( auto location = mlir::NameLoc::get( mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()), baseOp.getLoc()); - mlir::omp::MapInfoOp mapOp = createMapInfoOp( + mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp( firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds, /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, @@ -1507,26 +1507,27 @@ bool ClauseProcessor::processTaskReduction( } bool ClauseProcessor::processTo( - llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const { return findRepeatableClause<omp::clause::To>( [&](const omp::clause::To &clause, const parser::CharBlock &) { // Case: declare target to(func, var1, var2)... gatherFuncAndVarSyms(std::get<ObjectList>(clause.t), - mlir::omp::DeclareTargetCaptureClause::to, result); + mlir::omp::DeclareTargetCaptureClause::to, result, + /*automap=*/false); }); } bool ClauseProcessor::processEnter( - llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const { return findRepeatableClause<omp::clause::Enter>( [&](const omp::clause::Enter &clause, const parser::CharBlock &source) { - mlir::Location currentLocation = converter.genLocation(source); - if (std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t)) - TODO(currentLocation, "Declare target enter AUTOMAP modifier"); + bool automap = + std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t) + .has_value(); // Case: declare target enter(func, var1, var2)... gatherFuncAndVarSyms(std::get<ObjectList>(clause.t), mlir::omp::DeclareTargetCaptureClause::enter, - result); + result, automap); }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index f8a1f79..c46bdb3 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -118,7 +118,7 @@ public: bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx, mlir::omp::DependClauseOps &result) const; bool - processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; + processEnter(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const; bool processIf(omp::clause::If::DirectiveNameModifier directiveName, mlir::omp::IfClauseOps &result) const; bool processInReduction( @@ -129,7 +129,7 @@ public: llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; bool processLinear(mlir::omp::LinearClauseOps &result) const; bool - processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; + processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const; // This method is used to process a map clause. // The optional parameter mapSyms is used to store the original Fortran symbol @@ -150,7 +150,7 @@ public: bool processTaskReduction( mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; - bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; + bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const; bool processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, @@ -208,11 +208,15 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation, if (!x) return; unsigned version = semaCtx.langOptions().OpenMPVersion; - TODO(currentLocation, - "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + - " in " + - llvm::omp::getOpenMPDirectiveName(directive, version).upper() + - " construct"); + bool isSimdDirective = llvm::omp::getOpenMPDirectiveName(directive, version) + .upper() + .find("SIMD") != llvm::StringRef::npos; + if (!semaCtx.langOptions().OpenMPSimd || isSimdDirective) + TODO(currentLocation, + "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() + + " in " + + llvm::omp::getOpenMPDirectiveName(directive, version).upper() + + " construct"); }; for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it) diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 7f75aae..1a16e1c 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -396,6 +396,8 @@ makePrescriptiveness(parser::OmpPrescriptiveness::Value v) { switch (v) { case parser::OmpPrescriptiveness::Value::Strict: return clause::Prescriptiveness::Strict; + case parser::OmpPrescriptiveness::Value::Fallback: + return clause::Prescriptiveness::Fallback; } llvm_unreachable("Unexpected prescriptiveness"); } @@ -770,6 +772,27 @@ Doacross make(const parser::OmpClause::Doacross &inp, // DynamicAllocators: empty +DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp, + semantics::SemanticsContext &semaCtx) { + // imp.v -> OmpDyngroupprivateClause + CLAUSET_ENUM_CONVERT( // + convert, parser::OmpAccessGroup::Value, DynGroupprivate::AccessGroup, + // clang-format off + MS(Cgroup, Cgroup) + // clang-format on + ); + + auto &mods = semantics::OmpGetModifiers(inp.v); + auto *m0 = semantics::OmpGetUniqueModifier<parser::OmpAccessGroup>(mods); + auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpPrescriptiveness>(mods); + auto &size = std::get<parser::ScalarIntExpr>(inp.v.t); + + return DynGroupprivate{ + {/*AccessGroup=*/maybeApplyToV(convert, m0), + /*Prescriptiveness=*/maybeApplyToV(makePrescriptiveness, m1), + /*Size=*/makeExpr(size, semaCtx)}}; +} + Enter make(const parser::OmpClause::Enter &inp, semantics::SemanticsContext &semaCtx) { // inp.v -> parser::OmpEnterClause diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 67a9a46..146a252 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -30,18 +30,27 @@ #include "flang/Semantics/tools.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/Frontend/OpenMP/OMP.h" +#include <variant> namespace Fortran { namespace lower { namespace omp { bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy( const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const { - return eval.visit( - common::visitors{[&](const parser::OpenMPConstruct &functionParserNode) { - return symDefMap.count(symbol) && - symDefMap.at(symbol) == &functionParserNode; - }, - [](const auto &functionParserNode) { return false; }}); + return eval.visit(common::visitors{ + [&](const parser::OpenMPConstruct &functionParserNode) { + return symDefMap.count(symbol) && + symDefMap.at(symbol) == ConstructPtr(&functionParserNode); + }, + [](const auto &functionParserNode) { return false; }}); +} + +bool DataSharingProcessor::OMPConstructSymbolVisitor:: + isSymbolDefineByNestedDeclaration(const semantics::Symbol *symbol) const { + return symDefMap.count(symbol) && + std::holds_alternative<const parser::DeclarationConstruct *>( + symDefMap.at(symbol)); } static bool isConstructWithTopLevelTarget(lower::pft::Evaluation &eval) { @@ -81,13 +90,14 @@ DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter, isTargetPrivatization) {} void DataSharingProcessor::processStep1( - mlir::omp::PrivateClauseOps *clauseOps) { + mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir) { collectSymbolsForPrivatization(); collectDefaultSymbols(); collectImplicitSymbols(); collectPreDeterminedSymbols(); - privatize(clauseOps); + privatize(clauseOps, dir); insertBarrier(clauseOps); } @@ -414,47 +424,10 @@ static parser::CharBlock getSource(const semantics::SemanticsContext &semaCtx, }); } -static void collectPrivatizingConstructs( - llvm::SmallSet<llvm::omp::Directive, 16> &constructs, unsigned version) { - using Clause = llvm::omp::Clause; - using Directive = llvm::omp::Directive; - - static const Clause privatizingClauses[] = { - Clause::OMPC_private, - Clause::OMPC_lastprivate, - Clause::OMPC_firstprivate, - Clause::OMPC_in_reduction, - Clause::OMPC_reduction, - Clause::OMPC_linear, - // TODO: Clause::OMPC_induction, - Clause::OMPC_task_reduction, - Clause::OMPC_detach, - Clause::OMPC_use_device_ptr, - Clause::OMPC_is_device_ptr, - }; - - for (auto dir : llvm::enum_seq_inclusive<Directive>(Directive::First_, - Directive::Last_)) { - bool allowsPrivatizing = llvm::any_of(privatizingClauses, [&](Clause cls) { - return llvm::omp::isAllowedClauseForDirective(dir, cls, version); - }); - if (allowsPrivatizing) - constructs.insert(dir); - } -} - bool DataSharingProcessor::isOpenMPPrivatizingConstruct( const parser::OpenMPConstruct &omp, unsigned version) { - static llvm::SmallSet<llvm::omp::Directive, 16> privatizing; - [[maybe_unused]] static bool init = - (collectPrivatizingConstructs(privatizing, version), true); - - // As of OpenMP 6.0, privatizing constructs (with the test being if they - // allow a privatizing clause) are: dispatch, distribute, do, for, loop, - // parallel, scope, sections, simd, single, target, target_data, task, - // taskgroup, taskloop, and teams. - return llvm::is_contained(privatizing, - parser::omp::GetOmpDirectiveName(omp).v); + return llvm::omp::isPrivatizingConstruct( + parser::omp::GetOmpDirectiveName(omp).v, version); } bool DataSharingProcessor::isOpenMPPrivatizingEvaluation( @@ -550,11 +523,23 @@ void DataSharingProcessor::collectSymbols( return false; } - return sym->test(semantics::Symbol::Flag::OmpImplicit); + // Collect implicit symbols only if they are not defined by a nested + // `DeclarationConstruct`. If `sym` is not defined by the current OpenMP + // evaluation then it is defined by a block nested within the OpenMP + // construct. This, in turn, means that the private allocation for the + // symbol will be emitted as part of the nested block and there is no need + // to privatize it within the OpenMP construct. + return !visitor.isSymbolDefineByNestedDeclaration(sym) && + sym->test(semantics::Symbol::Flag::OmpImplicit); } - if (collectPreDetermined) - return sym->test(semantics::Symbol::Flag::OmpPreDetermined); + if (collectPreDetermined) { + // Similar to implicit symbols, collect pre-determined symbols only if + // they are not defined by a nested `DeclarationConstruct` + return visitor.isSymbolDefineBy(sym, eval) && + !visitor.isSymbolDefineByNestedDeclaration(sym) && + sym->test(semantics::Symbol::Flag::OmpPreDetermined); + } return !sym->test(semantics::Symbol::Flag::OmpImplicit) && !sym->test(semantics::Symbol::Flag::OmpPreDetermined); @@ -597,14 +582,15 @@ void DataSharingProcessor::collectPreDeterminedSymbols() { preDeterminedSymbols); } -void DataSharingProcessor::privatize(mlir::omp::PrivateClauseOps *clauseOps) { +void DataSharingProcessor::privatize(mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir) { for (const semantics::Symbol *sym : allPrivatizedSymbols) { if (const auto *commonDet = sym->detailsIf<semantics::CommonBlockDetails>()) { for (const auto &mem : commonDet->objects()) - privatizeSymbol(&*mem, clauseOps); + privatizeSymbol(&*mem, clauseOps, dir); } else - privatizeSymbol(sym, clauseOps); + privatizeSymbol(sym, clauseOps, dir); } } @@ -623,7 +609,8 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { void DataSharingProcessor::privatizeSymbol( const semantics::Symbol *symToPrivatize, - mlir::omp::PrivateClauseOps *clauseOps) { + mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir) { if (!useDelayedPrivatization) { cloneSymbol(symToPrivatize); copyFirstPrivateSymbol(symToPrivatize); @@ -633,7 +620,7 @@ void DataSharingProcessor::privatizeSymbol( Fortran::lower::privatizeSymbol<mlir::omp::PrivateClauseOp, mlir::omp::PrivateClauseOps>( converter, firOpBuilder, symTable, allPrivatizedSymbols, - mightHaveReadHostSym, symToPrivatize, clauseOps); + mightHaveReadHostSym, symToPrivatize, clauseOps, dir); } } // namespace omp } // namespace lower diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index 96e7fa6..f6aa865 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -19,6 +19,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include <variant> namespace mlir { namespace omp { @@ -58,20 +59,35 @@ private: } void Post(const parser::Name &name) { - auto *current = !constructs.empty() ? constructs.back() : nullptr; + auto current = !constructs.empty() ? constructs.back() : ConstructPtr(); symDefMap.try_emplace(name.symbol, current); } - llvm::SmallVector<const parser::OpenMPConstruct *> constructs; - llvm::DenseMap<semantics::Symbol *, const parser::OpenMPConstruct *> - symDefMap; + bool Pre(const parser::DeclarationConstruct &decl) { + constructs.push_back(&decl); + return true; + } + + void Post(const parser::DeclarationConstruct &decl) { + constructs.pop_back(); + } /// Given a \p symbol and an \p eval, returns true if eval is the OMP /// construct that defines symbol. bool isSymbolDefineBy(const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const; + // Given a \p symbol, returns true if it is defined by a nested + // `DeclarationConstruct`. + bool + isSymbolDefineByNestedDeclaration(const semantics::Symbol *symbol) const; + private: + using ConstructPtr = std::variant<const parser::OpenMPConstruct *, + const parser::DeclarationConstruct *>; + llvm::SmallVector<ConstructPtr> constructs; + llvm::DenseMap<semantics::Symbol *, ConstructPtr> symDefMap; + unsigned version; }; @@ -91,7 +107,7 @@ private: lower::pft::Evaluation &eval; bool shouldCollectPreDeterminedSymbols; bool useDelayedPrivatization; - llvm::SmallSet<const semantics::Symbol *, 16> mightHaveReadHostSym; + llvm::SmallPtrSet<const semantics::Symbol *, 16> mightHaveReadHostSym; lower::SymMap &symTable; bool isTargetPrivatization; OMPConstructSymbolVisitor visitor; @@ -110,7 +126,8 @@ private: void collectDefaultSymbols(); void collectImplicitSymbols(); void collectPreDeterminedSymbols(); - void privatize(mlir::omp::PrivateClauseOps *clauseOps); + void privatize(mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir = std::nullopt); void copyLastPrivatize(mlir::Operation *op); void insertLastPrivateCompare(mlir::Operation *op); void cloneSymbol(const semantics::Symbol *sym); @@ -151,7 +168,8 @@ public: // Step2 performs the copying for lastprivates and requires knowledge of the // MLIR operation to insert the last private update. Step2 adds // dealocation code as well. - void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr); + void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr, + std::optional<llvm::omp::Directive> dir = std::nullopt); void processStep2(mlir::Operation *op, bool isLoop); void pushLoopIV(mlir::Value iv) { loopIVs.push_back(iv); } @@ -168,7 +186,8 @@ public: } void privatizeSymbol(const semantics::Symbol *symToPrivatize, - mlir::omp::PrivateClauseOps *clauseOps); + mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir = std::nullopt); }; } // namespace omp diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index db6a0e2..574c322 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -34,9 +34,11 @@ #include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/tools.h" #include "flang/Support/Flags.h" #include "flang/Support/OpenMP-utils.h" +#include "flang/Utils/OpenMP.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Support/StateStack.h" @@ -46,6 +48,7 @@ using namespace Fortran::lower::omp; using namespace Fortran::common::openmp; +using namespace Fortran::utils::openmp; //===----------------------------------------------------------------------===// // Code generation helper functions @@ -406,7 +409,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, const parser::OmpClauseList *endClauseList = nullptr; common::visit( common::visitors{ - [&](const parser::OpenMPBlockConstruct &ompConstruct) { + [&](const parser::OmpBlockConstruct &ompConstruct) { beginClauseList = &ompConstruct.BeginDir().Clauses(); if (auto &endSpec = ompConstruct.EndDir()) endClauseList = &endSpec->Clauses(); @@ -533,6 +536,13 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; + case OMPD_teams_workdistribute: + cp.processThreadLimit(stmtCtx, hostInfo->ops); + [[fallthrough]]; + case OMPD_target_teams_workdistribute: + cp.processNumTeams(stmtCtx, hostInfo->ops); + break; + // Standalone 'target' case. case OMPD_target: { processSingleNestedIf( @@ -764,14 +774,14 @@ static void getDeclareTargetInfo( lower::pft::Evaluation &eval, const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, mlir::omp::DeclareTargetOperands &clauseOps, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause) { const auto &spec = std::get<parser::OmpDeclareTargetSpecifier>(declareTargetConstruct.t); if (const auto *objectList{parser::Unwrap<parser::OmpObjectList>(spec.u)}) { ObjectList objects{makeObjects(*objectList, semaCtx)}; // Case: declare target(func, var1, var2) gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, - symbolAndClause); + symbolAndClause, /*automap=*/false); } else if (const auto *clauseList{ parser::Unwrap<parser::OmpClauseList>(spec.u)}) { List<Clause> clauses = makeClauses(*clauseList, semaCtx); @@ -804,21 +814,20 @@ static void collectDeferredDeclareTargets( llvm::SmallVectorImpl<lower::OMPDeferredDeclareTargetInfo> &deferredDeclareTarget) { mlir::omp::DeclareTargetOperands clauseOps; - llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; + llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause; getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol( - converter.mangleName(std::get<const semantics::Symbol &>(symClause))); + for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(symClause.symbol)); if (!op) { - deferredDeclareTarget.push_back({std::get<0>(symClause), - clauseOps.deviceType, - std::get<1>(symClause)}); + deferredDeclareTarget.push_back({symClause.clause, clauseOps.deviceType, + symClause.automap, symClause.symbol}); } } } @@ -829,16 +838,16 @@ getDeclareTargetFunctionDevice( lower::pft::Evaluation &eval, const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { mlir::omp::DeclareTargetOperands clauseOps; - llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; + llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause; getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol( - converter.mangleName(std::get<const semantics::Symbol &>(symClause))); + for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(symClause.symbol)); if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op)) return clauseOps.deviceType; @@ -1055,7 +1064,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, static void markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, mlir::omp::DeclareTargetCaptureClause captureClause, - mlir::omp::DeclareTargetDeviceType deviceType) { + mlir::omp::DeclareTargetDeviceType deviceType, bool automap) { // TODO: Add support for program local variables with declare target applied auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op); if (!declareTargetOp) @@ -1070,11 +1079,11 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, if (declareTargetOp.isDeclareTarget()) { if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, - captureClause); + captureClause, automap); return; } - declareTargetOp.setDeclareTarget(deviceType, captureClause); + declareTargetOp.setDeclareTarget(deviceType, captureClause, automap); } //===----------------------------------------------------------------------===// @@ -2262,7 +2271,8 @@ genOrderedOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - TODO(loc, "OMPD_ordered"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "OMPD_ordered"); return nullptr; } @@ -2449,7 +2459,8 @@ genScopeOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - TODO(loc, "Scope construct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "Scope construct"); return nullptr; } @@ -2818,6 +2829,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +static mlir::omp::WorkdistributeOp genWorkdistributeOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + return genOpWithBody<mlir::omp::WorkdistributeOp>( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_workdistribute), + queue, item); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3235,7 +3257,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd( DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, /*useDelayedPrivatization=*/true, symTable); - simdItemDSP.processStep1(&simdClauseOps); + simdItemDSP.processStep1(&simdClauseOps, simdItem->id); // Pass the innermost leaf construct's clauses because that's where COLLAPSE // is placed by construct decomposition. @@ -3276,7 +3298,8 @@ static mlir::omp::TaskloopOp genCompositeTaskloopSimd( lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs"); - TODO(loc, "Composite TASKLOOP SIMD"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "Composite TASKLOOP SIMD"); return nullptr; } @@ -3448,13 +3471,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_tile: { unsigned version = semaCtx.langOptions().OpenMPVersion; - TODO(loc, "Unhandled loop directive (" + - llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "Unhandled loop directive (" + + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + break; } case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - // case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_workdistribute: + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, + item); + break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); @@ -3484,35 +3512,40 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclarativeAllocate &declarativeAllocate) { - TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclarativeAssumes &assumesConstruct) { - TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMP ASSUMES declaration"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OmpDeclareVariantDirective &declareVariantDirective) { - TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); } static void genOMP( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareSimdConstruct &declareSimdConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); } static void @@ -3563,14 +3596,14 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { mlir::omp::DeclareTargetOperands clauseOps; - llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; + llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, clauseOps, symbolAndClause); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol( - converter.mangleName(std::get<const semantics::Symbol &>(symClause))); + for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(symClause.symbol)); // Some symbols are deferred until later in the module, these are handled // upon finalization of the module for OpenMP inside of Bridge, so we simply @@ -3578,16 +3611,21 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, if (!op) continue; - markDeclareTarget( - op, converter, - std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), - clauseOps.deviceType); + markDeclareTarget(op, converter, symClause.clause, clauseOps.deviceType, + symClause.automap); } } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + const parser::OpenMPGroupprivate &directive) { + TODO(converter.getCurrentLocation(), "GROUPPRIVATE"); +} + +static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, const parser::OpenMPRequiresConstruct &requiresConstruct) { // Requires directives are gathered and processed in semantics and // then combined in the lowering bridge before triggering codegen @@ -3708,14 +3746,16 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, (void)objects; (void)clauses; - TODO(converter.getCurrentLocation(), "OpenMPDepobjConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPDepobjConstruct"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPInteropConstruct &interopConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPInteropConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPInteropConstruct"); } static void @@ -3731,7 +3771,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPAllocatorsConstruct &allocsConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct"); } //===----------------------------------------------------------------------===// @@ -3748,7 +3789,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - const parser::OpenMPBlockConstruct &blockConstruct) { + const parser::OmpBlockConstruct &blockConstruct) { const parser::OmpDirectiveSpecification &beginSpec = blockConstruct.BeginDir(); List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx); @@ -3797,7 +3838,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, !std::holds_alternative<clause::Detach>(clause.u)) { std::string name = parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id)); - TODO(clauseLocation, name + " clause is not implemented yet"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(clauseLocation, name + " clause is not implemented yet"); } } @@ -3813,46 +3855,61 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, const parser::OpenMPAssumeConstruct &assumeConstruct) { mlir::Location clauseLocation = converter.genLocation(assumeConstruct.source); - TODO(clauseLocation, "OpenMP ASSUME construct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(clauseLocation, "OpenMP ASSUME construct"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPCriticalConstruct &criticalConstruct) { - const auto &cd = std::get<parser::OmpCriticalDirective>(criticalConstruct.t); - List<Clause> clauses = - makeClauses(std::get<parser::OmpClauseList>(cd.t), semaCtx); + const parser::OmpDirectiveSpecification &beginSpec = + criticalConstruct.BeginDir(); + List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx); ConstructQueue queue{buildConstructQueue( - converter.getFirOpBuilder().getModule(), semaCtx, eval, cd.source, + converter.getFirOpBuilder().getModule(), semaCtx, eval, beginSpec.source, llvm::omp::Directive::OMPD_critical, clauses)}; - const auto &name = std::get<std::optional<parser::Name>>(cd.t); + std::optional<parser::Name> critName; + const parser::OmpArgumentList &args = beginSpec.Arguments(); + if (!args.v.empty()) { + // All of these things should be guaranteed to exist after semantic checks. + auto *object = parser::Unwrap<parser::OmpObject>(args.v.front()); + assert(object && "Expecting object as argument"); + auto *designator = semantics::omp::GetDesignatorFromObj(*object); + assert(designator && "Expecting desginator in argument"); + auto *name = semantics::getDesignatorNameIfDataRef(*designator); + assert(name && "Expecting dataref in designator"); + critName = *name; + } mlir::Location currentLocation = converter.getCurrentLocation(); genCriticalOp(converter, symTable, semaCtx, eval, currentLocation, queue, - queue.begin(), name); + queue.begin(), critName); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPUtilityConstruct &) { - TODO(converter.getCurrentLocation(), "OpenMPUtilityConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPUtilityConstruct"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDispatchConstruct &) { - TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPDispatchConstruct"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPExecutableAllocate &execAllocConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); + if (!semaCtx.langOptions().OpenMPSimd) + TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, @@ -3924,9 +3981,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, List<Clause> clauses = makeClauses( std::get<parser::OmpClauseList>(beginSectionsDirective.t), semaCtx); const auto &endSectionsDirective = - std::get<parser::OmpEndSectionsDirective>(sectionsConstruct.t); + std::get<std::optional<parser::OmpEndSectionsDirective>>( + sectionsConstruct.t); + assert(endSectionsDirective && + "Missing end section directive should have been handled in semantics"); clauses.append(makeClauses( - std::get<parser::OmpClauseList>(endSectionsDirective.t), semaCtx)); + std::get<parser::OmpClauseList>(endSectionsDirective->t), semaCtx)); mlir::Location currentLocation = converter.getCurrentLocation(); llvm::omp::Directive directive = @@ -4090,7 +4150,7 @@ void Fortran::lower::genDeclareTargetIntGlobal( bool Fortran::lower::isOpenMPTargetConstruct( const parser::OpenMPConstruct &omp) { llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown; - if (const auto *block = std::get_if<parser::OpenMPBlockConstruct>(&omp.u)) { + if (const auto *block = std::get_if<parser::OmpBlockConstruct>(&omp.u)) { dir = block->BeginDir().DirId(); } else if (const auto *loop = std::get_if<parser::OpenMPLoopConstruct>(&omp.u)) { @@ -4164,7 +4224,7 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions( deviceCodeFound = true; markDeclareTarget(op, converter, declTar.declareTargetCaptureClause, - devType); + devType, declTar.automap); } return deviceCodeFound; diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 13fda97..cb6dd57 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -24,6 +24,7 @@ #include <flang/Parser/parse-tree.h> #include <flang/Parser/tools.h> #include <flang/Semantics/tools.h> +#include <flang/Utils/OpenMP.h> #include <llvm/Support/CommandLine.h> #include <iterator> @@ -102,41 +103,10 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval) { void gatherFuncAndVarSyms( const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause, + bool automap) { for (const Object &object : objects) - symbolAndClause.emplace_back(clause, *object.sym()); -} - -mlir::omp::MapInfoOp -createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value baseAddr, mlir::Value varPtrPtr, - llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, - llvm::ArrayRef<mlir::Value> members, - mlir::ArrayAttr membersIndex, uint64_t mapType, - mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap, mlir::FlatSymbolRefAttr mapperId) { - if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { - baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr); - retTy = baseAddr.getType(); - } - - mlir::TypeAttr varType = mlir::TypeAttr::get( - llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); - - // For types with unknown extents such as <2x?xi32> we discard the incomplete - // type info and only retain the base type. The correct dimensions are later - // recovered through the bounds info. - if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) - if (seqType.hasDynamicExtents()) - varType = mlir::TypeAttr::get(seqType.getEleTy()); - - mlir::omp::MapInfoOp op = mlir::omp::MapInfoOp::create( - builder, loc, retTy, baseAddr, varType, - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), - builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), - varPtrPtr, members, membersIndex, bounds, mapperId, - builder.getStringAttr(name), builder.getBoolAttr(partialMap)); - return op; + symbolAndClause.emplace_back(clause, *object.sym(), automap); } // This function gathers the individual omp::Object's that make up a @@ -402,7 +372,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( // Create a map for the intermediate member and insert it and it's // indices into the parentMemberIndices list to track it. - mlir::omp::MapInfoOp mapOp = createMapInfoOp( + mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp( firOpBuilder, clauseLocation, curValue, /*varPtrPtr=*/mlir::Value{}, asFortran, /*bounds=*/interimBounds, @@ -562,7 +532,7 @@ void insertChildMapInfoIntoParent( converter.getCurrentLocation(), asFortran, bounds, treatIndexAsSection); - mlir::omp::MapInfoOp mapOp = createMapInfoOp( + mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp( firOpBuilder, info.rawInput.getLoc(), info.rawInput, /*varPtrPtr=*/mlir::Value(), asFortran.str(), bounds, members, firOpBuilder.create2DI64ArrayAttr( diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 11641ba..88371ab 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -42,8 +42,15 @@ class AbstractConverter; namespace omp { -using DeclareTargetCapturePair = - std::pair<mlir::omp::DeclareTargetCaptureClause, const semantics::Symbol &>; +struct DeclareTargetCaptureInfo { + mlir::omp::DeclareTargetCaptureClause clause; + bool automap = false; + const semantics::Symbol &symbol; + + DeclareTargetCaptureInfo(mlir::omp::DeclareTargetCaptureClause c, + const semantics::Symbol &s, bool a = false) + : clause(c), automap(a), symbol(s) {} +}; // A small helper structure for keeping track of a component members MapInfoOp // and index data when lowering OpenMP map clauses. Keeps track of the @@ -107,16 +114,6 @@ struct OmpMapParentAndMemberData { semantics::SemanticsContext &semaCtx); }; -mlir::omp::MapInfoOp -createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value baseAddr, mlir::Value varPtrPtr, - llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, - llvm::ArrayRef<mlir::Value> members, - mlir::ArrayAttr membersIndex, uint64_t mapType, - mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap = false, - mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()); - void insertChildMapInfoIntoParent( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, @@ -150,7 +147,8 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval); void gatherFuncAndVarSyms( const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause); + llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause, + bool automap = false); int64_t getCollapseValue(const List<Clause> &clauses); diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp index a28cc01..80f31c2 100644 --- a/flang/lib/Lower/PFTBuilder.cpp +++ b/flang/lib/Lower/PFTBuilder.cpp @@ -1742,11 +1742,11 @@ private: layeredVarList[i].end()); } - llvm::SmallSet<const semantics::Symbol *, 32> seen; + llvm::SmallPtrSet<const semantics::Symbol *, 32> seen; std::vector<Fortran::lower::pft::VariableList> layeredVarList; - llvm::SmallSet<const semantics::Symbol *, 32> aliasSyms; + llvm::SmallPtrSet<const semantics::Symbol *, 32> aliasSyms; /// Set of scopes that have been analyzed for aliases. - llvm::SmallSet<const semantics::Scope *, 4> analyzedScopes; + llvm::SmallPtrSet<const semantics::Scope *, 4> analyzedScopes; std::vector<Fortran::lower::pft::Variable::AggregateStore> stores; }; } // namespace diff --git a/flang/lib/Lower/Runtime.cpp b/flang/lib/Lower/Runtime.cpp index fc59a24..494dd49 100644 --- a/flang/lib/Lower/Runtime.cpp +++ b/flang/lib/Lower/Runtime.cpp @@ -39,8 +39,7 @@ static void genUnreachable(fir::FirOpBuilder &builder, mlir::Location loc) { if (parentOp->getDialect()->getNamespace() == mlir::omp::OpenMPDialect::getDialectNamespace()) Fortran::lower::genOpenMPTerminator(builder, parentOp, loc); - else if (parentOp->getDialect()->getNamespace() == - mlir::acc::OpenACCDialect::getDialectNamespace()) + else if (Fortran::lower::isInsideOpenACCComputeConstruct(builder)) Fortran::lower::genOpenACCTerminator(builder, parentOp, loc); else fir::UnreachableOp::create(builder, loc); diff --git a/flang/lib/Lower/Support/PrivateReductionUtils.cpp b/flang/lib/Lower/Support/PrivateReductionUtils.cpp index fff060b..1b09801 100644 --- a/flang/lib/Lower/Support/PrivateReductionUtils.cpp +++ b/flang/lib/Lower/Support/PrivateReductionUtils.cpp @@ -616,6 +616,8 @@ void PopulateInitAndCleanupRegionsHelper::populateByRefInitAndCleanupRegions() { assert(sym && "Symbol information is required to privatize derived types"); assert(!scalarInitValue && "ScalarInitvalue is unused for privatization"); } + if (hlfir::Entity{moldArg}.isAssumedRank()) + TODO(loc, "Privatization of assumed rank variable"); mlir::Type valTy = fir::unwrapRefType(argType); if (fir::isa_trivial(valTy)) { diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp index 881401e..1b4d37e 100644 --- a/flang/lib/Lower/Support/Utils.cpp +++ b/flang/lib/Lower/Support/Utils.cpp @@ -654,8 +654,9 @@ void privatizeSymbol( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, - llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, - const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps) { + llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, + const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps, + std::optional<llvm::omp::Directive> dir) { constexpr bool isDoConcurrent = std::is_same_v<OpType, fir::LocalitySpecifierOp>; mlir::OpBuilder::InsertPoint dcIP; @@ -676,6 +677,13 @@ void privatizeSymbol( bool emitCopyRegion = symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) || symToPrivatize->test(semantics::Symbol::Flag::LocalityLocalInit); + // A symbol attached to the simd directive can have the firstprivate flag set + // on it when it is also used in a non-firstprivate privatization clause. + // For instance: $omp do simd lastprivate(a) firstprivate(a) + // We cannot apply the firstprivate privatizer to simd, so make sure we do + // not emit the copy region when dealing with the SIMD directive. + if (dir && dir == llvm::omp::Directive::OMPD_simd) + emitCopyRegion = false; mlir::Value privVal = hsb.getAddr(); mlir::Type allocType = privVal.getType(); @@ -846,17 +854,19 @@ privatizeSymbol<mlir::omp::PrivateClauseOp, mlir::omp::PrivateClauseOps>( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, - llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, + llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, const semantics::Symbol *symToPrivatize, - mlir::omp::PrivateClauseOps *clauseOps); + mlir::omp::PrivateClauseOps *clauseOps, + std::optional<llvm::omp::Directive> dir); template void privatizeSymbol<fir::LocalitySpecifierOp, fir::LocalitySpecifierOperands>( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector<const semantics::Symbol *> &allPrivatizedSymbols, - llvm::SmallSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, + llvm::SmallPtrSet<const semantics::Symbol *, 16> &mightHaveReadHostSym, const semantics::Symbol *symToPrivatize, - fir::LocalitySpecifierOperands *clauseOps); + fir::LocalitySpecifierOperands *clauseOps, + std::optional<llvm::omp::Directive> dir); } // end namespace Fortran::lower diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt index 31ae395..404afd1 100644 --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -16,6 +16,7 @@ add_flang_library(FIRBuilder Runtime/Allocatable.cpp Runtime/ArrayConstructor.cpp Runtime/Assign.cpp + Runtime/Coarray.cpp Runtime/Character.cpp Runtime/Command.cpp Runtime/CUDA/Descriptor.cpp @@ -49,6 +50,7 @@ add_flang_library(FIRBuilder FIRDialectSupport FIRSupport FortranEvaluate + FortranSupport HLFIRDialect MLIR_DEPS diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 87a52ff..b6501fd 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -147,8 +147,20 @@ mlir::Value fir::FirOpBuilder::createIntegerConstant(mlir::Location loc, assert((cst >= 0 || mlir::isa<mlir::IndexType>(ty) || mlir::cast<mlir::IntegerType>(ty).getWidth() <= 64) && "must use APint"); - return mlir::arith::ConstantOp::create(*this, loc, ty, - getIntegerAttr(ty, cst)); + + mlir::Type cstType = ty; + if (auto intType = mlir::dyn_cast<mlir::IntegerType>(ty)) { + // Signed and unsigned constants must be encoded as signless + // arith.constant followed by fir.convert cast. + if (intType.isUnsigned()) + cstType = mlir::IntegerType::get(getContext(), intType.getWidth()); + else if (intType.isSigned()) + TODO(loc, "signed integer constant"); + } + + mlir::Value cstValue = mlir::arith::ConstantOp::create( + *this, loc, cstType, getIntegerAttr(cstType, cst)); + return createConvert(loc, ty, cstValue); } mlir::Value fir::FirOpBuilder::createAllOnesInteger(mlir::Location loc, @@ -411,10 +423,11 @@ mlir::Value fir::FirOpBuilder::genTempDeclareOp( llvm::ArrayRef<mlir::Value> typeParams, fir::FortranVariableFlagsAttr fortranAttrs) { auto nameAttr = mlir::StringAttr::get(builder.getContext(), name); - return fir::DeclareOp::create(builder, loc, memref.getType(), memref, shape, - typeParams, - /*dummy_scope=*/nullptr, nameAttr, fortranAttrs, - cuf::DataAttributeAttr{}); + return fir::DeclareOp::create( + builder, loc, memref.getType(), memref, shape, typeParams, + /*dummy_scope=*/nullptr, + /*storage=*/nullptr, + /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{}); } mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) { @@ -1947,17 +1960,17 @@ void fir::factory::genDimInfoFromBox( mlir::Value fir::factory::genLifetimeStart(mlir::OpBuilder &builder, mlir::Location loc, - fir::AllocaOp alloc, int64_t size, + fir::AllocaOp alloc, const mlir::DataLayout *dl) { mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get( alloc.getContext(), getAllocaAddressSpace(dl)); mlir::Value cast = fir::ConvertOp::create(builder, loc, ptrTy, alloc.getResult()); - mlir::LLVM::LifetimeStartOp::create(builder, loc, size, cast); + mlir::LLVM::LifetimeStartOp::create(builder, loc, cast); return cast; } void fir::factory::genLifetimeEnd(mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value cast, int64_t size) { - mlir::LLVM::LifetimeEndOp::create(builder, loc, size, cast); + mlir::Value cast) { + mlir::LLVM::LifetimeEndOp::create(builder, loc, cast); } diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index b6d692a..086dd66 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -416,7 +416,10 @@ hlfir::Entity hlfir::loadTrivialScalar(mlir::Location loc, entity = derefPointersAndAllocatables(loc, builder, entity); if (entity.isVariable() && entity.isScalar() && fir::isa_trivial(entity.getFortranElementType())) { - return Entity{fir::LoadOp::create(builder, loc, entity)}; + // Optional entities may be represented with !fir.box<i32/f32/...>. + // We need to take the data pointer before loading the scalar. + mlir::Value base = genVariableRawAddress(loc, builder, entity); + return Entity{fir::LoadOp::create(builder, loc, base)}; } return entity; } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index bfa470d..e1c9520 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -25,6 +25,7 @@ #include "flang/Optimizer/Builder/Runtime/Allocatable.h" #include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" #include "flang/Optimizer/Builder/Runtime/Character.h" +#include "flang/Optimizer/Builder/Runtime/Coarray.h" #include "flang/Optimizer/Builder/Runtime/Command.h" #include "flang/Optimizer/Builder/Runtime/Derived.h" #include "flang/Optimizer/Builder/Runtime/Exceptions.h" @@ -137,7 +138,7 @@ static const char __ldlu_r8x2[] = "__ldlu_r8x2_"; /// Table that drives the fir generation depending on the intrinsic or intrinsic /// module procedure one to one mapping with Fortran arguments. If no mapping is /// defined here for a generic intrinsic, genRuntimeCall will be called -/// to look for a match in the runtime a emit a call. Note that the argument +/// to look for a match in the runtime and emit a call. Note that the argument /// lowering rules for an intrinsic need to be provided only if at least one /// argument must not be lowered by value. In which case, the lowering rules /// should be provided for all the intrinsic arguments for completeness. @@ -778,6 +779,10 @@ static constexpr IntrinsicHandler handlers[]{ /*isElemental=*/false}, {"not", &I::genNot}, {"null", &I::genNull, {{{"mold", asInquired}}}, /*isElemental=*/false}, + {"num_images", + &I::genNumImages, + {{{"team", asAddr}, {"team_number", asAddr}}}, + /*isElemental*/ false}, {"pack", &I::genPack, {{{"array", asBox}, @@ -864,6 +869,10 @@ static constexpr IntrinsicHandler handlers[]{ {"back", asValue, handleDynamicOptional}, {"kind", asValue}}}, /*isElemental=*/true}, + {"secnds", + &I::genSecnds, + {{{"refTime", asAddr}}}, + /*isElemental=*/false}, {"second", &I::genSecond, {{{"time", asAddr}}}, @@ -947,6 +956,12 @@ static constexpr IntrinsicHandler handlers[]{ {"tand", &I::genTand}, {"tanpi", &I::genTanpi}, {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false}, + {"this_image", + &I::genThisImage, + {{{"coarray", asBox}, + {"dim", asAddr}, + {"team", asBox, handleDynamicOptional}}}, + /*isElemental=*/false}, {"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false}, {"this_warp", &I::genThisWarp, {}, /*isElemental=*/false}, {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false}, @@ -1047,7 +1062,7 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc, llvm::StringRef suffix, mlir::FunctionType funcType) { std::string output = prefix.str(); llvm::raw_string_ostream sstream(output); - if (name == "pow") { + if (name == "pow" || name == "pow-unsigned") { assert(funcType.getNumInputs() == 2 && "power operator has two arguments"); std::string displayName{" ** "}; sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc, @@ -1276,6 +1291,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } +mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, + const MathOperation &mathOp, + mlir::FunctionType mathLibFuncType, + llvm::ArrayRef<mlir::Value> args) { + bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); + if (!isAMDGPU) + return genLibCall(builder, loc, mathOp, mathLibFuncType, args); + + auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, args[1]); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + mlir::Value complexExp = + builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); + mlir::Value result = + builder.create<mlir::complex::PowOp>(loc, args[0], complexExp); + result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); + return result; +} + /// Mapping between mathematical intrinsic operations and MLIR operations /// of some appropriate dialect (math, complex, etc.) or libm calls. /// TODO: support remaining Fortran math intrinsics. @@ -1625,17 +1660,29 @@ static constexpr MathOperation mathOperations[] = { genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, genMathOp<mlir::math::FPowIOp>}, {"pow", RTNAME_STRING(cpowi), - genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall}, + genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, + genComplexPow}, {"pow", RTNAME_STRING(zpowi), - genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall}, + genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, + genComplexPow}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, genLibF128Call}, {"pow", RTNAME_STRING(cpowk), - genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall}, + genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, + genComplexPow}, {"pow", RTNAME_STRING(zpowk), - genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall}, + genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, + genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, genLibF128Call}, + {"pow-unsigned", RTNAME_STRING(UPow1), + genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow2), + genFuncType<Ty::Integer<2>, Ty::Integer<2>, Ty::Integer<2>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow4), + genFuncType<Ty::Integer<4>, Ty::Integer<4>, Ty::Integer<4>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow8), + genFuncType<Ty::Integer<8>, Ty::Integer<8>, Ty::Integer<8>>, genLibCall}, {"remainder", "remainderf", genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Real<4>>, genLibCall}, {"remainder", "remainder", @@ -2672,10 +2719,11 @@ mlir::Value IntrinsicLibrary::genAcosd(mlir::Type resultType, mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); mlir::Value result = getRuntimeCallGenerator("acos", ftype)(builder, loc, {args[0]}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, llvm::APFloat(fltSem, "180.0") / pi); return mlir::arith::MulFOp::create(builder, loc, result, factor); } @@ -2687,10 +2735,10 @@ mlir::Value IntrinsicLibrary::genAcospi(mlir::Type resultType, mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); mlir::Value acos = getRuntimeCallGenerator("acos", ftype)(builder, loc, args); - llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi); - mlir::Value factor = builder.createConvert(loc, resultType, dfactor); + llvm::APFloat inv_pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::inv_pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi); return mlir::arith::MulFOp::create(builder, loc, acos, factor); } @@ -2840,10 +2888,11 @@ mlir::Value IntrinsicLibrary::genAsind(mlir::Type resultType, mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); mlir::Value result = getRuntimeCallGenerator("asin", ftype)(builder, loc, {args[0]}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, llvm::APFloat(fltSem, "180.0") / pi); return mlir::arith::MulFOp::create(builder, loc, result, factor); } @@ -2855,10 +2904,10 @@ mlir::Value IntrinsicLibrary::genAsinpi(mlir::Type resultType, mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); mlir::Value asin = getRuntimeCallGenerator("asin", ftype)(builder, loc, args); - llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi); - mlir::Value factor = builder.createConvert(loc, resultType, dfactor); + llvm::APFloat inv_pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::inv_pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi); return mlir::arith::MulFOp::create(builder, loc, asin, factor); } @@ -2880,10 +2929,11 @@ mlir::Value IntrinsicLibrary::genAtand(mlir::Type resultType, mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); atan = getRuntimeCallGenerator("atan", ftype)(builder, loc, args); } - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), llvm::APFloat(180.0) / pi); - mlir::Value factor = builder.createConvert(loc, resultType, dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, llvm::APFloat(fltSem, "180.0") / pi); return mlir::arith::MulFOp::create(builder, loc, atan, factor); } @@ -2905,10 +2955,10 @@ mlir::Value IntrinsicLibrary::genAtanpi(mlir::Type resultType, mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); atan = getRuntimeCallGenerator("atan", ftype)(builder, loc, args); } - llvm::APFloat inv_pi = llvm::APFloat(llvm::numbers::inv_pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), inv_pi); - mlir::Value factor = builder.createConvert(loc, resultType, dfactor); + llvm::APFloat inv_pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::inv_pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, inv_pi); return mlir::arith::MulFOp::create(builder, loc, atan, factor); } @@ -3669,10 +3719,11 @@ mlir::Value IntrinsicLibrary::genCosd(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg}); } @@ -3684,10 +3735,10 @@ mlir::Value IntrinsicLibrary::genCospi(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), pi); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + llvm::APFloat pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, pi); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg}); } @@ -4031,21 +4082,20 @@ void IntrinsicLibrary::genExecuteCommandLine( mlir::Value waitAddr = fir::getBase(wait); mlir::Value waitIsPresentAtRuntime = builder.genIsNotNullAddr(loc, waitAddr); - waitBool = builder - .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, - /*withElseRegion=*/true) - .genThen([&]() { - auto waitLoad = - fir::LoadOp::create(builder, loc, waitAddr); - mlir::Value cast = - builder.createConvert(loc, i1Ty, waitLoad); - fir::ResultOp::create(builder, loc, cast); - }) - .genElse([&]() { - mlir::Value trueVal = builder.createBool(loc, true); - fir::ResultOp::create(builder, loc, trueVal); - }) - .getResults()[0]; + waitBool = + builder + .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, + /*withElseRegion=*/true) + .genThen([&]() { + auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr); + mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad); + fir::ResultOp::create(builder, loc, cast); + }) + .genElse([&]() { + mlir::Value trueVal = builder.createBool(loc, true); + fir::ResultOp::create(builder, loc, trueVal); + }) + .getResults()[0]; } mlir::Value exitstatBox = @@ -7277,6 +7327,19 @@ IntrinsicLibrary::genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) { return fir::MutableBoxValue(boxStorage, mold->nonDeferredLenParams(), {}); } +// NUM_IMAGES +fir::ExtendedValue +IntrinsicLibrary::genNumImages(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() == 0 || args.size() == 1); + + if (args.size()) + return fir::runtime::getNumImagesWithTeam(builder, loc, + fir::getBase(args[0])); + return fir::runtime::getNumImages(builder, loc); +} + // CLOCK, CLOCK64, GLOBALTIMER template <typename OpTy> mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType, @@ -7813,6 +7876,22 @@ IntrinsicLibrary::genScan(mlir::Type resultType, return readAndAddCleanUp(resultMutableBox, resultType, "SCAN"); } +// SECNDS +fir::ExtendedValue +IntrinsicLibrary::genSecnds(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1 && "SECNDS expects one argument"); + + mlir::Value refTime = fir::getBase(args[0]); + + if (!refTime) + fir::emitFatalError(loc, "expected REFERENCE TIME parameter"); + + mlir::Value result = fir::runtime::genSecnds(builder, loc, refTime); + + return builder.createConvert(loc, resultType, result); +} + // SECOND fir::ExtendedValue IntrinsicLibrary::genSecond(std::optional<mlir::Type> resultType, @@ -8121,10 +8200,11 @@ mlir::Value IntrinsicLibrary::genSind(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg}); } @@ -8136,10 +8216,10 @@ mlir::Value IntrinsicLibrary::genSinpi(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), pi); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + llvm::APFloat pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, pi); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg}); } @@ -8218,10 +8298,11 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg}); } @@ -8233,10 +8314,10 @@ mlir::Value IntrinsicLibrary::genTanpi(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = - builder.createRealConstant(loc, mlir::Float64Type::get(context), pi); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + llvm::APFloat pi = + llvm::APFloat(llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(), + llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant(loc, resultType, pi); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg}); } @@ -8327,6 +8408,27 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType, return res; } +// THIS_IMAGE +fir::ExtendedValue +IntrinsicLibrary::genThisImage(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() >= 1 && args.size() <= 3); + const bool coarrayIsAbsent = args.size() == 1; + mlir::Value team = + !isStaticallyAbsent(args, args.size() - 1) + ? fir::getBase(args[args.size() - 1]) + : builder + .create<fir::AbsentOp>(loc, + fir::BoxType::get(builder.getNoneType())) + .getResult(); + + if (!coarrayIsAbsent) + TODO(loc, "this_image with coarray argument."); + mlir::Value res = fir::runtime::getThisImage(builder, loc, team); + return builder.createConvert(loc, resultType, res); +} + // THIS_THREAD_BLOCK mlir::Value IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType, @@ -9347,6 +9449,14 @@ mlir::Value genPow(fir::FirOpBuilder &builder, mlir::Location loc, // implementation and mark it 'strictfp'. // Another option is to implement it in Fortran runtime library // (just like matmul). + if (type.isUnsignedInteger()) { + assert(x.getType().isUnsignedInteger() && y.getType().isUnsignedInteger() && + "unsigned pow requires unsigned arguments"); + return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow-unsigned", type, + {x, y}); + } + assert(!x.getType().isUnsignedInteger() && !y.getType().isUnsignedInteger() && + "non-unsigned pow requires non-unsigned arguments"); return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow", type, {x, y}); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp new file mode 100644 index 0000000..fb72fc2 --- /dev/null +++ b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp @@ -0,0 +1,86 @@ +//===-- Coarray.cpp -- runtime API for coarray intrinsics -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Runtime/Coarray.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace Fortran::runtime; +using namespace Fortran::semantics; + +/// Generate Call to runtime prif_init +mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder, + mlir::Location loc) { + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value result = builder.createTemporary(loc, i32Ty); + mlir::FunctionType ftype = PRIF_FUNCTYPE(builder.getRefType(i32Ty)); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, PRIFNAME_SUB("init"), ftype); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, result); + builder.create<fir::CallOp>(loc, funcOp, args); + return builder.create<fir::LoadOp>(loc, result); +} + +/// Generate Call to runtime prif_num_images +mlir::Value fir::runtime::getNumImages(fir::FirOpBuilder &builder, + mlir::Location loc) { + mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); + mlir::FunctionType ftype = + PRIF_FUNCTYPE(builder.getRefType(builder.getI32Type())); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, PRIFNAME_SUB("num_images"), ftype); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, result); + builder.create<fir::CallOp>(loc, funcOp, args); + return builder.create<fir::LoadOp>(loc, result); +} + +/// Generate Call to runtime prif_num_images_with_{team|team_number} +mlir::Value fir::runtime::getNumImagesWithTeam(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value team) { + bool isTeamNumber = fir::unwrapPassByRefType(team.getType()).isInteger(); + std::string numImagesName = isTeamNumber + ? PRIFNAME_SUB("num_images_with_team_number") + : PRIFNAME_SUB("num_images_with_team"); + + mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); + mlir::Type refTy = builder.getRefType(builder.getI32Type()); + mlir::FunctionType ftype = + isTeamNumber + ? PRIF_FUNCTYPE(builder.getRefType(builder.getI64Type()), refTy) + : PRIF_FUNCTYPE(fir::BoxType::get(builder.getNoneType()), refTy); + mlir::func::FuncOp funcOp = builder.createFunction(loc, numImagesName, ftype); + + if (!isTeamNumber) + team = builder.createBox(loc, team); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, team, result); + builder.create<fir::CallOp>(loc, funcOp, args); + return builder.create<fir::LoadOp>(loc, result); +} + +/// Generate Call to runtime prif_this_image_no_coarray +mlir::Value fir::runtime::getThisImage(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value team) { + mlir::Type refTy = builder.getRefType(builder.getI32Type()); + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = PRIF_FUNCTYPE(boxTy, refTy); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, PRIFNAME_SUB("this_image_no_coarray"), ftype); + + mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); + mlir::Value teamArg = + !team ? builder.create<fir::AbsentOp>(loc, boxTy) : team; + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, teamArg, result); + builder.create<fir::CallOp>(loc, funcOp, args); + return builder.create<fir::LoadOp>(loc, result); +} diff --git a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp index ee15157..dc61903 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp @@ -276,6 +276,23 @@ void fir::runtime::genRename(fir::FirOpBuilder &builder, mlir::Location loc, fir::CallOp::create(builder, loc, runtimeFunc, args); } +mlir::Value fir::runtime::genSecnds(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value refTime) { + auto runtimeFunc = + fir::runtime::getRuntimeFunc<mkRTKey(Secnds)>(loc, builder); + + mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType(); + + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, runtimeFuncTy.getInput(2)); + + llvm::SmallVector<mlir::Value> args = {refTime, sourceFile, sourceLine}; + args = fir::runtime::createArguments(builder, loc, runtimeFuncTy, args); + + return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0); +} + /// generate runtime call to time intrinsic mlir::Value fir::runtime::genTime(fir::FirOpBuilder &builder, mlir::Location loc) { diff --git a/flang/lib/Optimizer/Builder/Runtime/Main.cpp b/flang/lib/Optimizer/Builder/Runtime/Main.cpp index d35f687..d303e0a 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Main.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Main.cpp @@ -10,6 +10,7 @@ #include "flang/Lower/EnvironmentDefault.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/Coarray.h" #include "flang/Optimizer/Builder/Runtime/EnvironmentDefaults.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/Dialect/FIROps.h" @@ -23,8 +24,8 @@ using namespace Fortran::runtime; /// Create a `int main(...)` that calls the Fortran entry point void fir::runtime::genMain( fir::FirOpBuilder &builder, mlir::Location loc, - const std::vector<Fortran::lower::EnvironmentDefault> &defs, - bool initCuda) { + const std::vector<Fortran::lower::EnvironmentDefault> &defs, bool initCuda, + bool initCoarrayEnv) { auto *context = builder.getContext(); auto argcTy = builder.getDefaultIntegerType(); auto ptrTy = mlir::LLVM::LLVMPointerType::get(context); @@ -69,6 +70,8 @@ void fir::runtime::genMain( loc, RTNAME_STRING(CUFInit), mlir::FunctionType::get(context, {}, {})); fir::CallOp::create(builder, loc, initFn); } + if (initCoarrayEnv) + fir::runtime::genInitCoarray(builder, loc); fir::CallOp::create(builder, loc, qqMainFn); fir::CallOp::create(builder, loc, stopFn); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 1b289ae..76f3cbd 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -87,14 +87,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) { return mlir::IntegerType::get(context, 8); } -static mlir::LLVM::ConstantOp -genConstantIndex(mlir::Location loc, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter, - std::int64_t offset) { - auto cattr = rewriter.getI64IntegerAttr(offset); - return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr); -} - static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter, mlir::Block *insertBefore) { assert(insertBefore && "expected valid insertion block"); @@ -208,39 +200,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op, TODO(op.getLoc(), "did not find allocation function"); } -// Compute the alloc scale size (constant factors encoded in the array type). -// We do this for arrays without a constant interior or arrays of character with -// dynamic length arrays, since those are the only ones that get decayed to a -// pointer to the element type. -template <typename OP> -static mlir::Value -genAllocationScaleSize(OP op, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter) { - mlir::Location loc = op.getLoc(); - mlir::Type dataTy = op.getInType(); - auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy); - fir::SequenceType::Extent constSize = 1; - if (seqTy) { - int constRows = seqTy.getConstantRows(); - const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); - if (constRows != static_cast<int>(shape.size())) { - for (auto extent : shape) { - if (constRows-- > 0) - continue; - if (extent != fir::SequenceType::getUnknownExtent()) - constSize *= extent; - } - } - } - - if (constSize != 1) { - mlir::Value constVal{ - genConstantIndex(loc, ity, rewriter, constSize).getResult()}; - return constVal; - } - return nullptr; -} - namespace { struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> { public: @@ -275,7 +234,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> { auto loc = alloc.getLoc(); mlir::Type ity = lowerTy().indexType(); unsigned i = 0; - mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult(); + mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult(); mlir::Type firObjType = fir::unwrapRefType(alloc.getType()); mlir::Type llvmObjectType = convertObjectType(firObjType); if (alloc.hasLenParams()) { @@ -307,7 +266,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> { << scalarType << " with type parameters"; } } - if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter)) + if (auto scaleSize = fir::genAllocationScaleSize( + alloc.getLoc(), alloc.getInType(), ity, rewriter)) size = rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); if (alloc.hasShapeOperands()) { @@ -484,7 +444,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> { auto loc = boxisarray.getLoc(); TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType()); mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter); - mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0); + mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0); rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0); return mlir::success(); @@ -820,7 +780,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { // Do folding for constant inputs. if (auto constVal = fir::getIntIfConstant(op0)) { mlir::Value normVal = - genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); + fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); rewriter.replaceOp(convert, normVal); return mlir::success(); } @@ -833,7 +793,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { } // Compare the input with zero. - mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0); + mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0); auto isTrue = mlir::LLVM::ICmpOp::create( rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero); @@ -1082,21 +1042,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op, return getMallocInModule(mod, op, rewriter, indexType); } -/// Helper function for generating the LLVM IR that computes the distance -/// in bytes between adjacent elements pointed to by a pointer -/// of type \p ptrTy. The result is returned as a value of \p idxTy integer -/// type. -static mlir::Value -computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, - mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - const mlir::DataLayout &dataLayout) { - llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); - unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); - std::int64_t distance = llvm::alignTo(size, alignment); - return genConstantIndex(loc, idxTy, rewriter, distance); -} - /// Return value of the stride in bytes between adjacent elements /// of LLVM type \p llTy. The result is returned as a value of /// \p idxTy integer type. @@ -1105,7 +1050,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy, const mlir::DataLayout &dataLayout) { // Create a pointer type and use computeElementDistance(). - return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); } namespace { @@ -1124,8 +1069,9 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) TODO(loc, "fir.allocmem codegen of derived type with length parameters"); mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); - if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) - size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); + if (auto scaleSize = + fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) + size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, integerCast(loc, rewriter, ity, opnd)); @@ -1133,8 +1079,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { // As the return value of malloc(0) is implementation defined, allocate one // byte to ensure the allocation status being true. This behavior aligns to // what the runtime has. - mlir::Value zero = genConstantIndex(loc, ity, rewriter, 0); - mlir::Value one = genConstantIndex(loc, ity, rewriter, 1); + mlir::Value zero = fir::genConstantIndex(loc, ity, rewriter, 0); + mlir::Value one = fir::genConstantIndex(loc, ity, rewriter, 1); mlir::Value cmp = mlir::LLVM::ICmpOp::create( rewriter, loc, mlir::LLVM::ICmpPredicate::sgt, size, zero); size = mlir::LLVM::SelectOp::create(rewriter, loc, cmp, size, one); @@ -1157,7 +1103,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy) const { - return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, + getDataLayout()); } }; } // namespace @@ -1344,7 +1291,7 @@ genCUFAllocDescriptor(mlir::Location loc, mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; mlir::Value sizeInBytes = - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); + fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; return mlir::LLVM::CallOp::create(rewriter, loc, fctTy, RTNAME_STRING(CUFAllocDescriptor), args) @@ -1599,7 +1546,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> { // representation of derived types with pointer/allocatable components. // This has been seen in hashing algorithms using TRANSFER. mlir::Value zero = - genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); + fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); descriptor = insertField(rewriter, loc, descriptor, {getLenParamFieldId(boxTy), 0}, zero); } @@ -1944,8 +1891,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { bool hasSlice = !xbox.getSlice().empty(); unsigned sliceOffset = xbox.getSliceOperandIndex(); mlir::Location loc = xbox.getLoc(); - mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0); - mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1); + mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0); + mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1); mlir::Value prevPtrOff = one; mlir::Type eleTy = boxTy.getEleTy(); const unsigned rank = xbox.getRank(); @@ -1994,7 +1941,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { prevDimByteStride = getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams()); } else { - prevDimByteStride = genConstantIndex( + prevDimByteStride = fir::genConstantIndex( loc, i64Ty, rewriter, charTy.getLen() * lowerTy().characterBitsize(charTy) / 8); } @@ -2152,7 +2099,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> { if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) { if (charTy.hasConstantLen()) { mlir::Value len = - genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); lenParams.emplace_back(len); } else { mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair, @@ -2161,7 +2108,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> { assert(!isInGlobalOp(rewriter) && "character target in global op must have constant length"); mlir::Value width = - genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width); } lenParams.emplace_back(len); @@ -2215,8 +2162,9 @@ private: mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = rebox.getLoc(); mlir::Value zero = - genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); - mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); + mlir::Value one = + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) { mlir::Value extent = std::get<0>(iter.value()); unsigned dim = iter.index(); @@ -2249,7 +2197,7 @@ private: mlir::Location loc = rebox.getLoc(); mlir::Type byteTy = ::getI8Type(rebox.getContext()); mlir::Type idxTy = lowerTy().indexType(); - mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0); // Apply subcomponent and substring shift on base address. if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) { // Cast to inputEleTy* so that a GEP can be used. @@ -2277,7 +2225,7 @@ private: // and strides. llvm::SmallVector<mlir::Value> slicedExtents; llvm::SmallVector<mlir::Value> slicedStrides; - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); const bool sliceHasOrigins = !rebox.getShift().empty(); unsigned sliceOps = rebox.getSliceOperandIndex(); unsigned shiftOps = rebox.getShiftOperandIndex(); @@ -2350,7 +2298,7 @@ private: // which may be OK if all new extents are ones, the stride does not // matter, use one. mlir::Value stride = inputStrides.empty() - ? genConstantIndex(loc, idxTy, rewriter, 1) + ? fir::genConstantIndex(loc, idxTy, rewriter, 1) : inputStrides[0]; for (unsigned i = 0; i < rebox.getShape().size(); ++i) { mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i]; @@ -2585,9 +2533,9 @@ struct XArrayCoorOpConversion unsigned shiftOffset = coor.getShiftOperandIndex(); unsigned sliceOffset = coor.getSliceOperandIndex(); auto sliceOps = coor.getSlice().begin(); - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); mlir::Value prevExt = one; - mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0); const bool isShifted = !coor.getShift().empty(); const bool isSliced = !coor.getSlice().empty(); const bool baseIsBoxed = @@ -2918,7 +2866,7 @@ private: // of lower bound aspects. This both accounts for dynamically sized // types and non contiguous arrays. auto idxTy = lowerTy().indexType(); - mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0); unsigned arrayDim = arrTy.getDimension(); for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) { mlir::Value stride = @@ -3525,114 +3473,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> { } }; -/// Helper function for converting select ops. This function converts the -/// signature of the given block. If the new block signature is different from -/// `expectedTypes`, returns "failure". -static llvm::FailureOr<mlir::Block *> -getConvertedBlock(mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter, - mlir::Operation *branchOp, mlir::Block *block, - mlir::TypeRange expectedTypes) { - assert(converter && "expected non-null type converter"); - assert(!block->isEntryBlock() && "entry blocks have no predecessors"); - - // There is nothing to do if the types already match. - if (block->getArgumentTypes() == expectedTypes) - return block; - - // Compute the new block argument types and convert the block. - std::optional<mlir::TypeConverter::SignatureConversion> conversion = - converter->convertBlockSignature(block); - if (!conversion) - return rewriter.notifyMatchFailure(branchOp, - "could not compute block signature"); - if (expectedTypes != conversion->getConvertedTypes()) - return rewriter.notifyMatchFailure( - branchOp, - "mismatch between adaptor operand types and computed block signature"); - return rewriter.applySignatureConversion(block, *conversion, converter); -} - +/// Base class for SelectOpConversion and SelectRankOpConversion. template <typename OP> -static llvm::LogicalResult -selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select, - typename OP::Adaptor adaptor, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { - unsigned conds = select.getNumConditions(); - auto cases = select.getCases().getValue(); - mlir::Value selector = adaptor.getSelector(); - auto loc = select.getLoc(); - assert(conds > 0 && "select must have cases"); - - llvm::SmallVector<mlir::Block *> destinations; - llvm::SmallVector<mlir::ValueRange> destinationsOperands; - mlir::Block *defaultDestination; - mlir::ValueRange defaultOperands; - llvm::SmallVector<int32_t> caseValues; - - for (unsigned t = 0; t != conds; ++t) { - mlir::Block *dest = select.getSuccessor(t); - auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); - const mlir::Attribute &attr = cases[t]; - if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) { - destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{}); - auto convertedBlock = - getConvertedBlock(rewriter, converter, select, dest, - mlir::TypeRange(destinationsOperands.back())); +struct SelectOpConversionBase : public fir::FIROpConversion<OP> { + using fir::FIROpConversion<OP>::FIROpConversion; + +private: + /// Helper function for converting select ops. This function converts the + /// signature of the given block. If the new block signature is different from + /// `expectedTypes`, returns "failure". + llvm::FailureOr<mlir::Block *> + getConvertedBlock(mlir::ConversionPatternRewriter &rewriter, + mlir::Operation *branchOp, mlir::Block *block, + mlir::TypeRange expectedTypes) const { + const mlir::TypeConverter *converter = this->getTypeConverter(); + assert(converter && "expected non-null type converter"); + assert(!block->isEntryBlock() && "entry blocks have no predecessors"); + + // There is nothing to do if the types already match. + if (block->getArgumentTypes() == expectedTypes) + return block; + + // Compute the new block argument types and convert the block. + std::optional<mlir::TypeConverter::SignatureConversion> conversion = + converter->convertBlockSignature(block); + if (!conversion) + return rewriter.notifyMatchFailure(branchOp, + "could not compute block signature"); + if (expectedTypes != conversion->getConvertedTypes()) + return rewriter.notifyMatchFailure(branchOp, + "mismatch between adaptor operand " + "types and computed block signature"); + return rewriter.applySignatureConversion(block, *conversion, converter); + } + +protected: + llvm::LogicalResult + selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + unsigned conds = select.getNumConditions(); + auto cases = select.getCases().getValue(); + mlir::Value selector = adaptor.getSelector(); + auto loc = select.getLoc(); + assert(conds > 0 && "select must have cases"); + + llvm::SmallVector<mlir::Block *> destinations; + llvm::SmallVector<mlir::ValueRange> destinationsOperands; + mlir::Block *defaultDestination; + mlir::ValueRange defaultOperands; + // LLVM::SwitchOp selector type and the case values types + // must have the same bit width, so cast the selector to i64, + // and use i64 for the case values. It is hard to imagine + // a computed GO TO with the number of labels in the label-list + // bigger than INT_MAX, but let's use i64 to be on the safe side. + // Moreover, fir.select operation is more relaxed than + // a Fortran computed GO TO, so it may specify such a case value + // even if there is just a single label/case. + llvm::SmallVector<int64_t> caseValues; + + for (unsigned t = 0; t != conds; ++t) { + mlir::Block *dest = select.getSuccessor(t); + auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); + const mlir::Attribute &attr = cases[t]; + if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) { + destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{}); + auto convertedBlock = + getConvertedBlock(rewriter, select, dest, + mlir::TypeRange(destinationsOperands.back())); + if (mlir::failed(convertedBlock)) + return mlir::failure(); + destinations.push_back(*convertedBlock); + caseValues.push_back(intAttr.getInt()); + continue; + } + assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)); + assert((t + 1 == conds) && "unit must be last"); + defaultOperands = destOps ? *destOps : mlir::ValueRange{}; + auto convertedBlock = getConvertedBlock(rewriter, select, dest, + mlir::TypeRange(defaultOperands)); if (mlir::failed(convertedBlock)) return mlir::failure(); - destinations.push_back(*convertedBlock); - caseValues.push_back(intAttr.getInt()); - continue; + defaultDestination = *convertedBlock; } - assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)); - assert((t + 1 == conds) && "unit must be last"); - defaultOperands = destOps ? *destOps : mlir::ValueRange{}; - auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest, - mlir::TypeRange(defaultOperands)); - if (mlir::failed(convertedBlock)) - return mlir::failure(); - defaultDestination = *convertedBlock; - } - - // LLVM::SwitchOp takes a i32 type for the selector. - if (select.getSelector().getType() != rewriter.getI32Type()) - selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), - selector); - - rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( - select, selector, - /*defaultDestination=*/defaultDestination, - /*defaultOperands=*/defaultOperands, - /*caseValues=*/caseValues, - /*caseDestinations=*/destinations, - /*caseOperands=*/destinationsOperands, - /*branchWeights=*/llvm::ArrayRef<std::int32_t>()); - return mlir::success(); -} + selector = + this->integerCast(loc, rewriter, rewriter.getI64Type(), selector); + + rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( + select, selector, + /*defaultDestination=*/defaultDestination, + /*defaultOperands=*/defaultOperands, + /*caseValues=*/rewriter.getI64VectorAttr(caseValues), + /*caseDestinations=*/destinations, + /*caseOperands=*/destinationsOperands, + /*branchWeights=*/llvm::ArrayRef<std::int32_t>()); + return mlir::success(); + } +}; /// conversion of fir::SelectOp to an if-then-else ladder -struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> { - using FIROpConversion::FIROpConversion; +struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> { + using SelectOpConversionBase::SelectOpConversionBase; llvm::LogicalResult matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, - rewriter, getTypeConverter()); + return this->selectMatchAndRewrite(op, adaptor, rewriter); } }; /// conversion of fir::SelectRankOp to an if-then-else ladder -struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> { - using FIROpConversion::FIROpConversion; +struct SelectRankOpConversion + : public SelectOpConversionBase<fir::SelectRankOp> { + using SelectOpConversionBase::SelectOpConversionBase; llvm::LogicalResult matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - return selectMatchAndRewrite<fir::SelectRankOp>( - lowerTy(), op, adaptor, rewriter, getTypeConverter()); + return this->selectMatchAndRewrite(op, adaptor, rewriter); } }; @@ -3837,7 +3794,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> { ptr = mlir::LLVM::ExtractValueOp::create(rewriter, loc, ptr, 0); } mlir::LLVM::ConstantOp c0 = - genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); + fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr); rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0); diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 37f1c9f..97912bd 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -21,6 +21,7 @@ #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Support/FatalError.h" #include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Support/Utils.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -125,10 +126,58 @@ struct PrivateClauseOpConversion return mlir::success(); } }; + +// Convert FIR type to LLVM without turning fir.box<T> into memory +// reference. +static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType) { + if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType)) + return converter.convertBoxTypeAsStruct(boxTy); + return converter.convertType(firType); +} + +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = allocmemOp.getAllocatedType(); + mlir::Location loc = allocmemOp.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = fir::computeElementDistance( + loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); + if (auto scaleSize = fir::genAllocationScaleSize( + loc, allocmemOp.getInType(), ity, rewriter)) + size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create<mlir::LLVM::MulOp>( + loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + rewriter.modifyOpInPlace(allocmemOp, [&]() { + allocmemOp.setInType(rewriter.getI8Type()); + allocmemOp.getTypeparamsMutable().clear(); + allocmemOp.getTypeparamsMutable().append(size); + }); + return mlir::success(); + } +}; } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add<MapInfoOpConversion>(converter); patterns.add<PrivateClauseOpConversion>(converter); + patterns.add<TargetAllocMemOpConversion>(converter); } diff --git a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp index 52c733d..bd0499f 100644 --- a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" #include "flang/Optimizer/Dialect/CUF/Attributes/CUFEnumAttr.cpp.inc" @@ -29,4 +30,26 @@ void CUFDialect::registerAttributes() { LaunchBoundsAttr, ProcAttributeAttr>(); } +cuf::DataAttributeAttr getDataAttr(mlir::Operation *op) { + if (!op) + return {}; + + if (auto dataAttr = + op->getAttrOfType<cuf::DataAttributeAttr>(cuf::getDataAttrName())) + return dataAttr; + + // When the attribute is declared on the operation, it doesn't have a prefix. + if (auto dataAttr = + op->getAttrOfType<cuf::DataAttributeAttr>(cuf::dataAttrName)) + return dataAttr; + + return {}; +} + +bool hasDataAttr(mlir::Operation *op, cuf::DataAttribute value) { + if (auto dataAttr = getDataAttr(op)) + return dataAttr.getValue() == value; + return false; +} + } // namespace cuf diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 01975f3..87f9899 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { } /// Parser shared by Alloca and Allocmem -/// /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword diff --git a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp index 034f8c7..f16072a 100644 --- a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp +++ b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp @@ -68,3 +68,31 @@ fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) { } return mlir::success(); } + +mlir::LogicalResult +fir::detail::verifyFortranVariableStorageOpInterface(mlir::Operation *op) { + auto storageIface = mlir::cast<fir::FortranVariableStorageOpInterface>(op); + mlir::Value storage = storageIface.getStorage(); + std::uint64_t storageOffset = storageIface.getStorageOffset(); + if (!storage) { + if (storageOffset != 0) + return op->emitOpError( + "storage offset specified without the storage reference"); + return mlir::success(); + } + + auto storageType = + mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(storage.getType())); + if (!storageType || storageType.getDimension() != 1) + return op->emitOpError("storage must be a vector"); + if (storageType.hasDynamicExtents()) + return op->emitOpError("storage must have known extent"); + if (storageType.getEleTy() != mlir::IntegerType::get(op->getContext(), 8)) + return op->emitOpError("storage must be an array of i8 elements"); + if (storageOffset > storageType.getConstantArraySize()) + return op->emitOpError("storage offset exceeds the storage size"); + // TODO: we should probably verify that the (offset + sizeof(var)) + // is within the storage object, but this requires mlir::DataLayout. + // Can we make it available during the verification? + return mlir::success(); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index ed102db..2971a72 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -279,7 +279,8 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, auto [hlfirVariableType, firVarType] = getDeclareOutputTypes(inputType, hasExplicitLbs); build(builder, result, {hlfirVariableType, firVarType}, memref, shape, - typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr); + typeparams, dummy_scope, /*storage=*/nullptr, /*storage_offset=*/0, + nameAttr, fortran_attrs, data_attr); } llvm::LogicalResult hlfir::DeclareOp::verify() { @@ -821,6 +822,62 @@ void hlfir::ConcatOp::getEffects( } //===----------------------------------------------------------------------===// +// CmpCharOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult hlfir::CmpCharOp::verify() { + mlir::Value lchr = getLchr(); + mlir::Value rchr = getRchr(); + + unsigned kind = getCharacterKind(lchr.getType()); + if (kind != getCharacterKind(rchr.getType())) + return emitOpError("character arguments must have the same KIND"); + + switch (getPredicate()) { + case mlir::arith::CmpIPredicate::slt: + case mlir::arith::CmpIPredicate::sle: + case mlir::arith::CmpIPredicate::eq: + case mlir::arith::CmpIPredicate::ne: + case mlir::arith::CmpIPredicate::sgt: + case mlir::arith::CmpIPredicate::sge: + break; + default: + return emitOpError("expected signed predicate"); + } + + return mlir::success(); +} + +void hlfir::CmpCharOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// +// CharTrimOp +//===----------------------------------------------------------------------===// + +void hlfir::CharTrimOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value chr) { + unsigned kind = getCharacterKind(chr.getType()); + auto resultType = hlfir::ExprType::get( + builder.getContext(), hlfir::ExprType::Shape{}, + fir::CharacterType::get(builder.getContext(), kind, + fir::CharacterType::unknownLen()), + /*polymorphic=*/false); + build(builder, result, resultType, chr); +} + +void hlfir::CharTrimOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// // NumericalReductionOp //===----------------------------------------------------------------------===// @@ -1440,44 +1497,46 @@ void hlfir::MatmulTransposeOp::getEffects( } //===----------------------------------------------------------------------===// -// CShiftOp +// Array shifts: CShiftOp/EOShiftOp //===----------------------------------------------------------------------===// -llvm::LogicalResult hlfir::CShiftOp::verify() { - mlir::Value array = getArray(); +template <typename Op> +static llvm::LogicalResult verifyArrayShift(Op op) { + mlir::Value array = op.getArray(); fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>( hlfir::getFortranElementOrSequenceType(array.getType())); llvm::ArrayRef<int64_t> inShape = arrayTy.getShape(); std::size_t arrayRank = inShape.size(); mlir::Type eleTy = arrayTy.getEleTy(); - hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType()); + hlfir::ExprType resultTy = + mlir::cast<hlfir::ExprType>(op.getResult().getType()); llvm::ArrayRef<int64_t> resultShape = resultTy.getShape(); std::size_t resultRank = resultShape.size(); mlir::Type resultEleTy = resultTy.getEleTy(); - mlir::Value shift = getShift(); + mlir::Value shift = op.getShift(); mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType()); - // TODO: turn allowCharacterLenMismatch into true. - if (auto match = areMatchingTypes(*this, eleTy, resultEleTy, - /*allowCharacterLenMismatch=*/false); + if (auto match = areMatchingTypes( + op, eleTy, resultEleTy, + /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier); match.failed()) - return emitOpError( + return op.emitOpError( "input and output arrays should have the same element type"); if (arrayRank != resultRank) - return emitOpError("input and output arrays should have the same rank"); + return op.emitOpError("input and output arrays should have the same rank"); constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent(); for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape)) if (inDim != unknownExtent && resultDim != unknownExtent && inDim != resultDim) - return emitOpError( + return op.emitOpError( "output array's shape conflicts with the input array's shape"); int64_t dimVal = -1; - if (!getDim()) + if (!op.getDim()) dimVal = 1; - else if (auto dim = fir::getIntIfConstant(getDim())) + else if (auto dim = fir::getIntIfConstant(op.getDim())) dimVal = *dim; // The DIM argument may be statically invalid (e.g. exceed the @@ -1485,44 +1544,79 @@ llvm::LogicalResult hlfir::CShiftOp::verify() { // so avoid some checks unless useStrictIntrinsicVerifier is true. if (useStrictIntrinsicVerifier && dimVal != -1) { if (dimVal < 1) - return emitOpError("DIM must be >= 1"); + return op.emitOpError("DIM must be >= 1"); if (dimVal > static_cast<int64_t>(arrayRank)) - return emitOpError("DIM must be <= input array's rank"); + return op.emitOpError("DIM must be <= input array's rank"); } - if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) { - // SHIFT is an array. Verify the rank and the shape (if DIM is constant). - llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape(); - std::size_t shiftRank = shiftShape.size(); - if (shiftRank != arrayRank - 1) - return emitOpError( - "SHIFT's rank must be 1 less than the input array's rank"); - - if (useStrictIntrinsicVerifier && dimVal != -1) { - // SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)], - // where [d(1), d(2), ..., d(n)] is the shape of the ARRAY. - int64_t arrayDimIdx = 0; - int64_t shiftDimIdx = 0; - for (auto shiftDim : shiftShape) { - if (arrayDimIdx == dimVal - 1) + // A helper lambda to verify the shape of the array types of + // certain operands of the array shift (e.g. the SHIFT and BOUNDARY operands). + auto verifyOperandTypeShape = [&](mlir::Type type, + llvm::Twine name) -> llvm::LogicalResult { + if (auto opndSeqTy = mlir::dyn_cast<fir::SequenceType>(type)) { + // The operand is an array. Verify the rank and the shape (if DIM is + // constant). + llvm::ArrayRef<int64_t> opndShape = opndSeqTy.getShape(); + std::size_t opndRank = opndShape.size(); + if (opndRank != arrayRank - 1) + return op.emitOpError( + name + "'s rank must be 1 less than the input array's rank"); + + if (useStrictIntrinsicVerifier && dimVal != -1) { + // The operand's shape must be + // [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)], + // where [d(1), d(2), ..., d(n)] is the shape of the ARRAY. + int64_t arrayDimIdx = 0; + int64_t opndDimIdx = 0; + for (auto opndDim : opndShape) { + if (arrayDimIdx == dimVal - 1) + ++arrayDimIdx; + + if (inShape[arrayDimIdx] != unknownExtent && + opndDim != unknownExtent && inShape[arrayDimIdx] != opndDim) + return op.emitOpError("SHAPE(ARRAY)(" + + llvm::Twine(arrayDimIdx + 1) + + ") must be equal to SHAPE(" + name + ")(" + + llvm::Twine(opndDimIdx + 1) + + "): " + llvm::Twine(inShape[arrayDimIdx]) + + " != " + llvm::Twine(opndDim)); ++arrayDimIdx; - - if (inShape[arrayDimIdx] != unknownExtent && - shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim) - return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) + - ") must be equal to SHAPE(SHIFT)(" + - llvm::Twine(shiftDimIdx + 1) + - "): " + llvm::Twine(inShape[arrayDimIdx]) + - " != " + llvm::Twine(shiftDim)); - ++arrayDimIdx; - ++shiftDimIdx; + ++opndDimIdx; + } } } + return mlir::success(); + }; + + if (failed(verifyOperandTypeShape(shiftTy, "SHIFT"))) + return mlir::failure(); + + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) { + if (mlir::Value boundary = op.getBoundary()) { + mlir::Type boundaryTy = + hlfir::getFortranElementOrSequenceType(boundary.getType()); + if (auto match = areMatchingTypes( + op, eleTy, hlfir::getFortranElementType(boundaryTy), + /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier); + match.failed()) + return op.emitOpError( + "ARRAY and BOUNDARY operands must have the same element type"); + if (failed(verifyOperandTypeShape(boundaryTy, "BOUNDARY"))) + return mlir::failure(); + } } return mlir::success(); } +//===----------------------------------------------------------------------===// +// CShiftOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult hlfir::CShiftOp::verify() { + return verifyArrayShift(*this); +} + void hlfir::CShiftOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> @@ -1531,6 +1625,21 @@ void hlfir::CShiftOp::getEffects( } //===----------------------------------------------------------------------===// +// EOShiftOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult hlfir::EOShiftOp::verify() { + return verifyArrayShift(*this); +} + +void hlfir::EOShiftOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1543,7 +1652,8 @@ llvm::LogicalResult hlfir::ReshapeOp::verify() { hlfir::getFortranElementOrSequenceType(array.getType())); if (auto match = areMatchingTypes( *this, hlfir::getFortranElementType(resultType), - arrayType.getElementType(), /*allowCharacterLenMismatch=*/true); + arrayType.getElementType(), + /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier); match.failed()) return emitOpError("ARRAY and the result must have the same element type"); if (hlfir::isPolymorphicType(resultType) != @@ -1565,9 +1675,9 @@ llvm::LogicalResult hlfir::ReshapeOp::verify() { if (mlir::Value pad = getPad()) { auto padArrayType = mlir::cast<fir::SequenceType>( hlfir::getFortranElementOrSequenceType(pad.getType())); - if (auto match = areMatchingTypes(*this, arrayType.getElementType(), - padArrayType.getElementType(), - /*allowCharacterLenMismatch=*/true); + if (auto match = areMatchingTypes( + *this, arrayType.getElementType(), padArrayType.getElementType(), + /*allowCharacterLenMismatch=*/!useStrictIntrinsicVerifier); match.failed()) return emitOpError("ARRAY and PAD must be of the same type"); } @@ -1847,8 +1957,7 @@ hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf, // shape information is not available at compile time return llvm::LogicalResult::failure(); - rewriter.replaceAllUsesWith(shapeOf.getResult(), shape); - rewriter.eraseOp(shapeOf); + rewriter.replaceOp(shapeOf, shape); return llvm::LogicalResult::success(); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 9109f2b..886a8a5 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -455,12 +455,8 @@ struct AssociateOpConversion mlir::Type associateHlfirVarType = associate.getResultTypes()[0]; hlfirVar = adjustVar(hlfirVar, associateHlfirVarType); - associate.getResult(0).replaceAllUsesWith(hlfirVar); - mlir::Type associateFirVarType = associate.getResultTypes()[1]; firVar = adjustVar(firVar, associateFirVarType); - associate.getResult(1).replaceAllUsesWith(firVar); - associate.getResult(2).replaceAllUsesWith(flag); // FIXME: note that the AssociateOp that is being erased // here will continue to be a user of the original Source // operand (e.g. a result of hlfir.elemental), because @@ -472,7 +468,7 @@ struct AssociateOpConversion // the conversions, so that we can analyze HLFIR in its // original form and decide which of the AssociateOp // users of hlfir.expr can reuse the buffer (if it can). - rewriter.eraseOp(associate); + rewriter.replaceOp(associate, {hlfirVar, firVar, flag}); }; // If this is the last use of the expression value and this is an hlfir.expr diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index cc74273..3775a13 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -27,6 +27,8 @@ add_flang_library(HLFIRTransforms FIRSupport FIRTransforms FlangOpenMPTransforms + FortranEvaluate + FortranSupport HLFIRDialect LINK_COMPONENTS diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 2e27324..8104e53 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -305,6 +305,8 @@ public: auto firDeclareOp = fir::DeclareOp::create( rewriter, loc, memref.getType(), memref, declareOp.getShape(), declareOp.getTypeparams(), declareOp.getDummyScope(), + /*storage=*/declareOp.getStorage(), + /*storage_offset=*/declareOp.getStorageOffset(), declareOp.getUniqName(), fortranAttrs, dataAttr); // Propagate other attributes from hlfir.declare to fir.declare. @@ -490,15 +492,18 @@ public: } baseEleTy = hlfir::getFortranElementType(componentType); shape = designate.getComponentShape(); - } else { - // array%component[(indices) substring|complex part] cases. - // Component ref of array bases are dealt with below in embox/rebox. - assert(mlir::isa<fir::BaseBoxType>(designateResultType)); } } - if (mlir::isa<fir::BaseBoxType>(designateResultType)) { - // Generate embox or rebox. + if (mlir::isa<fir::BaseBoxType>(designateResultType) || + // Convert the component array slices using embox/rebox + // even if the result is a contiguous array section, e.g.: + // hlfir.designate %base{"i"} shape %shape : + // (!fir.box<!fir.array<2x!fir.type<_QMtypesTt{i:i32}>>>, + // !fir.shape<1>) -> !fir.ref<!fir.array<2xi32>> + // fir.coordinate_of should probably be a better option, though. + (fieldIndex && baseEntity.isArray())) { + // Generate embox or rebox for slicing. mlir::Type eleTy = fir::unwrapPassByRefType(designateResultType); bool isScalarDesignator = !mlir::isa<fir::SequenceType>(eleTy); mlir::Value sourceBox; @@ -575,8 +580,13 @@ public: else assert(sliceFields.empty() && substring.empty()); - llvm::SmallVector<mlir::Type> resultType{ - fir::updateTypeWithVolatility(designateResultType, isVolatile)}; + // If the designate's result type is not a box, then create + // a box type to be used for the result of the embox/rebox. + mlir::Type resultType = designateResultType; + if (!mlir::isa<fir::BaseBoxType>(resultType)) + resultType = fir::wrapInClassOrBoxType(resultType); + + resultType = fir::updateTypeWithVolatility(resultType, isVolatile); mlir::Value resultBox; if (mlir::isa<fir::BaseBoxType>(base.getType())) { @@ -587,6 +597,13 @@ public: fir::EmboxOp::create(builder, loc, resultType, base, shape, slice, firBaseTypeParameters, sourceBox); } + + if (!mlir::isa<fir::BaseBoxType>(designateResultType)) { + // If the designate's result is not a box, use the raw address + // as the new result. + resultBox = fir::BoxAddrOp::create(rewriter, loc, resultBox); + resultBox = builder.createConvert(loc, designateResultType, resultBox); + } rewriter.replaceOp(designate, resultBox); return mlir::success(); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp index c42b895..ff84a3c 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp @@ -101,9 +101,8 @@ public: elemental.getLoc(), builder, elemental, apply.getIndices()); // remove the old elemental and all of the bookkeeping - rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue()); + rewriter.replaceOp(apply, {yield.getElementValue()}); rewriter.eraseOp(yield); - rewriter.eraseOp(apply); rewriter.eraseOp(destroy); rewriter.eraseOp(elemental); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index 3c29d68..a913cfa 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -469,33 +469,49 @@ struct MatmulTransposeOpConversion } }; -class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> { - using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion; +// A converter for hlfir.cshift and hlfir.eoshift. +template <typename T> +class ArrayShiftOpConversion : public HlfirIntrinsicConversion<T> { + using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion; + using HlfirIntrinsicConversion<T>::lowerArguments; + using HlfirIntrinsicConversion<T>::processReturnValue; + using typename HlfirIntrinsicConversion<T>::IntrinsicArgument; llvm::LogicalResult - matchAndRewrite(hlfir::CShiftOp cshift, - mlir::PatternRewriter &rewriter) const override { - fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; - const mlir::Location &loc = cshift->getLoc(); + matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, op.getOperation()}; + const mlir::Location &loc = op->getLoc(); - llvm::SmallVector<IntrinsicArgument, 3> inArgs; - mlir::Value array = cshift.getArray(); + llvm::SmallVector<IntrinsicArgument, 4> inArgs; + llvm::StringRef intrinsicName{[]() { + if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) + return "eoshift"; + else if constexpr (std::is_same_v<T, hlfir::CShiftOp>) + return "cshift"; + else + llvm_unreachable("unsupported array shift"); + }()}; + + mlir::Value array = op.getArray(); inArgs.push_back({array, array.getType()}); - mlir::Value shift = cshift.getShift(); + mlir::Value shift = op.getShift(); inArgs.push_back({shift, shift.getType()}); - inArgs.push_back({cshift.getDim(), builder.getI32Type()}); + if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) { + mlir::Value boundary = op.getBoundary(); + inArgs.push_back({boundary, boundary ? boundary.getType() : nullptr}); + } + inArgs.push_back({op.getDim(), builder.getI32Type()}); - auto *argLowering = fir::getIntrinsicArgumentLowering("cshift"); + auto *argLowering = fir::getIntrinsicArgumentLowering(intrinsicName); llvm::SmallVector<fir::ExtendedValue, 3> args = - lowerArguments(cshift, inArgs, rewriter, argLowering); + lowerArguments(op, inArgs, rewriter, argLowering); - mlir::Type scalarResultType = - hlfir::getFortranElementType(cshift.getType()); + mlir::Type scalarResultType = hlfir::getFortranElementType(op.getType()); - auto [resultExv, mustBeFreed] = - fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args); + auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( + builder, loc, intrinsicName, scalarResultType, args); - processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter); + processReturnValue(op, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; @@ -535,6 +551,68 @@ class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> { } }; +class CmpCharOpConversion : public HlfirIntrinsicConversion<hlfir::CmpCharOp> { + using HlfirIntrinsicConversion<hlfir::CmpCharOp>::HlfirIntrinsicConversion; + + llvm::LogicalResult + matchAndRewrite(hlfir::CmpCharOp cmp, + mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, cmp.getOperation()}; + const mlir::Location &loc = cmp->getLoc(); + hlfir::Entity lhs{cmp.getLchr()}; + hlfir::Entity rhs{cmp.getRchr()}; + + auto [lhsExv, lhsCleanUp] = + hlfir::translateToExtendedValue(loc, builder, lhs); + auto [rhsExv, rhsCleanUp] = + hlfir::translateToExtendedValue(loc, builder, rhs); + + auto resultVal = fir::runtime::genCharCompare( + builder, loc, cmp.getPredicate(), lhsExv, rhsExv); + if (lhsCleanUp || rhsCleanUp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(cmp); + if (lhsCleanUp) + (*lhsCleanUp)(); + if (rhsCleanUp) + (*rhsCleanUp)(); + } + auto resultEntity = hlfir::EntityWithAttributes{resultVal}; + + processReturnValue(cmp, resultEntity, /*mustBeFreed=*/false, builder, + rewriter); + return mlir::success(); + } +}; + +class CharTrimOpConversion + : public HlfirIntrinsicConversion<hlfir::CharTrimOp> { + using HlfirIntrinsicConversion<hlfir::CharTrimOp>::HlfirIntrinsicConversion; + + llvm::LogicalResult + matchAndRewrite(hlfir::CharTrimOp trim, + mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, trim.getOperation()}; + const mlir::Location &loc = trim->getLoc(); + + llvm::SmallVector<IntrinsicArgument, 1> inArgs; + mlir::Value chr = trim.getChr(); + inArgs.push_back({chr, chr.getType()}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("trim"); + llvm::SmallVector<fir::ExtendedValue, 1> args = + lowerArguments(trim, inArgs, rewriter, argLowering); + + mlir::Type resultType = hlfir::getFortranElementType(trim.getType()); + + auto [resultExv, mustBeFreed] = + fir::genIntrinsicCall(builder, loc, "trim", resultType, args); + + processReturnValue(trim, resultExv, mustBeFreed, builder, rewriter); + return mlir::success(); + } +}; + class LowerHLFIRIntrinsics : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> { public: @@ -547,7 +625,9 @@ public: AnyOpConversion, SumOpConversion, ProductOpConversion, TransposeOpConversion, CountOpConversion, DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion, - MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context); + MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>, + ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion, + CmpCharOpConversion, CharTrimOpConversion>(context); // While conceptually this pass is performing dialect conversion, we use // pattern rewrites here instead of dialect conversion because this pass diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp index 8e25298..32998ab 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -96,7 +96,7 @@ struct MaskedArrayExpr { /// hlfir.elemental_addr that form the elemental tree producing /// the expression value. hlfir.elemental that produce values /// used inside transformational operations are not part of this set. - llvm::SmallSet<mlir::Operation *, 4> elementalParts{}; + llvm::SmallPtrSet<mlir::Operation *, 4> elementalParts{}; /// Was generateNoneElementalPart called? bool noneElementalPartWasGenerated = false; /// Is this expression the mask expression of the outer where statement? diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp index 722cd8a..a48b7ba 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp @@ -137,7 +137,7 @@ private: // Schedule being built. hlfir::Schedule schedule; /// Leaf regions that have been saved so far. - llvm::SmallSet<mlir::Region *, 16> savedRegions; + llvm::SmallPtrSet<mlir::Region *, 16> savedRegions; /// Is schedule.back() a schedule that is only saving region with read /// effects? bool currentRunIsReadOnly = false; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index b27c3a8..d8e36ea 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -10,6 +10,7 @@ // into the calling function. //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" @@ -1269,64 +1270,91 @@ public: } }; -class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> { +template <typename Op> +class ArrayShiftConversion : public mlir::OpRewritePattern<Op> { public: - using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern; + // The implementation below only support CShiftOp and EOShiftOp. + static_assert(std::is_same_v<Op, hlfir::CShiftOp> || + std::is_same_v<Op, hlfir::EOShiftOp>); + + using mlir::OpRewritePattern<Op>::OpRewritePattern; llvm::LogicalResult - matchAndRewrite(hlfir::CShiftOp cshift, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { - hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType()); + hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(op.getType()); assert(expr && - "expected an expression type for the result of hlfir.cshift"); + "expected an expression type for the result of the array shift"); unsigned arrayRank = expr.getRank(); - // When it is a 1D CSHIFT, we may assume that the DIM argument + // When it is a 1D CSHIFT/EOSHIFT, we may assume that the DIM argument // (whether it is present or absent) is equal to 1, otherwise, // the program is illegal. int64_t dimVal = 1; if (arrayRank != 1) - if (mlir::Value dim = cshift.getDim()) { + if (mlir::Value dim = op.getDim()) { auto constDim = fir::getIntIfConstant(dim); if (!constDim) - return rewriter.notifyMatchFailure(cshift, - "Nonconstant DIM for CSHIFT"); + return rewriter.notifyMatchFailure( + op, "Nonconstant DIM for CSHIFT/EOSHIFT"); dimVal = *constDim; } if (dimVal <= 0 || dimVal > arrayRank) - return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT"); + return rewriter.notifyMatchFailure(op, "Invalid DIM for CSHIFT/EOSHIFT"); + + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) { + // TODO: the EOSHIFT inlining code is not ready to produce + // fir.if selecting between ARRAY and BOUNDARY (or the default + // boundary value), when they are expressions of type CHARACTER. + // This needs more work. + if (mlir::isa<fir::CharacterType>(expr.getEleTy())) { + if (!hlfir::Entity{op.getArray()}.isVariable()) + return rewriter.notifyMatchFailure( + op, "EOSHIFT with ARRAY being CHARACTER expression"); + if (op.getBoundary() && !hlfir::Entity{op.getBoundary()}.isVariable()) + return rewriter.notifyMatchFailure( + op, "EOSHIFT with BOUNDARY being CHARACTER expression"); + } + // TODO: selecting between ARRAY and BOUNDARY values with derived types + // need more work. + if (fir::isa_derived(expr.getEleTy())) + return rewriter.notifyMatchFailure(op, "EOSHIFT of derived type"); + } // When DIM==1 and the contiguity of the input array is not statically // known, try to exploit the fact that the leading dimension might be // contiguous. We can do this now using hlfir.eval_in_mem with // a dynamic check for the leading dimension contiguity. - // Otherwise, convert hlfir.cshift to hlfir.elemental. + // Otherwise, convert hlfir.cshift/eoshift to hlfir.elemental. // // Note that the hlfir.elemental can be inlined into other hlfir.elemental, // while hlfir.eval_in_mem prevents this, and we will end up creating // a temporary array for the result. We may need to come up with // a more sophisticated logic for picking the most efficient // representation. - hlfir::Entity array = hlfir::Entity{cshift.getArray()}; + hlfir::Entity array = hlfir::Entity{op.getArray()}; mlir::Type elementType = array.getFortranElementType(); if (dimVal == 1 && fir::isa_trivial(elementType) && - // genInMemCShift() only works for variables currently. + // genInMemArrayShift() only works for variables currently. array.isVariable()) - rewriter.replaceOp(cshift, genInMemCShift(rewriter, cshift, dimVal)); + rewriter.replaceOp(op, genInMemArrayShift(rewriter, op, dimVal)); else - rewriter.replaceOp(cshift, genElementalCShift(rewriter, cshift, dimVal)); + rewriter.replaceOp(op, genElementalArrayShift(rewriter, op, dimVal)); return mlir::success(); } private: - /// Generate MODULO(\p shiftVal, \p extent). + /// For CSHIFT, generate MODULO(\p shiftVal, \p extent). + /// For EOSHIFT, return \p shiftVal casted to \p calcType. static mlir::Value normalizeShiftValue(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shiftVal, mlir::Value extent, mlir::Type calcType) { shiftVal = builder.createConvert(loc, calcType, shiftVal); + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) + return shiftVal; + extent = builder.createConvert(loc, calcType, extent); // Make sure that we do not divide by zero. When the dimension // has zero size, turn the extent into 1. Note that the computed @@ -1342,24 +1370,227 @@ private: return builder.createConvert(loc, calcType, shiftVal); } - /// Convert \p cshift into an hlfir.elemental using + /// The indices computations for the array shifts are done using I64 type. + /// For CSHIFT, all computations do not overflow signed and unsigned I64. + /// For EOSHIFT, some computations may involve negative shift values, + /// so using no-unsigned wrap flag would be incorrect. + static void setArithOverflowFlags(Op op, fir::FirOpBuilder &builder) { + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) + builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw); + else + builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | + mlir::arith::IntegerOverflowFlags::nuw); + } + + /// Return the element type of the EOSHIFT boundary that may be omitted + /// statically or dynamically. This element type might be used + /// to generate MLIR where we have to select between the default + /// boundary value and the dynamically absent/present boundary value. + /// If the boundary has a type not defined in Table 16.4 in 16.9.77 + /// of F2023, then the return value is nullptr. + static mlir::Type getDefaultBoundaryValueType(mlir::Type elementType) { + // To be able to generate a "select" between the default boundary value + // and the dynamic boundary value, use BoxCharType for the CHARACTER + // cases. This might be a little bit inefficient, because we may + // create unnecessary tuples, but it simplifies the inlining code. + if (auto charTy = mlir::dyn_cast<fir::CharacterType>(elementType)) + return fir::BoxCharType::get(charTy.getContext(), charTy.getFKind()); + + if (mlir::isa<fir::LogicalType>(elementType) || + fir::isa_integer(elementType) || fir::isa_real(elementType) || + fir::isa_complex(elementType)) + return elementType; + + return nullptr; + } + + /// Generate the default boundary value as defined in Table 16.4 in 16.9.77 + /// of F2023. + static mlir::Value genDefaultBoundary(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Type elementType) { + assert(getDefaultBoundaryValueType(elementType) && + "default boundary value cannot be computed for the given type"); + if (mlir::isa<fir::CharacterType>(elementType)) { + // Create an empty CHARACTER of the same kind. The assignment + // of this empty CHARACTER into the result will add the padding + // if necessary. + fir::factory::CharacterExprHelper charHelper{builder, loc}; + mlir::Value zeroLen = builder.createIntegerConstant( + loc, builder.getCharacterLengthType(), 0); + fir::CharBoxValue emptyCharTemp = + charHelper.createCharacterTemp(elementType, zeroLen); + return charHelper.createEmbox(emptyCharTemp); + } + + return fir::factory::createZeroValue(builder, loc, elementType); + } + + /// \p entity represents the boundary operand of hlfir.eoshift. + /// This method generates a scalar boundary value fetched + /// from the boundary entity using \p indices (which may be empty, + /// if the boundary operand is scalar). + static mlir::Value loadEoshiftVal(mlir::Location loc, + fir::FirOpBuilder &builder, + hlfir::Entity entity, + mlir::ValueRange indices = {}) { + hlfir::Entity boundaryVal = + hlfir::loadElementAt(loc, builder, entity, indices); + + mlir::Type boundaryValTy = + getDefaultBoundaryValueType(entity.getFortranElementType()); + + // Boxed !fir.char<KIND,LEN> with known LEN are loaded + // as raw references to !fir.char<KIND,LEN>. + // We need to wrap them into the !fir.boxchar. + if (boundaryVal.isVariable() && boundaryValTy && + mlir::isa<fir::BoxCharType>(boundaryValTy)) + return hlfir::genVariableBoxChar(loc, builder, boundaryVal); + return boundaryVal; + } + + /// This method generates a scalar boundary value for the given hlfir.eoshift + /// \p op that can be used to initialize cells of the result + /// if the scalar/array boundary operand is statically or dynamically + /// absent. The first result is the scalar boundary value. The second result + /// is a dynamic predicate indicating whether the scalar boundary value + /// should actually be used. + [[maybe_unused]] static std::pair<mlir::Value, mlir::Value> + genScalarBoundaryForEOShift(mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::EOShiftOp op) { + hlfir::Entity array{op.getArray()}; + mlir::Type elementType = array.getFortranElementType(); + + if (!op.getBoundary()) { + // Boundary operand is statically absent. + mlir::Value defaultVal = genDefaultBoundary(loc, builder, elementType); + mlir::Value boundaryIsScalarPred = builder.createBool(loc, true); + return {defaultVal, boundaryIsScalarPred}; + } + + hlfir::Entity boundary{op.getBoundary()}; + mlir::Type boundaryValTy = getDefaultBoundaryValueType(elementType); + + if (boundary.isScalar()) { + if (!boundaryValTy || !boundary.mayBeOptional()) { + // The boundary must be present. + mlir::Value boundaryVal = loadEoshiftVal(loc, builder, boundary); + mlir::Value boundaryIsScalarPred = builder.createBool(loc, true); + return {boundaryVal, boundaryIsScalarPred}; + } + + // Boundary is a scalar that may be dynamically absent. + // If boundary is not present dynamically, we must use the default + // value. + assert(mlir::isa<fir::BaseBoxType>(boundary.getType())); + mlir::Value isPresentPred = + fir::IsPresentOp::create(builder, loc, builder.getI1Type(), boundary); + mlir::Value boundaryVal = + builder + .genIfOp(loc, {boundaryValTy}, isPresentPred, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value boundaryVal = + loadEoshiftVal(loc, builder, boundary); + fir::ResultOp::create(builder, loc, boundaryVal); + }) + .genElse([&]() { + mlir::Value defaultVal = + genDefaultBoundary(loc, builder, elementType); + fir::ResultOp::create(builder, loc, defaultVal); + }) + .getResults()[0]; + mlir::Value boundaryIsScalarPred = builder.createBool(loc, true); + return {boundaryVal, boundaryIsScalarPred}; + } + if (!boundaryValTy || !boundary.mayBeOptional()) { + // The boundary must be present + mlir::Value boundaryIsScalarPred = builder.createBool(loc, false); + return {nullptr, boundaryIsScalarPred}; + } + + // Boundary is an array that may be dynamically absent. + mlir::Value defaultVal = genDefaultBoundary(loc, builder, elementType); + mlir::Value isPresentPred = + fir::IsPresentOp::create(builder, loc, builder.getI1Type(), boundary); + // If the array is present, then boundaryIsScalarPred must be equal + // to false, otherwise, it should be true. + mlir::Value trueVal = builder.createBool(loc, true); + mlir::Value falseVal = builder.createBool(loc, false); + mlir::Value boundaryIsScalarPred = mlir::arith::SelectOp::create( + builder, loc, isPresentPred, falseVal, trueVal); + return {defaultVal, boundaryIsScalarPred}; + } + + /// Generate code that produces the final boundary value to be assigned + /// to the result of hlfir.eoshift \p op. \p precomputedScalarBoundary + /// specifies the scalar boundary value pre-computed before the elemental + /// or the assignment loop. If it is nullptr, then the boundary operand + /// of \p op must be a present array. \p boundaryIsScalarPred is a dynamic + /// predicate that is true, when the pre-computed scalar value must be used. + /// \p oneBasedIndices specify the indices to address into the boundary + /// array - they may be empty, if the boundary is scalar. + [[maybe_unused]] static mlir::Value selectBoundaryValue( + mlir::Location loc, fir::FirOpBuilder &builder, hlfir::EOShiftOp op, + mlir::Value precomputedScalarBoundary, mlir::Value boundaryIsScalarPred, + mlir::ValueRange oneBasedIndices) { + // Boundary is statically absent: a default value has been precomputed. + if (!op.getBoundary()) + return precomputedScalarBoundary; + + // Boundary is statically present and is a scalar: boundary does not depend + // upon the indices and so it has been precomputed. + hlfir::Entity boundary{op.getBoundary()}; + if (boundary.isScalar()) + return precomputedScalarBoundary; + + // Boundary is statically present and is an array: if the scalar + // boundary has not been precomputed, this means that the data type + // of the shifted values does not provide a way to compute + // the default boundary value, so the array boundary must be dynamically + // present, and we can load the boundary values from it. + bool mustBePresent = !precomputedScalarBoundary; + if (mustBePresent) + return loadEoshiftVal(loc, builder, boundary, oneBasedIndices); + + // The array boundary may be dynamically absent. + // In this case, precomputedScalarBoundary is a pre-computed scalar + // boundary value that has to be used if boundaryIsScalarPred + // is true, otherwise, the boundary value has to be loaded + // from the boundary array. + mlir::Type boundaryValTy = precomputedScalarBoundary.getType(); + mlir::Value newBoundaryVal = + builder + .genIfOp(loc, {boundaryValTy}, boundaryIsScalarPred, + /*withElseRegion=*/true) + .genThen([&]() { + fir::ResultOp::create(builder, loc, precomputedScalarBoundary); + }) + .genElse([&]() { + mlir::Value elem = + loadEoshiftVal(loc, builder, boundary, oneBasedIndices); + fir::ResultOp::create(builder, loc, elem); + }) + .getResults()[0]; + return newBoundaryVal; + } + + /// Convert \p op into an hlfir.elemental using /// the pre-computed constant \p dimVal. - static mlir::Operation *genElementalCShift(mlir::PatternRewriter &rewriter, - hlfir::CShiftOp cshift, - int64_t dimVal) { + static mlir::Operation * + genElementalArrayShift(mlir::PatternRewriter &rewriter, Op op, + int64_t dimVal) { using Fortran::common::maxRank; - hlfir::Entity shift = hlfir::Entity{cshift.getShift()}; - hlfir::Entity array = hlfir::Entity{cshift.getArray()}; + hlfir::Entity shift = hlfir::Entity{op.getShift()}; + hlfir::Entity array = hlfir::Entity{op.getArray()}; - mlir::Location loc = cshift.getLoc(); - fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; + mlir::Location loc = op.getLoc(); + fir::FirOpBuilder builder{rewriter, op.getOperation()}; // The new index computation involves MODULO, which is not implemented // for IndexType, so use I64 instead. mlir::Type calcType = builder.getI64Type(); - // All the indices arithmetic used below does not overflow - // signed and unsigned I64. - builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | - mlir::arith::IntegerOverflowFlags::nuw); + // Set the indices arithmetic overflow flags. + setArithOverflowFlags(op, builder); mlir::Value arrayShape = hlfir::genShape(loc, builder, array); llvm::SmallVector<mlir::Value, maxRank> arrayExtents = @@ -1374,6 +1605,17 @@ private: shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); } + // The boundary operand of hlfir.eoshift may be statically or + // dynamically absent. + // In both cases, it is assumed to be a scalar with the value + // corresponding to the array element type. + // boundaryIsScalarPred is a dynamic predicate that identifies + // these cases. If boundaryIsScalarPred is dynamicaly false, + // then the boundary operand must be a present array. + mlir::Value boundaryVal, boundaryIsScalarPred; + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) + std::tie(boundaryVal, boundaryIsScalarPred) = + genScalarBoundaryForEOShift(loc, builder, op); auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange inputIndices) -> hlfir::Entity { @@ -1394,34 +1636,84 @@ private: shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); } + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) { + llvm::SmallVector<mlir::Value, maxRank> boundaryIndices{indices}; + boundaryIndices.erase(boundaryIndices.begin() + dimVal - 1); + boundaryVal = + selectBoundaryValue(loc, builder, op, boundaryVal, + boundaryIsScalarPred, boundaryIndices); + } - // Element i of the result (1-based) is element - // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original - // ARRAY (or its section, when ARRAY is not a vector). - - // Compute the index into the original array using the normalized - // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)): - // newIndex = - // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM)) - // - // Such index computation allows for further loop vectorization - // in LLVM. - mlir::Value wrapBound = - mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); - mlir::Value adjustedShiftVal = - mlir::arith::SubIOp::create(builder, loc, shiftVal, shiftDimExtent); - mlir::Value index = - builder.createConvert(loc, calcType, inputIndices[dimVal - 1]); - mlir::Value wrapCheck = mlir::arith::CmpIOp::create( - builder, loc, mlir::arith::CmpIPredicate::sle, index, wrapBound); - mlir::Value actualShift = mlir::arith::SelectOp::create( - builder, loc, wrapCheck, shiftVal, adjustedShiftVal); - mlir::Value newIndex = - mlir::arith::AddIOp::create(builder, loc, index, actualShift); - newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex); - indices[dimVal - 1] = newIndex; - hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices); - return hlfir::loadTrivialScalar(loc, builder, element); + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) { + // EOSHIFT: + // Element i of the result (1-based) is the element of the original + // array (or its section, when ARRAY is not a vector) with index + // (i + SH), if (1 <= i + SH <= SIZE(ARRAY,DIM)), otherwise + // it is the BOUNDARY value. + mlir::Value index = + builder.createConvert(loc, calcType, inputIndices[dimVal - 1]); + mlir::arith::IntegerOverflowFlags savedFlags = + builder.getIntegerOverflowFlags(); + builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw); + mlir::Value indexPlusShift = + mlir::arith::AddIOp::create(builder, loc, index, shiftVal); + builder.setIntegerOverflowFlags(savedFlags); + mlir::Value one = builder.createIntegerConstant(loc, calcType, 1); + mlir::Value cmp1 = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sge, indexPlusShift, one); + mlir::Value cmp2 = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sle, indexPlusShift, + shiftDimExtent); + mlir::Value loadFromArray = + mlir::arith::AndIOp::create(builder, loc, cmp1, cmp2); + mlir::Type boundaryValTy = boundaryVal.getType(); + mlir::Value result = + builder + .genIfOp(loc, {boundaryValTy}, loadFromArray, + /*withElseRegion=*/true) + .genThen([&]() { + indices[dimVal - 1] = builder.createConvert( + loc, builder.getIndexType(), indexPlusShift); + ; + mlir::Value elem = + loadEoshiftVal(loc, builder, array, indices); + fir::ResultOp::create(builder, loc, elem); + }) + .genElse( + [&]() { fir::ResultOp::create(builder, loc, boundaryVal); }) + .getResults()[0]; + return hlfir::Entity{result}; + } else { + // CSHIFT: + // Element i of the result (1-based) is element + // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original + // ARRAY (or its section, when ARRAY is not a vector). + + // Compute the index into the original array using the normalized + // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)): + // newIndex = + // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM)) + // + // Such index computation allows for further loop vectorization + // in LLVM. + mlir::Value wrapBound = + mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); + mlir::Value adjustedShiftVal = + mlir::arith::SubIOp::create(builder, loc, shiftVal, shiftDimExtent); + mlir::Value index = + builder.createConvert(loc, calcType, inputIndices[dimVal - 1]); + mlir::Value wrapCheck = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sle, index, wrapBound); + mlir::Value actualShift = mlir::arith::SelectOp::create( + builder, loc, wrapCheck, shiftVal, adjustedShiftVal); + mlir::Value newIndex = + mlir::arith::AddIOp::create(builder, loc, index, actualShift); + newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex); + indices[dimVal - 1] = newIndex; + hlfir::Entity element = + hlfir::getElementAt(loc, builder, array, indices); + return hlfir::loadTrivialScalar(loc, builder, element); + } }; mlir::Type elementType = array.getFortranElementType(); @@ -1429,19 +1721,42 @@ private: loc, builder, elementType, arrayShape, typeParams, genKernel, /*isUnordered=*/true, array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr, - cshift.getResult().getType()); + op.getResult().getType()); return elementalOp.getOperation(); } - /// Convert \p cshift into an hlfir.eval_in_mem using the pre-computed + /// Convert \p op into an hlfir.eval_in_mem using the pre-computed /// constant \p dimVal. - /// The converted code looks like this: - /// do i=1,SH - /// result(i + (SIZE(ARRAY,DIM) - SH)) = array(i) + /// The converted code for CSHIFT looks like this: + /// DEST_OFFSET = SIZE(ARRAY,DIM) - SH + /// COPY_END1 = SH + /// do i=1,COPY_END1 + /// result(i + DEST_OFFSET) = array(i) /// end - /// do i=1,SIZE(ARRAY,DIM) - SH - /// result(i) = array(i + SH) + /// SOURCE_OFFSET = SH + /// COPY_END2 = SIZE(ARRAY,DIM) - SH + /// do i=1,COPY_END2 + /// result(i) = array(i + SOURCE_OFFSET) /// end + /// Where SH is the normalized shift value, which satisfies + /// (SH >= 0 && SH < SIZE(ARRAY,DIM)). + /// + /// The converted code for EOSHIFT looks like this: + /// EXTENT = SIZE(ARRAY,DIM) + /// DEST_OFFSET = SH < 0 ? -SH : 0 + /// SOURCE_OFFSET = SH < 0 ? 0 : SH + /// COPY_END = SH < 0 ? + /// (-EXTENT > SH ? 0 : EXTENT + SH) : + /// (EXTENT < SH ? 0 : EXTENT - SH) + /// do i=1,COPY_END + /// result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET) + /// end + /// INIT_END = EXTENT - COPY_END + /// INIT_OFFSET = SH < 0 ? 0 : COPY_END + /// do i=1,INIT_END + /// result(i + INIT_OFFSET) = BOUNDARY + /// end + /// Where SH is the original shift value. /// /// When \p dimVal is 1, we generate the same code twice /// under a dynamic check for the contiguity of the leading @@ -1450,24 +1765,21 @@ private: /// as a contiguous slice of the original array. /// This allows recognizing the above two loops as memcpy /// loop idioms in LLVM. - static mlir::Operation *genInMemCShift(mlir::PatternRewriter &rewriter, - hlfir::CShiftOp cshift, - int64_t dimVal) { + static mlir::Operation *genInMemArrayShift(mlir::PatternRewriter &rewriter, + Op op, int64_t dimVal) { using Fortran::common::maxRank; - hlfir::Entity shift = hlfir::Entity{cshift.getShift()}; - hlfir::Entity array = hlfir::Entity{cshift.getArray()}; + hlfir::Entity shift = hlfir::Entity{op.getShift()}; + hlfir::Entity array = hlfir::Entity{op.getArray()}; assert(array.isVariable() && "array must be a variable"); assert(!array.isPolymorphic() && - "genInMemCShift does not support polymorphic types"); - mlir::Location loc = cshift.getLoc(); - fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; + "genInMemArrayShift does not support polymorphic types"); + mlir::Location loc = op.getLoc(); + fir::FirOpBuilder builder{rewriter, op.getOperation()}; // The new index computation involves MODULO, which is not implemented // for IndexType, so use I64 instead. mlir::Type calcType = builder.getI64Type(); - // All the indices arithmetic used below does not overflow - // signed and unsigned I64. - builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | - mlir::arith::IntegerOverflowFlags::nuw); + // Set the indices arithmetic overflow flags. + setArithOverflowFlags(op, builder); mlir::Value arrayShape = hlfir::genShape(loc, builder, array); llvm::SmallVector<mlir::Value, maxRank> arrayExtents = @@ -1482,10 +1794,20 @@ private: shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); } + // The boundary operand of hlfir.eoshift may be statically or + // dynamically absent. + // In both cases, it is assumed to be a scalar with the value + // corresponding to the array element type. + // boundaryIsScalarPred is a dynamic predicate that identifies + // these cases. If boundaryIsScalarPred is dynamicaly false, + // then the boundary operand must be a present array. + mlir::Value boundaryVal, boundaryIsScalarPred; + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) + std::tie(boundaryVal, boundaryIsScalarPred) = + genScalarBoundaryForEOShift(loc, builder, op); hlfir::EvaluateInMemoryOp evalOp = hlfir::EvaluateInMemoryOp::create( - builder, loc, mlir::cast<hlfir::ExprType>(cshift.getType()), - arrayShape); + builder, loc, mlir::cast<hlfir::ExprType>(op.getType()), arrayShape); builder.setInsertionPointToStart(&evalOp.getBody().front()); mlir::Value resultArray = evalOp.getMemory(); @@ -1499,11 +1821,12 @@ private: // (if any). If exposeContiguity is true, the array's section // array(s(1), ..., s(dim-1), :, s(dim+1), ..., s(n)) is represented // as a contiguous 1D array. - // shiftVal is the normalized shift value that satisfies (SH >= 0 && SH < - // SIZE(ARRAY,DIM)). + // For CSHIFT, shiftVal is the normalized shift value that satisfies + // (SH >= 0 && SH < SIZE(ARRAY,DIM)). // auto genDimensionShift = [&](mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Value shiftVal, bool exposeContiguity, + mlir::Value shiftVal, mlir::Value boundary, + bool exposeContiguity, mlir::ValueRange oneBasedIndices) -> llvm::SmallVector<mlir::Value, 0> { // Create a vector of indices (s(1), ..., s(dim-1), nullptr, s(dim+1), @@ -1536,63 +1859,143 @@ private: srcIndices.resize(1); } - // Copy first portion of the array: - // do i=1,SH - // result(i + (SIZE(ARRAY,DIM) - SH)) = array(i) - // end - auto genAssign1 = [&](mlir::Location loc, fir::FirOpBuilder &builder, - mlir::ValueRange index, - mlir::ValueRange reductionArgs) + // genCopy labda generates the body of a generic copy loop. + // do i=1,COPY_END + // result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET) + // end + // + // It is parameterized by DEST_OFFSET and SOURCE_OFFSET. + mlir::Value dstOffset, srcOffset; + auto genCopy = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, mlir::ValueRange reductionArgs) -> llvm::SmallVector<mlir::Value, 0> { assert(index.size() == 1 && "expected single loop"); mlir::Value srcIndex = builder.createConvert(loc, calcType, index[0]); + mlir::Value dstIndex = srcIndex; + if (srcOffset) + srcIndex = + mlir::arith::AddIOp::create(builder, loc, srcIndex, srcOffset); srcIndices[dimVal - 1] = srcIndex; hlfir::Entity srcElementValue = hlfir::loadElementAt(loc, builder, srcArray, srcIndices); - mlir::Value dstIndex = mlir::arith::AddIOp::create( - builder, loc, srcIndex, - mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, - shiftVal)); + if (dstOffset) + dstIndex = + mlir::arith::AddIOp::create(builder, loc, dstIndex, dstOffset); dstIndices[dimVal - 1] = dstIndex; hlfir::Entity dstElement = hlfir::getElementAt( loc, builder, hlfir::Entity{resultArray}, dstIndices); hlfir::AssignOp::create(builder, loc, srcElementValue, dstElement); + // Reset the external parameters' values to make sure + // they are properly updated between the labda calls. + // WARNING: if genLoopNestWithReductions() calls the lambda + // multiple times, this is going to be a problem. + dstOffset = nullptr; + srcOffset = nullptr; return {}; }; - // Generate the first loop. - hlfir::genLoopNestWithReductions(loc, builder, {shiftVal}, - /*reductionInits=*/{}, genAssign1, - /*isUnordered=*/true); - - // Copy second portion of the array: - // do i=1,SIZE(ARRAY,DIM)-SH - // result(i) = array(i + SH) - // end - auto genAssign2 = [&](mlir::Location loc, fir::FirOpBuilder &builder, - mlir::ValueRange index, - mlir::ValueRange reductionArgs) - -> llvm::SmallVector<mlir::Value, 0> { - assert(index.size() == 1 && "expected single loop"); - mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]); - mlir::Value srcIndex = - mlir::arith::AddIOp::create(builder, loc, dstIndex, shiftVal); - srcIndices[dimVal - 1] = srcIndex; - hlfir::Entity srcElementValue = - hlfir::loadElementAt(loc, builder, srcArray, srcIndices); - dstIndices[dimVal - 1] = dstIndex; - hlfir::Entity dstElement = hlfir::getElementAt( - loc, builder, hlfir::Entity{resultArray}, dstIndices); - hlfir::AssignOp::create(builder, loc, srcElementValue, dstElement); - return {}; - }; - - // Generate the second loop. - mlir::Value bound = - mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); - hlfir::genLoopNestWithReductions(loc, builder, {bound}, - /*reductionInits=*/{}, genAssign2, - /*isUnordered=*/true); + if constexpr (std::is_same_v<Op, hlfir::CShiftOp>) { + // Copy first portion of the array: + // DEST_OFFSET = SIZE(ARRAY,DIM) - SH + // COPY_END1 = SH + // do i=1,COPY_END1 + // result(i + DEST_OFFSET) = array(i) + // end + dstOffset = + mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); + srcOffset = nullptr; + hlfir::genLoopNestWithReductions(loc, builder, {shiftVal}, + /*reductionInits=*/{}, genCopy, + /*isUnordered=*/true); + + // Copy second portion of the array: + // SOURCE_OFFSET = SH + // COPY_END2 = SIZE(ARRAY,DIM) - SH + // do i=1,COPY_END2 + // result(i) = array(i + SOURCE_OFFSET) + // end + mlir::Value bound = + mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); + dstOffset = nullptr; + srcOffset = shiftVal; + hlfir::genLoopNestWithReductions(loc, builder, {bound}, + /*reductionInits=*/{}, genCopy, + /*isUnordered=*/true); + } else { + // Do the copy: + // EXTENT = SIZE(ARRAY,DIM) + // DEST_OFFSET = SH < 0 ? -SH : 0 + // SOURCE_OFFSET = SH < 0 ? 0 : SH + // COPY_END = SH < 0 ? + // (-EXTENT > SH ? 0 : EXTENT + SH) : + // (EXTENT < SH ? 0 : EXTENT - SH) + // do i=1,COPY_END + // result(i + DEST_OFFSET) = array(i + SOURCE_OFFSET) + // end + mlir::arith::IntegerOverflowFlags savedFlags = + builder.getIntegerOverflowFlags(); + builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw); + + mlir::Value zero = builder.createIntegerConstant(loc, calcType, 0); + mlir::Value isNegativeShift = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::slt, shiftVal, zero); + mlir::Value shiftNeg = + mlir::arith::SubIOp::create(builder, loc, zero, shiftVal); + dstOffset = mlir::arith::SelectOp::create(builder, loc, isNegativeShift, + shiftNeg, zero); + srcOffset = mlir::arith::SelectOp::create(builder, loc, isNegativeShift, + zero, shiftVal); + mlir::Value extentNeg = + mlir::arith::SubIOp::create(builder, loc, zero, shiftDimExtent); + mlir::Value extentPlusShift = + mlir::arith::AddIOp::create(builder, loc, shiftDimExtent, shiftVal); + mlir::Value extentNegShiftCmp = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sgt, extentNeg, shiftVal); + mlir::Value negativeShiftBound = mlir::arith::SelectOp::create( + builder, loc, extentNegShiftCmp, zero, extentPlusShift); + mlir::Value extentMinusShift = + mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, shiftVal); + mlir::Value extentShiftCmp = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::slt, shiftDimExtent, + shiftVal); + mlir::Value positiveShiftBound = mlir::arith::SelectOp::create( + builder, loc, extentShiftCmp, zero, extentMinusShift); + mlir::Value copyEnd = mlir::arith::SelectOp::create( + builder, loc, isNegativeShift, negativeShiftBound, + positiveShiftBound); + hlfir::genLoopNestWithReductions(loc, builder, {copyEnd}, + /*reductionInits=*/{}, genCopy, + /*isUnordered=*/true); + + // Do the init: + // INIT_END = EXTENT - COPY_END + // INIT_OFFSET = SH < 0 ? 0 : COPY_END + // do i=1,INIT_END + // result(i + INIT_OFFSET) = BOUNDARY + // end + assert(boundary && "boundary cannot be null"); + mlir::Value initEnd = + mlir::arith::SubIOp::create(builder, loc, shiftDimExtent, copyEnd); + mlir::Value initOffset = mlir::arith::SelectOp::create( + builder, loc, isNegativeShift, zero, copyEnd); + auto genInit = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, + mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 0> { + mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]); + dstIndex = + mlir::arith::AddIOp::create(builder, loc, dstIndex, initOffset); + dstIndices[dimVal - 1] = dstIndex; + hlfir::Entity dstElement = hlfir::getElementAt( + loc, builder, hlfir::Entity{resultArray}, dstIndices); + hlfir::AssignOp::create(builder, loc, boundary, dstElement); + return {}; + }; + hlfir::genLoopNestWithReductions(loc, builder, {initEnd}, + /*reductionInits=*/{}, genInit, + /*isUnordered=*/true); + builder.setIntegerOverflowFlags(savedFlags); + } return {}; }; @@ -1614,6 +2017,10 @@ private: shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); } + if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) + boundaryVal = + selectBoundaryValue(loc, builder, op, boundaryVal, + boundaryIsScalarPred, oneBasedIndices); // If we can fetch the byte stride of the leading dimension, // and the byte size of the element, then we can generate @@ -1635,8 +2042,8 @@ private: } if (array.isSimplyContiguous() || !elemSize || !stride) { - genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/false, - oneBasedIndices); + genDimensionShift(loc, builder, shiftVal, boundaryVal, + /*exposeContiguity=*/false, oneBasedIndices); return {}; } @@ -1644,11 +2051,11 @@ private: builder, loc, mlir::arith::CmpIPredicate::eq, elemSize, stride); builder.genIfOp(loc, {}, isContiguous, /*withElseRegion=*/true) .genThen([&]() { - genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/true, - oneBasedIndices); + genDimensionShift(loc, builder, shiftVal, boundaryVal, + /*exposeContiguity=*/true, oneBasedIndices); }) .genElse([&]() { - genDimensionShift(loc, builder, shiftVal, + genDimensionShift(loc, builder, shiftVal, boundaryVal, /*exposeContiguity=*/false, oneBasedIndices); }); @@ -1671,6 +2078,212 @@ private: } }; +class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> { +public: + using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::CmpCharOp cmp, + mlir::PatternRewriter &rewriter) const override { + + fir::FirOpBuilder builder{rewriter, cmp.getOperation()}; + const mlir::Location &loc = cmp->getLoc(); + + auto toVariable = + [&builder, + &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> { + mlir::Value opnd; + hlfir::AssociateOp associate; + if (mlir::isa<hlfir::ExprType>(val.getType())) { + hlfir::Entity entity{val}; + mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder); + associate = hlfir::genAssociateExpr(loc, builder, entity, + entity.getType(), "", byRefAttr); + opnd = associate.getBase(); + } else { + opnd = val; + } + return {opnd, associate}; + }; + + auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr()); + auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr()); + + hlfir::Entity lhs{lhsOpnd}; + hlfir::Entity rhs{rhsOpnd}; + + auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType()); + unsigned kind = charTy.getFKind(); + + auto bits = builder.getKindMap().getCharacterBitsize(kind); + auto intTy = builder.getIntegerType(bits); + + auto idxTy = builder.getIndexType(); + auto charLen1Ty = + fir::CharacterType::getSingleton(builder.getContext(), kind); + mlir::Type designatorType = + fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy)); + auto idxAttr = builder.getIntegerAttr(idxTy, 0); + + auto genExtractAndConvertToInt = + [&idxAttr, &intTy, &designatorType]( + mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity &charStr, mlir::Value index, mlir::Value length) { + auto singleChr = hlfir::DesignateOp::create( + builder, loc, designatorType, charStr, /*component=*/{}, + /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{index, index}, + /*complexPart=*/std::nullopt, + /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length}, + fir::FortranVariableFlagsAttr{}); + auto chrVal = fir::LoadOp::create(builder, loc, singleChr); + mlir::Value intVal = fir::ExtractValueOp::create( + builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr)); + return intVal; + }; + + mlir::arith::CmpIPredicate predicate = cmp.getPredicate(); + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + + mlir::Value lhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, lhs)); + mlir::Value rhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, rhs)); + + enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight }; + + mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0); + mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1); + mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1); + mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' '); + + auto step = GenCmp::LeftToRight; + auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 1> { + assert(index.size() == 1 && "expected single loop"); + assert(reductionArgs.size() == 1 && "expected single reduction value"); + mlir::Value inRes = reductionArgs[0]; + auto accEQzero = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt); + + mlir::Value res = + builder + .genIfOp(loc, {intTy}, accEQzero, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value offset = + builder.createConvert(loc, idxTy, index[0]); + mlir::Value lhsInt; + mlir::Value rhsInt; + if (step == GenCmp::LeftToRight) { + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else if (step == GenCmp::LeftToBlank) { + // lhsLen > rhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, rhsLen, offset); + + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = blankInt; + } else if (step == GenCmp::BlankToRight) { + // rhsLen > lhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, lhsLen, offset); + + lhsInt = blankInt; + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else { + llvm_unreachable( + "unknown compare step for CmpCharOp lowering"); + } + + mlir::Value newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ult, + lhsInt, rhsInt), + negOneInt, inRes); + newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ugt, + lhsInt, rhsInt), + oneInt, newVal); + fir::ResultOp::create(builder, loc, newVal); + }) + .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); }) + .getResults()[0]; + + return {res}; + }; + + // First generate comparison of two strings for the legth of the shorter + // one. + mlir::Value minLen = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen), + lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> loopOut = + hlfir::genLoopNestWithReductions(loc, builder, {minLen}, + /*reductionInits=*/{zeroInt}, genCmp, + /*isUnordered=*/false); + mlir::Value partRes = loopOut[0]; + + auto lhsLonger = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen); + mlir::Value tempRes = + builder + .genIfOp(loc, {intTy}, lhsLonger, + /*withElseRegion=*/true) + .genThen([&]() { + // If left is the longer string generate compare left to blank. + step = GenCmp::LeftToBlank; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .genElse([&]() { + // If right is the longer string generate compare blank to + // right. + step = GenCmp::BlankToRight; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen); + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .getResults()[0]; + if (lhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, lhsAssociate); + if (rhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, rhsAssociate); + + auto finalCmpResult = + mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt); + rewriter.replaceOp(cmp, finalCmpResult); + return mlir::success(); + } +}; + template <typename Op> class MatmulConversion : public mlir::OpRewritePattern<Op> { public: @@ -2339,9 +2952,10 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert<TransposeAsElementalConversion>(context); patterns.insert<ReductionConversion<hlfir::SumOp>>(context); - patterns.insert<CShiftConversion>(context); + patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); + patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); + patterns.insert<CmpCharOpConversion>(context); patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); - patterns.insert<ReductionConversion<hlfir::CountOp>>(context); patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); patterns.insert<ReductionConversion<hlfir::AllOp>>(context); diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index e5fd19d..c9aff59 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -271,8 +271,6 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var, mlir::Value extent = val; mlir::Value upperbound = mlir::arith::SubIOp::create(builder, loc, extent, one); - upperbound = mlir::arith::AddIOp::create(builder, loc, lowerbound, - upperbound); mlir::Value stride = one; if (strideIncludeLowerExtent) { stride = cummulativeExtent; @@ -591,7 +589,8 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( hlfir::AssignOp::create(firBuilder, loc, initVal, declareOp.getBase()); } else { - for (auto ext : seqTy.getShape()) { + // Generate loop nest from slowest to fastest running dimension + for (auto ext : llvm::reverse(seqTy.getShape())) { auto lb = firBuilder.createIntegerConstant(loc, idxTy, 0); auto ub = firBuilder.createIntegerConstant(loc, idxTy, ext - 1); auto step = firBuilder.createIntegerConstant(loc, idxTy, 1); @@ -614,6 +613,11 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); if (fir::isa_trivial(innerTy)) { retVal = getDeclareOpForType(unwrappedTy).getBase(); + mlir::Value allocatedScalar = + fir::AllocMemOp::create(builder, loc, innerTy); + mlir::Value firClass = + fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar); + fir::StoreOp::create(builder, loc, firClass, retVal); } else if (mlir::isa<fir::SequenceType>(innerTy)) { hlfir::Entity source = hlfir::Entity{var}; auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source); diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp new file mode 100644 index 0000000..8b99913 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp @@ -0,0 +1,159 @@ +//===- AutomapToTargetData.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/DirectivesCommon.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" + +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/Frontend/OpenMP/OMPConstants.h" + +namespace flangomp { +#define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +using namespace mlir; + +namespace { +class AutomapToTargetDataPass + : public flangomp::impl::AutomapToTargetDataPassBase< + AutomapToTargetDataPass> { + + // Returns true if the variable has a dynamic size and therefore requires + // bounds operations to describe its extents. + inline bool needsBoundsOps(mlir::Value var) { + assert(mlir::isa<mlir::omp::PointerLikeType>(var.getType()) && + "only pointer like types expected"); + mlir::Type t = fir::unwrapRefType(var.getType()); + if (mlir::Type inner = fir::dyn_cast_ptrOrBoxEleTy(t)) + return fir::hasDynamicSize(inner); + return fir::hasDynamicSize(t); + } + + // Generate MapBoundsOp operations for the variable if required. + inline void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value var, + llvm::SmallVectorImpl<mlir::Value> &boundsOps) { + mlir::Location loc = var.getLoc(); + fir::factory::AddrAndBoundsInfo info = + fir::factory::getDataOperandBaseAddr(builder, var, + /*isOptional=*/false, loc); + fir::ExtendedValue exv = + hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{info.addr}, + /*contiguousHint=*/true) + .first; + llvm::SmallVector<mlir::Value> tmp = + fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>( + builder, info, exv, /*dataExvIsAssumedSize=*/false, loc); + llvm::append_range(boundsOps, tmp); + } + + void findRelatedAllocmemFreemem(fir::AddrOfOp addressOfOp, + llvm::DenseSet<fir::StoreOp> &allocmems, + llvm::DenseSet<fir::LoadOp> &freemems) { + assert(addressOfOp->hasOneUse() && "op must have single use"); + + auto declaredRef = + cast<hlfir::DeclareOp>(*addressOfOp->getUsers().begin())->getResult(0); + + for (Operation *refUser : declaredRef.getUsers()) { + if (auto storeOp = dyn_cast<fir::StoreOp>(refUser)) + if (auto emboxOp = storeOp.getValue().getDefiningOp<fir::EmboxOp>()) + if (auto allocmemOp = + emboxOp.getOperand(0).getDefiningOp<fir::AllocMemOp>()) + allocmems.insert(storeOp); + + if (auto loadOp = dyn_cast<fir::LoadOp>(refUser)) + for (Operation *loadUser : loadOp.getResult().getUsers()) + if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(loadUser)) + for (Operation *boxAddrUser : boxAddrOp.getResult().getUsers()) + if (auto freememOp = dyn_cast<fir::FreeMemOp>(boxAddrUser)) + freemems.insert(loadOp); + } + } + + void runOnOperation() override { + ModuleOp module = getOperation()->getParentOfType<ModuleOp>(); + if (!module) + module = dyn_cast<ModuleOp>(getOperation()); + if (!module) + return; + + // Build FIR builder for helper utilities. + fir::KindMapping kindMap = fir::getKindMapping(module); + fir::FirOpBuilder builder{module, std::move(kindMap)}; + + // Collect global variables with AUTOMAP flag. + llvm::DenseSet<fir::GlobalOp> automapGlobals; + module.walk([&](fir::GlobalOp globalOp) { + if (auto iface = + dyn_cast<omp::DeclareTargetInterface>(globalOp.getOperation())) + if (iface.isDeclareTarget() && iface.getDeclareTargetAutomap() && + iface.getDeclareTargetDeviceType() != + omp::DeclareTargetDeviceType::host) + automapGlobals.insert(globalOp); + }); + + auto addMapInfo = [&](auto globalOp, auto memOp) { + builder.setInsertionPointAfter(memOp); + SmallVector<Value> bounds; + if (needsBoundsOps(memOp.getMemref())) + genBoundsOps(builder, memOp.getMemref(), bounds); + + omp::TargetEnterExitUpdateDataOperands clauses; + mlir::omp::MapInfoOp mapInfo = mlir::omp::MapInfoOp::create( + builder, memOp.getLoc(), memOp.getMemref().getType(), + memOp.getMemref(), + TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())), + builder.getIntegerAttr( + builder.getIntegerType(64, false), + static_cast<unsigned>( + isa<fir::StoreOp>(memOp) + ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)), + builder.getAttr<omp::VariableCaptureKindAttr>( + omp::VariableCaptureKind::ByCopy), + /*var_ptr_ptr=*/mlir::Value{}, + /*members=*/SmallVector<Value>{}, + /*members_index=*/ArrayAttr{}, bounds, + /*mapperId=*/mlir::FlatSymbolRefAttr(), globalOp.getSymNameAttr(), + builder.getBoolAttr(false)); + clauses.mapVars.push_back(mapInfo); + isa<fir::StoreOp>(memOp) + ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses) + : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses); + }; + + for (fir::GlobalOp globalOp : automapGlobals) { + if (auto uses = globalOp.getSymbolUses(module.getOperation())) { + llvm::DenseSet<fir::StoreOp> allocmemStores; + llvm::DenseSet<fir::LoadOp> freememLoads; + for (auto &x : *uses) + if (auto addrOp = dyn_cast<fir::AddrOfOp>(x.getUser())) + findRelatedAllocmemFreemem(addrOp, allocmemStores, freememLoads); + + for (auto storeOp : allocmemStores) + addMapInfo(globalOp, storeOp); + + for (auto loadOp : freememLoads) + addMapInfo(globalOp, loadOp); + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index e315433..e0aebd0 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -1,6 +1,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_flang_library(FlangOpenMPTransforms + AutomapToTargetData.cpp DoConcurrentConversion.cpp FunctionFiltering.cpp GenericLoopConversion.cpp @@ -9,6 +10,7 @@ add_flang_library(FlangOpenMPTransforms MarkDeclareTarget.cpp LowerWorkshare.cpp LowerNontemporal.cpp + SimdOnly.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 2b3ac16..c928b76 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -173,9 +173,11 @@ public: DoConcurrentConversion( mlir::MLIRContext *context, bool mapToDevice, - llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip) + llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip, + mlir::SymbolTable &moduleSymbolTable) : OpConversionPattern(context), mapToDevice(mapToDevice), - concurrentLoopsToSkip(concurrentLoopsToSkip) {} + concurrentLoopsToSkip(concurrentLoopsToSkip), + moduleSymbolTable(moduleSymbolTable) {} mlir::LogicalResult matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor, @@ -332,8 +334,8 @@ private: loop.getLocalVars(), loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(), loop.getRegionLocalArgs())) { - auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom< - fir::LocalitySpecifierOp>(loop, sym); + auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>( + sym.getLeafReference()); if (localizer.getLocalitySpecifierType() == fir::LocalitySpecifierType::LocalInit) TODO(localizer.getLoc(), @@ -352,6 +354,8 @@ private: cloneFIRRegionToOMP(localizer.getDeallocRegion(), privatizer.getDeallocRegion()); + moduleSymbolTable.insert(privatizer); + wsloopClauseOps.privateVars.push_back(op); wsloopClauseOps.privateSyms.push_back( mlir::SymbolRefAttr::get(privatizer)); @@ -362,28 +366,34 @@ private: loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(), loop.getRegionReduceArgs())) { - auto firReducer = - mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>( - loop, sym); + auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>( + sym.getLeafReference()); mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(firReducer); - - auto ompReducer = mlir::omp::DeclareReductionOp::create( - rewriter, firReducer.getLoc(), - sym.getLeafReference().str() + ".omp", - firReducer.getTypeAttr().getValue()); - - cloneFIRRegionToOMP(firReducer.getAllocRegion(), - ompReducer.getAllocRegion()); - cloneFIRRegionToOMP(firReducer.getInitializerRegion(), - ompReducer.getInitializerRegion()); - cloneFIRRegionToOMP(firReducer.getReductionRegion(), - ompReducer.getReductionRegion()); - cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), - ompReducer.getAtomicReductionRegion()); - cloneFIRRegionToOMP(firReducer.getCleanupRegion(), - ompReducer.getCleanupRegion()); + std::string ompReducerName = sym.getLeafReference().str() + ".omp"; + + auto ompReducer = + moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>( + rewriter.getStringAttr(ompReducerName)); + + if (!ompReducer) { + ompReducer = mlir::omp::DeclareReductionOp::create( + rewriter, firReducer.getLoc(), ompReducerName, + firReducer.getTypeAttr().getValue()); + + cloneFIRRegionToOMP(firReducer.getAllocRegion(), + ompReducer.getAllocRegion()); + cloneFIRRegionToOMP(firReducer.getInitializerRegion(), + ompReducer.getInitializerRegion()); + cloneFIRRegionToOMP(firReducer.getReductionRegion(), + ompReducer.getReductionRegion()); + cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), + ompReducer.getAtomicReductionRegion()); + cloneFIRRegionToOMP(firReducer.getCleanupRegion(), + ompReducer.getCleanupRegion()); + moduleSymbolTable.insert(ompReducer); + } wsloopClauseOps.reductionVars.push_back(op); wsloopClauseOps.reductionByref.push_back(byRef); @@ -431,6 +441,7 @@ private: bool mapToDevice; llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip; + mlir::SymbolTable &moduleSymbolTable; }; class DoConcurrentConversionPass @@ -444,12 +455,9 @@ public: : DoConcurrentConversionPassBase(options) {} void runOnOperation() override { - mlir::func::FuncOp func = getOperation(); - - if (func.isDeclaration()) - return; - + mlir::ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); + mlir::SymbolTable moduleSymbolTable(module); if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host && mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) { @@ -463,7 +471,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert<DoConcurrentConversion>( context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device, - concurrentLoopsToSkip); + concurrentLoopsToSkip, moduleSymbolTable); mlir::ConversionTarget target(*context); target.addDynamicallyLegalOp<fir::DoConcurrentOp>( [&](fir::DoConcurrentOp op) { @@ -472,8 +480,8 @@ public: target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); - if (mlir::failed(mlir::applyFullConversion(getOperation(), target, - std::move(patterns)))) { + if (mlir::failed( + mlir::applyFullConversion(module, target, std::move(patterns)))) { signalPassFailure(); } } diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp index ae5c0ec..3031bb5 100644 --- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp +++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp @@ -95,8 +95,9 @@ public: return WalkResult::skip(); } if (declareTargetOp) - declareTargetOp.setDeclareTarget(declareType, - omp::DeclareTargetCaptureClause::to); + declareTargetOp.setDeclareTarget( + declareType, omp::DeclareTargetCaptureClause::to, + declareTargetOp.getDeclareTargetAutomap()); } return WalkResult::advance(); }); diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 970f7d7..3032857 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -53,6 +53,7 @@ class MapsForPrivatizedSymbolsPass : public flangomp::impl::MapsForPrivatizedSymbolsPassBase< MapsForPrivatizedSymbolsPass> { + // TODO Use `createMapInfoOp` from `flang/Utils/OpenMP.h`. omp::MapInfoOp createMapInfo(Location loc, Value var, fir::FirOpBuilder &builder) { // Check if a value of type `type` can be passed to the kernel by value. diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp index a7ffd5f..0b0e6bd 100644 --- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp +++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp @@ -33,7 +33,7 @@ class MarkDeclareTargetPass void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy, mlir::omp::DeclareTargetCaptureClause parentCapClause, - mlir::Operation *currOp, + bool parentAutomap, mlir::Operation *currOp, llvm::SmallPtrSet<mlir::Operation *, 16> visited) { if (visited.contains(currOp)) return; @@ -57,13 +57,16 @@ class MarkDeclareTargetPass currentDt != mlir::omp::DeclareTargetDeviceType::any) { current.setDeclareTarget( mlir::omp::DeclareTargetDeviceType::any, - current.getDeclareTargetCaptureClause()); + current.getDeclareTargetCaptureClause(), + current.getDeclareTargetAutomap()); } } else { - current.setDeclareTarget(parentDevTy, parentCapClause); + current.setDeclareTarget(parentDevTy, parentCapClause, + parentAutomap); } - markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited); + markNestedFuncs(parentDevTy, parentCapClause, parentAutomap, + currFOp, visited); } } } @@ -81,7 +84,8 @@ class MarkDeclareTargetPass llvm::SmallPtrSet<mlir::Operation *, 16> visited; markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(), declareTargetOp.getDeclareTargetCaptureClause(), - functionOp, visited); + declareTargetOp.getDeclareTargetAutomap(), functionOp, + visited); } } @@ -92,9 +96,10 @@ class MarkDeclareTargetPass // the contents of the device clause getOperation()->walk([&](mlir::omp::TargetOp tarOp) { llvm::SmallPtrSet<mlir::Operation *, 16> visited; - markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost, - mlir::omp::DeclareTargetCaptureClause::to, tarOp, - visited); + markNestedFuncs( + /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost, + /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to, + /*parentAutomap=*/false, tarOp, visited); }); } }; diff --git a/flang/lib/Optimizer/OpenMP/SimdOnly.cpp b/flang/lib/Optimizer/OpenMP/SimdOnly.cpp new file mode 100644 index 0000000..4a559d2 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/SimdOnly.cpp @@ -0,0 +1,209 @@ +//===-- SimdOnly.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace flangomp { +#define GEN_PASS_DEF_SIMDONLYPASS +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +namespace { + +#define DEBUG_TYPE "omp-simd-only-pass" + +/// Rewrite and remove OpenMP operations left after the parse tree rewriting for +/// -fopenmp-simd is done. If possible, OpenMP constructs should be rewritten at +/// the parse tree stage. This pass is supposed to only handle complexities +/// around untangling composite simd constructs, and perform the necessary +/// cleanup. +class SimdOnlyConversionPattern : public mlir::RewritePattern { +public: + SimdOnlyConversionPattern(mlir::MLIRContext *ctx) + : mlir::RewritePattern(MatchAnyOpTypeTag{}, 1, ctx) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (op->getDialect()->getNamespace() != + mlir::omp::OpenMPDialect::getDialectNamespace()) + return rewriter.notifyMatchFailure(op, "Not an OpenMP op"); + + if (auto simdOp = mlir::dyn_cast<mlir::omp::SimdOp>(op)) { + // Remove the composite attr given that the op will no longer be composite + if (simdOp.isComposite()) { + simdOp.setComposite(false); + return mlir::success(); + } + + return rewriter.notifyMatchFailure(op, "Op is a plain SimdOp"); + } + + if (op->getParentOfType<mlir::omp::SimdOp>() && + (mlir::isa<mlir::omp::YieldOp>(op) || + mlir::isa<mlir::omp::ScanOp>(op) || + mlir::isa<mlir::omp::LoopNestOp>(op) || + mlir::isa<mlir::omp::TerminatorOp>(op))) + return rewriter.notifyMatchFailure(op, "Op is part of a simd construct"); + + if (!mlir::isa<mlir::func::FuncOp>(op->getParentOp()) && + (mlir::isa<mlir::omp::TerminatorOp>(op) || + mlir::isa<mlir::omp::YieldOp>(op))) + return rewriter.notifyMatchFailure(op, + "Non top-level yield or terminator"); + + LLVM_DEBUG(llvm::dbgs() << "SimdOnlyPass matched OpenMP op:\n"); + LLVM_DEBUG(op->dump()); + + auto eraseUnlessUsedBySimd = [&](mlir::Operation *ompOp, + mlir::StringAttr name) { + if (auto uses = + mlir::SymbolTable::getSymbolUses(name, op->getParentOp())) { + for (auto &use : *uses) + if (mlir::isa<mlir::omp::SimdOp>(use.getUser())) + return rewriter.notifyMatchFailure(op, + "Op used by a simd construct"); + } + rewriter.eraseOp(ompOp); + return mlir::success(); + }; + + if (auto ompOp = mlir::dyn_cast<mlir::omp::PrivateClauseOp>(op)) + return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr()); + if (auto ompOp = mlir::dyn_cast<mlir::omp::DeclareReductionOp>(op)) + return eraseUnlessUsedBySimd(ompOp, ompOp.getSymNameAttr()); + + // Might be left over from rewriting composite simd with target map + if (mlir::isa<mlir::omp::MapBoundsOp>(op)) { + rewriter.eraseOp(op); + return mlir::success(); + } + if (auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(op)) { + rewriter.replaceOp(mapInfoOp, {mapInfoOp.getVarPtr()}); + return mlir::success(); + } + + // Might be leftover after parse tree rewriting + if (auto threadPrivateOp = mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op)) { + rewriter.replaceOp(threadPrivateOp, {threadPrivateOp.getSymAddr()}); + return mlir::success(); + } + + fir::FirOpBuilder builder(rewriter, op); + mlir::Location loc = op->getLoc(); + + auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool { + if (!ompOp) + return false; + + assert("OpenMP operation has one region" && ompOp->getNumRegions() == 1); + + llvm::SmallVector<std::pair<mlir::Value, mlir::BlockArgument>> + blockArgsPairs; + if (auto iface = + mlir::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(op)) { + iface.getBlockArgsPairs(blockArgsPairs); + for (auto [value, argument] : blockArgsPairs) + rewriter.replaceAllUsesWith(argument, value); + } + + if (ompOp->getRegion(0).getBlocks().size() == 1) { + auto &block = *ompOp->getRegion(0).getBlocks().begin(); + // This block is about to be removed so any arguments should have been + // replaced by now. + block.eraseArguments(0, block.getNumArguments()); + if (auto terminatorOp = + mlir::dyn_cast<mlir::omp::TerminatorOp>(block.back())) { + rewriter.eraseOp(terminatorOp); + } + rewriter.inlineBlockBefore(&block, ompOp, {}); + } else { + // When dealing with multi-block regions we need to fix up the control + // flow + auto *origBlock = ompOp->getBlock(); + auto *newBlock = rewriter.splitBlock(origBlock, ompOp->getIterator()); + auto *innerFrontBlock = &ompOp->getRegion(0).getBlocks().front(); + builder.setInsertionPointToEnd(origBlock); + mlir::cf::BranchOp::create(builder, loc, innerFrontBlock); + // We are no longer passing any arguments to the first block in the + // region, so this should be safe to erase. + innerFrontBlock->eraseArguments(0, innerFrontBlock->getNumArguments()); + + for (auto &innerBlock : ompOp->getRegion(0).getBlocks()) { + // Remove now-unused block arguments + for (auto arg : innerBlock.getArguments()) { + if (arg.getUses().empty()) + innerBlock.eraseArgument(arg.getArgNumber()); + } + if (auto terminatorOp = + mlir::dyn_cast<mlir::omp::TerminatorOp>(innerBlock.back())) { + builder.setInsertionPointToEnd(&innerBlock); + mlir::cf::BranchOp::create(builder, loc, newBlock); + rewriter.eraseOp(terminatorOp); + } + } + + rewriter.inlineRegionBefore(ompOp->getRegion(0), newBlock); + } + + rewriter.eraseOp(op); + return true; + }; + + // Remove ops that will be surrounding simd once a composite simd construct + // goes through the codegen stage. All of the other ones should have alredy + // been removed in the parse tree rewriting stage. + if (inlineSimpleOp(mlir::dyn_cast<mlir::omp::TeamsOp>(op)) || + inlineSimpleOp(mlir::dyn_cast<mlir::omp::ParallelOp>(op)) || + inlineSimpleOp(mlir::dyn_cast<mlir::omp::TargetOp>(op)) || + inlineSimpleOp(mlir::dyn_cast<mlir::omp::WsloopOp>(op)) || + inlineSimpleOp(mlir::dyn_cast<mlir::omp::DistributeOp>(op))) + return mlir::success(); + + op->emitOpError("left unhandled after SimdOnly pass."); + return mlir::failure(); + } +}; + +class SimdOnlyPass : public flangomp::impl::SimdOnlyPassBase<SimdOnlyPass> { + +public: + SimdOnlyPass() = default; + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert<SimdOnlyConversionPattern>(context); + + mlir::GreedyRewriteConfig config; + // Prevent the pattern driver from merging blocks. + config.setRegionSimplificationLevel( + mlir::GreedySimplifyRegionLevel::Disabled); + + if (mlir::failed( + mlir::applyPatternsGreedily(module, std::move(patterns), config))) { + mlir::emitError(module.getLoc(), "Error in SimdOnly conversion pass"); + signalPassFailure(); + } + } +}; + +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index ca8e8206..7c2777b 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -14,7 +14,7 @@ /// Force setting the no-alias attribute on fuction arguments when possible. static llvm::cl::opt<bool> forceNoAlias("force-no-alias", llvm::cl::Hidden, - llvm::cl::init(false)); + llvm::cl::init(true)); namespace fir { @@ -217,9 +217,6 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(fir::createSimplifyFIROperations( {/*preferInlineImplementation=*/pc.OptLevel.isOptimizingForSpeed()})); - if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags) - pm.addPass(fir::createAddAliasTags()); - addNestedPassToAllTopLevelOperations<PassConstructor>( pm, fir::createStackReclaim); // convert control flow to CFG form @@ -242,7 +239,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, /// \param pm - MLIR pass manager that will hold the pipeline definition /// \param optLevel - optimization level used for creating FIR optimization /// passes pipeline -void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, +void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, + EnableOpenMP enableOpenMP, llvm::OptimizationLevel optLevel) { if (optLevel.isOptimizingForSpeed()) { addCanonicalizerPassWithoutRegionSimplification(pm); @@ -294,8 +292,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, addNestedPassToAllTopLevelOperations<PassConstructor>( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP) + if (enableOpenMP != EnableOpenMP::None) pm.addPass(flangomp::createLowerWorkshare()); + if (enableOpenMP == EnableOpenMP::Simd) + pm.addPass(flangomp::createSimdOnlyPass()); } /// Create a pass pipeline for handling certain OpenMP transformations needed @@ -316,13 +316,13 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm, pm.addPass(flangomp::createDoConcurrentConversionPass( opts.doConcurrentMappingKind == DoConcurrentMappingKind::DCMK_Device)); - // The MapsForPrivatizedSymbols pass needs to run before - // MapInfoFinalizationPass because the former creates new - // MapInfoOp instances, typically for descriptors. - // MapInfoFinalizationPass adds MapInfoOp instances for the descriptors - // underlying data which is necessary to access the data on the offload - // target device. + // The MapsForPrivatizedSymbols and AutomapToTargetDataPass pass need to run + // before MapInfoFinalizationPass because they create new MapInfoOp + // instances, typically for descriptors. MapInfoFinalizationPass adds + // MapInfoOp instances for the descriptors underlying data which is necessary + // to access the data on the offload target device. pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass()); + pm.addPass(flangomp::createAutomapToTargetDataPass()); pm.addPass(flangomp::createMapInfoFinalizationPass()); pm.addPass(flangomp::createMarkDeclareTargetPass()); pm.addPass(flangomp::createGenericLoopConversionPass()); @@ -342,6 +342,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config, llvm::StringRef inputFilename) { fir::addBoxedProcedurePass(pm); + if (config.OptLevel.isOptimizingForSpeed() && config.AliasAnalysis && + !disableFirAliasTags && !useOldAliasTags) + pm.addPass(fir::createAddAliasTags()); addNestedPassToAllTopLevelOperations<PassConstructor>( pm, fir::createAbstractResultOpt); addPassToGPUModuleOperations<PassConstructor>(pm, @@ -396,7 +399,12 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, void createMLIRToLLVMPassPipeline(mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig &config, llvm::StringRef inputFilename) { - fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel); + fir::EnableOpenMP enableOpenMP = fir::EnableOpenMP::None; + if (config.EnableOpenMP) + enableOpenMP = fir::EnableOpenMP::Full; + if (config.EnableOpenMPSimd) + enableOpenMP = fir::EnableOpenMP::Simd; + fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP, config.OptLevel); // Add default optimizer pass pipeline. fir::createDefaultFIROptimizerPassPipeline(pm, config); diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 5d663e2..c71642c 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -50,3 +50,74 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault( return componentInfo.getLowerBounds(); return std::nullopt; } + +mlir::LLVM::ConstantOp +fir::genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); +} + +mlir::Value +fir::computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout) { + llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); + unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); + std::int64_t distance = llvm::alignTo(size, alignment); + return fir::genConstantIndex(loc, idxTy, rewriter, distance); +} + +mlir::Value +fir::genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy, + mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter) { + auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy); + fir::SequenceType::Extent constSize = 1; + if (seqTy) { + int constRows = seqTy.getConstantRows(); + const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); + if (constRows != static_cast<int>(shape.size())) { + for (auto extent : shape) { + if (constRows-- > 0) + continue; + if (extent != fir::SequenceType::getUnknownExtent()) + constSize *= extent; + } + } + } + + if (constSize != 1) { + mlir::Value constVal{ + fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()}; + return constVal; + } + return nullptr; +} + +mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold) { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!mlir::isa<mlir::IntegerType>(valTy)) + valTy = converter.convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (fold) { + if (toSize < fromSize) + return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val); + if (toSize > fromSize) + return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); + } else { + if (toSize < fromSize) + return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); + if (toSize > fromSize) + return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); + } + return val; +} diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp index f1c66a5..430ef62 100644 --- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp @@ -117,10 +117,7 @@ public: op.getValue()); return success(); } - rewriter.startOpModification(op->getParentOp()); - op.getResult().replaceAllUsesWith(op.getValue()); - rewriter.finalizeOpModification(op->getParentOp()); - rewriter.eraseOp(op); + rewriter.replaceOp(op, op.getValue()); } return success(); } diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index b032767..061a7d2 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Visitors.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" #include <optional> @@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp, } static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { - for (auto &bodyOp : block->getOperations()) { + for (auto &bodyOp : llvm::make_early_inc_range(block->getOperations())) { if (isa<fir::LoadOp>(bodyOp)) rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); - if (isa<fir::StoreOp>(bodyOp)) + else if (isa<fir::StoreOp>(bodyOp)) rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); } } @@ -476,6 +476,8 @@ public: loop.dump();); LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = functionAnalysis.getChildLoopAnalysis(loop); + if (!loopAnalysis.canPromoteToAffine()) + return rewriter.notifyMatchFailure(loop, "cannot promote to affine"); auto &loopOps = loop.getBody()->getOperations(); auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator()); auto results = resultOp.getOperands(); @@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { public: using OpRewritePattern::OpRewritePattern; AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) - : OpRewritePattern(context) {} + : OpRewritePattern(context), functionAnalysis(afa) {} llvm::LogicalResult matchAndRewrite(fir::IfOp op, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; op.dump();); + if (!functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()) + return rewriter.notifyMatchFailure(op, "cannot promote to affine"); auto &ifOps = op.getThenRegion().front().getOperations(); auto affineCondition = AffineIfCondition(op.getCondition()); if (!affineCondition.hasIntegerSet()) { @@ -611,6 +615,8 @@ public: rewriter.replaceOp(op, affineIf.getOperation()->getResults()); return success(); } + + AffineFunctionAnalysis &functionAnalysis; }; /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases @@ -627,28 +633,11 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert<AffineIfConversion>(context, functionAnalysis); patterns.insert<AffineLoopConversion>(context, functionAnalysis); - mlir::ConversionTarget target = *context; - target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect, - mlir::scf::SCFDialect, mlir::arith::ArithDialect, - mlir::func::FuncDialect>(); - target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { - return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); - }); - target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( - fir::DoLoopOp op) { - return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); - }); - LLVM_DEBUG(llvm::dbgs() << "AffineDialectPromotion: running promotion on: \n"; function.print(llvm::dbgs());); // apply the patterns - if (mlir::failed(mlir::applyPartialConversion(function, target, - std::move(patterns)))) { - mlir::emitError(mlir::UnknownLoc::get(context), - "error in converting to affine dialect\n"); - signalPassFailure(); - } + walkAndApplyPatterns(function, std::move(patterns)); } }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp index 247ba95..ed9a2ae 100644 --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -1264,7 +1264,6 @@ public: auto lhsEltRefType = toRefType(update.getMerge().getType()); auto [_, lhsLoadResult] = materializeAssignment( loc, rewriter, update, assignElement, lhsEltRefType); - update.replaceAllUsesWith(lhsLoadResult); rewriter.replaceOp(update, lhsLoadResult); return mlir::success(); } @@ -1287,7 +1286,6 @@ public: auto lhsEltRefType = modify.getResult(0).getType(); auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( loc, rewriter, modify, assignElement, lhsEltRefType); - modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); return mlir::success(); } @@ -1339,7 +1337,6 @@ public: // This array_access is associated with an array_amend and there is a // conflict. Make a copy to store into. auto result = referenceToClone(loc, rewriter, access); - access.replaceAllUsesWith(result); rewriter.replaceOp(access, result); return mlir::success(); } diff --git a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp b/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp index 5e910f7..6e04c71 100644 --- a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp +++ b/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp @@ -38,6 +38,15 @@ using namespace Fortran::runtime::cuda; namespace { +static bool isAssumedSize(mlir::ValueRange shape) { + if (shape.size() != 1) + return false; + std::optional<std::int64_t> val = fir::getIntIfConstant(shape[0]); + if (val && *val == -1) + return true; + return false; +} + struct CUFComputeSharedMemoryOffsetsAndSize : public fir::impl::CUFComputeSharedMemoryOffsetsAndSizeBase< CUFComputeSharedMemoryOffsetsAndSize> { @@ -82,12 +91,12 @@ struct CUFComputeSharedMemoryOffsetsAndSize alignment = std::max(alignment, align); uint64_t tySize = dl->getTypeSize(ty); ++nbDynamicSharedVariables; - if (crtDynOffset) { - sharedOp.getOffsetMutable().assign( - builder.createConvert(loc, i32Ty, crtDynOffset)); - } else { + if (isAssumedSize(sharedOp.getShape()) || !crtDynOffset) { mlir::Value zero = builder.createIntegerConstant(loc, i32Ty, 0); sharedOp.getOffsetMutable().assign(zero); + } else { + sharedOp.getOffsetMutable().assign( + builder.createConvert(loc, i32Ty, crtDynOffset)); } mlir::Value dynSize = diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp index 5dcb54e..d038c46 100644 --- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp +++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp @@ -178,8 +178,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, - elements, dataLocation, rank, /*allocated=*/nullptr, - /*associated=*/nullptr); + dataLocation, rank, /*allocated=*/nullptr, + /*associated=*/nullptr, elements); } addOp(llvm::dwarf::DW_OP_push_object_address, {}); @@ -255,8 +255,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType( return mlir::LLVM::DICompositeTypeAttr::get( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements, - dataLocation, /*rank=*/nullptr, allocated, associated); + mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, + dataLocation, /*rank=*/nullptr, allocated, associated, elements); } std::pair<std::uint64_t, unsigned short> @@ -389,8 +389,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( context, recId, /*isRecSelf=*/true, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); DerivedTypeCache::ActiveLevels nestedRecursions = derivedTypeCache.startTranslating(Ty, placeHolder); @@ -429,8 +429,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, convertType(seqTy.getEleTy(), fileAttr, scope, declOp), mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, - arrayElements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, arrayElements); } else elemTy = convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr); offset = llvm::alignTo(offset, byteAlign); @@ -448,8 +448,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( context, recId, /*isRecSelf=*/false, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, sourceName.name), fileAttr, line, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); derivedTypeCache.finalize(Ty, finalAttr, std::move(nestedRecursions)); @@ -490,8 +490,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType( context, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); derivedTypeCache.finalize(Ty, typeAttr, std::move(nestedRecursions)); return typeAttr; } @@ -554,9 +554,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertSequenceType( return mlir::LLVM::DICompositeTypeAttr::get( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements, + mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr, - /*associated=*/nullptr); + /*associated=*/nullptr, elements); } mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType( @@ -587,9 +587,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType( context, llvm::dwarf::DW_TAG_array_type, mlir::StringAttr::get(context, name), /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0, elements, + mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr, - /*associated=*/nullptr); + /*associated=*/nullptr, elements); } mlir::LLVM::DITypeAttr DebugTypeGenerator::convertCharacterType( diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp index 2fcff87..031a5ae 100644 --- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp +++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp @@ -76,12 +76,49 @@ void ExternalNameConversionPass::runOnOperation() { auto *context = &getContext(); llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings; + mlir::SymbolTable symbolTable(op); auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) { auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>( mlir::SymbolTable::getSymbolAttrName()); auto deconstructedName = fir::NameUniquer::deconstruct(symName); if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { + // Check if this is a private function that would conflict with a common + // block and get its mangled name. + if (auto funcOp = llvm::dyn_cast<mlir::func::FuncOp>(funcOrGlobal)) { + if (funcOp.isPrivate()) { + std::string mangledName = + mangleExternalName(deconstructedName, appendUnderscoreOpt); + auto mod = funcOp->getParentOfType<mlir::ModuleOp>(); + bool hasConflictingCommonBlock = false; + + // Check if any existing global has the same mangled name. + if (symbolTable.lookup<fir::GlobalOp>(mangledName)) + hasConflictingCommonBlock = true; + + // Skip externalization if the function has a conflicting common block + // and is not directly called (i.e. procedure pointers or type + // specifications) + if (hasConflictingCommonBlock) { + bool isDirectlyCalled = false; + std::optional<SymbolTable::UseRange> uses = + funcOp.getSymbolUses(mod); + if (uses.has_value()) { + for (auto use : *uses) { + mlir::Operation *user = use.getUser(); + if (mlir::isa<fir::CallOp>(user) || + mlir::isa<mlir::func::CallOp>(user)) { + isDirectlyCalled = true; + break; + } + } + } + if (!isDirectlyCalled) + return; + } + } + } + auto newName = mangleExternalName(deconstructedName, appendUnderscoreOpt); auto newAttr = mlir::StringAttr::get(context, newName); mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr); diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 1902757..70d6ebb 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -9,36 +9,34 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace fir { #define GEN_PASS_DEF_FIRTOSCFPASS #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir -using namespace fir; -using namespace mlir; - namespace { class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> { public: void runOnOperation() override; }; -struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> { +struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern; - LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp, - PatternRewriter &rewriter) const override { - auto loc = doLoopOp.getLoc(); + mlir::LogicalResult + matchAndRewrite(fir::DoLoopOp doLoopOp, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = doLoopOp.getLoc(); bool hasFinalValue = doLoopOp.getFinalValue().has_value(); // Get loop values from the DoLoopOp - auto low = doLoopOp.getLowerBound(); - auto high = doLoopOp.getUpperBound(); + mlir::Value low = doLoopOp.getLowerBound(); + mlir::Value high = doLoopOp.getUpperBound(); assert(low && high && "must be a Value"); - auto step = doLoopOp.getStep(); - llvm::SmallVector<Value> iterArgs; + mlir::Value step = doLoopOp.getStep(); + mlir::SmallVector<mlir::Value> iterArgs; if (hasFinalValue) iterArgs.push_back(low); iterArgs.append(doLoopOp.getIterOperands().begin(), @@ -49,31 +47,33 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> { // must be a positive value. // For easier conversion, we calculate the trip count and use a canonical // induction variable. - auto diff = arith::SubIOp::create(rewriter, loc, high, low); - auto distance = arith::AddIOp::create(rewriter, loc, diff, step); - auto tripCount = arith::DivSIOp::create(rewriter, loc, distance, step); - auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto diff = mlir::arith::SubIOp::create(rewriter, loc, high, low); + auto distance = mlir::arith::AddIOp::create(rewriter, loc, diff, step); + auto tripCount = + mlir::arith::DivSIOp::create(rewriter, loc, distance, step); + auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); auto scfForOp = - scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); + mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); auto &loopOps = doLoopOp.getBody()->getOperations(); - auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator()); + auto resultOp = + mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator()); auto results = resultOp.getOperands(); - Block *loweredBody = scfForOp.getBody(); + mlir::Block *loweredBody = scfForOp.getBody(); loweredBody->getOperations().splice(loweredBody->begin(), loopOps, loopOps.begin(), std::prev(loopOps.end())); rewriter.setInsertionPointToStart(loweredBody); - Value iv = - arith::MulIOp::create(rewriter, loc, scfForOp.getInductionVar(), step); - iv = arith::AddIOp::create(rewriter, loc, low, iv); + mlir::Value iv = mlir::arith::MulIOp::create( + rewriter, loc, scfForOp.getInductionVar(), step); + iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv); if (!results.empty()) { rewriter.setInsertionPointToEnd(loweredBody); - scf::YieldOp::create(rewriter, resultOp->getLoc(), results); + mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results); } doLoopOp.getInductionVar().replaceAllUsesWith(iv); rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(), @@ -84,34 +84,103 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> { // Copy all the attributes from the old to new op. scfForOp->setAttrs(doLoopOp->getAttrs()); rewriter.replaceOp(doLoopOp, scfForOp); - return success(); + return mlir::success(); + } +}; + +struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { + using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(fir::IterWhileOp iterWhileOp, + mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = iterWhileOp.getLoc(); + mlir::Value lowerBound = iterWhileOp.getLowerBound(); + mlir::Value upperBound = iterWhileOp.getUpperBound(); + mlir::Value step = iterWhileOp.getStep(); + + mlir::Value okInit = iterWhileOp.getIterateIn(); + mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); + + mlir::SmallVector<mlir::Value> initVals; + initVals.push_back(lowerBound); + initVals.push_back(okInit); + initVals.append(iterArgs.begin(), iterArgs.end()); + + mlir::SmallVector<mlir::Type> loopTypes; + loopTypes.push_back(lowerBound.getType()); + loopTypes.push_back(okInit.getType()); + for (auto val : iterArgs) + loopTypes.push_back(val.getType()); + + auto scfWhileOp = + mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals); + + auto &beforeBlock = *rewriter.createBlock( + &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes, + mlir::SmallVector<mlir::Location>(loopTypes.size(), loc)); + + mlir::Region::BlockArgListType argsInBefore = + scfWhileOp.getBefore().getArguments(); + auto ivInBefore = argsInBefore[0]; + auto earlyExitInBefore = argsInBefore[1]; + + rewriter.setInsertionPointToStart(&beforeBlock); + + mlir::Value inductionCmp = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound); + mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, + earlyExitInBefore); + + mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore); + + rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(), + scfWhileOp.getAfter().begin()); + + auto *afterBody = scfWhileOp.getAfterBody(); + auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator()); + mlir::SmallVector<mlir::Value> results(resultOp->getOperands()); + mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; + + rewriter.setInsertionPointToStart(afterBody); + results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); + + rewriter.setInsertionPointToEnd(afterBody); + rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results); + + scfWhileOp->setAttrs(iterWhileOp->getAttrs()); + rewriter.replaceOp(iterWhileOp, scfWhileOp); + return mlir::success(); } }; -void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock, - Block &dstBlock) { - Operation *srcTerminator = srcBlock.getTerminator(); - auto resultOp = cast<fir::ResultOp>(srcTerminator); +void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter, + mlir::Block &srcBlock, mlir::Block &dstBlock) { + mlir::Operation *srcTerminator = srcBlock.getTerminator(); + auto resultOp = mlir::cast<fir::ResultOp>(srcTerminator); dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(), srcBlock.begin(), std::prev(srcBlock.end())); if (!resultOp->getOperands().empty()) { rewriter.setInsertionPointToEnd(&dstBlock); - scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands()); + mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), + resultOp->getOperands()); } rewriter.eraseOp(srcTerminator); } -struct IfConversion : public OpRewritePattern<fir::IfOp> { +struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> { using OpRewritePattern<fir::IfOp>::OpRewritePattern; - LogicalResult matchAndRewrite(fir::IfOp ifOp, - PatternRewriter &rewriter) const override { + mlir::LogicalResult + matchAndRewrite(fir::IfOp ifOp, + mlir::PatternRewriter &rewriter) const override { bool hasElse = !ifOp.getElseRegion().empty(); auto scfIfOp = - scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), - ifOp.getCondition(), hasElse); + mlir::scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), + ifOp.getCondition(), hasElse); copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(), scfIfOp.getThenRegion().front()); @@ -123,22 +192,18 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> { scfIfOp->setAttrs(ifOp->getAttrs()); rewriter.replaceOp(ifOp, scfIfOp); - return success(); + return mlir::success(); } }; } // namespace void FIRToSCFPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - patterns.add<DoLoopConversion, IfConversion>(patterns.getContext()); - ConversionTarget target(getContext()); - target.addIllegalOp<fir::DoLoopOp, fir::IfOp>(); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); + mlir::RewritePatternSet patterns(&getContext()); + patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>( + patterns.getContext()); + walkAndApplyPatterns(getOperation(), std::move(patterns)); } -std::unique_ptr<Pass> fir::createFIRToSCFPass() { +std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() { return std::make_unique<FIRToSCFPass>(); } diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp index 5ac4ed8..9dfe26cb 100644 --- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp +++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp @@ -95,10 +95,6 @@ void FunctionAttrPass::runOnOperation() { func->setAttr( mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName), mlir::BoolAttr::get(context, true)); - if (approxFuncFPMath) - func->setAttr( - mlir::LLVM::LLVMFuncOp::getApproxFuncFpMathAttrName(llvmFuncOpName), - mlir::BoolAttr::get(context, true)); if (noSignedZerosFPMath) func->setAttr( mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName), diff --git a/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp b/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp index 1688f28..68f5b5a 100644 --- a/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp +++ b/flang/lib/Optimizer/Transforms/OptimizeArrayRepacking.cpp @@ -26,6 +26,8 @@ namespace fir { #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir +#define DEBUG_TYPE "optimize-array-repacking" + namespace { class OptimizeArrayRepackingPass : public fir::impl::OptimizeArrayRepackingBase<OptimizeArrayRepackingPass> { @@ -56,8 +58,7 @@ PackingOfContiguous::matchAndRewrite(fir::PackArrayOp op, mlir::PatternRewriter &rewriter) const { mlir::Value box = op.getArray(); if (hlfir::isSimplyContiguous(box, !op.getInnermost())) { - rewriter.replaceAllUsesWith(op, box); - rewriter.eraseOp(op); + rewriter.replaceOp(op, box); return mlir::success(); } return mlir::failure(); @@ -78,13 +79,19 @@ void OptimizeArrayRepackingPass::runOnOperation() { mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); mlir::GreedyRewriteConfig config; - config.setRegionSimplificationLevel( - mlir::GreedySimplifyRegionLevel::Disabled); + config + .setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled) + // Traverse the operations top-down, so that fir.pack_array + // operations are optimized before their using fir.pack_array + // operations. This way the rewrite may converge faster. + .setUseTopDownTraversal(); patterns.insert<PackingOfContiguous>(context); patterns.insert<NoopUnpacking>(context); if (mlir::failed( mlir::applyPatternsGreedily(funcOp, std::move(patterns), config))) { - mlir::emitError(funcOp.getLoc(), "failure in array repacking optimization"); - signalPassFailure(); + // Failure may happen if the rewriter does not converge soon enough. + // That is not an error, so just report a diagnostic under debug. + LLVM_DEBUG(mlir::emitError(funcOp.getLoc(), + "failure in array repacking optimization")); } } diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp index c6aec96..03f97eb 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp @@ -210,19 +210,33 @@ public: mapper.map(region.getArguments(), regionArgs); for (mlir::Operation &op : region.front().without_terminator()) (void)rewriter.clone(op, mapper); + + auto yield = mlir::cast<fir::YieldOp>(region.front().getTerminator()); + assert(yield.getResults().size() < 2); + + return yield.getResults().empty() + ? mlir::Value{} + : mapper.lookup(yield.getResults()[0]); }; - if (!localizer.getInitRegion().empty()) - cloneLocalizerRegion(localizer.getInitRegion(), {localVar, localArg}, - rewriter.getInsertionPoint()); + if (!localizer.getInitRegion().empty()) { + // Prefer the value yielded from the init region to the allocated + // private variable in case the region is operating on arguments + // by-value (e.g. Fortran character boxes). + localAlloc = cloneLocalizerRegion(localizer.getInitRegion(), + {localVar, localAlloc}, + rewriter.getInsertionPoint()); + assert(localAlloc); + } if (localizer.getLocalitySpecifierType() == fir::LocalitySpecifierType::LocalInit) - cloneLocalizerRegion(localizer.getCopyRegion(), {localVar, localArg}, + cloneLocalizerRegion(localizer.getCopyRegion(), + {localVar, localAlloc}, rewriter.getInsertionPoint()); if (!localizer.getDeallocRegion().empty()) - cloneLocalizerRegion(localizer.getDeallocRegion(), {localArg}, + cloneLocalizerRegion(localizer.getDeallocRegion(), {localAlloc}, rewriter.getInsertionBlock()->end()); rewriter.replaceAllUsesWith(localArg, localAlloc); diff --git a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp index 7d1f86f..0cd2858 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp @@ -26,22 +26,16 @@ class SimplifyRegionLitePass public: void runOnOperation() override; }; - -class DummyRewriter : public mlir::PatternRewriter { -public: - DummyRewriter(mlir::MLIRContext *ctx) : mlir::PatternRewriter(ctx) {} -}; - } // namespace void SimplifyRegionLitePass::runOnOperation() { auto op = getOperation(); auto regions = op->getRegions(); mlir::RewritePatternSet patterns(op.getContext()); - DummyRewriter rewriter(op.getContext()); if (regions.empty()) return; + mlir::PatternRewriter rewriter(op.getContext()); (void)mlir::eraseUnreachableBlocks(rewriter, regions); (void)mlir::runRegionDCE(rewriter, regions); } diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 0d13129..80b3f68 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -600,10 +600,7 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, // replace references to heap allocation with references to stack allocation mlir::Value newValue = convertAllocationType( rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); - rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); - - // remove allocmem operation - rewriter.eraseOp(allocmem.getOperation()); + rewriter.replaceOp(allocmem, newValue); return mlir::success(); } @@ -813,10 +810,10 @@ void AllocMemConversion::insertLifetimeMarkers( mlir::OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPoint(oldAlloc); mlir::Value ptr = fir::factory::genLifetimeStart( - rewriter, newAlloc.getLoc(), newAlloc, *size, &*dl); + rewriter, newAlloc.getLoc(), newAlloc, &*dl); visitFreeMemOp(oldAlloc, [&](mlir::Operation *op) { rewriter.setInsertionPoint(op); - fir::factory::genLifetimeEnd(rewriter, op->getLoc(), ptr, *size); + fir::factory::genLifetimeEnd(rewriter, op->getLoc(), ptr); }); newAlloc->setAttr(attrName, rewriter.getUnitAttr()); } diff --git a/flang/lib/Parser/CMakeLists.txt b/flang/lib/Parser/CMakeLists.txt index 1855b8a..20c6c2a 100644 --- a/flang/lib/Parser/CMakeLists.txt +++ b/flang/lib/Parser/CMakeLists.txt @@ -12,6 +12,7 @@ add_flang_library(FortranParser message.cpp openacc-parsers.cpp openmp-parsers.cpp + openmp-utils.cpp parse-tree.cpp parsing.cpp preprocessor.cpp diff --git a/flang/lib/Parser/characters.cpp b/flang/lib/Parser/characters.cpp index f6ac777..1a00b16 100644 --- a/flang/lib/Parser/characters.cpp +++ b/flang/lib/Parser/characters.cpp @@ -289,7 +289,8 @@ RESULT DecodeString(const std::string &s, bool backslashEscapes) { DecodeCharacter<ENCODING>(p, bytes, backslashEscapes)}; if (decoded.bytes > 0) { if (static_cast<std::size_t>(decoded.bytes) <= bytes) { - result.append(1, decoded.codepoint); + result.append( + 1, static_cast<typename RESULT::value_type>(decoded.codepoint)); bytes -= decoded.bytes; p += decoded.bytes; continue; diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 84d1e81..ce46a86 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -469,6 +469,9 @@ TYPE_PARSER(sourced(construct<OmpContextSelectorSpecification>( // --- Parsers for clause modifiers ----------------------------------- +TYPE_PARSER(construct<OmpAccessGroup>( // + "CGROUP" >> pure(OmpAccessGroup::Value::Cgroup))) + TYPE_PARSER(construct<OmpAlignment>(scalarIntExpr)) TYPE_PARSER(construct<OmpAlignModifier>( // @@ -573,7 +576,8 @@ TYPE_PARSER(construct<OmpOrderingModifier>( "SIMD" >> pure(OmpOrderingModifier::Value::Simd))) TYPE_PARSER(construct<OmpPrescriptiveness>( - "STRICT" >> pure(OmpPrescriptiveness::Value::Strict))) + "STRICT" >> pure(OmpPrescriptiveness::Value::Strict) || + "FALLBACK" >> pure(OmpPrescriptiveness::Value::Fallback))) TYPE_PARSER(construct<OmpPresentModifier>( // "PRESENT" >> pure(OmpPresentModifier::Value::Present))) @@ -636,6 +640,12 @@ TYPE_PARSER(sourced(construct<OmpDependClause::TaskDep::Modifier>(sourced( construct<OmpDependClause::TaskDep::Modifier>( Parser<OmpTaskDependenceType>{}))))) +TYPE_PARSER( // + sourced(construct<OmpDynGroupprivateClause::Modifier>( + Parser<OmpAccessGroup>{})) || + sourced(construct<OmpDynGroupprivateClause::Modifier>( + Parser<OmpPrescriptiveness>{}))) + TYPE_PARSER( sourced(construct<OmpDeviceClause::Modifier>(Parser<OmpDeviceModifier>{}))) @@ -777,6 +787,10 @@ TYPE_PARSER(construct<OmpDefaultClause>( Parser<OmpDefaultClause::DataSharingAttribute>{}) || construct<OmpDefaultClause>(indirect(Parser<OmpDirectiveSpecification>{})))) +TYPE_PARSER(construct<OmpDynGroupprivateClause>( + maybe(nonemptyList(Parser<OmpDynGroupprivateClause::Modifier>{}) / ":"), + scalarIntExpr)) + TYPE_PARSER(construct<OmpEnterClause>( maybe(nonemptyList(Parser<OmpEnterClause::Modifier>{}) / ":"), Parser<OmpObjectList>{})) @@ -1068,6 +1082,9 @@ TYPE_PARSER( // construct<OmpClause>(parenthesized(Parser<OmpDoacrossClause>{})) || "DYNAMIC_ALLOCATORS" >> construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) || + "DYN_GROUPPRIVATE" >> + construct<OmpClause>(construct<OmpClause::DynGroupprivate>( + parenthesized(Parser<OmpDynGroupprivateClause>{}))) || "ENTER" >> construct<OmpClause>(construct<OmpClause::Enter>( parenthesized(Parser<OmpEnterClause>{}))) || "EXCLUSIVE" >> construct<OmpClause>(construct<OmpClause::Exclusive>( @@ -1264,6 +1281,16 @@ static bool IsFortranBlockConstruct(const ExecutionPartConstruct &epc) { } } +static bool IsStandaloneOrdered(const OmpDirectiveSpecification &dirSpec) { + // An ORDERED construct is standalone if it has DOACROSS or DEPEND clause. + return dirSpec.DirId() == llvm::omp::Directive::OMPD_ordered && + llvm::any_of(dirSpec.Clauses().v, [](const OmpClause &clause) { + llvm::omp::Clause id{clause.Id()}; + return id == llvm::omp::Clause::OMPC_depend || + id == llvm::omp::Clause::OMPC_doacross; + }); +} + struct StrictlyStructuredBlockParser { using resultType = Block; @@ -1272,9 +1299,9 @@ struct StrictlyStructuredBlockParser { if (lookAhead(skipStuffBeforeStatement >> "BLOCK"_tok).Parse(state)) { if (auto epc{Parser<ExecutionPartConstruct>{}.Parse(state)}) { if (IsFortranBlockConstruct(*epc)) { - Block block; - block.emplace_back(std::move(*epc)); - return std::move(block); + Block body; + body.emplace_back(std::move(*epc)); + return std::move(body); } } } @@ -1290,22 +1317,11 @@ struct LooselyStructuredBlockParser { if (lookAhead(skipStuffBeforeStatement >> "BLOCK"_tok).Parse(state)) { return std::nullopt; } - Block body; - if (auto epc{attempt(Parser<ExecutionPartConstruct>{}).Parse(state)}) { - if (!IsFortranBlockConstruct(*epc)) { - body.emplace_back(std::move(*epc)); - if (auto &&blk{attempt(block).Parse(state)}) { - for (auto &&s : *blk) { - body.emplace_back(std::move(s)); - } - } - } else { - // Fail if the first construct is BLOCK. - return std::nullopt; - } + if (auto &&body{block.Parse(state)}) { + // Empty body is ok. + return std::move(body); } - // Empty body is ok. - return std::move(body); + return std::nullopt; } }; @@ -1458,6 +1474,9 @@ struct OmpBlockConstructParser { std::optional<resultType> Parse(ParseState &state) const { if (auto &&begin{OmpBeginDirectiveParser(dir_).Parse(state)}) { + if (IsStandaloneOrdered(*begin)) { + return std::nullopt; + } if (auto &&body{attempt(StrictlyStructuredBlockParser{}).Parse(state)}) { // Try strictly-structured block with an optional end-directive auto end{maybe(OmpEndDirectiveParser{dir_}).Parse(state)}; @@ -1467,11 +1486,14 @@ struct OmpBlockConstructParser { [](auto &&s) { return OmpEndDirective(std::move(s)); })}; } else if (auto &&body{ attempt(LooselyStructuredBlockParser{}).Parse(state)}) { - // Try loosely-structured block with a mandatory end-directive - if (auto end{OmpEndDirectiveParser{dir_}.Parse(state)}) { - return OmpBlockConstruct{OmpBeginDirective(std::move(*begin)), - std::move(*body), OmpEndDirective{std::move(*end)}}; - } + // Try loosely-structured block with a mandatory end-directive. + auto end{maybe(OmpEndDirectiveParser{dir_}).Parse(state)}; + // Delay the error for a missing end-directive until semantics so that + // we have better control over the output. + return OmpBlockConstruct{OmpBeginDirective(std::move(*begin)), + std::move(*body), + llvm::transformOptional(std::move(*end), + [](auto &&s) { return OmpEndDirective(std::move(s)); })}; } } return std::nullopt; @@ -1622,7 +1644,6 @@ TYPE_PARSER(sourced( // static bool IsSimpleStandalone(const OmpDirectiveName &name) { switch (name.v) { case llvm::omp::Directive::OMPD_barrier: - case llvm::omp::Directive::OMPD_ordered: case llvm::omp::Directive::OMPD_scan: case llvm::omp::Directive::OMPD_target_enter_data: case llvm::omp::Directive::OMPD_target_exit_data: @@ -1638,7 +1659,9 @@ static bool IsSimpleStandalone(const OmpDirectiveName &name) { TYPE_PARSER(sourced( // construct<OpenMPSimpleStandaloneConstruct>( predicated(OmpDirectiveNameParser{}, IsSimpleStandalone) >= - Parser<OmpDirectiveSpecification>{}))) + Parser<OmpDirectiveSpecification>{}) || + construct<OpenMPSimpleStandaloneConstruct>( + predicated(Parser<OmpDirectiveSpecification>{}, IsStandaloneOrdered)))) TYPE_PARSER(sourced( // construct<OpenMPFlushConstruct>( @@ -1758,17 +1781,8 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareMapperConstruct>( TYPE_PARSER(construct<OmpReductionCombiner>(Parser<AssignmentStmt>{}) || construct<OmpReductionCombiner>(Parser<FunctionReference>{})) -// 2.13.2 OMP CRITICAL -TYPE_PARSER(startOmpLine >> - sourced(construct<OmpEndCriticalDirective>( - verbatim("END CRITICAL"_tok), maybe(parenthesized(name)))) / - endOmpLine) -TYPE_PARSER(sourced(construct<OmpCriticalDirective>(verbatim("CRITICAL"_tok), - maybe(parenthesized(name)), Parser<OmpClauseList>{})) / - endOmpLine) - TYPE_PARSER(construct<OpenMPCriticalConstruct>( - Parser<OmpCriticalDirective>{}, block, Parser<OmpEndCriticalDirective>{})) + OmpBlockConstructParser{llvm::omp::Directive::OMPD_critical})) // 2.11.3 Executable Allocate directive TYPE_PARSER( @@ -1782,6 +1796,12 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareSimdConstruct>( verbatim("DECLARE SIMD"_tok) || verbatim("DECLARE_SIMD"_tok), maybe(parenthesized(name)), Parser<OmpClauseList>{}))) +TYPE_PARSER(sourced( // + construct<OpenMPGroupprivate>( + predicated(OmpDirectiveNameParser{}, + IsDirective(llvm::omp::Directive::OMPD_groupprivate)) >= + Parser<OmpDirectiveSpecification>{}))) + // 2.4 Requires construct TYPE_PARSER(sourced(construct<OpenMPRequiresConstruct>( verbatim("REQUIRES"_tok), Parser<OmpClauseList>{}))) @@ -1818,6 +1838,8 @@ TYPE_PARSER( construct<OpenMPDeclarativeConstruct>( Parser<OpenMPDeclarativeAllocate>{}) || construct<OpenMPDeclarativeConstruct>( + Parser<OpenMPGroupprivate>{}) || + construct<OpenMPDeclarativeConstruct>( Parser<OpenMPRequiresConstruct>{}) || construct<OpenMPDeclarativeConstruct>( Parser<OpenMPThreadprivate>{}) || @@ -1827,20 +1849,12 @@ TYPE_PARSER( Parser<OmpMetadirectiveDirective>{})) / endOmpLine)) -// Assume Construct -TYPE_PARSER(sourced(construct<OmpAssumeDirective>( - verbatim("ASSUME"_tok), Parser<OmpClauseList>{}))) - -TYPE_PARSER(sourced(construct<OmpEndAssumeDirective>( - startOmpLine >> verbatim("END ASSUME"_tok)))) - -TYPE_PARSER(sourced( - construct<OpenMPAssumeConstruct>(Parser<OmpAssumeDirective>{} / endOmpLine, - block, maybe(Parser<OmpEndAssumeDirective>{} / endOmpLine)))) +TYPE_PARSER(construct<OpenMPAssumeConstruct>( + sourced(OmpBlockConstructParser{llvm::omp::Directive::OMPD_assume}))) // Block Construct #define MakeBlockConstruct(dir) \ - construct<OpenMPBlockConstruct>(OmpBlockConstructParser{dir}) + construct<OmpBlockConstruct>(OmpBlockConstructParser{dir}) TYPE_PARSER( // MakeBlockConstruct(llvm::omp::Directive::OMPD_masked) || MakeBlockConstruct(llvm::omp::Directive::OMPD_master) || @@ -1854,11 +1868,15 @@ TYPE_PARSER( // MakeBlockConstruct(llvm::omp::Directive::OMPD_target_data) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target_parallel) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target_teams) || + MakeBlockConstruct( + llvm::omp::Directive::OMPD_target_teams_workdistribute) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target) || MakeBlockConstruct(llvm::omp::Directive::OMPD_task) || MakeBlockConstruct(llvm::omp::Directive::OMPD_taskgroup) || MakeBlockConstruct(llvm::omp::Directive::OMPD_teams) || - MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare)) + MakeBlockConstruct(llvm::omp::Directive::OMPD_teams_workdistribute) || + MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare) || + MakeBlockConstruct(llvm::omp::Directive::OMPD_workdistribute)) #undef MakeBlockConstruct // OMP SECTIONS Directive @@ -1887,7 +1905,7 @@ TYPE_PARSER(sourced(construct<OpenMPSectionsConstruct>( construct<OpenMPSectionConstruct>(maybe(sectionDir), block))), many(construct<OpenMPConstruct>( sourced(construct<OpenMPSectionConstruct>(sectionDir, block))))), - Parser<OmpEndSectionsDirective>{} / endOmpLine))) + maybe(Parser<OmpEndSectionsDirective>{} / endOmpLine)))) static bool IsExecutionPart(const OmpDirectiveName &name) { return name.IsExecutionPart(); @@ -1901,8 +1919,8 @@ TYPE_CONTEXT_PARSER("OpenMP construct"_en_US, withMessage("expected OpenMP construct"_err_en_US, first(construct<OpenMPConstruct>(Parser<OpenMPSectionsConstruct>{}), construct<OpenMPConstruct>(Parser<OpenMPLoopConstruct>{}), - construct<OpenMPConstruct>(Parser<OpenMPBlockConstruct>{}), - // OpenMPBlockConstruct is attempted before + construct<OpenMPConstruct>(Parser<OmpBlockConstruct>{}), + // OmpBlockConstruct is attempted before // OpenMPStandaloneConstruct to resolve !$OMP ORDERED construct<OpenMPConstruct>(Parser<OpenMPStandaloneConstruct>{}), construct<OpenMPConstruct>(Parser<OpenMPAtomicConstruct>{}), diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp new file mode 100644 index 0000000..ef7e4fc --- /dev/null +++ b/flang/lib/Parser/openmp-utils.cpp @@ -0,0 +1,64 @@ +//===-- flang/Parser/openmp-utils.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Common OpenMP utilities. +// +//===----------------------------------------------------------------------===// + +#include "flang/Parser/openmp-utils.h" + +#include "flang/Common/template.h" +#include "flang/Common/visit.h" + +#include <tuple> +#include <type_traits> +#include <variant> + +namespace Fortran::parser::omp { + +const OmpObjectList *GetOmpObjectList(const OmpClause &clause) { + // Clauses with OmpObjectList as its data member + using MemberObjectListClauses = std::tuple<OmpClause::Copyin, + OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate, + OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr, + OmpClause::Link, OmpClause::Private, OmpClause::Shared, + OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>; + + // Clauses with OmpObjectList in the tuple + using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs, + OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate, + OmpClause::Enter, OmpClause::From, OmpClause::InReduction, + OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map, + OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>; + + // TODO:: Generate the tuples using TableGen. + return common::visit( + common::visitors{ + [&](const OmpClause::Depend &x) -> const OmpObjectList * { + if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) { + return &std::get<OmpObjectList>(taskDep->t); + } else { + return nullptr; + } + }, + [&](const auto &x) -> const OmpObjectList * { + using Ty = std::decay_t<decltype(x)>; + if constexpr (common::HasMember<Ty, MemberObjectListClauses>) { + return &x.v; + } else if constexpr (common::HasMember<Ty, + TupleObjectListClauses>) { + return &std::get<OmpObjectList>(x.v.t); + } else { + return nullptr; + } + }, + }, + clause.u); +} + +} // namespace Fortran::parser::omp diff --git a/flang/lib/Parser/parsing.cpp b/flang/lib/Parser/parsing.cpp index ceea747..8a8c6ef 100644 --- a/flang/lib/Parser/parsing.cpp +++ b/flang/lib/Parser/parsing.cpp @@ -96,9 +96,6 @@ const SourceFile *Parsing::Prescan(const std::string &path, Options options) { prescanner.AddCompilerDirectiveSentinel("$cuf"); prescanner.AddCompilerDirectiveSentinel("@cuf"); } - if (options.features.IsEnabled(LanguageFeature::CUDA)) { - preprocessor_.Define("_CUDA", "1"); - } ProvenanceRange range{allSources.AddIncludedFile( *sourceFile, ProvenanceRange{}, options.isModuleFile)}; prescanner.Prescan(range); diff --git a/flang/lib/Parser/preprocessor.cpp b/flang/lib/Parser/preprocessor.cpp index 0aadc41..9176b4d 100644 --- a/flang/lib/Parser/preprocessor.cpp +++ b/flang/lib/Parser/preprocessor.cpp @@ -414,7 +414,7 @@ std::optional<TokenSequence> Preprocessor::MacroReplacement( const TokenSequence &input, Prescanner &prescanner, std::optional<std::size_t> *partialFunctionLikeMacro, bool inIfExpression) { // Do quick scan for any use of a defined name. - if (definitions_.empty()) { + if (!inIfExpression && definitions_.empty()) { return std::nullopt; } std::size_t tokens{input.SizeInTokens()}; @@ -742,12 +742,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { "# missing or invalid name"_err_en_US); } else { if (dir.IsAnythingLeft(++j)) { - if (prescanner.features().ShouldWarn( - common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#undef: excess tokens at end of directive"_port_en_US); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#undef: excess tokens at end of directive"_port_en_US); } else { definitions_.erase(nameToken); } @@ -760,12 +757,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { "#%s: missing name"_err_en_US, dirName); } else { if (dir.IsAnythingLeft(++j)) { - if (prescanner.features().ShouldWarn( - common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#%s: excess tokens at end of directive"_port_en_US, dirName); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#%s: excess tokens at end of directive"_port_en_US, dirName); } doThen = IsNameDefined(nameToken) == (dirName == "ifdef"); } @@ -784,11 +778,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { } } else if (dirName == "else") { if (dir.IsAnythingLeft(j)) { - if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#else: excess tokens at end of directive"_port_en_US); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#else: excess tokens at end of directive"_port_en_US); } if (ifStack_.empty()) { prescanner.Say(dir.GetTokenProvenanceRange(dirOffset), @@ -815,11 +807,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { } } else if (dirName == "endif") { if (dir.IsAnythingLeft(j)) { - if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#endif: excess tokens at end of directive"_port_en_US); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#endif: excess tokens at end of directive"_port_en_US); } else if (ifStack_.empty()) { prescanner.Say(dir.GetTokenProvenanceRange(dirOffset), "#endif: no #if, #ifdef, or #ifndef"_err_en_US); @@ -866,12 +856,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { ++k; } if (k >= pathTokens) { - if (prescanner.features().ShouldWarn( - common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#include: expected '>' at end of included file"_port_en_US); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#include: expected '>' at end of included file"_port_en_US); } TokenSequence braced{path, 1, k - 1}; include = braced.ToString(); @@ -897,11 +884,9 @@ void Preprocessor::Directive(const TokenSequence &dir, Prescanner &prescanner) { } k = path.SkipBlanks(k + 1); if (k < pathTokens && path.TokenAt(k).ToString() != "!") { - if (prescanner.features().ShouldWarn(common::UsageWarning::Portability)) { - prescanner.Say(common::UsageWarning::Portability, - dir.GetIntervalProvenanceRange(j, tokens - j), - "#include: extra stuff ignored after file name"_port_en_US); - } + prescanner.Warn(common::UsageWarning::Portability, + dir.GetIntervalProvenanceRange(j, tokens - j), + "#include: extra stuff ignored after file name"_port_en_US); } std::string buf; llvm::raw_string_ostream error{buf}; diff --git a/flang/lib/Parser/prescan.h b/flang/lib/Parser/prescan.h index f650d54..c181c03 100644 --- a/flang/lib/Parser/prescan.h +++ b/flang/lib/Parser/prescan.h @@ -91,6 +91,15 @@ public: return messages_.Say(std::forward<A>(a)...); } + template <typename... A> + Message *Warn(common::UsageWarning warning, A &&...a) { + return messages_.Warn(false, features_, warning, std::forward<A>(a)...); + } + template <typename... A> + Message *Warn(common::LanguageFeature feature, A &&...a) { + return messages_.Warn(false, features_, feature, std::forward<A>(a)...); + } + private: struct LineClassification { enum class Kind { diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 46141e2..dc6d336 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2250,6 +2250,11 @@ public: Walk(std::get<OmpObjectList>(x.t)); Walk(": ", std::get<std::optional<std::list<Modifier>>>(x.t)); } + void Unparse(const OmpDynGroupprivateClause &x) { + using Modifier = OmpDynGroupprivateClause::Modifier; + Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": "); + Walk(std::get<ScalarIntExpr>(x.t)); + } void Unparse(const OmpEnterClause &x) { using Modifier = OmpEnterClause::Modifier; Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": "); @@ -2575,40 +2580,14 @@ public: Put("\n"); EndOpenMP(); } - void Unparse(const OpenMPAllocatorsConstruct &x) { // + void Unparse(const OpenMPAllocatorsConstruct &x) { Unparse(static_cast<const OmpBlockConstruct &>(x)); } - void Unparse(const OmpAssumeDirective &x) { - BeginOpenMP(); - Word("!$OMP ASSUME"); - Walk(" ", std::get<OmpClauseList>(x.t).v); - Put("\n"); - EndOpenMP(); - } - void Unparse(const OmpEndAssumeDirective &x) { - BeginOpenMP(); - Word("!$OMP END ASSUME\n"); - EndOpenMP(); - } - void Unparse(const OmpCriticalDirective &x) { - BeginOpenMP(); - Word("!$OMP CRITICAL"); - Walk(" (", std::get<std::optional<Name>>(x.t), ")"); - Walk(std::get<OmpClauseList>(x.t)); - Put("\n"); - EndOpenMP(); - } - void Unparse(const OmpEndCriticalDirective &x) { - BeginOpenMP(); - Word("!$OMP END CRITICAL"); - Walk(" (", std::get<std::optional<Name>>(x.t), ")"); - Put("\n"); - EndOpenMP(); + void Unparse(const OpenMPAssumeConstruct &x) { + Unparse(static_cast<const OmpBlockConstruct &>(x)); } void Unparse(const OpenMPCriticalConstruct &x) { - Walk(std::get<OmpCriticalDirective>(x.t)); - Walk(std::get<Block>(x.t), ""); - Walk(std::get<OmpEndCriticalDirective>(x.t)); + Unparse(static_cast<const OmpBlockConstruct &>(x)); } void Unparse(const OmpDeclareTargetWithList &x) { Put("("), Walk(x.v), Put(")"); @@ -2718,6 +2697,13 @@ public: void Unparse(const OpenMPDispatchConstruct &x) { // Unparse(static_cast<const OmpBlockConstruct &>(x)); } + void Unparse(const OpenMPGroupprivate &x) { + BeginOpenMP(); + Word("!$OMP "); + Walk(x.v); + Put("\n"); + EndOpenMP(); + } void Unparse(const OpenMPRequiresConstruct &y) { BeginOpenMP(); Word("!$OMP REQUIRES "); @@ -2778,7 +2764,7 @@ public: Walk(std::get<std::list<OpenMPConstruct>>(x.t), ""); BeginOpenMP(); Word("!$OMP END "); - Walk(std::get<OmpEndSectionsDirective>(x.t)); + Walk(std::get<std::optional<OmpEndSectionsDirective>>(x.t)); Put("\n"); EndOpenMP(); } @@ -2847,9 +2833,6 @@ public: Put("\n"); EndOpenMP(); } - void Unparse(const OpenMPBlockConstruct &x) { - Unparse(static_cast<const OmpBlockConstruct &>(x)); - } void Unparse(const OpenMPLoopConstruct &x) { BeginOpenMP(); Word("!$OMP "); @@ -2943,6 +2926,7 @@ public: WALK_NESTED_ENUM(OmpTaskDependenceType, Value) // OMP task-dependence-type WALK_NESTED_ENUM(OmpScheduleClause, Kind) // OMP schedule-kind WALK_NESTED_ENUM(OmpSeverityClause, Severity) // OMP severity + WALK_NESTED_ENUM(OmpAccessGroup, Value) WALK_NESTED_ENUM(OmpDeviceModifier, Value) // OMP device modifier WALK_NESTED_ENUM( OmpDeviceTypeClause, DeviceTypeDescription) // OMP device_type diff --git a/flang/lib/Semantics/check-acc-structure.cpp b/flang/lib/Semantics/check-acc-structure.cpp index 051abdc..6cb7e5e 100644 --- a/flang/lib/Semantics/check-acc-structure.cpp +++ b/flang/lib/Semantics/check-acc-structure.cpp @@ -983,24 +983,26 @@ void AccStructureChecker::Enter(const parser::AccClause::Reduction &reduction) { [&](const parser::Designator &designator) { if (const auto *name = getDesignatorNameIfDataRef(designator)) { if (name->symbol) { - const auto *type{name->symbol->GetType()}; - if (type->IsNumeric(TypeCategory::Integer) && - !reductionIntegerSet.test(op.v)) { - context_.Say(GetContext().clauseSource, - "reduction operator not supported for integer type"_err_en_US); - } else if (type->IsNumeric(TypeCategory::Real) && - !reductionRealSet.test(op.v)) { - context_.Say(GetContext().clauseSource, - "reduction operator not supported for real type"_err_en_US); - } else if (type->IsNumeric(TypeCategory::Complex) && - !reductionComplexSet.test(op.v)) { - context_.Say(GetContext().clauseSource, - "reduction operator not supported for complex type"_err_en_US); - } else if (type->category() == - Fortran::semantics::DeclTypeSpec::Category::Logical && - !reductionLogicalSet.test(op.v)) { - context_.Say(GetContext().clauseSource, - "reduction operator not supported for logical type"_err_en_US); + if (const auto *type{name->symbol->GetType()}) { + if (type->IsNumeric(TypeCategory::Integer) && + !reductionIntegerSet.test(op.v)) { + context_.Say(GetContext().clauseSource, + "reduction operator not supported for integer type"_err_en_US); + } else if (type->IsNumeric(TypeCategory::Real) && + !reductionRealSet.test(op.v)) { + context_.Say(GetContext().clauseSource, + "reduction operator not supported for real type"_err_en_US); + } else if (type->IsNumeric(TypeCategory::Complex) && + !reductionComplexSet.test(op.v)) { + context_.Say(GetContext().clauseSource, + "reduction operator not supported for complex type"_err_en_US); + } else if (type->category() == + Fortran::semantics::DeclTypeSpec::Category:: + Logical && + !reductionLogicalSet.test(op.v)) { + context_.Say(GetContext().clauseSource, + "reduction operator not supported for logical type"_err_en_US); + } } // TODO: check composite type. } diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp index 0805359..823aa4e 100644 --- a/flang/lib/Semantics/check-allocate.cpp +++ b/flang/lib/Semantics/check-allocate.cpp @@ -548,7 +548,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { } } // Shape related checks - if (ultimate_ && evaluate::IsAssumedRank(*ultimate_)) { + if (ultimate_ && IsAssumedRank(*ultimate_)) { context.Say(name_.source, "An assumed-rank dummy argument may not appear in an ALLOCATE statement"_err_en_US); return false; diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp index 6f250328..f0078fd 100644 --- a/flang/lib/Semantics/check-call.cpp +++ b/flang/lib/Semantics/check-call.cpp @@ -67,7 +67,7 @@ static void CheckImplicitInterfaceArg(evaluate::ActualArgument &arg, "Null pointer argument requires an explicit interface"_err_en_US); } else if (auto named{evaluate::ExtractNamedEntity(*expr)}) { const Symbol &symbol{named->GetLastSymbol()}; - if (evaluate::IsAssumedRank(symbol)) { + if (IsAssumedRank(symbol)) { messages.Say( "Assumed rank argument requires an explicit interface"_err_en_US); } @@ -131,7 +131,7 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual, dummy.type.type().kind() == actualType.type().kind() && !dummy.attrs.test( characteristics::DummyDataObject::Attr::DeducedFromActual)) { - bool actualIsAssumedRank{evaluate::IsAssumedRank(actual)}; + bool actualIsAssumedRank{IsAssumedRank(actual)}; if (actualIsAssumedRank && !dummy.type.attrs().test( characteristics::TypeAndShape::Attr::AssumedRank)) { @@ -140,7 +140,8 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual, messages.Say( "Assumed-rank character array may not be associated with a dummy argument that is not assumed-rank"_err_en_US); } else { - context.Warn(common::LanguageFeature::AssumedRankPassedToNonAssumedRank, + context.Warn(messages, + common::LanguageFeature::AssumedRankPassedToNonAssumedRank, messages.at(), "Assumed-rank character array should not be associated with a dummy argument that is not assumed-rank"_port_en_US); } @@ -187,9 +188,9 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual, "Actual argument has fewer characters remaining in storage sequence (%jd) than %s (%jd)"_err_en_US, static_cast<std::intmax_t>(actualChars), dummyName, static_cast<std::intmax_t>(dummyChars)); - } else if (context.ShouldWarn( - common::UsageWarning::ShortCharacterActual)) { - messages.Say(common::UsageWarning::ShortCharacterActual, + } else { + context.Warn(messages, + common::UsageWarning::ShortCharacterActual, "Actual argument has fewer characters remaining in storage sequence (%jd) than %s (%jd)"_warn_en_US, static_cast<std::intmax_t>(actualChars), dummyName, static_cast<std::intmax_t>(dummyChars)); @@ -207,9 +208,9 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual, static_cast<std::intmax_t>(*actualSize * *actualLength), dummyName, static_cast<std::intmax_t>(*dummySize * *dummyLength)); - } else if (context.ShouldWarn( - common::UsageWarning::ShortCharacterActual)) { - messages.Say(common::UsageWarning::ShortCharacterActual, + } else { + context.Warn(messages, + common::UsageWarning::ShortCharacterActual, "Actual argument array has fewer characters (%jd) than %s array (%jd)"_warn_en_US, static_cast<std::intmax_t>(*actualSize * *actualLength), dummyName, @@ -229,17 +230,14 @@ static void CheckCharacterActual(evaluate::Expr<evaluate::SomeType> &actual, } else if (*actualLength < *dummyLength) { CHECK(dummy.type.Rank() == 0); bool isVariable{evaluate::IsVariable(actual)}; - if (context.ShouldWarn( - common::UsageWarning::ShortCharacterActual)) { - if (isVariable) { - messages.Say(common::UsageWarning::ShortCharacterActual, - "Actual argument variable length '%jd' is less than expected length '%jd'"_warn_en_US, - *actualLength, *dummyLength); - } else { - messages.Say(common::UsageWarning::ShortCharacterActual, - "Actual argument expression length '%jd' is less than expected length '%jd'"_warn_en_US, - *actualLength, *dummyLength); - } + if (isVariable) { + context.Warn(messages, common::UsageWarning::ShortCharacterActual, + "Actual argument variable length '%jd' is less than expected length '%jd'"_warn_en_US, + *actualLength, *dummyLength); + } else { + context.Warn(messages, common::UsageWarning::ShortCharacterActual, + "Actual argument expression length '%jd' is less than expected length '%jd'"_warn_en_US, + *actualLength, *dummyLength); } if (!isVariable) { auto converted{ @@ -279,9 +277,8 @@ static void ConvertIntegerActual(evaluate::Expr<evaluate::SomeType> &actual, messages.Say( "Actual argument scalar expression of type INTEGER(%d) cannot be implicitly converted to smaller dummy argument type INTEGER(%d)"_err_en_US, actualType.type().kind(), dummyType.type().kind()); - } else if (semanticsContext.ShouldWarn(common::LanguageFeature:: - ActualIntegerConvertedToSmallerKind)) { - messages.Say( + } else { + semanticsContext.Warn(messages, common::LanguageFeature::ActualIntegerConvertedToSmallerKind, "Actual argument scalar expression of type INTEGER(%d) was converted to smaller dummy argument type INTEGER(%d)"_port_en_US, actualType.type().kind(), dummyType.type().kind()); @@ -364,20 +361,16 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, if (const auto *constantChar{ evaluate::UnwrapConstantValue<evaluate::Ascii>(actual)}; constantChar && constantChar->wasHollerith() && - dummy.type.type().IsUnlimitedPolymorphic() && - context.ShouldWarn(common::LanguageFeature::HollerithPolymorphic)) { - messages.Say(common::LanguageFeature::HollerithPolymorphic, + dummy.type.type().IsUnlimitedPolymorphic()) { + foldingContext.Warn(common::LanguageFeature::HollerithPolymorphic, "passing Hollerith to unlimited polymorphic as if it were CHARACTER"_port_en_US); } } else if (dummyRank == 0 && allowActualArgumentConversions) { // Extension: pass Hollerith literal to scalar as if it had been BOZ if (auto converted{evaluate::HollerithToBOZ( foldingContext, actual, dummy.type.type())}) { - if (context.ShouldWarn( - common::LanguageFeature::HollerithOrCharacterAsBOZ)) { - messages.Say(common::LanguageFeature::HollerithOrCharacterAsBOZ, - "passing Hollerith or character literal as if it were BOZ"_port_en_US); - } + foldingContext.Warn(common::LanguageFeature::HollerithOrCharacterAsBOZ, + "passing Hollerith or character literal as if it were BOZ"_port_en_US); actual = *converted; actualType.type() = dummy.type.type(); typesCompatible = true; @@ -387,7 +380,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, characteristics::TypeAndShape::Attr::AssumedRank)}; bool actualIsAssumedSize{actualType.attrs().test( characteristics::TypeAndShape::Attr::AssumedSize)}; - bool actualIsAssumedRank{evaluate::IsAssumedRank(actual)}; + bool actualIsAssumedRank{IsAssumedRank(actual)}; bool actualIsPointer{evaluate::IsObjectPointer(actual)}; bool actualIsAllocatable{evaluate::IsAllocatableDesignator(actual)}; bool actualMayBeAssumedSize{actualIsAssumedSize || @@ -411,7 +404,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, "%s actual argument may not be associated with INTENT(OUT) assumed-rank dummy argument requiring finalization, destruction, or initialization"_err_en_US, actualDesc); } else { - context.Warn(common::UsageWarning::Portability, messages.at(), + foldingContext.Warn(common::UsageWarning::Portability, messages.at(), "%s actual argument should not be associated with INTENT(OUT) assumed-rank dummy argument"_port_en_US, actualDesc); } @@ -671,9 +664,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, "Actual argument has fewer elements remaining in storage sequence (%jd) than %s array (%jd)"_err_en_US, static_cast<std::intmax_t>(*actualElements), dummyName, static_cast<std::intmax_t>(*dummySize)); - } else if (context.ShouldWarn( - common::UsageWarning::ShortArrayActual)) { - messages.Say(common::UsageWarning::ShortArrayActual, + } else { + context.Warn(common::UsageWarning::ShortArrayActual, "Actual argument has fewer elements remaining in storage sequence (%jd) than %s array (%jd)"_warn_en_US, static_cast<std::intmax_t>(*actualElements), dummyName, static_cast<std::intmax_t>(*dummySize)); @@ -690,9 +682,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, "Actual argument array has fewer elements (%jd) than %s array (%jd)"_err_en_US, static_cast<std::intmax_t>(*actualSize), dummyName, static_cast<std::intmax_t>(*dummySize)); - } else if (context.ShouldWarn( - common::UsageWarning::ShortArrayActual)) { - messages.Say(common::UsageWarning::ShortArrayActual, + } else { + context.Warn(common::UsageWarning::ShortArrayActual, "Actual argument array has fewer elements (%jd) than %s array (%jd)"_warn_en_US, static_cast<std::intmax_t>(*actualSize), dummyName, static_cast<std::intmax_t>(*dummySize)); @@ -779,24 +770,36 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, // Cases when temporaries might be needed but must not be permitted. bool dummyIsAssumedShape{dummy.type.attrs().test( characteristics::TypeAndShape::Attr::AssumedShape)}; - if ((actualIsAsynchronous || actualIsVolatile) && - (dummyIsAsynchronous || dummyIsVolatile) && !dummyIsValue) { - if (actualCoarrayRef) { // C1538 - messages.Say( - "Coindexed ASYNCHRONOUS or VOLATILE actual argument may not be associated with %s with ASYNCHRONOUS or VOLATILE attributes unless VALUE"_err_en_US, - dummyName); - } - if ((actualRank > 0 || actualIsAssumedRank) && !actualIsContiguous) { - if (dummyIsContiguous || - !(dummyIsAssumedShape || dummyIsAssumedRank || - (actualIsPointer && dummyIsPointer))) { // C1539 & C1540 + if (!dummyIsValue && (dummyIsAsynchronous || dummyIsVolatile)) { + if (actualIsAsynchronous || actualIsVolatile) { + if (actualCoarrayRef) { // F'2023 C1547 messages.Say( - "ASYNCHRONOUS or VOLATILE actual argument that is not simply contiguous may not be associated with a contiguous ASYNCHRONOUS or VOLATILE %s"_err_en_US, + "Coindexed ASYNCHRONOUS or VOLATILE actual argument may not be associated with %s with ASYNCHRONOUS or VOLATILE attributes unless VALUE"_err_en_US, dummyName); } + if ((actualRank > 0 || actualIsAssumedRank) && !actualIsContiguous) { + if (dummyIsContiguous || + !(dummyIsAssumedShape || dummyIsAssumedRank || + (actualIsPointer && dummyIsPointer))) { // F'2023 C1548 & C1549 + messages.Say( + "ASYNCHRONOUS or VOLATILE actual argument that is not simply contiguous may not be associated with a contiguous ASYNCHRONOUS or VOLATILE %s"_err_en_US, + dummyName); + } + } + // The vector subscript case is handled by the definability check above. + // The copy-in/copy-out cases are handled by the previous checks. + // Nag, GFortran, and NVFortran all error on this case, even though it is + // ok, prossibly as an over-restriction of C1548. + } else if (!(dummyIsAssumedShape || dummyIsAssumedRank || + (actualIsPointer && dummyIsPointer)) && + evaluate::IsArraySection(actual) && + !evaluate::HasVectorSubscript(actual)) { + context.Warn(common::UsageWarning::Portability, messages.at(), + "The array section '%s' should not be associated with %s with %s attribute, unless the dummy is assumed-shape or assumed-rank"_port_en_US, + actual.AsFortran(), dummyName, + dummyIsAsynchronous ? "ASYNCHRONOUS" : "VOLATILE"); } } - // 15.5.2.6 -- dummy is ALLOCATABLE bool dummyIsOptional{ dummy.attrs.test(characteristics::DummyDataObject::Attr::Optional)}; @@ -821,10 +824,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, messages.Say( "A null pointer should not be associated with allocatable %s without INTENT(IN)"_warn_en_US, dummyName); - } else if (dummy.intent == common::Intent::In && - context.ShouldWarn( - common::LanguageFeature::NullActualForAllocatable)) { - messages.Say(common::LanguageFeature::NullActualForAllocatable, + } else if (dummy.intent == common::Intent::In) { + foldingContext.Warn(common::LanguageFeature::NullActualForAllocatable, "Allocatable %s is associated with a null pointer"_port_en_US, dummyName); } @@ -878,11 +879,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, checkTypeCompatibility = false; if (dummyIsUnlimited && dummy.intent == common::Intent::In && context.IsEnabled(common::LanguageFeature::RelaxedIntentInChecking)) { - if (context.ShouldWarn( - common::LanguageFeature::RelaxedIntentInChecking)) { - messages.Say(common::LanguageFeature::RelaxedIntentInChecking, - "If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both should be so"_port_en_US); - } + foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking, + "If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both should be so"_port_en_US); } else { messages.Say( "If a POINTER or ALLOCATABLE dummy or actual argument is unlimited polymorphic, both must be so"_err_en_US); @@ -890,21 +888,15 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, } else if (dummyIsPolymorphic != actualIsPolymorphic) { if (dummyIsPolymorphic && dummy.intent == common::Intent::In && context.IsEnabled(common::LanguageFeature::RelaxedIntentInChecking)) { - if (context.ShouldWarn( - common::LanguageFeature::RelaxedIntentInChecking)) { - messages.Say(common::LanguageFeature::RelaxedIntentInChecking, - "If a POINTER or ALLOCATABLE dummy or actual argument is polymorphic, both should be so"_port_en_US); - } + foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking, + "If a POINTER or ALLOCATABLE dummy or actual argument is polymorphic, both should be so"_port_en_US); } else if (actualIsPolymorphic && context.IsEnabled(common::LanguageFeature:: PolymorphicActualAllocatableOrPointerToMonomorphicDummy)) { - if (context.ShouldWarn(common::LanguageFeature:: - PolymorphicActualAllocatableOrPointerToMonomorphicDummy)) { - messages.Say( - common::LanguageFeature:: - PolymorphicActualAllocatableOrPointerToMonomorphicDummy, - "If a POINTER or ALLOCATABLE actual argument is polymorphic, the corresponding dummy argument should also be so"_port_en_US); - } + foldingContext.Warn( + common::LanguageFeature:: + PolymorphicActualAllocatableOrPointerToMonomorphicDummy, + "If a POINTER or ALLOCATABLE actual argument is polymorphic, the corresponding dummy argument should also be so"_port_en_US); } else { checkTypeCompatibility = false; messages.Say( @@ -916,11 +908,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, if (dummy.intent == common::Intent::In && context.IsEnabled( common::LanguageFeature::RelaxedIntentInChecking)) { - if (context.ShouldWarn( - common::LanguageFeature::RelaxedIntentInChecking)) { - messages.Say(common::LanguageFeature::RelaxedIntentInChecking, - "POINTER or ALLOCATABLE dummy and actual arguments should have the same declared type and kind"_port_en_US); - } + foldingContext.Warn(common::LanguageFeature::RelaxedIntentInChecking, + "POINTER or ALLOCATABLE dummy and actual arguments should have the same declared type and kind"_port_en_US); } else { messages.Say( "POINTER or ALLOCATABLE dummy and actual arguments must have the same declared type and kind"_err_en_US); @@ -991,13 +980,13 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, bool actualIsTemp{ !actualIsVariable || HasVectorSubscript(actual) || actualCoarrayRef}; if (actualIsTemp) { - messages.Say(common::UsageWarning::NonTargetPassedToTarget, + foldingContext.Warn(common::UsageWarning::NonTargetPassedToTarget, "Any pointer associated with TARGET %s during this call will not be associated with the value of '%s' afterwards"_warn_en_US, dummyName, actual.AsFortran()); } else { auto actualSymbolVector{GetSymbolVector(actual)}; if (!evaluate::GetLastTarget(actualSymbolVector)) { - messages.Say(common::UsageWarning::NonTargetPassedToTarget, + foldingContext.Warn(common::UsageWarning::NonTargetPassedToTarget, "Any pointer associated with TARGET %s during this call must not be used afterwards, as '%s' is not a target"_warn_en_US, dummyName, actual.AsFortran()); } @@ -1058,12 +1047,11 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, dummyName); } } - std::optional<std::string> warning; bool isHostDeviceProc{procedure.cudaSubprogramAttrs && *procedure.cudaSubprogramAttrs == common::CUDASubprogramAttrs::HostDevice}; if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr, - dummy.ignoreTKR, &warning, /*allowUnifiedMatchingRule=*/true, + dummy.ignoreTKR, /*allowUnifiedMatchingRule=*/true, isHostDeviceProc, &context.languageFeatures())) { auto toStr{[](std::optional<common::CUDADataAttr> x) { return x ? "ATTRIBUTES("s + @@ -1074,10 +1062,6 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy, "%s has %s but its associated actual argument has %s"_err_en_US, dummyName, toStr(dummyDataAttr), toStr(actualDataAttr)); } - if (warning && context.ShouldWarn(common::UsageWarning::CUDAUsage)) { - messages.Say(common::UsageWarning::CUDAUsage, "%s"_warn_en_US, - std::move(*warning)); - } } // Warning for breaking F'2023 change with character allocatables @@ -1131,9 +1115,8 @@ static void CheckProcedureArg(evaluate::ActualArgument &arg, evaluate::SayWithDeclaration(messages, *argProcSymbol, "Procedure binding '%s' passed as an actual argument"_err_en_US, argProcSymbol->name()); - } else if (context.ShouldWarn( - common::LanguageFeature::BindingAsProcedure)) { - evaluate::SayWithDeclaration(messages, *argProcSymbol, + } else { + evaluate::WarnWithDeclaration(foldingContext, *argProcSymbol, common::LanguageFeature::BindingAsProcedure, "Procedure binding '%s' passed as an actual argument"_port_en_US, argProcSymbol->name()); @@ -1185,15 +1168,14 @@ static void CheckProcedureArg(evaluate::ActualArgument &arg, messages.Say( "Actual procedure argument for %s of a PURE procedure must have an explicit interface"_err_en_US, dummyName); - } else if (context.ShouldWarn( - common::UsageWarning::ImplicitInterfaceActual)) { - messages.Say(common::UsageWarning::ImplicitInterfaceActual, + } else { + foldingContext.Warn( + common::UsageWarning::ImplicitInterfaceActual, "Actual procedure argument has an implicit interface which is not known to be compatible with %s which has an explicit interface"_warn_en_US, dummyName); } - } else if (warning && - context.ShouldWarn(common::UsageWarning::ProcDummyArgShapes)) { - messages.Say(common::UsageWarning::ProcDummyArgShapes, + } else if (warning) { + foldingContext.Warn(common::UsageWarning::ProcDummyArgShapes, "Actual procedure argument has possible interface incompatibility with %s: %s"_warn_en_US, dummyName, std::move(*warning)); } @@ -1368,16 +1350,14 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg, messages.Say( "NULL() actual argument '%s' may not be associated with allocatable dummy argument %s that is INTENT(OUT) or INTENT(IN OUT)"_err_en_US, expr->AsFortran(), dummyName); - } else if (object.intent == common::Intent::Default && - context.ShouldWarn(common::UsageWarning:: - NullActualForDefaultIntentAllocatable)) { - messages.Say(common::UsageWarning:: - NullActualForDefaultIntentAllocatable, + } else if (object.intent == common::Intent::Default) { + foldingContext.Warn( + common::UsageWarning:: + NullActualForDefaultIntentAllocatable, "NULL() actual argument '%s' should not be associated with allocatable dummy argument %s without INTENT(IN)"_warn_en_US, expr->AsFortran(), dummyName); - } else if (context.ShouldWarn(common::LanguageFeature:: - NullActualForAllocatable)) { - messages.Say( + } else { + foldingContext.Warn( common::LanguageFeature::NullActualForAllocatable, "Allocatable %s is associated with %s"_port_en_US, dummyName, expr->AsFortran()); @@ -1395,8 +1375,7 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg, assumed.name(), dummyName); } else if (object.type.attrs().test(characteristics:: TypeAndShape::Attr::AssumedRank) && - !IsAssumedShape(assumed) && - !evaluate::IsAssumedRank(assumed)) { + !IsAssumedShape(assumed) && !IsAssumedRank(assumed)) { messages.Say( // C711 "Assumed-type '%s' must be either assumed shape or assumed rank to be associated with assumed rank %s"_err_en_US, assumed.name(), dummyName); @@ -1567,7 +1546,7 @@ static void CheckAssociated(evaluate::ActualArguments &arguments, if (semanticsContext.ShouldWarn(common::UsageWarning::Portability)) { if (!evaluate::ExtractDataRef(*pointerExpr) && !evaluate::IsProcedurePointer(*pointerExpr)) { - messages.Say(common::UsageWarning::Portability, + foldingContext.Warn(common::UsageWarning::Portability, pointerArg->sourceLocation(), "POINTER= argument of ASSOCIATED() is required by some other compilers to be a pointer"_port_en_US); } else if (scope && !evaluate::UnwrapProcedureRef(*pointerExpr)) { @@ -1578,7 +1557,8 @@ static void CheckAssociated(evaluate::ActualArguments &arguments, DefinabilityFlag::DoNotNoteDefinition}, *pointerExpr)}) { if (whyNot->IsFatal()) { - if (auto *msg{messages.Say(common::UsageWarning::Portability, + if (auto *msg{foldingContext.Warn( + common::UsageWarning::Portability, pointerArg->sourceLocation(), "POINTER= argument of ASSOCIATED() is required by some other compilers to be a valid left-hand side of a pointer assignment statement"_port_en_US)}) { msg->Attach(std::move( @@ -2005,8 +1985,9 @@ static void CheckReduce( } } } - const auto *result{ - procChars ? procChars->functionResult->GetTypeAndShape() : nullptr}; + const auto *result{procChars && procChars->functionResult + ? procChars->functionResult->GetTypeAndShape() + : nullptr}; if (!procChars || !procChars->IsPure() || procChars->dummyArguments.size() != 2 || !procChars->functionResult) { messages.Say( @@ -2092,10 +2073,8 @@ static void CheckReduce( // TRANSFER (16.9.193) static void CheckTransferOperandType(SemanticsContext &context, const evaluate::DynamicType &type, const char *which) { - if (type.IsPolymorphic() && - context.ShouldWarn(common::UsageWarning::PolymorphicTransferArg)) { - context.foldingContext().messages().Say( - common::UsageWarning::PolymorphicTransferArg, + if (type.IsPolymorphic()) { + context.foldingContext().Warn(common::UsageWarning::PolymorphicTransferArg, "%s of TRANSFER is polymorphic"_warn_en_US, which); } else if (!type.IsUnlimitedPolymorphic() && type.category() == TypeCategory::Derived && @@ -2103,7 +2082,7 @@ static void CheckTransferOperandType(SemanticsContext &context, DirectComponentIterator directs{type.GetDerivedTypeSpec()}; if (auto bad{std::find_if(directs.begin(), directs.end(), IsDescriptor)}; bad != directs.end()) { - evaluate::SayWithDeclaration(context.foldingContext().messages(), *bad, + evaluate::WarnWithDeclaration(context.foldingContext(), *bad, common::UsageWarning::PointerComponentTransferArg, "%s of TRANSFER contains allocatable or pointer component %s"_warn_en_US, which, bad.BuildResultDesignatorName()); @@ -2133,8 +2112,8 @@ static void CheckTransfer(evaluate::ActualArguments &arguments, messages.Say( "Element size of MOLD= array may not be zero when SOURCE= is not empty"_err_en_US); } - } else if (context.ShouldWarn(common::UsageWarning::VoidMold)) { - messages.Say(common::UsageWarning::VoidMold, + } else { + foldingContext.Warn(common::UsageWarning::VoidMold, "Element size of MOLD= array may not be zero unless SOURCE= is empty"_warn_en_US); } } @@ -2150,7 +2129,7 @@ static void CheckTransfer(evaluate::ActualArguments &arguments, } else if (context.ShouldWarn( common::UsageWarning::TransferSizePresence) && IsAllocatableOrObjectPointer(whole)) { - messages.Say(common::UsageWarning::TransferSizePresence, + foldingContext.Warn(common::UsageWarning::TransferSizePresence, "SIZE= argument that is allocatable or pointer must be present at execution; parenthesize to silence this warning"_warn_en_US); } } @@ -2373,13 +2352,10 @@ bool CheckArguments(const characteristics::Procedure &proc, /*extentErrors=*/true, ignoreImplicitVsExplicit)}; if (!buffer.empty()) { if (treatingExternalAsImplicit) { - if (context.ShouldWarn( - common::UsageWarning::KnownBadImplicitInterface)) { - if (auto *msg{messages.Say( - common::UsageWarning::KnownBadImplicitInterface, - "If the procedure's interface were explicit, this reference would be in error"_warn_en_US)}) { - buffer.AttachTo(*msg, parser::Severity::Because); - } + if (auto *msg{foldingContext.Warn( + common::UsageWarning::KnownBadImplicitInterface, + "If the procedure's interface were explicit, this reference would be in error"_warn_en_US)}) { + buffer.AttachTo(*msg, parser::Severity::Because); } else { buffer.clear(); } diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index d769f22..84edceb 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -130,21 +130,14 @@ private: } template <typename FeatureOrUsageWarning, typename... A> parser::Message *Warn(FeatureOrUsageWarning warning, A &&...x) { - if (!context_.ShouldWarn(warning) || InModuleFile()) { - return nullptr; - } else { - return messages_.Say(warning, std::forward<A>(x)...); - } + return messages_.Warn(InModuleFile(), context_.languageFeatures(), warning, + std::forward<A>(x)...); } template <typename FeatureOrUsageWarning, typename... A> parser::Message *Warn( FeatureOrUsageWarning warning, parser::CharBlock source, A &&...x) { - if (!context_.ShouldWarn(warning) || - FindModuleFileContaining(context_.FindScope(source))) { - return nullptr; - } else { - return messages_.Say(warning, source, std::forward<A>(x)...); - } + return messages_.Warn(FindModuleFileContaining(context_.FindScope(source)), + context_.languageFeatures(), warning, source, std::forward<A>(x)...); } bool IsResultOkToDiffer(const FunctionResult &); void CheckGlobalName(const Symbol &); @@ -326,7 +319,7 @@ void CheckHelper::Check(const Symbol &symbol) { !IsDummy(symbol)) { if (context_.IsEnabled( common::LanguageFeature::IgnoreIrrelevantAttributes)) { - context_.Warn(common::LanguageFeature::IgnoreIrrelevantAttributes, + Warn(common::LanguageFeature::IgnoreIrrelevantAttributes, "Only a dummy argument should have an INTENT, VALUE, or OPTIONAL attribute"_warn_en_US); } else { messages_.Say( @@ -633,7 +626,7 @@ void CheckHelper::CheckValue( "VALUE attribute may not apply to a type with a coarray ultimate component"_err_en_US); } } - if (evaluate::IsAssumedRank(symbol)) { + if (IsAssumedRank(symbol)) { messages_.Say( "VALUE attribute may not apply to an assumed-rank array"_err_en_US); } @@ -743,7 +736,7 @@ void CheckHelper::CheckObjectEntity( "Coarray '%s' may not have type TEAM_TYPE, C_PTR, or C_FUNPTR"_err_en_US, symbol.name()); } - if (evaluate::IsAssumedRank(symbol)) { + if (IsAssumedRank(symbol)) { messages_.Say("Coarray '%s' may not be an assumed-rank array"_err_en_US, symbol.name()); } @@ -889,7 +882,7 @@ void CheckHelper::CheckObjectEntity( "!DIR$ IGNORE_TKR may not apply to an allocatable or pointer"_err_en_US); } } else if (ignoreTKR.test(common::IgnoreTKR::Rank)) { - if (ignoreTKR.count() == 1 && evaluate::IsAssumedRank(symbol)) { + if (ignoreTKR.count() == 1 && IsAssumedRank(symbol)) { Warn(common::UsageWarning::IgnoreTKRUsage, "!DIR$ IGNORE_TKR(R) is not meaningful for an assumed-rank array"_warn_en_US); } else if (inExplicitExternalInterface) { @@ -1214,7 +1207,7 @@ void CheckHelper::CheckObjectEntity( SayWithDeclaration(symbol, "Deferred-shape entity of %s type is not supported"_err_en_US, typeName); - } else if (evaluate::IsAssumedRank(symbol)) { + } else if (IsAssumedRank(symbol)) { SayWithDeclaration(symbol, "Assumed rank entity of %s type is not supported"_err_en_US, typeName); @@ -2428,7 +2421,7 @@ void CheckHelper::CheckVolatile(const Symbol &symbol, void CheckHelper::CheckContiguous(const Symbol &symbol) { if (evaluate::IsVariable(symbol) && ((IsPointer(symbol) && symbol.Rank() > 0) || IsAssumedShape(symbol) || - evaluate::IsAssumedRank(symbol))) { + IsAssumedRank(symbol))) { } else { parser::MessageFixedText msg{symbol.owner().IsDerivedType() ? "CONTIGUOUS component '%s' should be an array with the POINTER attribute"_port_en_US @@ -2957,7 +2950,7 @@ static bool IsSubprogramDefinition(const Symbol &symbol) { static bool IsExternalProcedureDefinition(const Symbol &symbol) { return IsBlockData(symbol) || - (IsSubprogramDefinition(symbol) && + ((IsSubprogramDefinition(symbol) || IsAlternateEntry(&symbol)) && (IsExternal(symbol) || symbol.GetBindName())); } @@ -3141,16 +3134,14 @@ parser::Messages CheckHelper::WhyNotInteroperableDerivedType( *dyType, &context_.languageFeatures()) .value_or(false)) { if (type->category() == DeclTypeSpec::Logical) { - if (context_.ShouldWarn(common::UsageWarning::LogicalVsCBool)) { - msgs.Say(common::UsageWarning::LogicalVsCBool, component.name(), - "A LOGICAL component of an interoperable type should have the interoperable KIND=C_BOOL"_port_en_US); - } + context().Warn(msgs, common::UsageWarning::LogicalVsCBool, + component.name(), + "A LOGICAL component of an interoperable type should have the interoperable KIND=C_BOOL"_port_en_US); } else if (type->category() == DeclTypeSpec::Character && dyType && dyType->kind() == 1) { - if (context_.ShouldWarn(common::UsageWarning::BindCCharLength)) { - msgs.Say(common::UsageWarning::BindCCharLength, component.name(), - "A CHARACTER component of an interoperable type should have length 1"_port_en_US); - } + context().Warn(msgs, common::UsageWarning::BindCCharLength, + component.name(), + "A CHARACTER component of an interoperable type should have length 1"_port_en_US); } else { msgs.Say(component.name(), "Each component of an interoperable derived type must have an interoperable type"_err_en_US); @@ -3165,10 +3156,9 @@ parser::Messages CheckHelper::WhyNotInteroperableDerivedType( } } if (derived->componentNames().empty()) { // F'2023 C1805 - if (context_.ShouldWarn(common::LanguageFeature::EmptyBindCDerivedType)) { - msgs.Say(common::LanguageFeature::EmptyBindCDerivedType, symbol.name(), - "A derived type with the BIND attribute should not be empty"_warn_en_US); - } + context().Warn(msgs, common::LanguageFeature::EmptyBindCDerivedType, + symbol.name(), + "A derived type with the BIND attribute should not be empty"_warn_en_US); } } if (msgs.AnyFatalError()) { @@ -3218,7 +3208,7 @@ parser::Messages CheckHelper::WhyNotInteroperableObject( if (derived && !derived->typeSymbol().attrs().test(Attr::BIND_C)) { if (allowNonInteroperableType) { // portability warning only evaluate::AttachDeclaration( - context_.Warn(common::UsageWarning::Portability, symbol.name(), + Warn(common::UsageWarning::Portability, symbol.name(), "The derived type of this interoperable object should be BIND(C)"_port_en_US), derived->typeSymbol()); } else if (!context_.IsEnabled( @@ -3260,10 +3250,10 @@ parser::Messages CheckHelper::WhyNotInteroperableObject( } else if (type->category() == DeclTypeSpec::Logical) { if (context_.ShouldWarn(common::UsageWarning::LogicalVsCBool)) { if (IsDummy(symbol)) { - msgs.Say(common::UsageWarning::LogicalVsCBool, symbol.name(), + Warn(common::UsageWarning::LogicalVsCBool, symbol.name(), "A BIND(C) LOGICAL dummy argument should have the interoperable KIND=C_BOOL"_port_en_US); } else { - msgs.Say(common::UsageWarning::LogicalVsCBool, symbol.name(), + Warn(common::UsageWarning::LogicalVsCBool, symbol.name(), "A BIND(C) LOGICAL object should have the interoperable KIND=C_BOOL"_port_en_US); } } @@ -3459,7 +3449,7 @@ void CheckHelper::CheckBindC(const Symbol &symbol) { bool CheckHelper::CheckDioDummyIsData( const Symbol &subp, const Symbol *arg, std::size_t position) { if (arg && arg->detailsIf<ObjectEntityDetails>()) { - if (evaluate::IsAssumedRank(*arg)) { + if (IsAssumedRank(*arg)) { messages_.Say(arg->name(), "Dummy argument '%s' may not be assumed-rank"_err_en_US, arg->name()); return false; diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index a5fdabf..f25497e 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -11,13 +11,16 @@ //===----------------------------------------------------------------------===// #include "check-omp-structure.h" -#include "openmp-utils.h" #include "flang/Common/indirection.h" +#include "flang/Common/template.h" #include "flang/Evaluate/expression.h" +#include "flang/Evaluate/match.h" +#include "flang/Evaluate/rewrite.h" #include "flang/Evaluate/tools.h" #include "flang/Parser/char-block.h" #include "flang/Parser/parse-tree.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "flang/Semantics/type.h" @@ -42,11 +45,167 @@ using namespace Fortran::semantics::omp; namespace operation = Fortran::evaluate::operation; +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr); + template <typename T, typename U> static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) { return !(e == f); } +namespace { +template <typename...> struct IsIntegral { + static constexpr bool value{false}; +}; + +template <common::TypeCategory C, int K> +struct IsIntegral<evaluate::Type<C, K>> { + static constexpr bool value{// + C == common::TypeCategory::Integer || + C == common::TypeCategory::Unsigned || + C == common::TypeCategory::Logical}; +}; + +template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value}; + +template <typename...> struct IsFloatingPoint { + static constexpr bool value{false}; +}; + +template <common::TypeCategory C, int K> +struct IsFloatingPoint<evaluate::Type<C, K>> { + static constexpr bool value{// + C == common::TypeCategory::Real || C == common::TypeCategory::Complex}; +}; + +template <typename T> +constexpr bool is_floating_point_v{IsFloatingPoint<T>::value}; + +template <typename T> +constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>}; + +template <typename T, typename Op0, typename Op1> +using ReassocOpBase = evaluate::match::AnyOfPattern< // + evaluate::match::Add<T, Op0, Op1>, // + evaluate::match::Mul<T, Op0, Op1>>; + +template <typename T, typename Op0, typename Op1> +struct ReassocOp : public ReassocOpBase<T, Op0, Op1> { + using Base = ReassocOpBase<T, Op0, Op1>; + using Base::Base; +}; + +template <typename T, typename Op0, typename Op1> +ReassocOp<T, Op0, Op1> reassocOp(const Op0 &op0, const Op1 &op1) { + return ReassocOp<T, Op0, Op1>(op0, op1); +} +} // namespace + +struct ReassocRewriter : public evaluate::rewrite::Identity { + using Id = evaluate::rewrite::Identity; + struct NonIntegralTag {}; + + ReassocRewriter(const SomeExpr &atom, const SemanticsContext &context) + : atom_(atom), context_(context) {} + + // Try to find cases where the input expression is of the form + // (1) (a . b) . c, or + // (2) a . (b . c), + // where . denotes an associative operation (currently + or *), and a, b, c + // are some subexpresions. + // If one of the operands in the nested operation is the atomic variable + // (with some possible type conversions applied to it), bring it to the + // top-level operation, and move the top-level operand into the nested + // operation. + // For example, assuming x is the atomic variable: + // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b. + template <typename T, typename U, + typename = std::enable_if_t<is_numeric_v<T>>> + evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) { + if constexpr (is_floating_point_v<T>) { + if (!context_.langOptions().AssociativeMath) { + return Id::operator()(std::move(x), u); + } + } + // As per the above comment, there are 3 subexpressions involved in this + // transformation. A match::Expr<T> will match evaluate::Expr<U> when T is + // same as U, plus it will store a pointer (ref) to the matched expression. + // When the match is successful, the sub[i].ref will point to a, b, x (in + // some order) from the example above. + evaluate::match::Expr<T> sub[3]; + auto inner{reassocOp<T>(sub[0], sub[1])}; + auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something + auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner +#if !defined(__clang__) && !defined(_MSC_VER) && \ + (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5)) + // If GCC version < 8.5, use this definition. For the other definition + // (which is equivalent), GCC 7.5 emits a somewhat cryptic error: + // use of ‘outer1’ before deduction of ‘auto’ + // inside of the visitor function in common::visit. + // Since this works with clang, MSVC and at least GCC 8.5, I'm assuming + // that this is some kind of a GCC issue. + using MatchTypes = std::tuple<evaluate::Add<T>, evaluate::Multiply<T>>; +#else + using MatchTypes = typename decltype(outer1)::MatchTypes; +#endif + // There is no way to ensure that the outer operation is the same as + // the inner one. They are matched independently, so we need to compare + // the index in the member variant that represents the matched type. + if ((match(outer1, x) && outer1.ref.index() == inner.ref.index()) || + (match(outer2, x) && outer2.ref.index() == inner.ref.index())) { + size_t atomIdx{[&]() { // sub[atomIdx] will be the atom. + size_t idx; + for (idx = 0; idx != 3; ++idx) { + if (IsAtom(*sub[idx].ref)) { + break; + } + } + return idx; + }()}; + + if (atomIdx > 2) { + return Id::operator()(std::move(x), u); + } + return common::visit( + [&](auto &&s) { + using Expr = evaluate::Expr<T>; + using TypeS = llvm::remove_cvref_t<decltype(s)>; + // This visitor has to be semantically correct for all possible + // types of s even though at runtime s will only be one of the + // matched types. + // Limit the construction to the operation types that we tried + // to match (otherwise TypeS(op1, op2) would fail for non-binary + // operations). + if constexpr (common::HasMember<TypeS, MatchTypes>) { + Expr atom{*sub[atomIdx].ref}; + Expr op1{*sub[(atomIdx + 1) % 3].ref}; + Expr op2{*sub[(atomIdx + 2) % 3].ref}; + return Expr( + TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2))))); + } else { + return Expr(TypeS(s)); + } + }, + evaluate::match::deparen(x).u); + } + return Id::operator()(std::move(x), u); + } + + template <typename T, typename U, + typename = std::enable_if_t<!is_numeric_v<T>>> + evaluate::Expr<T> operator()( + evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) { + return Id::operator()(std::move(x), u); + } + +private: + template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const { + return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_); + } + + const SomeExpr &atom_; + const SemanticsContext &context_; +}; + struct AnalyzedCondStmt { SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted parser::CharBlock source; @@ -196,6 +355,26 @@ static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource( llvm_unreachable("Could not find assignment operator"); } +static std::vector<SomeExpr> GetNonAtomExpressions( + const SomeExpr &atom, const std::vector<SomeExpr> &exprs) { + std::vector<SomeExpr> nonAtom; + for (const SomeExpr &e : exprs) { + if (!IsSameOrConvertOf(e, atom)) { + nonAtom.push_back(e); + } + } + return nonAtom; +} + +static std::vector<SomeExpr> GetNonAtomArguments( + const SomeExpr &atom, const SomeExpr &expr) { + if (auto &&maybe{GetConvertInput(expr)}) { + return GetNonAtomExpressions( + atom, GetTopLevelOperationIgnoreResizing(*maybe).second); + } + return {}; +} + static bool IsCheckForAssociated(const SomeExpr &cond) { return GetTopLevelOperationIgnoreResizing(cond).first == operation::Operator::Associated; @@ -222,47 +401,85 @@ static void SetAssignment(parser::AssignmentStmt::TypedAssignment &assign, } } -static parser::OpenMPAtomicConstruct::Analysis::Op MakeAtomicAnalysisOp( - int what, - const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) { - parser::OpenMPAtomicConstruct::Analysis::Op operation; - operation.what = what; - SetAssignment(operation.assign, maybeAssign); - return operation; -} +namespace { +struct AtomicAnalysis { + AtomicAnalysis(const SomeExpr &atom, const MaybeExpr &cond = std::nullopt) + : atom_(atom), cond_(cond) {} -static parser::OpenMPAtomicConstruct::Analysis MakeAtomicAnalysis( - const SomeExpr &atom, const MaybeExpr &cond, - parser::OpenMPAtomicConstruct::Analysis::Op &&op0, - parser::OpenMPAtomicConstruct::Analysis::Op &&op1) { - // Defined in flang/include/flang/Parser/parse-tree.h - // - // struct Analysis { - // struct Kind { - // static constexpr int None = 0; - // static constexpr int Read = 1; - // static constexpr int Write = 2; - // static constexpr int Update = Read | Write; - // static constexpr int Action = 3; // Bits containing N, R, W, U - // static constexpr int IfTrue = 4; - // static constexpr int IfFalse = 8; - // static constexpr int Condition = 12; // Bits containing IfTrue, IfFalse - // }; - // struct Op { - // int what; - // TypedAssignment assign; - // }; - // TypedExpr atom, cond; - // Op op0, op1; - // }; - - parser::OpenMPAtomicConstruct::Analysis an; - SetExpr(an.atom, atom); - SetExpr(an.cond, cond); - an.op0 = std::move(op0); - an.op1 = std::move(op1); - return an; -} + AtomicAnalysis &addOp0(int what, + const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) { + return addOp(op0_, what, maybeAssign); + } + AtomicAnalysis &addOp1(int what, + const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) { + return addOp(op1_, what, maybeAssign); + } + + operator parser::OpenMPAtomicConstruct::Analysis() const { + // Defined in flang/include/flang/Parser/parse-tree.h + // + // struct Analysis { + // struct Kind { + // static constexpr int None = 0; + // static constexpr int Read = 1; + // static constexpr int Write = 2; + // static constexpr int Update = Read | Write; + // static constexpr int Action = 3; // Bits containing None, Read, + // // Write, Update + // static constexpr int IfTrue = 4; + // static constexpr int IfFalse = 8; + // static constexpr int Condition = 12; // Bits containing IfTrue, + // // IfFalse + // }; + // struct Op { + // int what; + // TypedAssignment assign; + // }; + // TypedExpr atom, cond; + // Op op0, op1; + // }; + + parser::OpenMPAtomicConstruct::Analysis an; + SetExpr(an.atom, atom_); + SetExpr(an.cond, cond_); + an.op0 = std::move(op0_); + an.op1 = std::move(op1_); + return an; + } + +private: + struct Op { + operator parser::OpenMPAtomicConstruct::Analysis::Op() const { + parser::OpenMPAtomicConstruct::Analysis::Op op; + op.what = what; + SetAssignment(op.assign, assign); + return op; + } + + int what; + std::optional<evaluate::Assignment> assign; + }; + + AtomicAnalysis &addOp(Op &op, int what, + const std::optional<evaluate::Assignment> &maybeAssign) { + op.what = what; + if (maybeAssign) { + if (MaybeExpr rewritten{PostSemaRewrite(atom_, maybeAssign->rhs)}) { + op.assign = evaluate::Assignment( + AsRvalue(maybeAssign->lhs), std::move(*rewritten)); + op.assign->u = std::move(maybeAssign->u); + } else { + op.assign = *maybeAssign; + } + } + return *this; + } + + const SomeExpr &atom_; + const MaybeExpr &cond_; + Op op0_, op1_; +}; +} // namespace /// Check if `expr` satisfies the following conditions for x and v: /// @@ -535,6 +752,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment( const evaluate::Assignment &capture, const SomeExpr &atom, parser::CharBlock source) { auto [lsrc, rsrc]{SplitAssignmentSource(source)}; + (void)lsrc; const SomeExpr &cap{capture.lhs}; if (!IsVarOrFunctionRef(atom)) { @@ -551,6 +769,7 @@ void OmpStructureChecker::CheckAtomicCaptureAssignment( void OmpStructureChecker::CheckAtomicReadAssignment( const evaluate::Assignment &read, parser::CharBlock source) { auto [lsrc, rsrc]{SplitAssignmentSource(source)}; + (void)lsrc; if (auto maybe{GetConvertInput(read.rhs)}) { const SomeExpr &atom{*maybe}; @@ -584,7 +803,8 @@ void OmpStructureChecker::CheckAtomicWriteAssignment( } } -void OmpStructureChecker::CheckAtomicUpdateAssignment( +std::optional<evaluate::Assignment> +OmpStructureChecker::CheckAtomicUpdateAssignment( const evaluate::Assignment &update, parser::CharBlock source) { // [6.0:191:1-7] // An update structured block is update-statement, an update statement @@ -600,14 +820,47 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment( if (!IsVarOrFunctionRef(atom)) { ErrorShouldBeVariable(atom, rsrc); // Skip other checks. - return; + return std::nullopt; } CheckAtomicVariable(atom, lsrc); + auto [hasErrors, tryReassoc]{CheckAtomicUpdateAssignmentRhs( + atom, update.rhs, source, /*suppressDiagnostics=*/true)}; + + if (!hasErrors) { + CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source); + return std::nullopt; + } else if (tryReassoc) { + ReassocRewriter ra(atom, context_); + SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)}; + + std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs( + atom, raRhs, source, /*suppressDiagnostics=*/true); + if (!hasErrors) { + CheckStorageOverlap(atom, GetNonAtomArguments(atom, raRhs), source); + + evaluate::Assignment raAssign(update); + raAssign.rhs = raRhs; + return raAssign; + } + } + + // This is guaranteed to report errors. + CheckAtomicUpdateAssignmentRhs( + atom, update.rhs, source, /*suppressDiagnostics=*/false); + return std::nullopt; +} + +std::pair<bool, bool> OmpStructureChecker::CheckAtomicUpdateAssignmentRhs( + const SomeExpr &atom, const SomeExpr &rhs, parser::CharBlock source, + bool suppressDiagnostics) { + auto [lsrc, rsrc]{SplitAssignmentSource(source)}; + (void)lsrc; + std::pair<operation::Operator, std::vector<SomeExpr>> top{ operation::Operator::Unknown, {}}; - if (auto &&maybeInput{GetConvertInput(update.rhs)}) { + if (auto &&maybeInput{GetConvertInput(rhs)}) { top = GetTopLevelOperationIgnoreResizing(*maybeInput); } switch (top.first) { @@ -624,29 +877,39 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment( case operation::Operator::Identity: break; case operation::Operator::Call: - context_.Say(source, - "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US); - return; + if (!suppressDiagnostics) { + context_.Say(source, + "A call to this function is not a valid ATOMIC UPDATE operation"_err_en_US); + } + return std::make_pair(true, false); case operation::Operator::Convert: - context_.Say(source, - "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US); - return; + if (!suppressDiagnostics) { + context_.Say(source, + "An implicit or explicit type conversion is not a valid ATOMIC UPDATE operation"_err_en_US); + } + return std::make_pair(true, false); case operation::Operator::Intrinsic: - context_.Say(source, - "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US); - return; + if (!suppressDiagnostics) { + context_.Say(source, + "This intrinsic function is not a valid ATOMIC UPDATE operation"_err_en_US); + } + return std::make_pair(true, false); case operation::Operator::Constant: case operation::Operator::Unknown: - context_.Say( - source, "This is not a valid ATOMIC UPDATE operation"_err_en_US); - return; + if (!suppressDiagnostics) { + context_.Say( + source, "This is not a valid ATOMIC UPDATE operation"_err_en_US); + } + return std::make_pair(true, false); default: assert( top.first != operation::Operator::Identity && "Handle this separately"); - context_.Say(source, - "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US, - operation::ToString(top.first)); - return; + if (!suppressDiagnostics) { + context_.Say(source, + "The %s operator is not a valid ATOMIC UPDATE operation"_err_en_US, + operation::ToString(top.first)); + } + return std::make_pair(true, false); } // Check how many times `atom` occurs as an argument, if it's a subexpression // of an argument, and collect the non-atom arguments. @@ -667,39 +930,48 @@ void OmpStructureChecker::CheckAtomicUpdateAssignment( return count; }()}; - bool hasError{false}; + bool hasError{false}, tryReassoc{false}; if (subExpr) { - context_.Say(rsrc, - "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US, - atom.AsFortran(), subExpr->AsFortran()); + if (!suppressDiagnostics) { + context_.Say(rsrc, + "The atomic variable %s cannot be a proper subexpression of an argument (here: %s) in the update operation"_err_en_US, + atom.AsFortran(), subExpr->AsFortran()); + } hasError = true; } if (top.first == operation::Operator::Identity) { // This is "x = y". assert((atomCount == 0 || atomCount == 1) && "Unexpected count"); if (atomCount == 0) { - context_.Say(rsrc, - "The atomic variable %s should appear as an argument in the update operation"_err_en_US, - atom.AsFortran()); + if (!suppressDiagnostics) { + context_.Say(rsrc, + "The atomic variable %s should appear as an argument in the update operation"_err_en_US, + atom.AsFortran()); + } hasError = true; } } else { if (atomCount == 0) { - context_.Say(rsrc, - "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US, - atom.AsFortran(), operation::ToString(top.first)); + if (!suppressDiagnostics) { + context_.Say(rsrc, + "The atomic variable %s should appear as an argument of the top-level %s operator"_err_en_US, + atom.AsFortran(), operation::ToString(top.first)); + } + // If `atom` is a proper subexpression, and it not present as an + // argument on its own, reassociation may be able to help. + tryReassoc = subExpr.has_value(); hasError = true; } else if (atomCount > 1) { - context_.Say(rsrc, - "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US, - atom.AsFortran(), operation::ToString(top.first)); + if (!suppressDiagnostics) { + context_.Say(rsrc, + "The atomic variable %s should be exactly one of the arguments of the top-level %s operator"_err_en_US, + atom.AsFortran(), operation::ToString(top.first)); + } hasError = true; } } - if (!hasError) { - CheckStorageOverlap(atom, nonAtom, source); - } + return std::make_pair(hasError, tryReassoc); } void OmpStructureChecker::CheckAtomicConditionalUpdateAssignment( @@ -802,12 +1074,14 @@ void OmpStructureChecker::CheckAtomicUpdateOnly( SourcedActionStmt action{GetActionStmt(&body.front())}; if (auto maybeUpdate{GetEvaluateAssignment(action.stmt)}) { const SomeExpr &atom{maybeUpdate->lhs}; - CheckAtomicUpdateAssignment(*maybeUpdate, action.source); + auto maybeAssign{ + CheckAtomicUpdateAssignment(*maybeUpdate, action.source)}; + auto &updateAssign{maybeAssign.has_value() ? maybeAssign : maybeUpdate}; using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Update, maybeUpdate), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Update, updateAssign) + .addOp1(Analysis::None); } else if (!IsAssignment(action.stmt)) { context_.Say( source, "ATOMIC UPDATE operation should be an assignment"_err_en_US); @@ -889,9 +1163,11 @@ void OmpStructureChecker::CheckAtomicConditionalUpdate( } using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(assign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Update | Analysis::IfTrue, assign), - MakeAtomicAnalysisOp(Analysis::None)); + const SomeExpr &atom{assign.lhs}; + + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Update | Analysis::IfTrue, assign) + .addOp1(Analysis::None); } void OmpStructureChecker::CheckAtomicUpdateCapture( @@ -920,29 +1196,32 @@ void OmpStructureChecker::CheckAtomicUpdateCapture( using Analysis = parser::OpenMPAtomicConstruct::Analysis; int action; + std::optional<evaluate::Assignment> updateAssign{update}; if (IsMaybeAtomicWrite(update)) { action = Analysis::Write; CheckAtomicWriteAssignment(update, uact.source); } else { action = Analysis::Update; - CheckAtomicUpdateAssignment(update, uact.source); + if (auto &&maybe{CheckAtomicUpdateAssignment(update, uact.source)}) { + updateAssign = maybe; + } } CheckAtomicCaptureAssignment(capture, atom, cact.source); - if (IsPointerAssignment(update) != IsPointerAssignment(capture)) { + if (IsPointerAssignment(*updateAssign) != IsPointerAssignment(capture)) { context_.Say(cact.source, "The update and capture assignments should both be pointer-assignments or both be non-pointer-assignments"_err_en_US); return; } if (GetActionStmt(&body.front()).stmt == uact.stmt) { - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(action, update), - MakeAtomicAnalysisOp(Analysis::Read, capture)); + x.analysis = AtomicAnalysis(atom) + .addOp0(action, updateAssign) + .addOp1(Analysis::Read, capture); } else { - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Read, capture), - MakeAtomicAnalysisOp(action, update)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Read, capture) + .addOp1(action, updateAssign); } } @@ -1087,15 +1366,16 @@ void OmpStructureChecker::CheckAtomicConditionalUpdateCapture( evaluate::Assignment updAssign{*GetEvaluateAssignment(update.ift.stmt)}; evaluate::Assignment capAssign{*GetEvaluateAssignment(capture.stmt)}; + const SomeExpr &atom{updAssign.lhs}; if (captureFirst) { - x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign), - MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign)); + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Read | captureWhen, capAssign) + .addOp1(Analysis::Write | updateWhen, updAssign); } else { - x.analysis = MakeAtomicAnalysis(updAssign.lhs, update.cond, - MakeAtomicAnalysisOp(Analysis::Write | updateWhen, updAssign), - MakeAtomicAnalysisOp(Analysis::Read | captureWhen, capAssign)); + x.analysis = AtomicAnalysis(atom, update.cond) + .addOp0(Analysis::Write | updateWhen, updAssign) + .addOp1(Analysis::Read | captureWhen, capAssign); } } @@ -1125,9 +1405,9 @@ void OmpStructureChecker::CheckAtomicRead( if (auto maybe{GetConvertInput(maybeRead->rhs)}) { const SomeExpr &atom{*maybe}; using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Read, maybeRead), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Read, maybeRead) + .addOp1(Analysis::None); } } else if (!IsAssignment(action.stmt)) { context_.Say( @@ -1159,9 +1439,9 @@ void OmpStructureChecker::CheckAtomicWrite( CheckAtomicWriteAssignment(*maybeWrite, action.source); using Analysis = parser::OpenMPAtomicConstruct::Analysis; - x.analysis = MakeAtomicAnalysis(atom, std::nullopt, - MakeAtomicAnalysisOp(Analysis::Write, maybeWrite), - MakeAtomicAnalysisOp(Analysis::None)); + x.analysis = AtomicAnalysis(atom) + .addOp0(Analysis::Write, maybeWrite) + .addOp1(Analysis::None); } else if (!IsAssignment(action.stmt)) { context_.Say( x.source, "ATOMIC WRITE operation should be an assignment"_err_en_US); @@ -1260,4 +1540,118 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) { dirContext_.pop_back(); } +// Rewrite min/max: +// Min and max intrinsics in Fortran take an arbitrary number of arguments +// (two or more). The first two are mandatory, the rest is optional. That +// means that arguments beyond the first two may be optional dummy argument +// from the caller. In that case, a reference to such an argument will +// cause presence test to be emitted, which cannot go inside of the atomic +// operation. Since the atom operand must be present, rewrite the min/max +// operation in a way that avoid the presence tests in the atomic code. +// For example, in +// subroutine f(atom, x, y, z) +// integer :: atom, x +// integer, optional :: y, z +// !$omp atomic update +// atom = min(atom, x, y, z) +// end +// the min operation will become +// atom = min(atom, min(x, y, z)) +// and in the final code +// // Presence check is fine here. +// tmp = min(x, y, z) +// atomic update { +// // Both operands are mandatory, no presence check needed. +// atom = min(atom, tmp) +// } +struct MinMaxRewriter : public evaluate::rewrite::Identity { + using Id = evaluate::rewrite::Identity; + using Id::operator(); + + MinMaxRewriter(const SomeExpr &atom) : atom_(atom) {} + + static bool IsMinMax(const evaluate::ProcedureDesignator &p) { + if (auto *intrin{p.GetSpecificIntrinsic()}) { + return intrin->name == "min" || intrin->name == "max"; + } + return false; + } + + // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] + // One of the a_i's, say a_t, must be the atom. + // Generate + // min/max(a_t, min/max(a0, a1, ... [except a_t])) + template <typename T> + evaluate::Expr<T> operator()( + evaluate::Expr<T> &&x, const evaluate::FunctionRef<T> &f) { + const evaluate::ProcedureDesignator &proc = f.proc(); + if (!IsMinMax(proc) || f.arguments().size() <= 2) { + return Id::operator()(std::move(x), f); + } + + // Collect arguments as SomeExpr's and find out which argument + // corresponds to atom. + const SomeExpr *atomArg{nullptr}; + std::vector<const SomeExpr *> args; + for (const std::optional<evaluate::ActualArgument> &a : f.arguments()) { + if (!a) { + continue; + } + if (const SomeExpr *e{a->UnwrapExpr()}) { + if (evaluate::IsSameOrConvertOf(*e, atom_)) { + atomArg = e; + } + args.push_back(e); + } + } + if (!atomArg) { + return Id::operator()(std::move(x), f); + } + + evaluate::ActualArguments nonAtoms; + + auto AsActual = [](const SomeExpr &z) { + SomeExpr copy = z; + return evaluate::ActualArgument(std::move(copy)); + }; + // Semantic checks guarantee that the "atom" shows exactly once in the + // argument list (with potential conversions around it). + // For the first two (non-optional) arguments, if "atom" is among them, + // replace it with another occurrence of the other non-optional argument. + if (atomArg == args[0]) { + // (atom, x, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[1])); + nonAtoms.push_back(AsActual(*args[1])); + } else if (atomArg == args[1]) { + // (x, atom, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[0])); + } else { + // (x, y, z...) -> unchanged + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[1])); + } + + // The rest of arguments are optional, so we can just skip "atom". + for (size_t i = 2, e = args.size(); i != e; ++i) { + if (atomArg != args[i]) + nonAtoms.push_back(AsActual(*args[i])); + } + + SomeExpr tmp = evaluate::AsGenericExpr( + evaluate::FunctionRef<T>(AsRvalue(proc), AsRvalue(nonAtoms))); + + return evaluate::Expr<T>(evaluate::FunctionRef<T>( + AsRvalue(proc), {AsActual(*atomArg), AsActual(tmp)})); + } + +private: + const SomeExpr &atom_; +}; + +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr) { + MinMaxRewriter rewriter(atom); + return evaluate::rewrite::Mutator(rewriter)(expr); +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp index 59d57a2..9384e03 100644 --- a/flang/lib/Semantics/check-omp-loop.cpp +++ b/flang/lib/Semantics/check-omp-loop.cpp @@ -13,7 +13,6 @@ #include "check-omp-structure.h" #include "check-directive-structure.h" -#include "openmp-utils.h" #include "flang/Common/idioms.h" #include "flang/Common/visit.h" @@ -23,6 +22,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Parser/tools.h" #include "flang/Semantics/openmp-modifiers.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" @@ -196,7 +196,7 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) { common::visit( common::visitors{ // Allow `!$OMP ORDERED SIMD` - [&](const parser::OpenMPBlockConstruct &c) { + [&](const parser::OmpBlockConstruct &c) { const parser::OmpDirectiveSpecification &beginSpec{c.BeginDir()}; if (beginSpec.DirId() == llvm::omp::Directive::OMPD_ordered) { for (const auto &clause : beginSpec.Clauses().v) { diff --git a/flang/lib/Semantics/check-omp-metadirective.cpp b/flang/lib/Semantics/check-omp-metadirective.cpp index 03487da..cf5ea90 100644 --- a/flang/lib/Semantics/check-omp-metadirective.cpp +++ b/flang/lib/Semantics/check-omp-metadirective.cpp @@ -12,8 +12,6 @@ #include "check-omp-structure.h" -#include "openmp-utils.h" - #include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/visit.h" @@ -21,6 +19,7 @@ #include "flang/Parser/message.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-modifiers.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/tools.h" #include "llvm/Frontend/OpenMP/OMP.h" diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index a9c56c3..85d79a00 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -10,7 +10,6 @@ #include "check-directive-structure.h" #include "definable.h" -#include "openmp-utils.h" #include "resolve-names-utils.h" #include "flang/Common/idioms.h" @@ -21,12 +20,14 @@ #include "flang/Parser/char-block.h" #include "flang/Parser/characters.h" #include "flang/Parser/message.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" #include "flang/Parser/tools.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/openmp-modifiers.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" @@ -57,6 +58,7 @@ namespace Fortran::semantics { using namespace Fortran::semantics::omp; +using namespace Fortran::parser::omp; // Use when clause falls under 'struct OmpClause' in 'parse-tree.h'. #define CHECK_SIMPLE_CLAUSE(X, Y) \ @@ -141,6 +143,64 @@ private: parser::CharBlock source_; }; +// 'OmpWorkdistributeBlockChecker' is used to check the validity of the +// assignment statements and the expressions enclosed in an OpenMP +// WORKDISTRIBUTE construct +class OmpWorkdistributeBlockChecker { +public: + OmpWorkdistributeBlockChecker( + SemanticsContext &context, parser::CharBlock source) + : context_{context}, source_{source} {} + + template <typename T> bool Pre(const T &) { return true; } + template <typename T> void Post(const T &) {} + + bool Pre(const parser::AssignmentStmt &assignment) { + const auto &var{std::get<parser::Variable>(assignment.t)}; + const auto &expr{std::get<parser::Expr>(assignment.t)}; + const auto *lhs{GetExpr(context_, var)}; + const auto *rhs{GetExpr(context_, expr)}; + if (lhs && rhs) { + Tristate isDefined{semantics::IsDefinedAssignment( + lhs->GetType(), lhs->Rank(), rhs->GetType(), rhs->Rank())}; + if (isDefined == Tristate::Yes) { + context_.Say(expr.source, + "Defined assignment statement is not allowed in a WORKDISTRIBUTE construct"_err_en_US); + } + } + return true; + } + + bool Pre(const parser::Expr &expr) { + if (const auto *e{GetExpr(context_, expr)}) { + if (!e) + return false; + for (const Symbol &symbol : evaluate::CollectSymbols(*e)) { + const Symbol &root{GetAssociationRoot(symbol)}; + if (IsFunction(root)) { + std::vector<std::string> attrs; + if (!IsElementalProcedure(root)) { + attrs.push_back("non-ELEMENTAL"); + } + if (root.attrs().test(Attr::IMPURE)) { + attrs.push_back("IMPURE"); + } + std::string attrsStr = + attrs.empty() ? "" : " " + llvm::join(attrs, ", "); + context_.Say(expr.source, + "User defined%s function '%s' is not allowed in a WORKDISTRIBUTE construct"_err_en_US, + attrsStr, root.name()); + } + } + } + return false; + } + +private: + SemanticsContext &context_; + parser::CharBlock source_; +}; + // `OmpUnitedTaskDesignatorChecker` is used to check if the designator // can appear within the TASK construct class OmpUnitedTaskDesignatorChecker { @@ -208,6 +268,41 @@ bool OmpStructureChecker::CheckAllowedClause(llvmOmpClause clause) { return CheckAllowed(clause); } +void OmpStructureChecker::AnalyzeObject(const parser::OmpObject &object) { + if (std::holds_alternative<parser::Name>(object.u)) { + // Do not analyze common block names. The analyzer will flag an error + // on those. + return; + } + if (auto *symbol{GetObjectSymbol(object)}) { + // Eliminate certain kinds of symbols before running the analyzer to + // avoid confusing error messages. The analyzer assumes that the context + // of the object use is an expression, and some diagnostics are tailored + // to that. + if (symbol->has<DerivedTypeDetails>() || symbol->has<MiscDetails>()) { + // Type names, construct names, etc. + return; + } + if (auto *typeSpec{symbol->GetType()}) { + if (typeSpec->category() == DeclTypeSpec::Category::Character) { + // Don't pass character objects to the analyzer, it can emit somewhat + // cryptic errors (e.g. "'obj' is not an array"). Substrings are + // checked elsewhere in OmpStructureChecker. + return; + } + } + } + evaluate::ExpressionAnalyzer ea{context_}; + auto restore{ea.AllowWholeAssumedSizeArray(true)}; + common::visit([&](auto &&s) { ea.Analyze(s); }, object.u); +} + +void OmpStructureChecker::AnalyzeObjects(const parser::OmpObjectList &objects) { + for (const parser::OmpObject &object : objects.v) { + AnalyzeObject(object); + } +} + bool OmpStructureChecker::IsCloselyNestedRegion(const OmpDirectiveSet &set) { // Definition of close nesting: // @@ -529,22 +624,6 @@ template <typename Checker> struct DirectiveSpellingVisitor { checker_(GetDirName(x.t).source, Directive::OMPD_allocators); return false; } - bool Pre(const parser::OmpAssumeDirective &x) { - checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_assume); - return false; - } - bool Pre(const parser::OmpEndAssumeDirective &x) { - checker_(x.v.source, Directive::OMPD_assume); - return false; - } - bool Pre(const parser::OmpCriticalDirective &x) { - checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_critical); - return false; - } - bool Pre(const parser::OmpEndCriticalDirective &x) { - checker_(std::get<parser::Verbatim>(x.t).source, Directive::OMPD_critical); - return false; - } bool Pre(const parser::OmpMetadirectiveDirective &x) { checker_( std::get<parser::Verbatim>(x.t).source, Directive::OMPD_metadirective); @@ -579,6 +658,10 @@ template <typename Checker> struct DirectiveSpellingVisitor { Directive::OMPD_declare_variant); return false; } + bool Pre(const parser::OpenMPGroupprivate &x) { + checker_(x.v.DirName().source, Directive::OMPD_groupprivate); + return false; + } bool Pre(const parser::OpenMPThreadprivate &x) { checker_( std::get<parser::Verbatim>(x.t).source, Directive::OMPD_threadprivate); @@ -731,7 +814,7 @@ void OmpStructureChecker::CheckTargetNest(const parser::OpenMPConstruct &c) { parser::CharBlock source; common::visit( common::visitors{ - [&](const parser::OpenMPBlockConstruct &c) { + [&](const parser::OmpBlockConstruct &c) { const parser::OmpDirectiveSpecification &beginSpec{c.BeginDir()}; source = beginSpec.DirName().source; if (beginSpec.DirId() == llvm::omp::Directive::OMPD_target_data) { @@ -781,12 +864,44 @@ void OmpStructureChecker::CheckTargetNest(const parser::OpenMPConstruct &c) { } } -void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { +void OmpStructureChecker::Enter(const parser::OmpBlockConstruct &x) { const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()}; const std::optional<parser::OmpEndDirective> &endSpec{x.EndDir()}; const parser::Block &block{std::get<parser::Block>(x.t)}; PushContextAndClauseSets(beginSpec.DirName().source, beginSpec.DirId()); + + // Missing mandatory end block: this is checked in semantics because that + // makes it easier to control the error messages. + // The end block is mandatory when the construct is not applied to a strictly + // structured block (aka it is applied to a loosely structured block). In + // other words, the body doesn't contain exactly one parser::BlockConstruct. + auto isStrictlyStructuredBlock{[](const parser::Block &block) -> bool { + if (block.size() != 1) { + return false; + } + const parser::ExecutionPartConstruct &contents{block.front()}; + auto *executableConstruct{ + std::get_if<parser::ExecutableConstruct>(&contents.u)}; + if (!executableConstruct) { + return false; + } + return std::holds_alternative<common::Indirection<parser::BlockConstruct>>( + executableConstruct->u); + }}; + if (!endSpec && !isStrictlyStructuredBlock(block)) { + llvm::omp::Directive dirId{beginSpec.DirId()}; + auto &msg{context_.Say(beginSpec.source, + "Expected OpenMP END %s directive"_err_en_US, + parser::ToUpperCaseLetters(getDirectiveName(dirId)))}; + // ORDERED has two variants, so be explicit about which variant we think + // this is. + if (dirId == llvm::omp::Directive::OMPD_ordered) { + msg.Attach( + beginSpec.source, "The ORDERED directive is block-associated"_en_US); + } + } + if (llvm::omp::allTargetSet.test(GetContext().directive)) { EnterDirectiveNest(TargetNest); } @@ -817,6 +932,12 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { "TARGET construct with nested TEAMS region contains statements or " "directives outside of the TEAMS construct"_err_en_US); } + if (GetContext().directive == llvm::omp::Directive::OMPD_workdistribute && + GetContextParent().directive != llvm::omp::Directive::OMPD_teams) { + context_.Say(x.BeginDir().DirName().source, + "%s region can only be strictly nested within TEAMS region"_err_en_US, + ContextDirectiveAsFortran()); + } } CheckNoBranching(block, beginSpec.DirId(), beginSpec.source); @@ -900,6 +1021,17 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { HasInvalidWorksharingNesting( beginSpec.source, llvm::omp::nestedWorkshareErrSet); break; + case llvm::omp::OMPD_workdistribute: + if (!CurrentDirectiveIsNested()) { + context_.Say(beginSpec.source, + "A WORKDISTRIBUTE region must be nested inside TEAMS region only."_err_en_US); + } + CheckWorkdistributeBlockStmts(block, beginSpec.source); + break; + case llvm::omp::OMPD_teams_workdistribute: + case llvm::omp::OMPD_target_teams_workdistribute: + CheckWorkdistributeBlockStmts(block, beginSpec.source); + break; case llvm::omp::Directive::OMPD_scope: case llvm::omp::Directive::OMPD_single: // TODO: This check needs to be extended while implementing nesting of @@ -921,7 +1053,7 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { } void OmpStructureChecker::CheckMasterNesting( - const parser::OpenMPBlockConstruct &x) { + const parser::OmpBlockConstruct &x) { // A MASTER region may not be `closely nested` inside a worksharing, loop, // task, taskloop, or atomic region. // TODO: Expand the check to include `LOOP` construct as well when it is @@ -950,7 +1082,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAssumes &) { dirContext_.pop_back(); } -void OmpStructureChecker::Leave(const parser::OpenMPBlockConstruct &) { +void OmpStructureChecker::Leave(const parser::OmpBlockConstruct &) { if (GetDirectiveNest(TargetBlockOnlyTeams)) { ExitDirectiveNest(TargetBlockOnlyTeams); } @@ -1041,14 +1173,23 @@ void OmpStructureChecker::Leave(const parser::OmpBeginDirective &) { void OmpStructureChecker::Enter(const parser::OpenMPSectionsConstruct &x) { const auto &beginSectionsDir{ std::get<parser::OmpBeginSectionsDirective>(x.t)}; - const auto &endSectionsDir{std::get<parser::OmpEndSectionsDirective>(x.t)}; + const auto &endSectionsDir{ + std::get<std::optional<parser::OmpEndSectionsDirective>>(x.t)}; const auto &beginDir{ std::get<parser::OmpSectionsDirective>(beginSectionsDir.t)}; - const auto &endDir{std::get<parser::OmpSectionsDirective>(endSectionsDir.t)}; + PushContextAndClauseSets(beginDir.source, beginDir.v); + + if (!endSectionsDir) { + context_.Say(beginSectionsDir.source, + "Expected OpenMP END SECTIONS directive"_err_en_US); + // Following code assumes the option is present. + return; + } + + const auto &endDir{std::get<parser::OmpSectionsDirective>(endSectionsDir->t)}; CheckMatching<parser::OmpSectionsDirective>(beginDir, endDir); - PushContextAndClauseSets(beginDir.source, beginDir.v); - AddEndDirectiveClauses(std::get<parser::OmpClauseList>(endSectionsDir.t)); + AddEndDirectiveClauses(std::get<parser::OmpClauseList>(endSectionsDir->t)); const auto §ionBlocks{std::get<std::list<parser::OpenMPConstruct>>(x.t)}; for (const parser::OpenMPConstruct &construct : sectionBlocks) { @@ -1090,113 +1231,155 @@ void OmpStructureChecker::Leave(const parser::OmpEndSectionsDirective &x) { } void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar( + const parser::Designator &designator) { + auto *name{parser::Unwrap<parser::Name>(designator)}; + // If the symbol is null, return early, CheckSymbolNames + // should have already reported the missing symbol as a + // diagnostic error + if (!name || !name->symbol) { + return; + } + + llvm::omp::Directive directive{GetContext().directive}; + + if (name->symbol->GetUltimate().IsSubprogram()) { + if (directive == llvm::omp::Directive::OMPD_threadprivate) + context_.Say(name->source, + "The procedure name cannot be in a %s directive"_err_en_US, + ContextDirectiveAsFortran()); + // TODO: Check for procedure name in declare target directive. + } else if (name->symbol->attrs().test(Attr::PARAMETER)) { + if (directive == llvm::omp::Directive::OMPD_threadprivate) + context_.Say(name->source, + "The entity with PARAMETER attribute cannot be in a %s directive"_err_en_US, + ContextDirectiveAsFortran()); + else if (directive == llvm::omp::Directive::OMPD_declare_target) + context_.Warn(common::UsageWarning::OpenMPUsage, name->source, + "The entity with PARAMETER attribute is used in a %s directive"_warn_en_US, + ContextDirectiveAsFortran()); + } else if (FindCommonBlockContaining(*name->symbol)) { + context_.Say(name->source, + "A variable in a %s directive cannot be an element of a common block"_err_en_US, + ContextDirectiveAsFortran()); + } else if (FindEquivalenceSet(*name->symbol)) { + context_.Say(name->source, + "A variable in a %s directive cannot appear in an EQUIVALENCE statement"_err_en_US, + ContextDirectiveAsFortran()); + } else if (name->symbol->test(Symbol::Flag::OmpThreadprivate) && + directive == llvm::omp::Directive::OMPD_declare_target) { + context_.Say(name->source, + "A THREADPRIVATE variable cannot appear in a %s directive"_err_en_US, + ContextDirectiveAsFortran()); + } else { + const semantics::Scope &useScope{ + context_.FindScope(GetContext().directiveSource)}; + const semantics::Scope &curScope = name->symbol->GetUltimate().owner(); + if (!curScope.IsTopLevel()) { + const semantics::Scope &declScope = + GetProgramUnitOrBlockConstructContaining(curScope); + const semantics::Symbol *sym{ + declScope.parent().FindSymbol(name->symbol->name())}; + if (sym && + (sym->has<MainProgramDetails>() || sym->has<ModuleDetails>())) { + context_.Say(name->source, + "The module name cannot be in a %s directive"_err_en_US, + ContextDirectiveAsFortran()); + } else if (!IsSaved(*name->symbol) && + declScope.kind() != Scope::Kind::MainProgram && + declScope.kind() != Scope::Kind::Module) { + context_.Say(name->source, + "A variable that appears in a %s directive must be declared in the scope of a module or have the SAVE attribute, either explicitly or implicitly"_err_en_US, + ContextDirectiveAsFortran()); + } else if (useScope != declScope) { + context_.Say(name->source, + "The %s directive and the common block or variable in it must appear in the same declaration section of a scoping unit"_err_en_US, + ContextDirectiveAsFortran()); + } + } + } +} + +void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar( + const parser::Name &name) { + if (!name.symbol) { + return; + } + + if (auto *cb{name.symbol->detailsIf<CommonBlockDetails>()}) { + for (const auto &obj : cb->objects()) { + if (FindEquivalenceSet(*obj)) { + context_.Say(name.source, + "A variable in a %s directive cannot appear in an EQUIVALENCE statement (variable '%s' from common block '/%s/')"_err_en_US, + ContextDirectiveAsFortran(), obj->name(), name.symbol->name()); + } + } + } +} + +void OmpStructureChecker::CheckThreadprivateOrDeclareTargetVar( const parser::OmpObjectList &objList) { for (const auto &ompObject : objList.v) { - common::visit( - common::visitors{ - [&](const parser::Designator &) { - if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) { - // The symbol is null, return early, CheckSymbolNames - // should have already reported the missing symbol as a - // diagnostic error - if (!name->symbol) { - return; - } - - if (name->symbol->GetUltimate().IsSubprogram()) { - if (GetContext().directive == - llvm::omp::Directive::OMPD_threadprivate) - context_.Say(name->source, - "The procedure name cannot be in a %s " - "directive"_err_en_US, - ContextDirectiveAsFortran()); - // TODO: Check for procedure name in declare target directive. - } else if (name->symbol->attrs().test(Attr::PARAMETER)) { - if (GetContext().directive == - llvm::omp::Directive::OMPD_threadprivate) - context_.Say(name->source, - "The entity with PARAMETER attribute cannot be in a %s " - "directive"_err_en_US, - ContextDirectiveAsFortran()); - else if (GetContext().directive == - llvm::omp::Directive::OMPD_declare_target) - context_.Warn(common::UsageWarning::OpenMPUsage, - name->source, - "The entity with PARAMETER attribute is used in a %s directive"_warn_en_US, - ContextDirectiveAsFortran()); - } else if (FindCommonBlockContaining(*name->symbol)) { - context_.Say(name->source, - "A variable in a %s directive cannot be an element of a " - "common block"_err_en_US, - ContextDirectiveAsFortran()); - } else if (FindEquivalenceSet(*name->symbol)) { - context_.Say(name->source, - "A variable in a %s directive cannot appear in an " - "EQUIVALENCE statement"_err_en_US, - ContextDirectiveAsFortran()); - } else if (name->symbol->test(Symbol::Flag::OmpThreadprivate) && - GetContext().directive == - llvm::omp::Directive::OMPD_declare_target) { - context_.Say(name->source, - "A THREADPRIVATE variable cannot appear in a %s " - "directive"_err_en_US, - ContextDirectiveAsFortran()); - } else { - const semantics::Scope &useScope{ - context_.FindScope(GetContext().directiveSource)}; - const semantics::Scope &curScope = - name->symbol->GetUltimate().owner(); - if (!curScope.IsTopLevel()) { - const semantics::Scope &declScope = - GetProgramUnitOrBlockConstructContaining(curScope); - const semantics::Symbol *sym{ - declScope.parent().FindSymbol(name->symbol->name())}; - if (sym && - (sym->has<MainProgramDetails>() || - sym->has<ModuleDetails>())) { - context_.Say(name->source, - "The module name cannot be in a %s " - "directive"_err_en_US, - ContextDirectiveAsFortran()); - } else if (!IsSaved(*name->symbol) && - declScope.kind() != Scope::Kind::MainProgram && - declScope.kind() != Scope::Kind::Module) { - context_.Say(name->source, - "A variable that appears in a %s directive must be " - "declared in the scope of a module or have the SAVE " - "attribute, either explicitly or " - "implicitly"_err_en_US, - ContextDirectiveAsFortran()); - } else if (useScope != declScope) { - context_.Say(name->source, - "The %s directive and the common block or variable " - "in it must appear in the same declaration section " - "of a scoping unit"_err_en_US, - ContextDirectiveAsFortran()); - } - } - } - } - }, - [&](const parser::Name &name) { - if (name.symbol) { - if (auto *cb{name.symbol->detailsIf<CommonBlockDetails>()}) { - for (const auto &obj : cb->objects()) { - if (FindEquivalenceSet(*obj)) { - context_.Say(name.source, - "A variable in a %s directive cannot appear in an EQUIVALENCE statement (variable '%s' from common block '/%s/')"_err_en_US, - ContextDirectiveAsFortran(), obj->name(), - name.symbol->name()); - } - } - } - } - }, - }, + common::visit([&](auto &&s) { CheckThreadprivateOrDeclareTargetVar(s); }, ompObject.u); } } +void OmpStructureChecker::Enter(const parser::OpenMPGroupprivate &x) { + PushContextAndClauseSets( + x.v.DirName().source, llvm::omp::Directive::OMPD_groupprivate); + + for (const parser::OmpArgument &arg : x.v.Arguments().v) { + auto *locator{std::get_if<parser::OmpLocator>(&arg.u)}; + const Symbol *sym{GetArgumentSymbol(arg)}; + + if (!locator || !sym || + (!IsVariableListItem(*sym) && !IsCommonBlock(*sym))) { + context_.Say(arg.source, + "GROUPPRIVATE argument should be a variable or a named common block"_err_en_US); + continue; + } + + if (sym->has<AssocEntityDetails>()) { + context_.SayWithDecl(*sym, arg.source, + "GROUPPRIVATE argument cannot be an ASSOCIATE name"_err_en_US); + continue; + } + if (auto *obj{sym->detailsIf<ObjectEntityDetails>()}) { + if (obj->IsCoarray()) { + context_.Say( + arg.source, "GROUPPRIVATE argument cannot be a coarray"_err_en_US); + continue; + } + if (obj->init()) { + context_.SayWithDecl(*sym, arg.source, + "GROUPPRIVATE argument cannot be declared with an initializer"_err_en_US); + continue; + } + } + if (sym->test(Symbol::Flag::InCommonBlock)) { + context_.Say(arg.source, + "GROUPPRIVATE argument cannot be a member of a common block"_err_en_US); + continue; + } + if (!IsCommonBlock(*sym)) { + const Scope &thisScope{context_.FindScope(x.v.source)}; + if (thisScope != sym->owner()) { + context_.SayWithDecl(*sym, arg.source, + "GROUPPRIVATE argument variable must be declared in the same scope as the construct on which it appears"_err_en_US); + continue; + } else if (!thisScope.IsModule() && !sym->attrs().test(Attr::SAVE)) { + context_.SayWithDecl(*sym, arg.source, + "GROUPPRIVATE argument variable must be declared in the module scope or have SAVE attribute"_err_en_US); + continue; + } + } + } +} + +void OmpStructureChecker::Leave(const parser::OpenMPGroupprivate &x) { + dirContext_.pop_back(); +} + void OmpStructureChecker::Enter(const parser::OpenMPThreadprivate &c) { const auto &dir{std::get<parser::Verbatim>(c.t)}; PushContextAndClauseSets( @@ -2034,41 +2217,87 @@ void OmpStructureChecker::Leave(const parser::OpenMPCancelConstruct &) { } void OmpStructureChecker::Enter(const parser::OpenMPCriticalConstruct &x) { - const auto &dir{std::get<parser::OmpCriticalDirective>(x.t)}; - const auto &dirSource{std::get<parser::Verbatim>(dir.t).source}; - const auto &endDir{std::get<parser::OmpEndCriticalDirective>(x.t)}; - PushContextAndClauseSets(dirSource, llvm::omp::Directive::OMPD_critical); + const parser::OmpBeginDirective &beginSpec{x.BeginDir()}; + const std::optional<parser::OmpEndDirective> &endSpec{x.EndDir()}; + PushContextAndClauseSets(beginSpec.DirName().source, beginSpec.DirName().v); + const auto &block{std::get<parser::Block>(x.t)}; - CheckNoBranching(block, llvm::omp::Directive::OMPD_critical, dir.source); - const auto &dirName{std::get<std::optional<parser::Name>>(dir.t)}; - const auto &endDirName{std::get<std::optional<parser::Name>>(endDir.t)}; - const auto &ompClause{std::get<parser::OmpClauseList>(dir.t)}; - if (dirName && endDirName && - dirName->ToString().compare(endDirName->ToString())) { - context_ - .Say(endDirName->source, - parser::MessageFormattedText{ - "CRITICAL directive names do not match"_err_en_US}) - .Attach(dirName->source, "should be "_en_US); - } else if (dirName && !endDirName) { - context_ - .Say(dirName->source, - parser::MessageFormattedText{ - "CRITICAL directive names do not match"_err_en_US}) - .Attach(dirName->source, "should be NULL"_en_US); - } else if (!dirName && endDirName) { - context_ - .Say(endDirName->source, - parser::MessageFormattedText{ - "CRITICAL directive names do not match"_err_en_US}) - .Attach(endDirName->source, "should be NULL"_en_US); - } - if (!dirName && !ompClause.source.empty() && - ompClause.source.NULTerminatedToString() != "hint(omp_sync_hint_none)") { - context_.Say(dir.source, - parser::MessageFormattedText{ - "Hint clause other than omp_sync_hint_none cannot be specified for " - "an unnamed CRITICAL directive"_err_en_US}); + CheckNoBranching( + block, llvm::omp::Directive::OMPD_critical, beginSpec.DirName().source); + + auto getNameFromArg{[](const parser::OmpArgument &arg) { + if (auto *object{parser::Unwrap<parser::OmpObject>(arg.u)}) { + if (auto *designator{omp::GetDesignatorFromObj(*object)}) { + return getDesignatorNameIfDataRef(*designator); + } + } + return static_cast<const parser::Name *>(nullptr); + }}; + + auto checkArgumentList{[&](const parser::OmpArgumentList &args) { + if (args.v.size() > 1) { + context_.Say(args.source, + "Only a single argument is allowed in CRITICAL directive"_err_en_US); + } else if (!args.v.empty()) { + if (!getNameFromArg(args.v.front())) { + context_.Say(args.v.front().source, + "CRITICAL argument should be a name"_err_en_US); + } + } + }}; + + const parser::Name *beginName{nullptr}; + const parser::Name *endName{nullptr}; + + auto &beginArgs{beginSpec.Arguments()}; + checkArgumentList(beginArgs); + + if (!beginArgs.v.empty()) { + beginName = getNameFromArg(beginArgs.v.front()); + } + + if (endSpec) { + auto &endArgs{endSpec->Arguments()}; + checkArgumentList(endArgs); + + if (beginArgs.v.empty() != endArgs.v.empty()) { + parser::CharBlock source{ + beginArgs.v.empty() ? endArgs.source : beginArgs.source}; + context_.Say(source, + "Either both CRITICAL and END CRITICAL should have an argument, or none of them should"_err_en_US); + } else if (!beginArgs.v.empty()) { + endName = getNameFromArg(endArgs.v.front()); + if (beginName && endName) { + if (beginName->ToString() != endName->ToString()) { + context_.Say(endName->source, + "The names on CRITICAL and END CRITICAL must match"_err_en_US); + } + } + } + } + + for (auto &clause : beginSpec.Clauses().v) { + auto *hint{std::get_if<parser::OmpClause::Hint>(&clause.u)}; + if (!hint) { + continue; + } + const int64_t OmpSyncHintNone = 0; // omp_sync_hint_none + std::optional<int64_t> hintValue{GetIntValue(hint->v.v)}; + if (hintValue && *hintValue != OmpSyncHintNone) { + // Emit a diagnostic if the name is missing, and point to the directive + // with a missing name. + parser::CharBlock source; + if (!beginName) { + source = beginSpec.DirName().source; + } else if (endSpec && !endName) { + source = endSpec->DirName().source; + } + + if (!source.empty()) { + context_.Say(source, + "When HINT other than 'omp_sync_hint_none' is present, CRITICAL directive should have a name"_err_en_US); + } + } } } @@ -2511,8 +2740,9 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) { void OmpStructureChecker::Enter(const parser::OmpClause &x) { SetContextClause(x); + llvm::omp::Clause id{x.Id()}; // The visitors for these clauses do their own checks. - switch (x.Id()) { + switch (id) { case llvm::omp::Clause::OMPC_copyprivate: case llvm::omp::Clause::OMPC_enter: case llvm::omp::Clause::OMPC_lastprivate: @@ -2523,11 +2753,25 @@ void OmpStructureChecker::Enter(const parser::OmpClause &x) { break; } + // Named constants are OK to be used within 'shared' and 'firstprivate' + // clauses. The check for this happens a few lines below. + bool SharedOrFirstprivate = false; + switch (id) { + case llvm::omp::Clause::OMPC_shared: + case llvm::omp::Clause::OMPC_firstprivate: + SharedOrFirstprivate = true; + break; + default: + break; + } + if (const parser::OmpObjectList *objList{GetOmpObjectList(x)}) { + AnalyzeObjects(*objList); SymbolSourceMap symbols; GetSymbolsInObjectList(*objList, symbols); for (const auto &[symbol, source] : symbols) { - if (!IsVariableListItem(*symbol)) { + if (!IsVariableListItem(*symbol) && + !(IsNamedConstant(*symbol) && SharedOrFirstprivate)) { deferredNonVariables_.insert({symbol, source}); } } @@ -2543,6 +2787,7 @@ CHECK_SIMPLE_CLAUSE(Default, OMPC_default) CHECK_SIMPLE_CLAUSE(Depobj, OMPC_depobj) CHECK_SIMPLE_CLAUSE(DeviceType, OMPC_device_type) CHECK_SIMPLE_CLAUSE(DistSchedule, OMPC_dist_schedule) +CHECK_SIMPLE_CLAUSE(DynGroupprivate, OMPC_dyn_groupprivate) CHECK_SIMPLE_CLAUSE(Exclusive, OMPC_exclusive) CHECK_SIMPLE_CLAUSE(Final, OMPC_final) CHECK_SIMPLE_CLAUSE(Flush, OMPC_flush) @@ -2853,7 +3098,8 @@ static bool CheckSymbolSupportsType(const Scope &scope, static bool IsReductionAllowedForType( const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type, - const Scope &scope, SemanticsContext &context) { + bool cannotBeBuiltinReduction, const Scope &scope, + SemanticsContext &context) { auto isLogical{[](const DeclTypeSpec &type) -> bool { return type.category() == DeclTypeSpec::Logical; }}; @@ -2864,6 +3110,10 @@ static bool IsReductionAllowedForType( auto checkOperator{[&](const parser::DefinedOperator &dOpr) { if (const auto *intrinsicOp{ std::get_if<parser::DefinedOperator::IntrinsicOperator>(&dOpr.u)}) { + if (cannotBeBuiltinReduction) { + return false; + } + // OMP5.2: The type [...] of a list item that appears in a // reduction clause must be valid for the combiner expression // See F2023: Table 10.2 @@ -2915,7 +3165,8 @@ static bool IsReductionAllowedForType( // IAND: arguments must be integers: F2023 16.9.100 // IEOR: arguments must be integers: F2023 16.9.106 // IOR: arguments must be integers: F2023 16.9.111 - if (type.IsNumeric(TypeCategory::Integer)) { + if (type.IsNumeric(TypeCategory::Integer) && + !cannotBeBuiltinReduction) { return true; } } else if (realName == "max" || realName == "min") { @@ -2923,8 +3174,9 @@ static bool IsReductionAllowedForType( // F2023 16.9.135 // MIN: arguments must be integer, real, or character: // F2023 16.9.141 - if (type.IsNumeric(TypeCategory::Integer) || - type.IsNumeric(TypeCategory::Real) || isCharacter(type)) { + if ((type.IsNumeric(TypeCategory::Integer) || + type.IsNumeric(TypeCategory::Real) || isCharacter(type)) && + !cannotBeBuiltinReduction) { return true; } } @@ -2957,9 +3209,16 @@ void OmpStructureChecker::CheckReductionObjectTypes( GetSymbolsInObjectList(objects, symbols); for (auto &[symbol, source] : symbols) { + // Built in reductions require types which can be used in their initializer + // and combiner expressions. For example, for +: + // r = 0; r = r + r2 + // But it might be valid to use these with DECLARE REDUCTION. + // Assumed size is already caught elsewhere. + bool cannotBeBuiltinReduction{IsAssumedRank(*symbol)}; if (auto *type{symbol->GetType()}) { const auto &scope{context_.FindScope(symbol->name())}; - if (!IsReductionAllowedForType(ident, *type, scope, context_)) { + if (!IsReductionAllowedForType( + ident, *type, cannotBeBuiltinReduction, scope, context_)) { context_.Say(source, "The type of '%s' is incompatible with the reduction operator."_err_en_US, symbol->name()); @@ -3238,9 +3497,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Aligned &x) { x.v, llvm::omp::OMPC_aligned, GetContext().clauseSource, context_)) { auto &modifiers{OmpGetModifiers(x.v)}; if (auto *align{OmpGetUniqueModifier<parser::OmpAlignment>(modifiers)}) { - if (const auto &v{GetIntValue(align->v)}; !v || *v <= 0) { + const auto &v{GetIntValue(align->v)}; + if (!v || *v <= 0) { context_.Say(OmpGetModifierSource(modifiers, align), "The alignment value should be a constant positive integer"_err_en_US); + } else if (((*v) & (*v - 1)) != 0) { + context_.Warn(common::UsageWarning::OpenMPUsage, + OmpGetModifierSource(modifiers, align), + "Alignment is not a power of 2, Aligned clause will be ignored"_warn_en_US); } } } @@ -4349,7 +4613,7 @@ bool OmpStructureChecker::CheckTargetBlockOnlyTeams( if (const auto *ompConstruct{ parser::Unwrap<parser::OpenMPConstruct>(*it)}) { if (const auto *ompBlockConstruct{ - std::get_if<parser::OpenMPBlockConstruct>(&ompConstruct->u)}) { + std::get_if<parser::OmpBlockConstruct>(&ompConstruct->u)}) { llvm::omp::Directive dirId{ompBlockConstruct->BeginDir().DirId()}; if (dirId == llvm::omp::Directive::OMPD_teams) { nestedTeams = true; @@ -4396,7 +4660,7 @@ void OmpStructureChecker::CheckWorkshareBlockStmts( // 'Parallel' constructs auto currentDir{llvm::omp::Directive::OMPD_unknown}; if (const auto *ompBlockConstruct{ - std::get_if<parser::OpenMPBlockConstruct>(&ompConstruct->u)}) { + std::get_if<parser::OmpBlockConstruct>(&ompConstruct->u)}) { currentDir = ompBlockConstruct->BeginDir().DirId(); } else if (const auto *ompLoopConstruct{ std::get_if<parser::OpenMPLoopConstruct>( @@ -4432,6 +4696,27 @@ void OmpStructureChecker::CheckWorkshareBlockStmts( } } +void OmpStructureChecker::CheckWorkdistributeBlockStmts( + const parser::Block &block, parser::CharBlock source) { + unsigned version{context_.langOptions().OpenMPVersion}; + unsigned since{60}; + if (version < since) + context_.Say(source, + "WORKDISTRIBUTE construct is not allowed in %s, %s"_err_en_US, + ThisVersion(version), TryVersion(since)); + + OmpWorkdistributeBlockChecker ompWorkdistributeBlockChecker{context_, source}; + + for (auto it{block.begin()}; it != block.end(); ++it) { + if (parser::Unwrap<parser::AssignmentStmt>(*it)) { + parser::Walk(*it, ompWorkdistributeBlockChecker); + } else { + context_.Say(source, + "The structured block in a WORKDISTRIBUTE construct may consist of only SCALAR or ARRAY assignments"_err_en_US); + } + } +} + void OmpStructureChecker::CheckIfContiguous(const parser::OmpObject &object) { if (auto contig{IsContiguous(context_, object)}; contig && !*contig) { const parser::Name *name{GetObjectName(object)}; @@ -4475,42 +4760,6 @@ const parser::Name *OmpStructureChecker::GetObjectName( return NameHelper::Visit(object); } -const parser::OmpObjectList *OmpStructureChecker::GetOmpObjectList( - const parser::OmpClause &clause) { - - // Clauses with OmpObjectList as its data member - using MemberObjectListClauses = - std::tuple<parser::OmpClause::Copyprivate, parser::OmpClause::Copyin, - parser::OmpClause::Firstprivate, parser::OmpClause::Link, - parser::OmpClause::Private, parser::OmpClause::Shared, - parser::OmpClause::UseDevicePtr, parser::OmpClause::UseDeviceAddr>; - - // Clauses with OmpObjectList in the tuple - using TupleObjectListClauses = - std::tuple<parser::OmpClause::Aligned, parser::OmpClause::Allocate, - parser::OmpClause::From, parser::OmpClause::Lastprivate, - parser::OmpClause::Map, parser::OmpClause::Reduction, - parser::OmpClause::To, parser::OmpClause::Enter>; - - // TODO:: Generate the tuples using TableGen. - // Handle other constructs with OmpObjectList such as OpenMPThreadprivate. - return common::visit( - common::visitors{ - [&](const auto &x) -> const parser::OmpObjectList * { - using Ty = std::decay_t<decltype(x)>; - if constexpr (common::HasMember<Ty, MemberObjectListClauses>) { - return &x.v; - } else if constexpr (common::HasMember<Ty, - TupleObjectListClauses>) { - return &(std::get<parser::OmpObjectList>(x.v.t)); - } else { - return nullptr; - } - }, - }, - clause.u); -} - void OmpStructureChecker::Enter( const parser::OmpClause::AtomicDefaultMemOrder &x) { CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_atomic_default_mem_order); diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index 6b33ca6..ce074f5 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -88,8 +88,8 @@ public: void Leave(const parser::OpenMPAssumeConstruct &); void Enter(const parser::OpenMPDeclarativeAssumes &); void Leave(const parser::OpenMPDeclarativeAssumes &); - void Enter(const parser::OpenMPBlockConstruct &); - void Leave(const parser::OpenMPBlockConstruct &); + void Enter(const parser::OmpBlockConstruct &); + void Leave(const parser::OmpBlockConstruct &); void Leave(const parser::OmpBeginDirective &); void Enter(const parser::OmpEndDirective &); void Leave(const parser::OmpEndDirective &); @@ -126,6 +126,8 @@ public: void Leave(const parser::OpenMPAllocatorsConstruct &); void Enter(const parser::OpenMPRequiresConstruct &); void Leave(const parser::OpenMPRequiresConstruct &); + void Enter(const parser::OpenMPGroupprivate &); + void Leave(const parser::OpenMPGroupprivate &); void Enter(const parser::OpenMPThreadprivate &); void Leave(const parser::OpenMPThreadprivate &); @@ -165,6 +167,8 @@ private: void CheckVariableListItem(const SymbolSourceMap &symbols); void CheckDirectiveSpelling( parser::CharBlock spelling, llvm::omp::Directive id); + void AnalyzeObject(const parser::OmpObject &object); + void AnalyzeObjects(const parser::OmpObjectList &objects); void CheckMultipleOccurrence(semantics::UnorderedSymbolSet &listVars, const std::list<parser::Name> &nameList, const parser::CharBlock &item, const std::string &clauseName); @@ -222,8 +226,9 @@ private: const parser::OmpObject &obj, llvm::StringRef clause = ""); void CheckVarIsNotPartOfAnotherVar(const parser::CharBlock &source, const parser::OmpObjectList &objList, llvm::StringRef clause = ""); - void CheckThreadprivateOrDeclareTargetVar( - const parser::OmpObjectList &objList); + void CheckThreadprivateOrDeclareTargetVar(const parser::Designator &); + void CheckThreadprivateOrDeclareTargetVar(const parser::Name &); + void CheckThreadprivateOrDeclareTargetVar(const parser::OmpObjectList &); void CheckSymbolNames( const parser::CharBlock &source, const parser::OmpObjectList &objList); void CheckIntentInPointer(SymbolSourceMap &, const llvm::omp::Clause); @@ -242,6 +247,7 @@ private: llvmOmpClause clause, const parser::OmpObjectList &ompObjectList); bool CheckTargetBlockOnlyTeams(const parser::Block &); void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock); + void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock); void CheckIteratorRange(const parser::OmpIteratorSpecifier &x); void CheckIteratorModifier(const parser::OmpIterator &x); @@ -267,8 +273,10 @@ private: const evaluate::Assignment &read, parser::CharBlock source); void CheckAtomicWriteAssignment( const evaluate::Assignment &write, parser::CharBlock source); - void CheckAtomicUpdateAssignment( + std::optional<evaluate::Assignment> CheckAtomicUpdateAssignment( const evaluate::Assignment &update, parser::CharBlock source); + std::pair<bool, bool> CheckAtomicUpdateAssignmentRhs(const SomeExpr &atom, + const SomeExpr &rhs, parser::CharBlock source, bool suppressDiagnostics); void CheckAtomicConditionalUpdateAssignment(const SomeExpr &cond, parser::CharBlock condSource, const evaluate::Assignment &assign, parser::CharBlock assignSource); @@ -307,7 +315,7 @@ private: const parser::OmpReductionIdentifier &ident); void CheckReductionModifier(const parser::OmpReductionModifier &); void CheckLastprivateModifier(const parser::OmpLastprivateModifier &); - void CheckMasterNesting(const parser::OpenMPBlockConstruct &x); + void CheckMasterNesting(const parser::OmpBlockConstruct &x); void ChecksOnOrderedAsBlock(); void CheckBarrierNesting(const parser::OpenMPSimpleStandaloneConstruct &x); void CheckScan(const parser::OpenMPSimpleStandaloneConstruct &x); @@ -321,7 +329,6 @@ private: const parser::OmpObjectList &ompObjectList); void CheckIfContiguous(const parser::OmpObject &object); const parser::Name *GetObjectName(const parser::OmpObject &object); - const parser::OmpObjectList *GetOmpObjectList(const parser::OmpClause &); void CheckPredefinedAllocatorRestriction(const parser::CharBlock &source, const parser::OmpObjectList &ompObjectList); void CheckPredefinedAllocatorRestriction( diff --git a/flang/lib/Semantics/check-select-rank.cpp b/flang/lib/Semantics/check-select-rank.cpp index b227bba..5dade2c 100644 --- a/flang/lib/Semantics/check-select-rank.cpp +++ b/flang/lib/Semantics/check-select-rank.cpp @@ -32,7 +32,7 @@ void SelectRankConstructChecker::Leave( const Symbol *saveSelSymbol{nullptr}; if (const auto selExpr{GetExprFromSelector(selectRankStmtSel)}) { if (const Symbol * sel{evaluate::UnwrapWholeSymbolDataRef(*selExpr)}) { - if (!evaluate::IsAssumedRank(*sel)) { // C1150 + if (!semantics::IsAssumedRank(*sel)) { // C1150 context_.Say(parser::FindSourceLocation(selectRankStmtSel), "Selector '%s' is not an assumed-rank array variable"_err_en_US, sel->name().ToString()); diff --git a/flang/lib/Semantics/check-select-type.cpp b/flang/lib/Semantics/check-select-type.cpp index 94d16a7..b1b22c3 100644 --- a/flang/lib/Semantics/check-select-type.cpp +++ b/flang/lib/Semantics/check-select-type.cpp @@ -252,7 +252,7 @@ void SelectTypeChecker::Enter(const parser::SelectTypeConstruct &construct) { if (IsProcedure(*selector)) { context_.Say( selectTypeStmt.source, "Selector may not be a procedure"_err_en_US); - } else if (evaluate::IsAssumedRank(*selector)) { + } else if (IsAssumedRank(*selector)) { context_.Say(selectTypeStmt.source, "Assumed-rank variable may only be used as actual argument"_err_en_US); } else if (auto exprType{selector->GetType()}) { diff --git a/flang/lib/Semantics/compute-offsets.cpp b/flang/lib/Semantics/compute-offsets.cpp index 6d4fce2..1c48d33 100644 --- a/flang/lib/Semantics/compute-offsets.cpp +++ b/flang/lib/Semantics/compute-offsets.cpp @@ -239,7 +239,9 @@ void ComputeOffsetsHelper::DoCommonBlock(Symbol &commonBlock) { std::size_t minAlignment{0}; UnorderedSymbolSet previous; for (auto object : details.objects()) { - Symbol &symbol{*object}; + // Allow for host association when the common block is + // OpenMP firstprivate. + Symbol &symbol{object->GetUltimate()}; auto errorSite{ commonBlock.name().empty() ? symbol.name() : commonBlock.name()}; if (std::size_t padding{DoSymbol(symbol.GetUltimate())}) { diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp index b4c83ba..1c45438 100644 --- a/flang/lib/Semantics/data-to-inits.cpp +++ b/flang/lib/Semantics/data-to-inits.cpp @@ -285,21 +285,22 @@ template <typename DSV> std::optional<std::pair<SomeExpr, bool>> DataInitializationCompiler<DSV>::ConvertElement( const SomeExpr &expr, const evaluate::DynamicType &type) { + evaluate::FoldingContext &foldingContext{exprAnalyzer_.GetFoldingContext()}; + evaluate::CheckRealWidening(expr, type, foldingContext); if (auto converted{evaluate::ConvertToType(type, SomeExpr{expr})}) { return {std::make_pair(std::move(*converted), false)}; } // Allow DATA initialization with Hollerith and kind=1 CHARACTER like // (most) other Fortran compilers do. - if (auto converted{evaluate::HollerithToBOZ( - exprAnalyzer_.GetFoldingContext(), expr, type)}) { + if (auto converted{evaluate::HollerithToBOZ(foldingContext, expr, type)}) { return {std::make_pair(std::move(*converted), true)}; } SemanticsContext &context{exprAnalyzer_.context()}; if (context.IsEnabled(common::LanguageFeature::LogicalIntegerAssignment)) { if (MaybeExpr converted{evaluate::DataConstantConversionExtension( - exprAnalyzer_.GetFoldingContext(), type, expr)}) { + foldingContext, type, expr)}) { context.Warn(common::LanguageFeature::LogicalIntegerAssignment, - exprAnalyzer_.GetFoldingContext().messages().at(), + foldingContext.messages().at(), "nonstandard usage: initialization of %s with %s"_port_en_US, type.AsFortran(), expr.GetType().value().AsFortran()); return {std::make_pair(std::move(*converted), false)}; diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index 92dbe0e..ccccf60 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -828,7 +828,7 @@ MaybeExpr ExpressionAnalyzer::Analyze( template <typename TYPE> Constant<TYPE> ReadRealLiteral( - parser::CharBlock source, FoldingContext &context) { + parser::CharBlock source, FoldingContext &context, bool isDefaultKind) { const char *p{source.begin()}; auto valWithFlags{ Scalar<TYPE>::Read(p, context.targetCharacteristics().roundingMode())}; @@ -838,19 +838,24 @@ Constant<TYPE> ReadRealLiteral( if (context.targetCharacteristics().areSubnormalsFlushedToZero()) { value = value.FlushSubnormalToZero(); } - return {value}; + typename Constant<TYPE>::Result resultInfo; + resultInfo.set_isFromInexactLiteralConversion( + isDefaultKind && valWithFlags.flags.test(RealFlag::Inexact)); + return {value, resultInfo}; } struct RealTypeVisitor { using Result = std::optional<Expr<SomeReal>>; using Types = RealTypes; - RealTypeVisitor(int k, parser::CharBlock lit, FoldingContext &ctx) - : kind{k}, literal{lit}, context{ctx} {} + RealTypeVisitor( + int k, parser::CharBlock lit, FoldingContext &ctx, bool isDeftKind) + : kind{k}, literal{lit}, context{ctx}, isDefaultKind{isDeftKind} {} template <typename T> Result Test() { if (kind == T::kind) { - return {AsCategoryExpr(ReadRealLiteral<T>(literal, context))}; + return { + AsCategoryExpr(ReadRealLiteral<T>(literal, context, isDefaultKind))}; } return std::nullopt; } @@ -858,6 +863,7 @@ struct RealTypeVisitor { int kind; parser::CharBlock literal; FoldingContext &context; + bool isDefaultKind; }; // Reads a real literal constant and encodes it with the right kind. @@ -909,8 +915,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::RealLiteralConstant &x) { "Explicit kind parameter together with non-'E' exponent letter is not standard"_port_en_US); } } - auto result{common::SearchTypes( - RealTypeVisitor{kind, x.real.source, GetFoldingContext()})}; + bool isDefaultKind{!x.kind && letterKind.value_or('e') == 'e'}; + auto result{common::SearchTypes(RealTypeVisitor{ + kind, x.real.source, GetFoldingContext(), isDefaultKind})}; if (!result) { // C717 Say("Unsupported REAL(KIND=%d)"_err_en_US, kind); } @@ -1841,8 +1848,7 @@ void ArrayConstructorContext::Push(MaybeExpr &&x) { if (*thisLen != *constantLength_ && !(messageDisplayedSet_ & 1)) { exprAnalyzer_.Warn( common::LanguageFeature::DistinctArrayConstructorLengths, - "Character literal in array constructor without explicit " - "type has different length than earlier elements"_port_en_US); + "Character literal in array constructor without explicit type has different length than earlier elements"_port_en_US); messageDisplayedSet_ |= 1; } if (*thisLen > *constantLength_) { @@ -1862,17 +1868,17 @@ void ArrayConstructorContext::Push(MaybeExpr &&x) { } else { if (!(messageDisplayedSet_ & 2)) { exprAnalyzer_.Say( - "Values in array constructor must have the same declared type " - "when no explicit type appears"_err_en_US); // C7110 + "Values in array constructor must have the same declared type when no explicit type appears"_err_en_US); // C7110 messageDisplayedSet_ |= 2; } } } else { + CheckRealWidening(*x, *type_, exprAnalyzer_.GetFoldingContext()); if (auto cast{ConvertToType(*type_, std::move(*x))}) { values_.Push(std::move(*cast)); } else if (!(messageDisplayedSet_ & 4)) { - exprAnalyzer_.Say("Value in array constructor of type '%s' could not " - "be converted to the type of the array '%s'"_err_en_US, + exprAnalyzer_.Say( + "Value in array constructor of type '%s' could not be converted to the type of the array '%s'"_err_en_US, x->GetType()->AsFortran(), type_->AsFortran()); // C7111, C7112 messageDisplayedSet_ |= 4; } @@ -2065,8 +2071,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ArrayConstructor &array) { // Check if implicit conversion of expr to the symbol type is legal (if needed), // and make it explicit if requested. -static MaybeExpr ImplicitConvertTo(const semantics::Symbol &sym, - Expr<SomeType> &&expr, bool keepConvertImplicit) { +static MaybeExpr ImplicitConvertTo(const Symbol &sym, Expr<SomeType> &&expr, + bool keepConvertImplicit, FoldingContext &foldingContext) { + CheckRealWidening(expr, DynamicType::From(sym), foldingContext); if (!keepConvertImplicit) { return ConvertToType(sym, std::move(expr)); } else { @@ -2191,7 +2198,8 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor( } if (symbol) { const semantics::Scope &innermost{context_.FindScope(exprSource)}; - if (auto msg{CheckAccessibleSymbol(innermost, *symbol)}) { + if (auto msg{CheckAccessibleSymbol( + innermost, *symbol, /*inStructureConstructor=*/true)}) { Say(exprSource, std::move(*msg)); } if (checkConflicts) { @@ -2293,10 +2301,12 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor( // convert would cause a segfault. Lowering will deal with // conditionally converting and preserving the lower bounds in this // case. - if (MaybeExpr converted{ImplicitConvertTo( - *symbol, std::move(value), IsAllocatable(*symbol))}) { - if (auto componentShape{GetShape(GetFoldingContext(), *symbol)}) { - if (auto valueShape{GetShape(GetFoldingContext(), *converted)}) { + FoldingContext &foldingContext{GetFoldingContext()}; + if (MaybeExpr converted{ImplicitConvertTo(*symbol, std::move(value), + /*keepConvertImplicit=*/IsAllocatable(*symbol), + foldingContext)}) { + if (auto componentShape{GetShape(foldingContext, *symbol)}) { + if (auto valueShape{GetShape(foldingContext, *converted)}) { if (GetRank(*componentShape) == 0 && GetRank(*valueShape) > 0) { AttachDeclaration( Say(exprSource, @@ -2310,7 +2320,7 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor( if (checked && *checked && GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0 && (IsDeferredShape(*symbol) || - !IsExpandableScalar(*converted, GetFoldingContext(), + !IsExpandableScalar(*converted, foldingContext, *componentShape, true /*admit PURE call*/))) { AttachDeclaration( Say(exprSource, @@ -3774,10 +3784,9 @@ MaybeExpr NumericBinaryHelper( analyzer.CheckForNullPointer(); analyzer.CheckForAssumedRank(); analyzer.CheckConformance(); - constexpr bool canBeUnsigned{opr != NumericOperator::Power}; - return NumericOperation<OPR, canBeUnsigned>( - context.GetContextualMessages(), analyzer.MoveExpr(0), - analyzer.MoveExpr(1), context.GetDefaultKind(TypeCategory::Real)); + return NumericOperation<OPR>(context.GetContextualMessages(), + analyzer.MoveExpr(0), analyzer.MoveExpr(1), + context.GetDefaultKind(TypeCategory::Real)); } else { return analyzer.TryDefinedOp(AsFortran(opr), "Operands of %s must be numeric; have %s and %s"_err_en_US); @@ -4623,7 +4632,7 @@ bool ArgumentAnalyzer::CheckForNullPointer(const char *where) { bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) { for (const std::optional<ActualArgument> &arg : actuals_) { - if (arg && IsAssumedRank(arg->UnwrapExpr())) { + if (arg && semantics::IsAssumedRank(arg->UnwrapExpr())) { context_.Say(source_, "An assumed-rank dummy argument is not allowed %s"_err_en_US, where); fatalErrors_ = true; @@ -4827,6 +4836,11 @@ std::optional<ProcedureRef> ArgumentAnalyzer::TryDefinedAssignment() { // conversion in this case. if (lhsType) { if (rhsType) { + FoldingContext &foldingContext{context_.GetFoldingContext()}; + auto restorer{foldingContext.messages().SetLocation( + actuals_.at(1).value().sourceLocation().value_or( + foldingContext.messages().at()))}; + CheckRealWidening(rhs, lhsType, foldingContext); if (!IsAllocatableDesignator(lhs) || context_.inWhereBody()) { AddAssignmentConversion(*lhsType, *rhsType); } diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index 7a492a4..e8df346c 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "openmp-utils.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" diff --git a/flang/lib/Semantics/openmp-utils.h b/flang/lib/Semantics/openmp-utils.h deleted file mode 100644 index b8ad9ed..0000000 --- a/flang/lib/Semantics/openmp-utils.h +++ /dev/null @@ -1,81 +0,0 @@ -//===-- lib/Semantics/openmp-utils.h --------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Common utilities used in OpenMP semantic checks. -// -//===----------------------------------------------------------------------===// - -#ifndef FORTRAN_SEMANTICS_OPENMP_UTILS_H -#define FORTRAN_SEMANTICS_OPENMP_UTILS_H - -#include "flang/Evaluate/type.h" -#include "flang/Parser/char-block.h" -#include "flang/Parser/parse-tree.h" -#include "flang/Semantics/tools.h" - -#include "llvm/ADT/ArrayRef.h" - -#include <optional> -#include <string> - -namespace Fortran::semantics { -class SemanticsContext; -class Symbol; - -// Add this namespace to avoid potential conflicts -namespace omp { -// There is no consistent way to get the source of an ActionStmt, but there -// is "source" in Statement<T>. This structure keeps the ActionStmt with the -// extracted source for further use. -struct SourcedActionStmt { - const parser::ActionStmt *stmt{nullptr}; - parser::CharBlock source; - - operator bool() const { return stmt != nullptr; } -}; - -SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x); -SourcedActionStmt GetActionStmt(const parser::Block &block); - -std::string ThisVersion(unsigned version); -std::string TryVersion(unsigned version); - -const parser::Designator *GetDesignatorFromObj(const parser::OmpObject &object); -const parser::DataRef *GetDataRefFromObj(const parser::OmpObject &object); -const parser::ArrayElement *GetArrayElementFromObj( - const parser::OmpObject &object); -const Symbol *GetObjectSymbol(const parser::OmpObject &object); -const Symbol *GetArgumentSymbol(const parser::OmpArgument &argument); -std::optional<parser::CharBlock> GetObjectSource( - const parser::OmpObject &object); - -bool IsCommonBlock(const Symbol &sym); -bool IsExtendedListItem(const Symbol &sym); -bool IsVariableListItem(const Symbol &sym); -bool IsVarOrFunctionRef(const MaybeExpr &expr); - -bool IsMapEnteringType(parser::OmpMapType::Value type); -bool IsMapExitingType(parser::OmpMapType::Value type); - -std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr); -std::optional<evaluate::DynamicType> GetDynamicType( - const parser::Expr &parserExpr); - -std::optional<bool> IsContiguous( - SemanticsContext &semaCtx, const parser::OmpObject &object); - -std::vector<SomeExpr> GetAllDesignators(const SomeExpr &expr); -const SomeExpr *HasStorageOverlap( - const SomeExpr &base, llvm::ArrayRef<SomeExpr> exprs); -bool IsAssignment(const parser::ActionStmt *x); -bool IsPointerAssignment(const evaluate::Assignment &x); -const parser::Block &GetInnermostExecPart(const parser::Block &block); -} // namespace omp -} // namespace Fortran::semantics - -#endif // FORTRAN_SEMANTICS_OPENMP_UTILS_H diff --git a/flang/lib/Semantics/pointer-assignment.cpp b/flang/lib/Semantics/pointer-assignment.cpp index e767bf8..5508ba8 100644 --- a/flang/lib/Semantics/pointer-assignment.cpp +++ b/flang/lib/Semantics/pointer-assignment.cpp @@ -159,7 +159,7 @@ bool PointerAssignmentChecker::CheckLeftHandSide(const SomeExpr &lhs) { msg->Attach(std::move(whyNot->set_severity(parser::Severity::Because))); } return false; - } else if (evaluate::IsAssumedRank(lhs)) { + } else if (IsAssumedRank(lhs)) { Say("The left-hand side of a pointer assignment must not be an assumed-rank dummy argument"_err_en_US); return false; } else if (evaluate::ExtractCoarrayRef(lhs)) { // F'2023 C1027 diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 0557b08..a08e764 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -10,7 +10,6 @@ #include "check-acc-structure.h" #include "check-omp-structure.h" -#include "openmp-utils.h" #include "resolve-names-utils.h" #include "flang/Common/idioms.h" #include "flang/Evaluate/fold.h" @@ -22,6 +21,7 @@ #include "flang/Semantics/expression.h" #include "flang/Semantics/openmp-dsa.h" #include "flang/Semantics/openmp-modifiers.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "flang/Support/Flags.h" @@ -29,7 +29,6 @@ #include "llvm/Support/Debug.h" #include <list> #include <map> -#include <sstream> template <typename T> static Fortran::semantics::Scope *GetScope( @@ -61,6 +60,13 @@ protected: parser::OmpDefaultmapClause::ImplicitBehavior> defaultMap; + std::optional<Symbol::Flag> FindSymbolWithDSA(const Symbol &symbol) { + if (auto it{objectWithDSA.find(&symbol)}; it != objectWithDSA.end()) { + return it->second; + } + return std::nullopt; + } + bool withinConstruct{false}; std::int64_t associatedLoopLevel{0}; }; @@ -75,10 +81,19 @@ protected: : std::make_optional<DirContext>(dirContext_.back()); } void PushContext(const parser::CharBlock &source, T dir, Scope &scope) { - dirContext_.emplace_back(source, dir, scope); + if constexpr (std::is_same_v<T, llvm::acc::Directive>) { + dirContext_.emplace_back(source, dir, scope); + if (std::size_t size{dirContext_.size()}; size > 1) { + std::size_t lastIndex{size - 1}; + dirContext_[lastIndex].defaultDSA = + dirContext_[lastIndex - 1].defaultDSA; + } + } else { + dirContext_.emplace_back(source, dir, scope); + } } void PushContext(const parser::CharBlock &source, T dir) { - dirContext_.emplace_back(source, dir, context_.FindScope(source)); + PushContext(source, dir, context_.FindScope(source)); } void PopContext() { dirContext_.pop_back(); } void SetContextDirectiveSource(parser::CharBlock &dir) { @@ -100,9 +115,21 @@ protected: AddToContextObjectWithDSA(symbol, flag, GetContext()); } bool IsObjectWithDSA(const Symbol &symbol) { - auto it{GetContext().objectWithDSA.find(&symbol)}; - return it != GetContext().objectWithDSA.end(); + return GetContext().FindSymbolWithDSA(symbol).has_value(); } + bool IsObjectWithVisibleDSA(const Symbol &symbol) { + for (std::size_t i{dirContext_.size()}; i != 0; i--) { + if (dirContext_[i - 1].FindSymbolWithDSA(symbol).has_value()) { + return true; + } + } + return false; + } + + bool WithinConstruct() { + return !dirContext_.empty() && GetContext().withinConstruct; + } + void SetContextAssociatedLoopLevel(std::int64_t level) { GetContext().associatedLoopLevel = level; } @@ -384,13 +411,16 @@ public: } void Post(const parser::OmpMetadirectiveDirective &) { PopContext(); } - bool Pre(const parser::OpenMPBlockConstruct &); - void Post(const parser::OpenMPBlockConstruct &); + bool Pre(const parser::OmpBlockConstruct &); + void Post(const parser::OmpBlockConstruct &); void Post(const parser::OmpBeginDirective &x) { GetContext().withinConstruct = true; } + bool Pre(const parser::OpenMPGroupprivate &); + void Post(const parser::OpenMPGroupprivate &) { PopContext(); } + bool Pre(const parser::OpenMPStandaloneConstruct &x) { common::visit( [&](auto &&s) { @@ -528,6 +558,9 @@ public: bool Pre(const parser::OpenMPDeclarativeAllocate &); void Post(const parser::OpenMPDeclarativeAllocate &) { PopContext(); } + bool Pre(const parser::OpenMPAssumeConstruct &); + void Post(const parser::OpenMPAssumeConstruct &) { PopContext(); } + bool Pre(const parser::OpenMPAtomicConstruct &); void Post(const parser::OpenMPAtomicConstruct &) { PopContext(); } @@ -793,7 +826,8 @@ public: if (name->symbol) { name->symbol->set( ompFlag.value_or(Symbol::Flag::OmpMapStorage)); - AddToContextObjectWithDSA(*name->symbol, *ompFlag); + AddToContextObjectWithDSA(*name->symbol, + ompFlag.value_or(Symbol::Flag::OmpMapStorage)); if (semantics::IsAssumedSizeArray(*name->symbol)) { context_.Say(designator.source, "Assumed-size whole arrays may not appear on the %s " @@ -841,7 +875,8 @@ private: Symbol::Flags ompFlagsRequireMark{Symbol::Flag::OmpThreadprivate, Symbol::Flag::OmpDeclareTarget, Symbol::Flag::OmpExclusiveScan, - Symbol::Flag::OmpInclusiveScan, Symbol::Flag::OmpInScanReduction}; + Symbol::Flag::OmpInclusiveScan, Symbol::Flag::OmpInScanReduction, + Symbol::Flag::OmpGroupPrivate}; Symbol::Flags dataCopyingAttributeFlags{ Symbol::Flag::OmpCopyIn, Symbol::Flag::OmpCopyPrivate}; @@ -876,6 +911,9 @@ private: bool IsNestedInDirective(llvm::omp::Directive directive); void ResolveOmpObjectList(const parser::OmpObjectList &, Symbol::Flag); + void ResolveOmpDesignator( + const parser::Designator &designator, Symbol::Flag ompFlag); + void ResolveOmpCommonBlock(const parser::Name &name, Symbol::Flag ompFlag); void ResolveOmpObject(const parser::OmpObject &, Symbol::Flag); Symbol *ResolveOmp(const parser::Name &, Symbol::Flag, Scope &); Symbol *ResolveOmp(Symbol &, Symbol::Flag, Scope &); @@ -1562,10 +1600,10 @@ void AccAttributeVisitor::Post(const parser::AccDefaultClause &x) { // and adjust the symbol for each Name if necessary void AccAttributeVisitor::Post(const parser::Name &name) { auto *symbol{name.symbol}; - if (symbol && !dirContext_.empty() && GetContext().withinConstruct) { + if (symbol && WithinConstruct()) { symbol = &symbol->GetUltimate(); if (!symbol->owner().IsDerivedType() && !symbol->has<ProcEntityDetails>() && - !symbol->has<SubprogramDetails>() && !IsObjectWithDSA(*symbol)) { + !symbol->has<SubprogramDetails>() && !IsObjectWithVisibleDSA(*symbol)) { if (Symbol * found{currScope().FindSymbol(name.source)}) { if (symbol != found) { name.symbol = found; // adjust the symbol within region @@ -1715,7 +1753,7 @@ static std::string ScopeSourcePos(const Fortran::semantics::Scope &scope); #endif -bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { +bool OmpAttributeVisitor::Pre(const parser::OmpBlockConstruct &x) { const parser::OmpDirectiveSpecification &dirSpec{x.BeginDir()}; llvm::omp::Directive dirId{dirSpec.DirId()}; switch (dirId) { @@ -1732,10 +1770,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_teams_workdistribute: PushContext(dirSpec.source, dirId); break; default: @@ -1751,7 +1792,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { return true; } -void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { +void OmpAttributeVisitor::Post(const parser::OmpBlockConstruct &x) { const parser::OmpDirectiveSpecification &dirSpec{x.BeginDir()}; llvm::omp::Directive dirId{dirSpec.DirId()}; switch (dirId) { @@ -1765,9 +1806,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: - case llvm::omp::Directive::OMPD_target_parallel: { + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: { bool hasPrivate; for (const auto *allocName : allocateNames_) { hasPrivate = false; @@ -1942,7 +1986,7 @@ void OmpAttributeVisitor::ResolveSeqLoopIndexInParallelOrTaskConstruct( // till OpenMP-5.0 standard. // In above both cases we skip the privatization of iteration variables. bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) { - if (!dirContext_.empty() && GetContext().withinConstruct) { + if (WithinConstruct()) { llvm::SmallVector<const parser::Name *> ivs; if (x.IsDoNormal()) { const parser::Name *iv{GetLoopIndex(x)}; @@ -2114,6 +2158,18 @@ void OmpAttributeVisitor::CheckAssocLoopLevel( } } +bool OmpAttributeVisitor::Pre(const parser::OpenMPGroupprivate &x) { + PushContext(x.source, llvm::omp::Directive::OMPD_groupprivate); + for (const parser::OmpArgument &arg : x.v.Arguments().v) { + if (auto *locator{std::get_if<parser::OmpLocator>(&arg.u)}) { + if (auto *object{std::get_if<parser::OmpObject>(&locator->u)}) { + ResolveOmpObject(*object, Symbol::Flag::OmpGroupPrivate); + } + } + } + return true; +} + bool OmpAttributeVisitor::Pre(const parser::OpenMPSectionsConstruct &x) { const auto &beginSectionsDir{ std::get<parser::OmpBeginSectionsDirective>(x.t)}; @@ -2139,8 +2195,8 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPSectionConstruct &x) { } bool OmpAttributeVisitor::Pre(const parser::OpenMPCriticalConstruct &x) { - const auto &beginCriticalDir{std::get<parser::OmpCriticalDirective>(x.t)}; - PushContext(beginCriticalDir.source, llvm::omp::Directive::OMPD_critical); + const parser::OmpBeginDirective &beginSpec{x.BeginDir()}; + PushContext(beginSpec.DirName().source, beginSpec.DirName().v); GetContext().withinConstruct = true; return true; } @@ -2194,6 +2250,11 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPDeclarativeAllocate &x) { return false; } +bool OmpAttributeVisitor::Pre(const parser::OpenMPAssumeConstruct &x) { + PushContext(x.source, llvm::omp::Directive::OMPD_assume); + return true; +} + bool OmpAttributeVisitor::Pre(const parser::OpenMPAtomicConstruct &x) { PushContext(x.source, llvm::omp::Directive::OMPD_atomic); return true; @@ -2435,7 +2496,7 @@ static bool IsTargetCaptureImplicitlyFirstprivatizeable(const Symbol &symbol, // investigate the flags we can intermix with. if (!(dsa & (dataSharingAttributeFlags | dataMappingAttributeFlags)) .none() || - !checkSym.flags().none() || semantics::IsAssumedShape(checkSym) || + !checkSym.flags().none() || IsAssumedShape(checkSym) || semantics::IsAllocatableOrPointer(checkSym)) { return false; } @@ -2651,7 +2712,7 @@ void OmpAttributeVisitor::CreateImplicitSymbols(const Symbol *symbol) { void OmpAttributeVisitor::Post(const parser::Name &name) { auto *symbol{name.symbol}; - if (symbol && !dirContext_.empty() && GetContext().withinConstruct) { + if (symbol && WithinConstruct()) { if (IsPrivatizable(symbol) && !IsObjectWithDSA(*symbol)) { // TODO: create a separate function to go through the rules for // predetermined, explicitly determined, and implicitly @@ -2786,196 +2847,182 @@ static bool SymbolOrEquivalentIsInNamelist(const Symbol &symbol) { }); } -void OmpAttributeVisitor::ResolveOmpObject( - const parser::OmpObject &ompObject, Symbol::Flag ompFlag) { +void OmpAttributeVisitor::ResolveOmpDesignator( + const parser::Designator &designator, Symbol::Flag ompFlag) { unsigned version{context_.langOptions().OpenMPVersion}; - common::visit( - common::visitors{ - [&](const parser::Designator &designator) { - if (const auto *name{ - semantics::getDesignatorNameIfDataRef(designator)}) { - if (auto *symbol{ResolveOmp(*name, ompFlag, currScope())}) { - auto checkExclusivelists = - [&](const Symbol *symbol1, Symbol::Flag firstOmpFlag, - const Symbol *symbol2, Symbol::Flag secondOmpFlag) { - if ((symbol1->test(firstOmpFlag) && - symbol2->test(secondOmpFlag)) || - (symbol1->test(secondOmpFlag) && - symbol2->test(firstOmpFlag))) { - context_.Say(designator.source, - "Variable '%s' may not " - "appear on both %s and %s " - "clauses on a %s construct"_err_en_US, - symbol2->name(), - Symbol::OmpFlagToClauseName(firstOmpFlag), - Symbol::OmpFlagToClauseName(secondOmpFlag), - parser::ToUpperCaseLetters( - llvm::omp::getOpenMPDirectiveName( - GetContext().directive, version) - .str())); - } - }; - if (dataCopyingAttributeFlags.test(ompFlag)) { - CheckDataCopyingClause(*name, *symbol, ompFlag); - } else { - AddToContextObjectWithExplicitDSA(*symbol, ompFlag); - if (dataSharingAttributeFlags.test(ompFlag)) { - CheckMultipleAppearances(*name, *symbol, ompFlag); - } - if (privateDataSharingAttributeFlags.test(ompFlag)) { - CheckObjectIsPrivatizable(*name, *symbol, ompFlag); - } + llvm::omp::Directive directive{GetContext().directive}; - if (ompFlag == Symbol::Flag::OmpAllocate) { - AddAllocateName(name); - } - } - if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective && - IsAllocatable(*symbol) && - !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) { - context_.Say(designator.source, - "List items specified in the ALLOCATE directive must not " - "have the ALLOCATABLE attribute unless the directive is " - "associated with an ALLOCATE statement"_err_en_US); - } - if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective || - ompFlag == - Symbol::Flag::OmpExecutableAllocateDirective) && - ResolveOmpObjectScope(name) == nullptr) { - context_.Say(designator.source, // 2.15.3 - "List items must be declared in the same scoping unit " - "in which the %s directive appears"_err_en_US, - parser::ToUpperCaseLetters( - llvm::omp::getOpenMPDirectiveName( - GetContext().directive, version) - .str())); - } - if (ompFlag == Symbol::Flag::OmpReduction) { - // Using variables inside of a namelist in OpenMP reductions - // is allowed by the standard, but is not allowed for - // privatisation. This looks like an oversight. If the - // namelist is hoisted to a global, we cannot apply the - // mapping for the reduction variable: resulting in incorrect - // results. Disabling this hoisting could make some real - // production code go slower. See discussion in #109303 - if (SymbolOrEquivalentIsInNamelist(*symbol)) { - context_.Say(name->source, - "Variable '%s' in NAMELIST cannot be in a REDUCTION clause"_err_en_US, - name->ToString()); - } - } - if (ompFlag == Symbol::Flag::OmpInclusiveScan || - ompFlag == Symbol::Flag::OmpExclusiveScan) { - if (!symbol->test(Symbol::Flag::OmpInScanReduction)) { - context_.Say(name->source, - "List item %s must appear in REDUCTION clause " - "with the INSCAN modifier of the parent " - "directive"_err_en_US, - name->ToString()); - } - } - if (ompFlag == Symbol::Flag::OmpDeclareTarget) { - if (symbol->IsFuncResult()) { - if (Symbol * func{currScope().symbol()}) { - CHECK(func->IsSubprogram()); - func->set(ompFlag); - name->symbol = func; - } - } - } - if (GetContext().directive == - llvm::omp::Directive::OMPD_target_data) { - checkExclusivelists(symbol, Symbol::Flag::OmpUseDevicePtr, - symbol, Symbol::Flag::OmpUseDeviceAddr); - } - if (llvm::omp::allDistributeSet.test(GetContext().directive)) { - checkExclusivelists(symbol, Symbol::Flag::OmpFirstPrivate, - symbol, Symbol::Flag::OmpLastPrivate); - } - if (llvm::omp::allTargetSet.test(GetContext().directive)) { - checkExclusivelists(symbol, Symbol::Flag::OmpIsDevicePtr, - symbol, Symbol::Flag::OmpHasDeviceAddr); - const auto *hostAssocSym{symbol}; - if (!(symbol->test(Symbol::Flag::OmpIsDevicePtr) || - symbol->test(Symbol::Flag::OmpHasDeviceAddr))) { - if (const auto *details{ - symbol->detailsIf<HostAssocDetails>()}) { - hostAssocSym = &details->symbol(); - } - } - Symbol::Flag dataMappingAttributeFlags[] = { - Symbol::Flag::OmpMapTo, Symbol::Flag::OmpMapFrom, - Symbol::Flag::OmpMapToFrom, Symbol::Flag::OmpMapStorage, - Symbol::Flag::OmpMapDelete, Symbol::Flag::OmpIsDevicePtr, - Symbol::Flag::OmpHasDeviceAddr}; - - Symbol::Flag dataSharingAttributeFlags[] = { - Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate, - Symbol::Flag::OmpLastPrivate, Symbol::Flag::OmpShared, - Symbol::Flag::OmpLinear}; - - // For OMP TARGET TEAMS directive some sharing attribute - // flags and mapping attribute flags can co-exist. - if (!(llvm::omp::allTeamsSet.test(GetContext().directive) || - llvm::omp::allParallelSet.test( - GetContext().directive))) { - for (Symbol::Flag ompFlag1 : dataMappingAttributeFlags) { - for (Symbol::Flag ompFlag2 : dataSharingAttributeFlags) { - if ((hostAssocSym->test(ompFlag2) && - hostAssocSym->test( - Symbol::Flag::OmpExplicit)) || - (symbol->test(ompFlag2) && - symbol->test(Symbol::Flag::OmpExplicit))) { - checkExclusivelists( - hostAssocSym, ompFlag1, symbol, ompFlag2); - } - } - } - } - } - } - } else { - // Array sections to be changed to substrings as needed - if (AnalyzeExpr(context_, designator)) { - if (std::holds_alternative<parser::Substring>(designator.u)) { - context_.Say(designator.source, - "Substrings are not allowed on OpenMP " - "directives or clauses"_err_en_US); - } - } - // other checks, more TBD - } - }, - [&](const parser::Name &name) { // common block - if (auto *symbol{ResolveOmpCommonBlockName(&name)}) { - if (!dataCopyingAttributeFlags.test(ompFlag)) { - CheckMultipleAppearances( - name, *symbol, Symbol::Flag::OmpCommonBlock); - } - // 2.15.3 When a named common block appears in a list, it has the - // same meaning as if every explicit member of the common block - // appeared in the list - auto &details{symbol->get<CommonBlockDetails>()}; - unsigned index{0}; - for (auto &object : details.objects()) { - if (auto *resolvedObject{ - ResolveOmp(*object, ompFlag, currScope())}) { - if (dataCopyingAttributeFlags.test(ompFlag)) { - CheckDataCopyingClause(name, *resolvedObject, ompFlag); - } else { - AddToContextObjectWithExplicitDSA(*resolvedObject, ompFlag); - } - details.replace_object(*resolvedObject, index); - } - index++; - } - } else { - context_.Say(name.source, // 2.15.3 - "COMMON block must be declared in the same scoping unit " - "in which the OpenMP directive or clause appears"_err_en_US); + const auto *name{semantics::getDesignatorNameIfDataRef(designator)}; + if (!name) { + // Array sections to be changed to substrings as needed + if (AnalyzeExpr(context_, designator)) { + if (std::holds_alternative<parser::Substring>(designator.u)) { + context_.Say(designator.source, + "Substrings are not allowed on OpenMP directives or clauses"_err_en_US); + } + } + // other checks, more TBD + return; + } + + if (auto *symbol{ResolveOmp(*name, ompFlag, currScope())}) { + auto checkExclusivelists{// + [&](const Symbol *symbol1, Symbol::Flag firstOmpFlag, + const Symbol *symbol2, Symbol::Flag secondOmpFlag) { + if ((symbol1->test(firstOmpFlag) && symbol2->test(secondOmpFlag)) || + (symbol1->test(secondOmpFlag) && symbol2->test(firstOmpFlag))) { + context_.Say(designator.source, + "Variable '%s' may not appear on both %s and %s clauses on a %s construct"_err_en_US, + symbol2->name(), Symbol::OmpFlagToClauseName(firstOmpFlag), + Symbol::OmpFlagToClauseName(secondOmpFlag), + parser::ToUpperCaseLetters( + llvm::omp::getOpenMPDirectiveName(directive, version))); + } + }}; + if (dataCopyingAttributeFlags.test(ompFlag)) { + CheckDataCopyingClause(*name, *symbol, ompFlag); + } else { + AddToContextObjectWithExplicitDSA(*symbol, ompFlag); + if (dataSharingAttributeFlags.test(ompFlag)) { + CheckMultipleAppearances(*name, *symbol, ompFlag); + } + if (privateDataSharingAttributeFlags.test(ompFlag)) { + CheckObjectIsPrivatizable(*name, *symbol, ompFlag); + } + + if (ompFlag == Symbol::Flag::OmpAllocate) { + AddAllocateName(name); + } + } + if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective && + IsAllocatable(*symbol) && + !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) { + context_.Say(designator.source, + "List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US); + } + if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective || + ompFlag == Symbol::Flag::OmpExecutableAllocateDirective) && + ResolveOmpObjectScope(name) == nullptr) { + context_.Say(designator.source, // 2.15.3 + "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, + parser::ToUpperCaseLetters( + llvm::omp::getOpenMPDirectiveName(directive, version))); + } + if (ompFlag == Symbol::Flag::OmpReduction) { + // Using variables inside of a namelist in OpenMP reductions + // is allowed by the standard, but is not allowed for + // privatisation. This looks like an oversight. If the + // namelist is hoisted to a global, we cannot apply the + // mapping for the reduction variable: resulting in incorrect + // results. Disabling this hoisting could make some real + // production code go slower. See discussion in #109303 + if (SymbolOrEquivalentIsInNamelist(*symbol)) { + context_.Say(name->source, + "Variable '%s' in NAMELIST cannot be in a REDUCTION clause"_err_en_US, + name->ToString()); + } + } + if (ompFlag == Symbol::Flag::OmpInclusiveScan || + ompFlag == Symbol::Flag::OmpExclusiveScan) { + if (!symbol->test(Symbol::Flag::OmpInScanReduction)) { + context_.Say(name->source, + "List item %s must appear in REDUCTION clause with the INSCAN modifier of the parent directive"_err_en_US, + name->ToString()); + } + } + if (ompFlag == Symbol::Flag::OmpDeclareTarget) { + if (symbol->IsFuncResult()) { + if (Symbol * func{currScope().symbol()}) { + CHECK(func->IsSubprogram()); + func->set(ompFlag); + name->symbol = func; + } + } + } + if (directive == llvm::omp::Directive::OMPD_target_data) { + checkExclusivelists(symbol, Symbol::Flag::OmpUseDevicePtr, symbol, + Symbol::Flag::OmpUseDeviceAddr); + } + if (llvm::omp::allDistributeSet.test(directive)) { + checkExclusivelists(symbol, Symbol::Flag::OmpFirstPrivate, symbol, + Symbol::Flag::OmpLastPrivate); + } + if (llvm::omp::allTargetSet.test(directive)) { + checkExclusivelists(symbol, Symbol::Flag::OmpIsDevicePtr, symbol, + Symbol::Flag::OmpHasDeviceAddr); + const auto *hostAssocSym{symbol}; + if (!symbol->test(Symbol::Flag::OmpIsDevicePtr) && + !symbol->test(Symbol::Flag::OmpHasDeviceAddr)) { + if (const auto *details{symbol->detailsIf<HostAssocDetails>()}) { + hostAssocSym = &details->symbol(); + } + } + static Symbol::Flag dataMappingAttributeFlags[] = {// + Symbol::Flag::OmpMapTo, Symbol::Flag::OmpMapFrom, + Symbol::Flag::OmpMapToFrom, Symbol::Flag::OmpMapStorage, + Symbol::Flag::OmpMapDelete, Symbol::Flag::OmpIsDevicePtr, + Symbol::Flag::OmpHasDeviceAddr}; + + static Symbol::Flag dataSharingAttributeFlags[] = {// + Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate, + Symbol::Flag::OmpLastPrivate, Symbol::Flag::OmpShared, + Symbol::Flag::OmpLinear}; + + // For OMP TARGET TEAMS directive some sharing attribute + // flags and mapping attribute flags can co-exist. + if (!llvm::omp::allTeamsSet.test(directive) && + !llvm::omp::allParallelSet.test(directive)) { + for (Symbol::Flag ompFlag1 : dataMappingAttributeFlags) { + for (Symbol::Flag ompFlag2 : dataSharingAttributeFlags) { + if ((hostAssocSym->test(ompFlag2) && + hostAssocSym->test(Symbol::Flag::OmpExplicit)) || + (symbol->test(ompFlag2) && + symbol->test(Symbol::Flag::OmpExplicit))) { + checkExclusivelists(hostAssocSym, ompFlag1, symbol, ompFlag2); } - }, - }, + } + } + } + } + } +} + +void OmpAttributeVisitor::ResolveOmpCommonBlock( + const parser::Name &name, Symbol::Flag ompFlag) { + if (auto *symbol{ResolveOmpCommonBlockName(&name)}) { + if (!dataCopyingAttributeFlags.test(ompFlag)) { + CheckMultipleAppearances(name, *symbol, Symbol::Flag::OmpCommonBlock); + } + // 2.15.3 When a named common block appears in a list, it has the + // same meaning as if every explicit member of the common block + // appeared in the list + auto &details{symbol->get<CommonBlockDetails>()}; + for (auto [index, object] : llvm::enumerate(details.objects())) { + if (auto *resolvedObject{ResolveOmp(*object, ompFlag, currScope())}) { + if (dataCopyingAttributeFlags.test(ompFlag)) { + CheckDataCopyingClause(name, *resolvedObject, ompFlag); + } else { + AddToContextObjectWithExplicitDSA(*resolvedObject, ompFlag); + } + details.replace_object(*resolvedObject, index); + } + } + } else { + context_.Say(name.source, // 2.15.3 + "COMMON block must be declared in the same scoping unit in which the OpenMP directive or clause appears"_err_en_US); + } +} + +void OmpAttributeVisitor::ResolveOmpObject( + const parser::OmpObject &ompObject, Symbol::Flag ompFlag) { + common::visit(common::visitors{ + [&](const parser::Designator &designator) { + ResolveOmpDesignator(designator, ompFlag); + }, + [&](const parser::Name &name) { // common block + ResolveOmpCommonBlock(name, ompFlag); + }, + }, ompObject.u); } diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 66a45dd..4720932 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -30,6 +30,7 @@ #include "flang/Semantics/attr.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/openmp-modifiers.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/program-tree.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" @@ -487,6 +488,10 @@ public: // Result symbol Symbol *resultSymbol{nullptr}; bool inFunctionStmt{false}; // true between Pre/Post of FunctionStmt + // Functions with previous implicitly-typed references get those types + // checked against their later definitions. + const DeclTypeSpec *previousImplicitType{nullptr}; + SourceName previousName; }; // Completes the definition of the top function's result. @@ -942,7 +947,7 @@ private: // Edits an existing symbol created for earlier calls to a subprogram or ENTRY // so that it can be replaced by a later definition. bool HandlePreviousCalls(const parser::Name &, Symbol &, Symbol::Flag); - void CheckExtantProc(const parser::Name &, Symbol::Flag); + const Symbol *CheckExtantProc(const parser::Name &, Symbol::Flag); // Create a subprogram symbol in the current scope and push a new scope. Symbol &PushSubprogramScope(const parser::Name &, Symbol::Flag, const parser::LanguageBindingSpec * = nullptr, @@ -1465,7 +1470,7 @@ class OmpVisitor : public virtual DeclarationVisitor { public: void AddOmpSourceRange(const parser::CharBlock &); - static bool NeedsScope(const parser::OpenMPBlockConstruct &); + static bool NeedsScope(const parser::OmpBlockConstruct &); static bool NeedsScope(const parser::OmpClause &); bool Pre(const parser::OmpMetadirectiveDirective &x) { // @@ -1482,10 +1487,20 @@ public: AddOmpSourceRange(x.source); return true; } - bool Pre(const parser::OpenMPBlockConstruct &); - void Post(const parser::OpenMPBlockConstruct &); + bool Pre(const parser::OmpBlockConstruct &); + void Post(const parser::OmpBlockConstruct &); bool Pre(const parser::OmpBeginDirective &x) { AddOmpSourceRange(x.source); + // Manually resolve names in CRITICAL directives. This is because these + // names do not denote Fortran objects, and the CRITICAL directive causes + // them to be "auto-declared", i.e. inserted into the global scope. + // More specifically, they are not expected to have explicit declarations, + // and if they do the behavior is unspeficied. + if (x.DirName().v == llvm::omp::Directive::OMPD_critical) { + for (const parser::OmpArgument &arg : x.Arguments().v) { + ResolveCriticalName(arg); + } + } return true; } void Post(const parser::OmpBeginDirective &) { @@ -1493,6 +1508,12 @@ public: } bool Pre(const parser::OmpEndDirective &x) { AddOmpSourceRange(x.source); + // Manually resolve names in CRITICAL directives. + if (x.DirName().v == llvm::omp::Directive::OMPD_critical) { + for (const parser::OmpArgument &arg : x.Arguments().v) { + ResolveCriticalName(arg); + } + } return true; } void Post(const parser::OmpEndDirective &) { @@ -1591,32 +1612,6 @@ public: void Post(const parser::OmpEndSectionsDirective &) { messageHandler().set_currStmtSource(std::nullopt); } - bool Pre(const parser::OmpCriticalDirective &x) { - AddOmpSourceRange(x.source); - // Manually resolve names in CRITICAL directives. This is because these - // names do not denote Fortran objects, and the CRITICAL directive causes - // them to be "auto-declared", i.e. inserted into the global scope. - // More specifically, they are not expected to have explicit declarations, - // and if they do the behavior is unspeficied. - if (auto &maybeName{std::get<std::optional<parser::Name>>(x.t)}) { - ResolveCriticalName(*maybeName); - } - return true; - } - void Post(const parser::OmpCriticalDirective &) { - messageHandler().set_currStmtSource(std::nullopt); - } - bool Pre(const parser::OmpEndCriticalDirective &x) { - AddOmpSourceRange(x.source); - // Manually resolve names in CRITICAL directives. - if (auto &maybeName{std::get<std::optional<parser::Name>>(x.t)}) { - ResolveCriticalName(*maybeName); - } - return true; - } - void Post(const parser::OmpEndCriticalDirective &) { - messageHandler().set_currStmtSource(std::nullopt); - } bool Pre(const parser::OpenMPThreadprivate &) { SkipImplicitTyping(true); return true; @@ -1732,13 +1727,13 @@ private: const std::optional<parser::OmpClauseList> &clauses, const T &wholeConstruct); - void ResolveCriticalName(const parser::Name &name); + void ResolveCriticalName(const parser::OmpArgument &arg); int metaLevel_{0}; const parser::OmpMetadirectiveDirective *metaDirective_{nullptr}; }; -bool OmpVisitor::NeedsScope(const parser::OpenMPBlockConstruct &x) { +bool OmpVisitor::NeedsScope(const parser::OmpBlockConstruct &x) { switch (x.BeginDir().DirId()) { case llvm::omp::Directive::OMPD_master: case llvm::omp::Directive::OMPD_ordered: @@ -1759,14 +1754,14 @@ void OmpVisitor::AddOmpSourceRange(const parser::CharBlock &source) { currScope().AddSourceRange(source); } -bool OmpVisitor::Pre(const parser::OpenMPBlockConstruct &x) { +bool OmpVisitor::Pre(const parser::OmpBlockConstruct &x) { if (NeedsScope(x)) { PushScope(Scope::Kind::OtherConstruct, nullptr); } return true; } -void OmpVisitor::Post(const parser::OpenMPBlockConstruct &x) { +void OmpVisitor::Post(const parser::OmpBlockConstruct &x) { if (NeedsScope(x)) { PopScope(); } @@ -1961,7 +1956,7 @@ void OmpVisitor::ProcessReductionSpecifier( } } -void OmpVisitor::ResolveCriticalName(const parser::Name &name) { +void OmpVisitor::ResolveCriticalName(const parser::OmpArgument &arg) { auto &globalScope{[&]() -> Scope & { for (Scope *s{&currScope()};; s = &s->parent()) { if (s->IsTopLevel()) { @@ -1971,15 +1966,21 @@ void OmpVisitor::ResolveCriticalName(const parser::Name &name) { llvm_unreachable("Cannot find global scope"); }()}; - if (auto *symbol{FindInScope(globalScope, name)}) { - if (!symbol->test(Symbol::Flag::OmpCriticalLock)) { - SayWithDecl(name, *symbol, - "CRITICAL construct name '%s' conflicts with a previous declaration"_warn_en_US, - name.ToString()); + if (auto *object{parser::Unwrap<parser::OmpObject>(arg.u)}) { + if (auto *desg{omp::GetDesignatorFromObj(*object)}) { + if (auto *name{getDesignatorNameIfDataRef(*desg)}) { + if (auto *symbol{FindInScope(globalScope, *name)}) { + if (!symbol->test(Symbol::Flag::OmpCriticalLock)) { + SayWithDecl(*name, *symbol, + "CRITICAL construct name '%s' conflicts with a previous declaration"_warn_en_US, + name->ToString()); + } + } else { + name->symbol = &MakeSymbol(globalScope, name->source, Attrs{}); + name->symbol->set(Symbol::Flag::OmpCriticalLock); + } + } } - } else { - name.symbol = &MakeSymbol(globalScope, name.source, Attrs{}); - name.symbol->set(Symbol::Flag::OmpCriticalLock); } } @@ -2694,11 +2695,24 @@ void ArraySpecVisitor::PostAttrSpec() { FuncResultStack::~FuncResultStack() { CHECK(stack_.empty()); } +// True when either type is absent, or if they are both present and are +// equivalent for interface compatibility purposes. +static bool TypesMismatchIfNonNull( + const DeclTypeSpec *type1, const DeclTypeSpec *type2) { + if (auto t1{evaluate::DynamicType::From(type1)}) { + if (auto t2{evaluate::DynamicType::From(type2)}) { + return !t1->IsEquivalentTo(*t2); + } + } + return false; +} + void FuncResultStack::CompleteFunctionResultType() { // If the function has a type in the prefix, process it now. FuncInfo *info{Top()}; - if (info && &info->scope == &scopeHandler_.currScope()) { - if (info->parsedType && info->resultSymbol) { + if (info && &info->scope == &scopeHandler_.currScope() && + info->resultSymbol) { + if (info->parsedType) { scopeHandler_.messageHandler().set_currStmtSource(info->source); if (const auto *type{ scopeHandler_.ProcessTypeSpec(*info->parsedType, true)}) { @@ -2715,6 +2729,16 @@ void FuncResultStack::CompleteFunctionResultType() { } info->parsedType = nullptr; } + if (TypesMismatchIfNonNull( + info->resultSymbol->GetType(), info->previousImplicitType)) { + scopeHandler_ + .Say(info->resultSymbol->name(), + "Function '%s' has a result type that differs from the implicit type it obtained in a previous reference"_err_en_US, + info->previousName) + .Attach(info->previousName, + "Previous reference implicitly typed as %s\n"_en_US, + info->previousImplicitType->AsFortran()); + } } } @@ -4764,9 +4788,7 @@ void SubprogramVisitor::Post(const parser::FunctionStmt &stmt) { if (info.resultName && !distinctResultName) { context().Warn(common::UsageWarning::HomonymousResult, info.resultName->source, - "The function name should not appear in RESULT; references to '%s' " - "inside the function will be considered as references to the " - "result only"_warn_en_US, + "The function name should not appear in RESULT; references to '%s' inside the function will be considered as references to the result only"_warn_en_US, name.source); // RESULT name was ignored above, the only side effect from doing so will be // the inability to make recursive calls. The related parser::Name is still @@ -5077,8 +5099,7 @@ bool SubprogramVisitor::BeginSubprogram(const parser::Name &name, if (hasModulePrefix && !currScope().IsModule() && !currScope().IsSubmodule()) { // C1547 Say(name, - "'%s' is a MODULE procedure which must be declared within a " - "MODULE or SUBMODULE"_err_en_US); + "'%s' is a MODULE procedure which must be declared within a MODULE or SUBMODULE"_err_en_US); // Don't return here because it can be useful to have the scope set for // other semantic checks run before we print the errors isValid = false; @@ -5199,9 +5220,10 @@ bool SubprogramVisitor::HandlePreviousCalls( } } -void SubprogramVisitor::CheckExtantProc( +const Symbol *SubprogramVisitor::CheckExtantProc( const parser::Name &name, Symbol::Flag subpFlag) { - if (auto *prev{FindSymbol(name)}) { + Symbol *prev{FindSymbol(name)}; + if (prev) { if (IsDummy(*prev)) { } else if (auto *entity{prev->detailsIf<EntityDetails>()}; IsPointer(*prev) && entity && !entity->type()) { @@ -5213,12 +5235,15 @@ void SubprogramVisitor::CheckExtantProc( SayAlreadyDeclared(name, *prev); } } + return prev; } Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name, Symbol::Flag subpFlag, const parser::LanguageBindingSpec *bindingSpec, bool hasModulePrefix) { Symbol *symbol{GetSpecificFromGeneric(name)}; + const DeclTypeSpec *previousImplicitType{nullptr}; + SourceName previousName; if (!symbol) { if (bindingSpec && currScope().IsGlobal() && std::get<std::optional<parser::ScalarDefaultCharConstantExpr>>( @@ -5231,14 +5256,25 @@ Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name, &MakeSymbol(context().GetTempName(currScope()), Attrs{}, MiscDetails{MiscDetails::Kind::ScopeName})); } - CheckExtantProc(name, subpFlag); + if (const Symbol *previous{CheckExtantProc(name, subpFlag)}) { + if (previous->test(Symbol::Flag::Function) && + previous->test(Symbol::Flag::Implicit)) { + // Function was implicitly typed in previous compilation unit. + previousImplicitType = previous->GetType(); + previousName = previous->name(); + } + } symbol = &MakeSymbol(name, SubprogramDetails{}); } symbol->ReplaceName(name.source); symbol->set(subpFlag); PushScope(Scope::Kind::Subprogram, symbol); if (subpFlag == Symbol::Flag::Function) { - funcResultStack().Push(currScope(), name.source); + auto &funcResultTop{funcResultStack().Push(currScope(), name.source)}; + funcResultTop.previousImplicitType = previousImplicitType; + ; + funcResultTop.previousName = previousName; + ; } if (inInterfaceBlock()) { auto &details{symbol->get<SubprogramDetails>()}; @@ -7916,7 +7952,7 @@ void ConstructVisitor::Post(const parser::AssociateStmt &x) { if (ExtractCoarrayRef(expr)) { // C1103 Say("Selector must not be a coindexed object"_err_en_US); } - if (evaluate::IsAssumedRank(expr)) { + if (IsAssumedRank(expr)) { Say("Selector must not be assumed-rank"_err_en_US); } SetTypeFromAssociation(*symbol); @@ -8672,11 +8708,6 @@ const parser::Name *DeclarationVisitor::ResolveDataRef( x.u); } -static bool TypesMismatchIfNonNull( - const DeclTypeSpec *type1, const DeclTypeSpec *type2) { - return type1 && type2 && *type1 != *type2; -} - // If implicit types are allowed, ensure name is in the symbol table. // Otherwise, report an error if it hasn't been declared. const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) { diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp index 4eeb1b9..eae22dc 100644 --- a/flang/lib/Semantics/rewrite-parse-tree.cpp +++ b/flang/lib/Semantics/rewrite-parse-tree.cpp @@ -12,6 +12,7 @@ #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" #include "flang/Parser/tools.h" +#include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" @@ -41,11 +42,23 @@ public: void Post(parser::Name &); bool Pre(parser::MainProgram &); + bool Pre(parser::Module &); bool Pre(parser::FunctionSubprogram &); bool Pre(parser::SubroutineSubprogram &); bool Pre(parser::SeparateModuleSubprogram &); bool Pre(parser::BlockConstruct &); + bool Pre(parser::Block &); + bool Pre(parser::DoConstruct &); + bool Pre(parser::IfConstruct &); bool Pre(parser::ActionStmt &); + void Post(parser::MainProgram &); + void Post(parser::FunctionSubprogram &); + void Post(parser::SubroutineSubprogram &); + void Post(parser::SeparateModuleSubprogram &); + void Post(parser::BlockConstruct &); + void Post(parser::Block &); + void Post(parser::DoConstruct &); + void Post(parser::IfConstruct &); void Post(parser::ReadStmt &); void Post(parser::WriteStmt &); @@ -67,8 +80,15 @@ public: bool Pre(parser::EndSubroutineStmt &) { return false; } bool Pre(parser::EndTypeStmt &) { return false; } + bool Pre(parser::OmpBlockConstruct &); + bool Pre(parser::OpenMPLoopConstruct &); + void Post(parser::OmpBlockConstruct &); + void Post(parser::OpenMPLoopConstruct &); + private: void FixMisparsedStmtFuncs(parser::SpecificationPart &, parser::Block &); + void OpenMPSimdOnly(parser::Block &, bool); + void OpenMPSimdOnly(parser::SpecificationPart &); SemanticsContext &context_; bool errorOnUnresolvedName_{true}; @@ -96,6 +116,132 @@ static bool ReturnsDataPointer(const Symbol &symbol) { return false; } +static bool LoopConstructIsSIMD(parser::OpenMPLoopConstruct *ompLoop) { + auto &begin = std::get<parser::OmpBeginLoopDirective>(ompLoop->t); + auto directive = std::get<parser::OmpLoopDirective>(begin.t).v; + return llvm::omp::allSimdSet.test(directive); +} + +// Remove non-SIMD OpenMPConstructs once they are parsed. +// This massively simplifies the logic inside the SimdOnlyPass for +// -fopenmp-simd. +void RewriteMutator::OpenMPSimdOnly(parser::SpecificationPart &specPart) { + auto &list{std::get<std::list<parser::DeclarationConstruct>>(specPart.t)}; + for (auto it{list.begin()}; it != list.end();) { + if (auto *specConstr{std::get_if<parser::SpecificationConstruct>(&it->u)}) { + if (auto *ompDecl{std::get_if< + common::Indirection<parser::OpenMPDeclarativeConstruct>>( + &specConstr->u)}) { + if (std::holds_alternative<parser::OpenMPThreadprivate>( + ompDecl->value().u) || + std::holds_alternative<parser::OpenMPDeclareMapperConstruct>( + ompDecl->value().u)) { + it = list.erase(it); + continue; + } + } + } + ++it; + } +} + +// Remove non-SIMD OpenMPConstructs once they are parsed. +// This massively simplifies the logic inside the SimdOnlyPass for +// -fopenmp-simd. `isNonSimdLoopBody` should be set to true if `block` is the +// body of a non-simd OpenMP loop. This is to indicate that scan constructs +// should be removed from the body, where they would be kept if it were a simd +// loop. +void RewriteMutator::OpenMPSimdOnly( + parser::Block &block, bool isNonSimdLoopBody = false) { + auto replaceInlineBlock = + [&](std::list<parser::ExecutionPartConstruct> &innerBlock, + auto it) -> auto { + auto insertPos = std::next(it); + block.splice(insertPos, innerBlock); + block.erase(it); + return insertPos; + }; + + for (auto it{block.begin()}; it != block.end();) { + if (auto *stmt{std::get_if<parser::ExecutableConstruct>(&it->u)}) { + if (auto *omp{std::get_if<common::Indirection<parser::OpenMPConstruct>>( + &stmt->u)}) { + if (auto *ompStandalone{std::get_if<parser::OpenMPStandaloneConstruct>( + &omp->value().u)}) { + if (std::holds_alternative<parser::OpenMPCancelConstruct>( + ompStandalone->u) || + std::holds_alternative<parser::OpenMPFlushConstruct>( + ompStandalone->u) || + std::holds_alternative<parser::OpenMPCancellationPointConstruct>( + ompStandalone->u)) { + it = block.erase(it); + continue; + } + if (auto *constr{std::get_if<parser::OpenMPSimpleStandaloneConstruct>( + &ompStandalone->u)}) { + auto directive = constr->v.DirId(); + // Scan should only be removed from non-simd loops + if (llvm::omp::simpleStandaloneNonSimdOnlySet.test(directive) || + (isNonSimdLoopBody && directive == llvm::omp::OMPD_scan)) { + it = block.erase(it); + continue; + } + } + } else if (auto *ompBlock{std::get_if<parser::OmpBlockConstruct>( + &omp->value().u)}) { + it = replaceInlineBlock(std::get<parser::Block>(ompBlock->t), it); + continue; + } else if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>( + &omp->value().u)}) { + if (LoopConstructIsSIMD(ompLoop)) { + ++it; + continue; + } + auto &nest = + std::get<std::optional<parser::NestedConstruct>>(ompLoop->t); + + if (auto *doConstruct = + std::get_if<parser::DoConstruct>(&nest.value())) { + auto &loopBody = std::get<parser::Block>(doConstruct->t); + // We can only remove some constructs from a loop when it's _not_ a + // OpenMP simd loop + OpenMPSimdOnly(loopBody, /*isNonSimdLoopBody=*/true); + auto newDoConstruct = std::move(*doConstruct); + auto newLoop = parser::ExecutionPartConstruct{ + parser::ExecutableConstruct{std::move(newDoConstruct)}}; + it = block.erase(it); + block.insert(it, std::move(newLoop)); + continue; + } + } else if (auto *ompCon{std::get_if<parser::OpenMPSectionsConstruct>( + &omp->value().u)}) { + auto §ions = + std::get<std::list<parser::OpenMPConstruct>>(ompCon->t); + auto insertPos = std::next(it); + for (auto §ionCon : sections) { + auto §ion = + std::get<parser::OpenMPSectionConstruct>(sectionCon.u); + auto &innerBlock = std::get<parser::Block>(section.t); + block.splice(insertPos, innerBlock); + } + block.erase(it); + it = insertPos; + continue; + } else if (auto *atomic{std::get_if<parser::OpenMPAtomicConstruct>( + &omp->value().u)}) { + it = replaceInlineBlock(std::get<parser::Block>(atomic->t), it); + continue; + } else if (auto *critical{std::get_if<parser::OpenMPCriticalConstruct>( + &omp->value().u)}) { + it = replaceInlineBlock(std::get<parser::Block>(critical->t), it); + continue; + } + } + } + ++it; + } +} + // Finds misparsed statement functions in a specification part, rewrites // them into array element assignment statements, and moves them into the // beginning of the corresponding (execution part's) block. @@ -133,33 +279,155 @@ void RewriteMutator::FixMisparsedStmtFuncs( bool RewriteMutator::Pre(parser::MainProgram &program) { FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(program.t), std::get<parser::ExecutionPart>(program.t).v); + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(program.t).v); + OpenMPSimdOnly(std::get<parser::SpecificationPart>(program.t)); + } + return true; +} + +void RewriteMutator::Post(parser::MainProgram &program) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(program.t).v); + } +} + +bool RewriteMutator::Pre(parser::Module &module) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::SpecificationPart>(module.t)); + } return true; } bool RewriteMutator::Pre(parser::FunctionSubprogram &func) { FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(func.t), std::get<parser::ExecutionPart>(func.t).v); + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(func.t).v); + } return true; } +void RewriteMutator::Post(parser::FunctionSubprogram &func) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(func.t).v); + } +} + bool RewriteMutator::Pre(parser::SubroutineSubprogram &subr) { FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(subr.t), std::get<parser::ExecutionPart>(subr.t).v); + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(subr.t).v); + } return true; } +void RewriteMutator::Post(parser::SubroutineSubprogram &subr) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(subr.t).v); + } +} + bool RewriteMutator::Pre(parser::SeparateModuleSubprogram &subp) { FixMisparsedStmtFuncs(std::get<parser::SpecificationPart>(subp.t), std::get<parser::ExecutionPart>(subp.t).v); + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(subp.t).v); + } return true; } +void RewriteMutator::Post(parser::SeparateModuleSubprogram &subp) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::ExecutionPart>(subp.t).v); + } +} + bool RewriteMutator::Pre(parser::BlockConstruct &block) { FixMisparsedStmtFuncs(std::get<parser::BlockSpecificationPart>(block.t).v, std::get<parser::Block>(block.t)); + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::Block>(block.t)); + } + return true; +} + +void RewriteMutator::Post(parser::BlockConstruct &block) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(std::get<parser::Block>(block.t)); + } +} + +bool RewriteMutator::Pre(parser::Block &block) { + if (context_.langOptions().OpenMPSimd) { + OpenMPSimdOnly(block); + } return true; } +void RewriteMutator::Post(parser::Block &block) { this->Pre(block); } + +bool RewriteMutator::Pre(parser::OmpBlockConstruct &block) { + if (context_.langOptions().OpenMPSimd) { + auto &innerBlock = std::get<parser::Block>(block.t); + OpenMPSimdOnly(innerBlock); + } + return true; +} + +void RewriteMutator::Post(parser::OmpBlockConstruct &block) { + this->Pre(block); +} + +bool RewriteMutator::Pre(parser::OpenMPLoopConstruct &ompLoop) { + if (context_.langOptions().OpenMPSimd) { + if (LoopConstructIsSIMD(&ompLoop)) { + return true; + } + // If we're looking at a non-simd OpenMP loop, we need to explicitly + // call OpenMPSimdOnly on the nested loop block while indicating where + // the block comes from. + auto &nest = std::get<std::optional<parser::NestedConstruct>>(ompLoop.t); + if (!nest.has_value()) { + return true; + } + if (auto *doConstruct = std::get_if<parser::DoConstruct>(&*nest)) { + auto &innerBlock = std::get<parser::Block>(doConstruct->t); + OpenMPSimdOnly(innerBlock, /*isNonSimdLoopBody=*/true); + } + } + return true; +} + +void RewriteMutator::Post(parser::OpenMPLoopConstruct &ompLoop) { + this->Pre(ompLoop); +} + +bool RewriteMutator::Pre(parser::DoConstruct &doConstruct) { + if (context_.langOptions().OpenMPSimd) { + auto &innerBlock = std::get<parser::Block>(doConstruct.t); + OpenMPSimdOnly(innerBlock); + } + return true; +} + +void RewriteMutator::Post(parser::DoConstruct &doConstruct) { + this->Pre(doConstruct); +} + +bool RewriteMutator::Pre(parser::IfConstruct &ifConstruct) { + if (context_.langOptions().OpenMPSimd) { + auto &innerBlock = std::get<parser::Block>(ifConstruct.t); + OpenMPSimdOnly(innerBlock); + } + return true; +} + +void RewriteMutator::Post(parser::IfConstruct &ifConstruct) { + this->Pre(ifConstruct); +} + // Rewrite PRINT NML -> WRITE(*,NML=NML) bool RewriteMutator::Pre(parser::ActionStmt &x) { if (auto *print{std::get_if<common::Indirection<parser::PrintStmt>>(&x.u)}; diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp index 2259cfc..a6b402c 100644 --- a/flang/lib/Semantics/symbol.cpp +++ b/flang/lib/Semantics/symbol.cpp @@ -611,7 +611,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) { sep = ','; } }, - [](const HostAssocDetails &) {}, + [&os](const HostAssocDetails &x) { os << " => " << x.symbol(); }, [&](const ProcBindingDetails &x) { os << " => " << x.symbol().name(); DumpOptional(os, "passName", x.passName()); diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index 913bf08..28829d3 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -705,7 +705,7 @@ SymbolVector FinalsForDerivedTypeInstantiation(const DerivedTypeSpec &spec) { const Symbol *IsFinalizable(const Symbol &symbol, std::set<const DerivedTypeSpec *> *inProgress, bool withImpureFinalizer) { - if (IsPointer(symbol) || evaluate::IsAssumedRank(symbol)) { + if (IsPointer(symbol) || IsAssumedRank(symbol)) { return nullptr; } if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()}) { @@ -741,7 +741,7 @@ const Symbol *IsFinalizable(const DerivedTypeSpec &derived, if (const SubprogramDetails * subp{symbol->detailsIf<SubprogramDetails>()}) { if (const auto &args{subp->dummyArgs()}; !args.empty() && - args.at(0) && !evaluate::IsAssumedRank(*args.at(0)) && + args.at(0) && !IsAssumedRank(*args.at(0)) && args.at(0)->Rank() != *rank) { continue; // not a finalizer for this rank } @@ -790,7 +790,7 @@ const Symbol *HasImpureFinal(const Symbol &original, std::optional<int> rank) { if (symbol.has<ObjectEntityDetails>()) { if (const DeclTypeSpec * symType{symbol.GetType()}) { if (const DerivedTypeSpec * derived{symType->AsDerived()}) { - if (evaluate::IsAssumedRank(symbol)) { + if (IsAssumedRank(symbol)) { // finalizable assumed-rank not allowed (C839) return nullptr; } else { @@ -1170,7 +1170,7 @@ bool IsAccessible(const Symbol &original, const Scope &scope) { } std::optional<parser::MessageFormattedText> CheckAccessibleSymbol( - const Scope &scope, const Symbol &symbol) { + const Scope &scope, const Symbol &symbol, bool inStructureConstructor) { if (IsAccessible(symbol, scope)) { return std::nullopt; } else if (FindModuleFileContaining(scope)) { @@ -1179,10 +1179,20 @@ std::optional<parser::MessageFormattedText> CheckAccessibleSymbol( // whose structure constructors reference private components. return std::nullopt; } else { + const Scope &module{DEREF(FindModuleContaining(symbol.owner()))}; + // Subtlety: Sometimes we want to be able to convert a generated + // module file back into Fortran, perhaps to convert it into a + // hermetic module file. Don't emit a fatal error for things like + // "__builtin_c_ptr(__address=0)" that came from expansions of + // "cptr_null()"; specifically, just warn about structure constructor + // component names from intrinsic modules when in a module. + parser::MessageFixedText text{FindModuleContaining(scope) && + module.parent().IsIntrinsicModules() && + inStructureConstructor && symbol.owner().IsDerivedType() + ? "PRIVATE name '%s' is accessible only within module '%s'"_warn_en_US + : "PRIVATE name '%s' is accessible only within module '%s'"_err_en_US}; return parser::MessageFormattedText{ - "PRIVATE name '%s' is accessible only within module '%s'"_err_en_US, - symbol.name(), - DEREF(FindModuleContaining(symbol.owner())).GetName().value()}; + std::move(text), symbol.name(), module.GetName().value()}; } } diff --git a/flang/lib/Semantics/unparse-with-symbols.cpp b/flang/lib/Semantics/unparse-with-symbols.cpp index 3093e39..b199481 100644 --- a/flang/lib/Semantics/unparse-with-symbols.cpp +++ b/flang/lib/Semantics/unparse-with-symbols.cpp @@ -47,6 +47,11 @@ public: return true; } void Post(const parser::OmpClause &) { currStmt_ = std::nullopt; } + bool Pre(const parser::OpenMPGroupprivate &dir) { + currStmt_ = dir.source; + return true; + } + void Post(const parser::OpenMPGroupprivate &) { currStmt_ = std::nullopt; } bool Pre(const parser::OpenMPThreadprivate &dir) { currStmt_ = dir.source; return true; @@ -70,20 +75,6 @@ public: currStmt_ = std::nullopt; } - bool Pre(const parser::OmpCriticalDirective &x) { - currStmt_ = x.source; - return true; - } - void Post(const parser::OmpCriticalDirective &) { currStmt_ = std::nullopt; } - - bool Pre(const parser::OmpEndCriticalDirective &x) { - currStmt_ = x.source; - return true; - } - void Post(const parser::OmpEndCriticalDirective &) { - currStmt_ = std::nullopt; - } - // Directive arguments can be objects with symbols. bool Pre(const parser::OmpBeginDirective &x) { currStmt_ = x.source; diff --git a/flang/lib/Support/Fortran-features.cpp b/flang/lib/Support/Fortran-features.cpp index df51b3c..4a6fb8d 100644 --- a/flang/lib/Support/Fortran-features.cpp +++ b/flang/lib/Support/Fortran-features.cpp @@ -90,6 +90,7 @@ LanguageFeatureControl::LanguageFeatureControl() { disable_.set(LanguageFeature::OldStyleParameter); // Possibly an accidental "feature" of nvfortran. disable_.set(LanguageFeature::AssumedRankPassedToNonAssumedRank); + disable_.set(LanguageFeature::Coarray); // These warnings are enabled by default, but only because they used // to be unconditional. TODO: prune this list warnLanguage_.set(LanguageFeature::ExponentMatchingKindParam); @@ -147,6 +148,7 @@ LanguageFeatureControl::LanguageFeatureControl() { warnUsage_.set(UsageWarning::UseAssociationIntoSameNameSubprogram); warnUsage_.set(UsageWarning::HostAssociatedIntentOutInSpecExpr); warnUsage_.set(UsageWarning::NonVolatilePointerToVolatile); + warnUsage_.set(UsageWarning::RealConstantWidening); // New warnings, on by default warnLanguage_.set(LanguageFeature::SavedLocalInSpecExpr); warnLanguage_.set(LanguageFeature::NullActualForAllocatable); diff --git a/flang/lib/Support/Fortran.cpp b/flang/lib/Support/Fortran.cpp index 8e286be..3a8ebbb 100644 --- a/flang/lib/Support/Fortran.cpp +++ b/flang/lib/Support/Fortran.cpp @@ -103,8 +103,8 @@ std::string AsFortran(IgnoreTKRSet tkr) { /// dummy argument attribute while `y` represents the actual argument attribute. bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x, std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR, - std::optional<std::string> *warning, bool allowUnifiedMatchingRule, - bool isHostDeviceProcedure, const LanguageFeatureControl *features) { + bool allowUnifiedMatchingRule, bool isHostDeviceProcedure, + const LanguageFeatureControl *features) { bool isCudaManaged{features ? features->IsEnabled(common::LanguageFeature::CudaManaged) : false}; @@ -145,9 +145,6 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x, *y == CUDADataAttr::Shared || *y == CUDADataAttr::Constant)) || (!y && (isCudaUnified || isCudaManaged))) { - if (y && *y == CUDADataAttr::Shared && warning) { - *warning = "SHARED attribute ignored"s; - } return true; } } else if (*x == CUDADataAttr::Managed) { diff --git a/flang/lib/Utils/CMakeLists.txt b/flang/lib/Utils/CMakeLists.txt new file mode 100644 index 0000000..2119b0e --- /dev/null +++ b/flang/lib/Utils/CMakeLists.txt @@ -0,0 +1,20 @@ +#===-- lib/Utils/CMakeLists.txt --------------------------------------------===# +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +#===------------------------------------------------------------------------===# + +add_flang_library(FortranUtils + OpenMP.cpp + + DEPENDS + FIRDialect + + LINK_LIBS + FIRDialect + + MLIR_LIBS + MLIROpenMPDialect +) diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp new file mode 100644 index 0000000..e1681e9 --- /dev/null +++ b/flang/lib/Utils/OpenMP.cpp @@ -0,0 +1,47 @@ +//===-- lib/Utisl/OpenMP.cpp ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Utils/OpenMP.h" + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" + +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +namespace Fortran::utils::openmp { +mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, + mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr, + llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, + llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex, + uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, + mlir::Type retTy, bool partialMap, mlir::FlatSymbolRefAttr mapperId) { + + if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { + baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr); + retTy = baseAddr.getType(); + } + + mlir::TypeAttr varType = mlir::TypeAttr::get( + llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType()); + + // For types with unknown extents such as <2x?xi32> we discard the incomplete + // type info and only retain the base type. The correct dimensions are later + // recovered through the bounds info. + if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) + if (seqType.hasDynamicExtents()) + varType = mlir::TypeAttr::get(seqType.getEleTy()); + + mlir::omp::MapInfoOp op = + mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varType, + builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), + varPtrPtr, members, membersIndex, bounds, mapperId, + builder.getStringAttr(name), builder.getBoolAttr(partialMap)); + return op; +} +} // namespace Fortran::utils::openmp |