//===-- runtime/product.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Implements PRODUCT for all required operand types and shapes. #include "reduction-templates.h" #include "flang/Common/float128.h" #include "flang/Runtime/reduction.h" #include #include #include namespace Fortran::runtime { template class NonComplexProductAccumulator { public: explicit RT_API_ATTRS NonComplexProductAccumulator(const Descriptor &array) : array_{array} {} RT_API_ATTRS void Reinitialize() { product_ = 1; } template RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const { *p = static_cast(product_); } template RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) { product_ *= *array_.Element(at); return product_ != 0; } private: const Descriptor &array_; INTERMEDIATE product_{1}; }; // Suppress the warnings about calling __host__-only std::complex operators, // defined in C++ STD header files, from __device__ code. RT_DIAG_PUSH RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN template class ComplexProductAccumulator { public: explicit RT_API_ATTRS ComplexProductAccumulator(const Descriptor &array) : array_{array} {} RT_API_ATTRS void Reinitialize() { product_ = std::complex{1, 0}; } template RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const { using ResultPart = typename A::value_type; *p = {static_cast(product_.real()), static_cast(product_.imag())}; } template RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) { product_ *= *array_.Element(at); return true; } private: const Descriptor &array_; std::complex product_{1, 0}; }; RT_DIAG_POP extern "C" { RT_EXT_API_GROUP_BEGIN CppTypeFor RTDEF(ProductInteger1)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } CppTypeFor RTDEF(ProductInteger2)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } CppTypeFor RTDEF(ProductInteger4)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } CppTypeFor RTDEF(ProductInteger8)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } #ifdef __SIZEOF_INT128__ CppTypeFor RTDEF(ProductInteger16)( const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } #endif // TODO: real/complex(2 & 3) CppTypeFor RTDEF(ProductReal4)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } CppTypeFor RTDEF(ProductReal8)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } #if LDBL_MANT_DIG == 64 CppTypeFor RTDEF(ProductReal10)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } #endif #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 CppTypeFor RTDEF(ProductReal16)(const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { return GetTotalReduction(x, source, line, dim, mask, NonComplexProductAccumulator>{x}, "PRODUCT"); } #endif void RTDEF(CppProductComplex4)(CppTypeFor &result, const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { result = GetTotalReduction(x, source, line, dim, mask, ComplexProductAccumulator>{x}, "PRODUCT"); } void RTDEF(CppProductComplex8)(CppTypeFor &result, const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { result = GetTotalReduction(x, source, line, dim, mask, ComplexProductAccumulator>{x}, "PRODUCT"); } #if LDBL_MANT_DIG == 64 void RTDEF(CppProductComplex10)(CppTypeFor &result, const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { result = GetTotalReduction(x, source, line, dim, mask, ComplexProductAccumulator>{x}, "PRODUCT"); } #endif #if LDBL_MANT_DIG == 113 || HAS_FLOAT128 void RTDEF(CppProductComplex16)(CppTypeFor &result, const Descriptor &x, const char *source, int line, int dim, const Descriptor *mask) { result = GetTotalReduction(x, source, line, dim, mask, ComplexProductAccumulator>{x}, "PRODUCT"); } #endif void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim, const char *source, int line, const Descriptor *mask) { TypedPartialNumericReduction(result, x, dim, source, line, mask, "PRODUCT"); } RT_EXT_API_GROUP_END } // extern "C" } // namespace Fortran::runtime