//===-- MPCUtils.h ----------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H #define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H #include "hdr/stdint_proxy.h" #include "src/__support/CPP/type_traits.h" #include "src/__support/complex_type.h" #include "src/__support/macros/config.h" #include "src/__support/macros/properties/complex_types.h" #include "src/__support/macros/properties/types.h" #include "test/UnitTest/RoundingModeUtils.h" #include "test/UnitTest/Test.h" namespace LIBC_NAMESPACE_DECL { namespace testing { namespace mpc { enum class Operation { // Operations which take a single complex floating point number as input // and produce a single floating point number as output which has the same // floating point type as the real/imaginary part of the input. BeginUnaryOperationsSingleOutputDifferentOutputType, Carg, Cabs, EndUnaryOperationsSingleOutputDifferentOutputType, // Operations which take a single complex floating point number as input // and produce a single complex floating point number of the same kind // as output. BeginUnaryOperationsSingleOutputSameOutputType, Cproj, Csqrt, Clog, Cexp, Csinh, Ccosh, Ctanh, Casinh, Cacosh, Catanh, Csin, Ccos, Ctan, Casin, Cacos, Catan, EndUnaryOperationsSingleOutputSameOutputType, // Operations which take two complex floating point numbers as input // and produce a single complex floating point number of the same kind // as output. BeginBinaryOperationsSingleOutput, Cpow, EndBinaryOperationsSingleOutput, }; using LIBC_NAMESPACE::fputil::testing::RoundingMode; template struct BinaryInput { static_assert(LIBC_NAMESPACE::cpp::is_complex_v, "Template parameter of BinaryInput must be a complex floating " "point type."); using Type = T; T x, y; }; namespace internal { template bool compare_unary_operation_single_output_same_type(Operation op, InputType input, OutputType libc_output, double ulp_tolerance, RoundingMode rounding); template bool compare_unary_operation_single_output_different_type( Operation op, InputType input, OutputType libc_output, double ulp_tolerance, RoundingMode rounding); template bool compare_binary_operation_one_output(Operation op, const BinaryInput &input, OutputType libc_output, double ulp_tolerance, RoundingMode rounding); template void explain_unary_operation_single_output_same_type_error( Operation op, InputType input, OutputType match_value, double ulp_tolerance, RoundingMode rounding); template void explain_unary_operation_single_output_different_type_error( Operation op, InputType input, OutputType match_value, double ulp_tolerance, RoundingMode rounding); template void explain_binary_operation_one_output_error( Operation op, const BinaryInput &input, OutputType match_value, double ulp_tolerance, RoundingMode rounding); template class MPCMatcher : public testing::Matcher { private: InputType input; OutputType match_value; double ulp_tolerance; RoundingMode rounding; public: MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} bool match(OutputType libcResult) { match_value = libcResult; return match(input, match_value); } void explainError() override { // NOLINT explain_error(input, match_value); } private: template bool match(InType in, OutType out) { if (cpp::is_same_v) { return compare_unary_operation_single_output_same_type( op, in, out, ulp_tolerance, rounding); } else { return compare_unary_operation_single_output_different_type( op, in, out, ulp_tolerance, rounding); } } template bool match(const BinaryInput &in, U out) { return compare_binary_operation_one_output(op, in, out, ulp_tolerance, rounding); } template void explain_error(InType in, OutType out) { if (cpp::is_same_v) { explain_unary_operation_single_output_same_type_error( op, in, out, ulp_tolerance, rounding); } else { explain_unary_operation_single_output_different_type_error( op, in, out, ulp_tolerance, rounding); } } template void explain_error(const BinaryInput &in, U out) { explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, rounding); } }; } // namespace internal // Return true if the input and ouput types for the operation op are valid // types. template constexpr bool is_valid_operation() { return (Operation::BeginBinaryOperationsSingleOutput < op && op < Operation::EndBinaryOperationsSingleOutput && cpp::is_complex_type_same() && cpp::is_complex_v) || (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op && op < Operation::EndUnaryOperationsSingleOutputSameOutputType && cpp::is_complex_type_same() && cpp::is_complex_v) || (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op && op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType && cpp::is_same_v, OutputType> && cpp::is_complex_v); } template cpp::enable_if_t(), internal::MPCMatcher> get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output, double ulp_tolerance, RoundingMode rounding) { return internal::MPCMatcher(input, ulp_tolerance, rounding); } } // namespace mpc } // namespace testing } // namespace LIBC_NAMESPACE_DECL #define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ EXPECT_THAT(match_value, \ LIBC_NAMESPACE::testing::mpc::get_mpc_matcher( \ input, match_value, ulp_tolerance, \ LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) #define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ rounding) \ EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher( \ input, match_value, ulp_tolerance, rounding)) #define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ ulp_tolerance, rounding) \ { \ MPCRND::ForceRoundingMode __r(rounding); \ if (__r.success) { \ EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ rounding); \ } \ } #define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ { \ namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ for (int i = 0; i < 4; i++) { \ MPCRND::RoundingMode r_mode = static_cast(i); \ EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ ulp_tolerance, r_mode); \ } \ } #define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ rounding) \ LIBC_NAMESPACE::testing::mpc::get_mpc_matcher(input, match_value, \ ulp_tolerance, rounding) \ .match(match_value) #define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ ASSERT_THAT(match_value, \ LIBC_NAMESPACE::testing::mpc::get_mpc_matcher( \ input, match_value, ulp_tolerance, \ LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest)) #define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ rounding) \ ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher( \ input, match_value, ulp_tolerance, rounding)) #define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ ulp_tolerance, rounding) \ { \ MPCRND::ForceRoundingMode __r(rounding); \ if (__r.success) { \ ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ rounding); \ } \ } #define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ { \ namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \ for (int i = 0; i < 4; i++) { \ MPCRND::RoundingMode r_mode = static_cast(i); \ ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \ ulp_tolerance, r_mode); \ } \ } #endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H