diff options
| author | Peter Klausler <pklausler@nvidia.com> | 2026-01-12 15:40:44 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-12 15:40:44 -0800 |
| commit | a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da (patch) | |
| tree | aa488d58663dc908865a6c672ab26aef74dfe58f | |
| parent | 3874c4541a2e7fe044f02cda962657be35314deb (diff) | |
| download | llvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.tar.gz llvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.tar.bz2 llvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.zip | |
[flang] Fix spurious NaN result from infinite Kahan summation (#175373)
There are six instances of Kahan's extended precision summation
algorithm in flang/flang-rt, and they share a bug: the calculation of
the correction value produces a Nan due to the subtraction Inf-Inf after
the accumulation saturates to Inf. This leads to the surprising Nan
result from SUM([Inf, 0.]).
This bug doesn't affect run-time calculation of SUM when optimization is
enabled -- lowering emits an open-coded SUM that lacks Kahan summation
-- but it does affect compilation-time folding and -O0 runtime results.
Fix the one instance of Kahan summation in the runtime, and consolidate
the other five instances in Evaluate into one new member function, also
corrected.
Fixes https://github.com/llvm/llvm-project/issues/89528.
| -rw-r--r-- | flang-rt/lib/runtime/sum.cpp | 12 | ||||
| -rw-r--r-- | flang-rt/unittests/Runtime/Reduction.cpp | 16 | ||||
| -rw-r--r-- | flang/include/flang/Evaluate/complex.h | 3 | ||||
| -rw-r--r-- | flang/include/flang/Evaluate/real.h | 2 | ||||
| -rw-r--r-- | flang/lib/Evaluate/complex.cpp | 11 | ||||
| -rw-r--r-- | flang/lib/Evaluate/fold-matmul.h | 11 | ||||
| -rw-r--r-- | flang/lib/Evaluate/fold-real.cpp | 7 | ||||
| -rw-r--r-- | flang/lib/Evaluate/fold-reduction.h | 30 | ||||
| -rw-r--r-- | flang/lib/Evaluate/real.cpp | 18 | ||||
| -rw-r--r-- | flang/test/Evaluate/bug89528.f90 | 8 |
10 files changed, 78 insertions, 40 deletions
diff --git a/flang-rt/lib/runtime/sum.cpp b/flang-rt/lib/runtime/sum.cpp index a76e228f18a4..0c540606f27c 100644 --- a/flang-rt/lib/runtime/sum.cpp +++ b/flang-rt/lib/runtime/sum.cpp @@ -54,9 +54,15 @@ public: template <typename A> RT_API_ATTRS bool Accumulate(A x) { // Kahan summation auto next{x - correction_}; - auto oldSum{sum_}; - sum_ += next; - correction_ = (sum_ - oldSum) - next; // algebraically zero + if (next != next) { + // Avoid propagating an accidental Nan from Inf-Inf in corrections + sum_ += x; + correction_ = 0; + } else { + auto oldSum{sum_}; + sum_ += next; + correction_ = (sum_ - oldSum) - next; // algebraically zero + } return true; } template <typename A> diff --git a/flang-rt/unittests/Runtime/Reduction.cpp b/flang-rt/unittests/Runtime/Reduction.cpp index 3701a32042c5..ac6fed1e34a9 100644 --- a/flang-rt/unittests/Runtime/Reduction.cpp +++ b/flang-rt/unittests/Runtime/Reduction.cpp @@ -672,3 +672,19 @@ TEST(Reductions, ReduceInt4Dim) { EXPECT_EQ(*sums.ZeroBasedIndexedElement<std::int32_t>(1), 6); sums.Destroy(); } + +TEST(Reductions, InfSums) { + float inf{1.0f / 0.0f}; + auto inf0{MakeArray<TypeCategory::Real, 4>( + std::vector<int>{2, 3}, std::vector<float>{inf, 0.0f})}; + auto t1{RTNAME(SumReal4)(*inf0, __FILE__, __LINE__)}; + EXPECT_EQ(t1, inf) << t1; + auto infMinusInf{MakeArray<TypeCategory::Real, 4>( + std::vector<int>{2, 3}, std::vector<float>{inf, -inf})}; + auto t2{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)}; + EXPECT_NE(t2, t2) << t2; + auto minusInfInf{MakeArray<TypeCategory::Real, 4>( + std::vector<int>{2, 3}, std::vector<float>{-inf, inf})}; + auto t3{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)}; + EXPECT_NE(t3, t3) << t3; +} diff --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h index 720ccaf512df..9781db9a25a6 100644 --- a/flang/include/flang/Evaluate/complex.h +++ b/flang/include/flang/Evaluate/complex.h @@ -77,6 +77,9 @@ public: Rounding rounding = TargetCharacteristics::defaultRounding) const; ValueWithRealFlags<Complex> Divide(const Complex &, Rounding rounding = TargetCharacteristics::defaultRounding) const; + ValueWithRealFlags<Complex> KahanSummation(const Complex &, + Complex &correction, + Rounding rounding = TargetCharacteristics::defaultRounding) const; // ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2) ValueWithRealFlags<Part> ABS( diff --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h index dcd74073a473..c0a966820d13 100644 --- a/flang/include/flang/Evaluate/real.h +++ b/flang/include/flang/Evaluate/real.h @@ -175,6 +175,8 @@ public: Rounding rounding = TargetCharacteristics::defaultRounding) const; ValueWithRealFlags<Real> MODULO(const Real &, Rounding rounding = TargetCharacteristics::defaultRounding) const; + ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction, + Rounding rounding = TargetCharacteristics::defaultRounding) const; template <typename INT> constexpr INT EXPONENT() const { if (Exponent() == maxExponent) { diff --git a/flang/lib/Evaluate/complex.cpp b/flang/lib/Evaluate/complex.cpp index ab83f193e3f3..a245fb38c82b 100644 --- a/flang/lib/Evaluate/complex.cpp +++ b/flang/lib/Evaluate/complex.cpp @@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide( return {Complex{re, im}, flags}; } +template <typename R> +ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation( + const Complex &that, Complex &correction, Rounding rounding) const { + RealFlags flags; + Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding) + .AccumulateFlags(flags)}; + Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding) + .AccumulateFlags(flags)}; + return {Complex{reSum, imSum}, flags}; +} + template <typename R> std::string Complex<R>::DumpHexadecimal() const { std::string result{'('}; result += re_.DumpHexadecimal(); diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h index ae9221f9ce04..a8a24c09774e 100644 --- a/flang/lib/Evaluate/fold-matmul.h +++ b/flang/lib/Evaluate/fold-matmul.h @@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { auto product{aElt.Multiply(bElt)}; overflow |= product.flags.test(RealFlag::Overflow); if constexpr (useKahanSummation) { - auto next{product.value.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(product.value, correction)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(product.value)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } else if constexpr (T::category == TypeCategory::Integer) { auto product{aElt.MultiplySigned(bElt)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 907d01b005a0..9c591e2ef36e 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -77,13 +77,8 @@ public: auto scaled{item.Divide(scale).value}; auto square{scaled.Multiply(scaled).value}; if constexpr (useKahanSummation) { - auto next{square.Subtract(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; + auto sum{element.KahanSummation(square, correction_, rounding_)}; overflow_ |= sum.flags.test(RealFlag::Overflow); - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; element = sum.value; } else { auto sum{element.Add(square, rounding_)}; diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index fe897393fe13..a06836413529 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct( const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { if constexpr (useKahanSummation) { - auto next{x.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(x, correction, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(x, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } } else if constexpr (T::category == TypeCategory::Logical) { @@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct( const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { if constexpr (useKahanSummation) { - auto next{x.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(x, correction, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(x, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } } @@ -357,14 +347,8 @@ public: } else if constexpr (T::category == TypeCategory::Unsigned) { element = element.AddUnsigned(array_.At(at)).value; } else { // Real & Complex: use Kahan summation - auto next{array_.At(at).Subtract(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; + auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)}; overflow_ |= sum.flags.test(RealFlag::Overflow); - // correction = (sum - element) - next; algebraically zero - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; element = sum.value; } } diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp index 6e6b9f3ac77c..eb335ce32851 100644 --- a/flang/lib/Evaluate/real.cpp +++ b/flang/lib/Evaluate/real.cpp @@ -466,6 +466,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO( } template <typename W, int P> +ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation( + const Real &y, Real &correction, Rounding rounding) const { + Real next{y.Subtract(correction, rounding).value}; + if (next.IsNotANumber()) { + // Avoid propagating an accidental NaN from Inf-Inf in corrections + correction = Real{}; // 0. + return Add(y, rounding); + } else { + auto sum{Add(next, rounding)}; + // correction = (sum - *this) - next; algebraically zero + correction = sum.value.Subtract(*this, rounding) + .value.Subtract(next, rounding) + .value; + return sum; + } +} + +template <typename W, int P> ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM( const Real &y, Rounding rounding) const { ValueWithRealFlags<Real> result; diff --git a/flang/test/Evaluate/bug89528.f90 b/flang/test/Evaluate/bug89528.f90 new file mode 100644 index 000000000000..281f35bf653c --- /dev/null +++ b/flang/test/Evaluate/bug89528.f90 @@ -0,0 +1,8 @@ +!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s +!CHECK: REAL :: avoidkahannan = (1._4/0.) +real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN +!CHECK: REAL :: expectnan1 = (0._4/0.) +real :: expectNaN1 = sum([1./0., -1./0.]) +!CHECK: REAL :: expectnan2 = (0._4/0.) +real :: expectNaN2 = sum([-1./0., 1./0.]) +end |
