From 9da9127fec7b0a252b80d60b09b8c0ccedb41672 Mon Sep 17 00:00:00 2001 From: lntue <35648136+lntue@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:40:44 -0400 Subject: [libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. (#99432) The errors were reported by Paul Zimmermann with the CORE-MATH project's test suites: ``` zimmerma@tartine:/tmp/core-math$ CORE_MATH_CHECK_STD=true LIBM=$L ./check.sh hypot Running worst cases check in --rndn mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndz mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndu mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndd mode... Spurious inexact exception for x=0x1.ffffffffffffep+24 y=0x1p+0 (z=0x1.0000000000001p+25) ``` --- libc/src/__support/FPUtil/Hypot.h | 63 +++++++++++----------- libc/src/math/generic/CMakeLists.txt | 5 +- libc/src/math/generic/hypotf.cpp | 100 ++++++++++++++++++++++------------- 3 files changed, 96 insertions(+), 72 deletions(-) (limited to 'libc/src') diff --git a/libc/src/__support/FPUtil/Hypot.h b/libc/src/__support/FPUtil/Hypot.h index a5a991e..6aa8084 100644 --- a/libc/src/__support/FPUtil/Hypot.h +++ b/libc/src/__support/FPUtil/Hypot.h @@ -109,45 +109,39 @@ LIBC_INLINE T hypot(T x, T y) { using StorageType = typename FPBits::StorageType; using DStorageType = typename DoubleLength::Type; - FPBits_t x_bits(x), y_bits(y); + FPBits_t x_abs = FPBits_t(x).abs(); + FPBits_t y_abs = FPBits_t(y).abs(); - if (x_bits.is_inf() || y_bits.is_inf()) { - return FPBits_t::inf().get_val(); - } - if (x_bits.is_nan()) { - return x; - } - if (y_bits.is_nan()) { + bool x_abs_larger = x_abs.uintval() >= y_abs.uintval(); + + FPBits_t a_bits = x_abs_larger ? x_abs : y_abs; + FPBits_t b_bits = x_abs_larger ? y_abs : x_abs; + + if (LIBC_UNLIKELY(a_bits.is_inf_or_nan())) { + if (x_abs.is_signaling_nan() || y_abs.is_signaling_nan()) { + fputil::raise_except_if_required(FE_INVALID); + return FPBits_t::quiet_nan().get_val(); + } + if (x_abs.is_inf() || y_abs.is_inf()) + return FPBits_t::inf().get_val(); + if (x_abs.is_nan()) + return x; + // y is nan return y; } - uint16_t x_exp = x_bits.get_biased_exponent(); - uint16_t y_exp = y_bits.get_biased_exponent(); - uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp); + uint16_t a_exp = a_bits.get_biased_exponent(); + uint16_t b_exp = b_bits.get_biased_exponent(); - if ((exp_diff >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) { - return abs(x) + abs(y); - } + if ((a_exp - b_exp >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) + return x_abs.get_val() + y_abs.get_val(); - uint16_t a_exp, b_exp, out_exp; - StorageType a_mant, b_mant; + uint64_t out_exp = a_exp; + StorageType a_mant = a_bits.get_mantissa(); + StorageType b_mant = b_bits.get_mantissa(); DStorageType a_mant_sq, b_mant_sq; bool sticky_bits; - if (abs(x) >= abs(y)) { - a_exp = x_exp; - a_mant = x_bits.get_mantissa(); - b_exp = y_exp; - b_mant = y_bits.get_mantissa(); - } else { - a_exp = y_exp; - a_mant = y_bits.get_mantissa(); - b_exp = x_exp; - b_mant = x_bits.get_mantissa(); - } - - out_exp = a_exp; - // Add an extra bit to simplify the final rounding bit computation. constexpr StorageType ONE = StorageType(1) << (FPBits_t::FRACTION_LEN + 1); @@ -165,11 +159,10 @@ LIBC_INLINE T hypot(T x, T y) { a_exp = 1; } - if (b_exp != 0) { + if (b_exp != 0) b_mant |= ONE; - } else { + else b_exp = 1; - } a_mant_sq = static_cast(a_mant) * a_mant; b_mant_sq = static_cast(b_mant) * b_mant; @@ -260,6 +253,10 @@ LIBC_INLINE T hypot(T x, T y) { } y_new |= static_cast(out_exp) << FPBits_t::FRACTION_LEN; + + if (!(round_bit || sticky_bits || (r != 0))) + fputil::clear_except_if_required(FE_INEXACT); + return cpp::bit_cast(y_new); } diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt index 9c86bac..d775026 100644 --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -2830,9 +2830,12 @@ add_entrypoint_object( HDRS ../hypotf.h DEPENDS - libc.src.__support.FPUtil.basic_operations + libc.src.__support.FPUtil.double_double + libc.src.__support.FPUtil.fenv_impl libc.src.__support.FPUtil.fp_bits + libc.src.__support.FPUtil.multiply_add libc.src.__support.FPUtil.sqrt + libc.src.__support.macros.optimization COMPILE_OPTIONS -O3 ) diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp index 75c55ed..959c042 100644 --- a/libc/src/math/generic/hypotf.cpp +++ b/libc/src/math/generic/hypotf.cpp @@ -6,11 +6,14 @@ // //===----------------------------------------------------------------------===// #include "src/math/hypotf.h" -#include "src/__support/FPUtil/BasicOperations.h" +#include "src/__support/FPUtil/FEnvImpl.h" #include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/double_double.h" +#include "src/__support/FPUtil/multiply_add.h" #include "src/__support/FPUtil/sqrt.h" #include "src/__support/common.h" #include "src/__support/macros/config.h" +#include "src/__support/macros/optimization.h" namespace LIBC_NAMESPACE_DECL { @@ -18,54 +21,75 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) { using DoubleBits = fputil::FPBits; using FPBits = fputil::FPBits; - FPBits x_bits(x), y_bits(y); + FPBits x_abs = FPBits(x).abs(); + FPBits y_abs = FPBits(y).abs(); - uint16_t x_exp = x_bits.get_biased_exponent(); - uint16_t y_exp = y_bits.get_biased_exponent(); - uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp); + bool x_abs_larger = x_abs.uintval() >= y_abs.uintval(); - if (exp_diff >= FPBits::FRACTION_LEN + 2) { - return fputil::abs(x) + fputil::abs(y); - } + FPBits a_bits = x_abs_larger ? x_abs : y_abs; + FPBits b_bits = x_abs_larger ? y_abs : x_abs; - double xd = static_cast(x); - double yd = static_cast(y); + uint32_t a_u = a_bits.uintval(); + uint32_t b_u = b_bits.uintval(); - // These squares are exact. - double x_sq = xd * xd; - double y_sq = yd * yd; + // Note: replacing `a_u >= FPBits::EXP_MASK` with `a_bits.is_inf_or_nan()` + // generates extra exponent bit masking instructions on x86-64. + if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) { + // x or y is inf or nan + if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) { + fputil::raise_except_if_required(FE_INVALID); + return FPBits::quiet_nan().get_val(); + } + if (a_bits.is_inf() || b_bits.is_inf()) + return FPBits::inf().get_val(); + return a_bits.get_val(); + } - // Compute the sum of squares. - double sum_sq = x_sq + y_sq; + if (LIBC_UNLIKELY(a_u - b_u >= + static_cast((FPBits::FRACTION_LEN + 2) + << FPBits::FRACTION_LEN))) + return x_abs.get_val() + y_abs.get_val(); - // Compute the rounding error with Fast2Sum algorithm: - // x_sq + y_sq = sum_sq - err - double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq; + double ad = static_cast(a_bits.get_val()); + double bd = static_cast(b_bits.get_val()); + + // These squares are exact. + double a_sq = ad * ad; +#ifdef LIBC_TARGET_CPU_HAS_FMA + double sum_sq = fputil::multiply_add(bd, bd, a_sq); +#else + double b_sq = bd * bd; + double sum_sq = a_sq + b_sq; +#endif // Take sqrt in double precision. DoubleBits result(fputil::sqrt(sum_sq)); + uint64_t r_u = result.uintval(); - if (!DoubleBits(sum_sq).is_inf_or_nan()) { - // Correct rounding. - double r_sq = result.get_val() * result.get_val(); - double diff = sum_sq - r_sq; - constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL; - uint64_t lrs = result.uintval() & MASK; - - if (lrs == 0x0000'0000'1000'0000ULL && err < diff) { - result.set_uintval(result.uintval() | 1ULL); - } else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) { - result.set_uintval(result.uintval() - 1ULL); - } - } else { - FPBits bits_x(x), bits_y(y); - if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) { - if (bits_x.is_inf() || bits_y.is_inf()) - return FPBits::inf().get_val(); - if (bits_x.is_nan()) - return x; - return y; + // If any of the sticky bits of the result are non-zero, except the LSB, then + // the rounded result is correct. + if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) { + double r_d = result.get_val(); + + // Perform rounding correction. +#ifdef LIBC_TARGET_CPU_HAS_FMA + double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq); + double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq); +#else + fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d); + double sum_sq_lo = b_sq - (sum_sq - a_sq); + double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo); +#endif + + if (err > 0) { + r_u |= 1; + } else if ((err < 0) && (r_u & 1) == 0) { + r_u -= 1; + } else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) { + // The rounded result is exact. + fputil::clear_except_if_required(FE_INEXACT); } + return static_cast(DoubleBits(r_u).get_val()); } return static_cast(result.get_val()); -- cgit v1.1