summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPeter Klausler <35819229+klausler@users.noreply.github.com>2024-05-01 14:06:32 -0700
committerGitHub <noreply@github.com>2024-05-01 14:06:32 -0700
commit3502d340c9276f1828da9db72f83e5e25b163b8b (patch)
treebc14977a08dcc34ca600063a4eee069da7b23da8
parenta1c12794226ffde0a84c96b9188a266eafd85fb3 (diff)
[flang] Adjust transformational folding to match runtime (#90132)
The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2 all involve summing up intermediate products into accumulators. In the constant folding library, this is done with extended precision Kahan summation for REAL and COMPLEX arguments, but in the runtime implementations it is not, and this leads to discrepancies between folded results and dynamic results. Disable the use of Kahan summation in folding to resolve these discrepancies, but don't discard the code, in case we want to add Kahan summation in the runtime for some or all of these intrinsic functions.
-rw-r--r--flang/lib/Evaluate/fold-implementation.h6
-rw-r--r--flang/lib/Evaluate/fold-matmul.h27
-rw-r--r--flang/lib/Evaluate/fold-real.cpp33
-rw-r--r--flang/lib/Evaluate/fold-reduction.h48
4 files changed, 71 insertions, 43 deletions
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 093f26bea1a4..2c0e0883207e 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -45,6 +45,12 @@
namespace Fortran::evaluate {
+// Don't use Kahan extended precision summation any more when folding
+// transformational intrinsic functions other than SUM, since it is
+// not used in the runtime implementations of those functions and we
+// want results to match.
+static constexpr bool useKahanSummation{false};
+
// Utilities
template <typename T> class Folder {
public:
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index 27b6db1fd8bf..bd61969a822c 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
- // Kahan summation
- auto product{aElt.Multiply(bElt, rounding)};
+ auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
- auto next{correction.Add(product.value, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(product.value, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(product.value)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
} else if constexpr (T::category == TypeCategory::Integer) {
+ // Don't use Kahan summation in numeric MATMUL folding;
+ // the runtime doesn't use it, and results should match.
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index fd37437c643a..4df709d3d2c2 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -54,7 +54,7 @@ public:
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
- // Kahan summation of scaled elements:
+ // Summation of scaled elements:
// Naively,
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
// For any T > 0, we have mathematically
@@ -76,24 +76,27 @@ public:
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
- auto next{square.Add(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
- overflow_ |= sum.flags.test(RealFlag::Overflow);
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
- element = sum.value;
+ if constexpr (useKahanSummation) {
+ auto next{square.Add(correction_, rounding_)};
+ overflow_ |= next.flags.test(RealFlag::Overflow);
+ auto sum{element.Add(next.value, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ correction_ = sum.value.Subtract(element, rounding_)
+ .value.Subtract(next.value, rounding_)
+ .value;
+ element = sum.value;
+ } else {
+ auto sum{element.Add(square, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ element = sum.value;
+ }
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
- // result+correction == SUM((data(:)/maxAbs)**2)
- // result = maxAbs * SQRT(result+correction)
- auto corrected{result.Add(correction_, rounding_)};
- overflow_ |= corrected.flags.test(RealFlag::Overflow);
- correction_ = Scalar<T>{};
- auto root{corrected.value.SQRT().value};
+ // incoming result = SUM((data(:)/maxAbs)**2)
+ // outgoing result = maxAbs * SQRT(result)
+ auto root{result.SQRT().value};
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= product.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index c84d35734ab5..ae17770dc296 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction{}; // Use Kahan summation for greater precision.
+ [[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
- auto next{correction.Add(x, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(x, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
}
} else if constexpr (T::category == TypeCategory::Logical) {
Expr<T> conjunctions{Fold(context,
@@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction{}; // Use Kahan summation for greater precision.
+ [[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
- auto next{correction.Add(x, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(x, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
}
}
if (overflow) {