diff options
author | Peter Klausler <35819229+klausler@users.noreply.github.com> | 2024-05-01 14:06:32 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-01 14:06:32 -0700 |
commit | 3502d340c9276f1828da9db72f83e5e25b163b8b (patch) | |
tree | bc14977a08dcc34ca600063a4eee069da7b23da8 | |
parent | a1c12794226ffde0a84c96b9188a266eafd85fb3 (diff) |
[flang] Adjust transformational folding to match runtime (#90132)
The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2
all involve summing up intermediate products into accumulators. In the
constant folding library, this is done with extended precision Kahan
summation for REAL and COMPLEX arguments, but in the runtime
implementations it is not, and this leads to discrepancies between
folded results and dynamic results.
Disable the use of Kahan summation in folding to resolve these
discrepancies, but don't discard the code, in case we want to add Kahan
summation in the runtime for some or all of these intrinsic functions.
-rw-r--r-- | flang/lib/Evaluate/fold-implementation.h | 6 | ||||
-rw-r--r-- | flang/lib/Evaluate/fold-matmul.h | 27 | ||||
-rw-r--r-- | flang/lib/Evaluate/fold-real.cpp | 33 | ||||
-rw-r--r-- | flang/lib/Evaluate/fold-reduction.h | 48 |
4 files changed, 71 insertions, 43 deletions
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 093f26bea1a4..2c0e0883207e 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -45,6 +45,12 @@ namespace Fortran::evaluate { +// Don't use Kahan extended precision summation any more when folding +// transformational intrinsic functions other than SUM, since it is +// not used in the runtime implementations of those functions and we +// want results to match. +static constexpr bool useKahanSummation{false}; + // Utilities template <typename T> class Folder { public: diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h index 27b6db1fd8bf..bd61969a822c 100644 --- a/flang/lib/Evaluate/fold-matmul.h +++ b/flang/lib/Evaluate/fold-matmul.h @@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { Element bElt{mb->At(bAt)}; if constexpr (T::category == TypeCategory::Real || T::category == TypeCategory::Complex) { - // Kahan summation - auto product{aElt.Multiply(bElt, rounding)}; + auto product{aElt.Multiply(bElt)}; overflow |= product.flags.test(RealFlag::Overflow); - auto next{correction.Add(product.value, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(product.value, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(product.value)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } else if constexpr (T::category == TypeCategory::Integer) { + // Don't use Kahan summation in numeric MATMUL folding; + // the runtime doesn't use it, and results should match. auto product{aElt.MultiplySigned(bElt)}; overflow |= product.SignedMultiplicationOverflowed(); auto added{sum.AddSigned(product.lower)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index fd37437c643a..4df709d3d2c2 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -54,7 +54,7 @@ public: : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; void operator()( Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { - // Kahan summation of scaled elements: + // Summation of scaled elements: // Naively, // NORM2(A(:)) = SQRT(SUM(A(:)**2)) // For any T > 0, we have mathematically @@ -76,24 +76,27 @@ public: auto item{array_.At(at)}; auto scaled{item.Divide(scale).value}; auto square{scaled.Multiply(scaled).value}; - auto next{square.Add(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; - overflow_ |= sum.flags.test(RealFlag::Overflow); - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; - element = sum.value; + if constexpr (useKahanSummation) { + auto next{square.Add(correction_, rounding_)}; + overflow_ |= next.flags.test(RealFlag::Overflow); + auto sum{element.Add(next.value, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + correction_ = sum.value.Subtract(element, rounding_) + .value.Subtract(next.value, rounding_) + .value; + element = sum.value; + } else { + auto sum{element.Add(square, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + element = sum.value; + } } } bool overflow() const { return overflow_; } void Done(Scalar<T> &result) { - // result+correction == SUM((data(:)/maxAbs)**2) - // result = maxAbs * SQRT(result+correction) - auto corrected{result.Add(correction_, rounding_)}; - overflow_ |= corrected.flags.test(RealFlag::Overflow); - correction_ = Scalar<T>{}; - auto root{corrected.value.SQRT().value}; + // incoming result = SUM((data(:)/maxAbs)**2) + // outgoing result = maxAbs * SQRT(result) + auto root{result.SQRT().value}; auto product{root.Multiply(maxAbs_.At(maxAbsAt_))}; maxAbs_.IncrementSubscripts(maxAbsAt_); overflow_ |= product.flags.test(RealFlag::Overflow); diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index c84d35734ab5..ae17770dc296 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct( Expr<T> products{Fold( context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; - Element correction{}; // Use Kahan summation for greater precision. + [[maybe_unused]] Element correction{}; const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { - auto next{correction.Add(x, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(x, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } } else if constexpr (T::category == TypeCategory::Logical) { Expr<T> conjunctions{Fold(context, @@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct( Expr<T> products{ Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; - Element correction{}; // Use Kahan summation for greater precision. + [[maybe_unused]] Element correction{}; const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { - auto next{correction.Add(x, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(x, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } } if (overflow) { |