//===-- runtime/dot-product.cpp -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "float.h" #include "terminator.h" #include "tools.h" #include "flang/Runtime/cpp-type.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/reduction.h" #include #include namespace Fortran::runtime { // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first // argument; MATMUL does not. // General accumulator for any type and stride; this is not used for // contiguous numeric vectors. template class Accumulator { public: using Result = AccumulationType; Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { if constexpr (RCAT == TypeCategory::Logical) { sum_ = sum_ || (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); } else { const XT &xElement{*x_.Element(&xAt)}; const YT &yElement{*y_.Element(&yAt)}; if constexpr (RCAT == TypeCategory::Complex) { sum_ += std::conj(static_cast(xElement)) * static_cast(yElement); } else { sum_ += static_cast(xElement) * static_cast(yElement); } } } Result GetResult() const { return sum_; } private: const Descriptor &x_, &y_; Result sum_{}; }; template static inline CppTypeFor DoDotProduct( const Descriptor &x, const Descriptor &y, Terminator &terminator) { using Result = CppTypeFor; RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); SubscriptValue n{x.GetDimension(0).Extent()}; if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { terminator.Crash( "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", static_cast(n), static_cast(yN)); } if constexpr (RCAT != TypeCategory::Logical) { if (x.GetDimension(0).ByteStride() == sizeof(XT) && y.GetDimension(0).ByteStride() == sizeof(YT)) { // Contiguous numeric vectors if constexpr (std::is_same_v) { // Contiguous homogeneous numeric vectors if constexpr (std::is_same_v) { // TODO: call BLAS-1 SDOT or SDSDOT } else if constexpr (std::is_same_v) { // TODO: call BLAS-1 DDOT } else if constexpr (std::is_same_v>) { // TODO: call BLAS-1 CDOTC } else if constexpr (std::is_same_v>) { // TODO: call BLAS-1 ZDOTC } } XT *xp{x.OffsetElement(0)}; YT *yp{y.OffsetElement(0)}; using AccumType = AccumulationType; AccumType accum{}; if constexpr (RCAT == TypeCategory::Complex) { for (SubscriptValue j{0}; j < n; ++j) { accum += std::conj(static_cast(*xp++)) * static_cast(*yp++); } } else { for (SubscriptValue j{0}; j < n; ++j) { accum += static_cast(*xp++) * static_cast(*yp++); } } return static_cast(accum); } } // Non-contiguous, heterogeneous, & LOGICAL cases SubscriptValue xAt{x.GetDimension(0).LowerBound()}; SubscriptValue yAt{y.GetDimension(0).LowerBound()}; Accumulator accumulator{x, y}; for (SubscriptValue j{0}; j < n; ++j) { accumulator.AccumulateIndexed(xAt++, yAt++); } return static_cast(accumulator.GetResult()); } template struct DotProduct { using Result = CppTypeFor; template struct DP1 { template struct DP2 { Result operator()(const Descriptor &x, const Descriptor &y, Terminator &terminator) const { if constexpr (constexpr auto resultType{ GetResultType(XCAT, XKIND, YCAT, YKIND)}) { if constexpr (resultType->first == RCAT && (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { return DoDotProduct, CppTypeFor>(x, y, terminator); } } terminator.Crash( "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", static_cast(RCAT), RKIND, static_cast(XCAT), XKIND, static_cast(YCAT), YKIND); } }; Result operator()(const Descriptor &x, const Descriptor &y, Terminator &terminator, TypeCategory yCat, int yKind) const { return ApplyType(yCat, yKind, terminator, x, y, terminator); } }; Result operator()(const Descriptor &x, const Descriptor &y, const char *source, int line) const { Terminator terminator{source, line}; if (RCAT != TypeCategory::Logical && x.type() == y.type()) { // No conversions needed, operands and result have same known type return typename DP1::template DP2{}( x, y, terminator); } else { auto xCatKind{x.type().GetCategoryAndKind()}; auto yCatKind{y.type().GetCategoryAndKind()}; RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); return ApplyType(xCatKind->first, xCatKind->second, terminator, x, y, terminator, yCatKind->first, yCatKind->second); } } }; extern "C" { CppTypeFor RTNAME(DotProductInteger1)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } CppTypeFor RTNAME(DotProductInteger2)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } CppTypeFor RTNAME(DotProductInteger4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } CppTypeFor RTNAME(DotProductInteger8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #ifdef __SIZEOF_INT128__ CppTypeFor RTNAME(DotProductInteger16)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #endif // TODO: REAL/COMPLEX(2 & 3) // Intermediate results and operations are at least 64 bits CppTypeFor RTNAME(DotProductReal4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } CppTypeFor RTNAME(DotProductReal8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #if LDBL_MANT_DIG == 64 CppTypeFor RTNAME(DotProductReal10)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #endif #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 CppTypeFor RTNAME(DotProductReal16)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } #endif void RTNAME(CppDotProductComplex4)(CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } void RTNAME(CppDotProductComplex8)(CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } #if LDBL_MANT_DIG == 64 void RTNAME(CppDotProductComplex10)( CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } #endif #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 void RTNAME(CppDotProductComplex16)( CppTypeFor &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { result = DotProduct{}(x, y, source, line); } #endif bool RTNAME(DotProductLogical)( const Descriptor &x, const Descriptor &y, const char *source, int line) { return DotProduct{}(x, y, source, line); } } // extern "C" } // namespace Fortran::runtime