aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Klausler <pklausler@nvidia.com>2026-01-12 15:40:44 -0800
committerGitHub <noreply@github.com>2026-01-12 15:40:44 -0800
commita8ba9c4fa2ae1ae80738e7848c28c19a9aa564da (patch)
treeaa488d58663dc908865a6c672ab26aef74dfe58f
parent3874c4541a2e7fe044f02cda962657be35314deb (diff)
downloadllvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.tar.gz
llvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.tar.bz2
llvm-a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da.zip
[flang] Fix spurious NaN result from infinite Kahan summation (#175373)
There are six instances of Kahan's extended precision summation algorithm in flang/flang-rt, and they share a bug: the calculation of the correction value produces a Nan due to the subtraction Inf-Inf after the accumulation saturates to Inf. This leads to the surprising Nan result from SUM([Inf, 0.]). This bug doesn't affect run-time calculation of SUM when optimization is enabled -- lowering emits an open-coded SUM that lacks Kahan summation -- but it does affect compilation-time folding and -O0 runtime results. Fix the one instance of Kahan summation in the runtime, and consolidate the other five instances in Evaluate into one new member function, also corrected. Fixes https://github.com/llvm/llvm-project/issues/89528.
-rw-r--r--flang-rt/lib/runtime/sum.cpp12
-rw-r--r--flang-rt/unittests/Runtime/Reduction.cpp16
-rw-r--r--flang/include/flang/Evaluate/complex.h3
-rw-r--r--flang/include/flang/Evaluate/real.h2
-rw-r--r--flang/lib/Evaluate/complex.cpp11
-rw-r--r--flang/lib/Evaluate/fold-matmul.h11
-rw-r--r--flang/lib/Evaluate/fold-real.cpp7
-rw-r--r--flang/lib/Evaluate/fold-reduction.h30
-rw-r--r--flang/lib/Evaluate/real.cpp18
-rw-r--r--flang/test/Evaluate/bug89528.f908
10 files changed, 78 insertions, 40 deletions
diff --git a/flang-rt/lib/runtime/sum.cpp b/flang-rt/lib/runtime/sum.cpp
index a76e228f18a4..0c540606f27c 100644
--- a/flang-rt/lib/runtime/sum.cpp
+++ b/flang-rt/lib/runtime/sum.cpp
@@ -54,9 +54,15 @@ public:
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
auto next{x - correction_};
- auto oldSum{sum_};
- sum_ += next;
- correction_ = (sum_ - oldSum) - next; // algebraically zero
+ if (next != next) {
+ // Avoid propagating an accidental Nan from Inf-Inf in corrections
+ sum_ += x;
+ correction_ = 0;
+ } else {
+ auto oldSum{sum_};
+ sum_ += next;
+ correction_ = (sum_ - oldSum) - next; // algebraically zero
+ }
return true;
}
template <typename A>
diff --git a/flang-rt/unittests/Runtime/Reduction.cpp b/flang-rt/unittests/Runtime/Reduction.cpp
index 3701a32042c5..ac6fed1e34a9 100644
--- a/flang-rt/unittests/Runtime/Reduction.cpp
+++ b/flang-rt/unittests/Runtime/Reduction.cpp
@@ -672,3 +672,19 @@ TEST(Reductions, ReduceInt4Dim) {
EXPECT_EQ(*sums.ZeroBasedIndexedElement<std::int32_t>(1), 6);
sums.Destroy();
}
+
+TEST(Reductions, InfSums) {
+ float inf{1.0f / 0.0f};
+ auto inf0{MakeArray<TypeCategory::Real, 4>(
+ std::vector<int>{2, 3}, std::vector<float>{inf, 0.0f})};
+ auto t1{RTNAME(SumReal4)(*inf0, __FILE__, __LINE__)};
+ EXPECT_EQ(t1, inf) << t1;
+ auto infMinusInf{MakeArray<TypeCategory::Real, 4>(
+ std::vector<int>{2, 3}, std::vector<float>{inf, -inf})};
+ auto t2{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
+ EXPECT_NE(t2, t2) << t2;
+ auto minusInfInf{MakeArray<TypeCategory::Real, 4>(
+ std::vector<int>{2, 3}, std::vector<float>{-inf, inf})};
+ auto t3{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
+ EXPECT_NE(t3, t3) << t3;
+}
diff --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h
index 720ccaf512df..9781db9a25a6 100644
--- a/flang/include/flang/Evaluate/complex.h
+++ b/flang/include/flang/Evaluate/complex.h
@@ -77,6 +77,9 @@ public:
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Complex> Divide(const Complex &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
+ ValueWithRealFlags<Complex> KahanSummation(const Complex &,
+ Complex &correction,
+ Rounding rounding = TargetCharacteristics::defaultRounding) const;
// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
ValueWithRealFlags<Part> ABS(
diff --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h
index dcd74073a473..c0a966820d13 100644
--- a/flang/include/flang/Evaluate/real.h
+++ b/flang/include/flang/Evaluate/real.h
@@ -175,6 +175,8 @@ public:
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Real> MODULO(const Real &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
+ ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction,
+ Rounding rounding = TargetCharacteristics::defaultRounding) const;
template <typename INT> constexpr INT EXPONENT() const {
if (Exponent() == maxExponent) {
diff --git a/flang/lib/Evaluate/complex.cpp b/flang/lib/Evaluate/complex.cpp
index ab83f193e3f3..a245fb38c82b 100644
--- a/flang/lib/Evaluate/complex.cpp
+++ b/flang/lib/Evaluate/complex.cpp
@@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
return {Complex{re, im}, flags};
}
+template <typename R>
+ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation(
+ const Complex &that, Complex &correction, Rounding rounding) const {
+ RealFlags flags;
+ Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding)
+ .AccumulateFlags(flags)};
+ Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding)
+ .AccumulateFlags(flags)};
+ return {Complex{reSum, imSum}, flags};
+}
+
template <typename R> std::string Complex<R>::DumpHexadecimal() const {
std::string result{'('};
result += re_.DumpHexadecimal();
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index ae9221f9ce04..a8a24c09774e 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{product.value.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(product.value, correction)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
} else if constexpr (T::category == TypeCategory::Integer) {
auto product{aElt.MultiplySigned(bElt)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 907d01b005a0..9c591e2ef36e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -77,13 +77,8 @@ public:
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Subtract(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
+ auto sum{element.KahanSummation(square, correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
element = sum.value;
} else {
auto sum{element.Add(square, rounding_)};
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index fe897393fe13..a06836413529 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{x.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
@@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{x.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
}
}
@@ -357,14 +347,8 @@ public:
} else if constexpr (T::category == TypeCategory::Unsigned) {
element = element.AddUnsigned(array_.At(at)).value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Subtract(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
+ auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
- // correction = (sum - element) - next; algebraically zero
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
element = sum.value;
}
}
diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 6e6b9f3ac77c..eb335ce32851 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -466,6 +466,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO(
}
template <typename W, int P>
+ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation(
+ const Real &y, Real &correction, Rounding rounding) const {
+ Real next{y.Subtract(correction, rounding).value};
+ if (next.IsNotANumber()) {
+ // Avoid propagating an accidental NaN from Inf-Inf in corrections
+ correction = Real{}; // 0.
+ return Add(y, rounding);
+ } else {
+ auto sum{Add(next, rounding)};
+ // correction = (sum - *this) - next; algebraically zero
+ correction = sum.value.Subtract(*this, rounding)
+ .value.Subtract(next, rounding)
+ .value;
+ return sum;
+ }
+}
+
+template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
const Real &y, Rounding rounding) const {
ValueWithRealFlags<Real> result;
diff --git a/flang/test/Evaluate/bug89528.f90 b/flang/test/Evaluate/bug89528.f90
new file mode 100644
index 000000000000..281f35bf653c
--- /dev/null
+++ b/flang/test/Evaluate/bug89528.f90
@@ -0,0 +1,8 @@
+!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+!CHECK: REAL :: avoidkahannan = (1._4/0.)
+real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN
+!CHECK: REAL :: expectnan1 = (0._4/0.)
+real :: expectNaN1 = sum([1./0., -1./0.])
+!CHECK: REAL :: expectnan2 = (0._4/0.)
+real :: expectNaN2 = sum([-1./0., 1./0.])
+end