diff options
author | Jakub Kuderski <jakub@nod-labs.com> | 2024-01-31 20:28:01 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-31 20:28:01 -0500 |
commit | 730f498c961f29691a605028f9b1cd6d9e232460 (patch) | |
tree | 68452154766530bca5cdfda3a1fcfe6d62c919b6 | |
parent | 8ba018d72a67050a9c37065ea2668814ebf513a9 (diff) |
[mlir][arith] Improve `truncf` folding (#80206)
* Use APFloat conversion function instead of going through double to
check if fold results in information loss.
* Support folding vector constants.
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 32 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/canonicalize.mlir | 9 |
2 files changed, 25 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ff72becc8dfa..270df3f3f9e9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -22,8 +22,10 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() { // TruncFOp //===----------------------------------------------------------------------===// -/// Perform safe const propagation for truncf, i.e. only propagate if FP value -/// can be represented without precision loss or rounding. +/// Perform safe const propagation for truncf, i.e., only propagate if FP value +/// can be represented without precision loss or rounding. This is because the +/// semantics of `arith.truncf` do not assume a specific rounding mode. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { - auto constOperand = adaptor.getIn(); - if (!constOperand || !llvm::isa<FloatAttr>(constOperand)) - return {}; - - // Convert to target type via 'double'. - double sourceValue = - llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble(); - auto targetAttr = FloatAttr::get(getType(), sourceValue); - - // Propagate if constant's value does not change after truncation. - if (sourceValue == targetAttr.getValue().convertToDouble()) - return targetAttr; - - return {}; + auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); + return constFoldCastOp<FloatAttr, FloatAttr>( + adaptor.getOperands(), getType(), + [&targetSemantics](APFloat a, bool &castStatus) { + bool losesInfo = false; + auto status = a.convert( + targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + castStatus = !losesInfo && status == APFloat::opOK; + return a; + }); } bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 10050d87d756..44df11ab2433 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -825,6 +825,15 @@ func.func @truncFPConstant() -> bf16 { return %0 : bf16 } +// CHECK-LABEL: @truncFPVectorConstant +// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16> +// CHECK: return %[[cres]] +func.func @truncFPVectorConstant() -> vector<2xbf16> { + %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf32> + %0 = arith.truncf %cst : vector<2xf32> to vector<2xbf16> + return %0 : vector<2xbf16> +} + // Test that cases with rounding are NOT propagated // CHECK-LABEL: @truncFPConstantRounding // CHECK: arith.constant 1.444000e+25 : f32 |