summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFelix Schneider <fx.schn@gmail.com>2024-02-18 10:17:03 +0100
committerGitHub <noreply@github.com>2024-02-18 10:17:03 +0100
commit1a8c6130f60fe517fb722ab4309997ce7b638234 (patch)
tree798e322f183da2efe1399a3c87209ea63572cdfe
parent833fea40d22ff7265a8331c88a7dc3b32a84c6a8 (diff)
[mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133)
This patch aligns the shift Ops in `arith` with respective LLVM instructions. Specifically, shifting by an amount equal to the bitwidth of the operand is now defined to return poison. Relevant discussion: https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10 Relevant issue: https://github.com/llvm/llvm-project/issues/80960
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOps.td10
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp12
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir33
3 files changed, 44 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4babbe80e285..c9df50d0395d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
The `shli` operation shifts the integer value of the first operand to the left
by the integer value of the second operand. The second operand is interpreted as
unsigned. The low order bits are filled with zeros. If the value of the second
- operand is greater than the bitwidth of the first operand, then the
+ operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.
This op supports `nuw`/`nsw` overflow flags which stands stand for
@@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
The `shrui` operation shifts an integer value of the first operand to the right
by the value of the second operand. The first operand is interpreted as unsigned,
and the second operand is interpreted as unsigned. The high order bits are always
- filled with zeros. If the value of the second operand is greater than the bitwidth
- of the first operand, then the operation returns poison.
+ filled with zeros. If the value of the second operand is greater or equal than the
+ bitwidth of the first operand, then the operation returns poison.
Example:
@@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
and the second operand is interpreter as unsigned. The high order bits in the
output are filled with copies of the most-significant bit of the shifted value
(which means that the sign of the value is preserved). If the value of the second
- operand is greater than bitwidth of the first operand, then the operation returns
- poison.
+ operand is greater or equal than bitwidth of the first operand, then the operation
+ returns poison.
Example:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 275c2debe9a6..0f71c19c23b6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
// shli(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
- // Don't fold if shifting more than the bit width.
+ // Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
- bounded = b.ule(b.getBitWidth());
+ bounded = b.ult(b.getBitWidth());
return a.shl(b);
});
return bounded ? result : Attribute();
@@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
// shrui(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
- // Don't fold if shifting more than the bit width.
+ // Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
- bounded = b.ule(b.getBitWidth());
+ bounded = b.ult(b.getBitWidth());
return a.lshr(b);
});
return bounded ? result : Attribute();
@@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
// shrsi(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
- // Don't fold if shifting more than the bit width.
+ // Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
- bounded = b.ule(b.getBitWidth());
+ bounded = b.ult(b.getBitWidth());
return a.ashr(b);
});
return bounded ? result : Attribute();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f128b13e9f73..cb98a10048a3 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
return %r : i64
}
+// CHECK-LABEL: @nofoldShl3(
+// CHECK: %[[res:.+]] = arith.shli
+// CHECK: return %[[res]]
+func.func @nofoldShl3() -> i64 {
+ %c1 = arith.constant 1 : i64
+ %c64 = arith.constant 64 : i64
+ // Note that this should return Poison in the future.
+ %r = arith.shli %c1, %c64 : i64
+ return %r : i64
+}
+
// CHECK-LABEL: @foldShru(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
@@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
return %r : i64
}
+// CHECK-LABEL: @nofoldShru3(
+// CHECK: %[[res:.+]] = arith.shrui
+// CHECK: return %[[res]]
+func.func @nofoldShru3() -> i64 {
+ %c1 = arith.constant 8 : i64
+ %c64 = arith.constant 64 : i64
+ // Note that this should return Poison in the future.
+ %r = arith.shrui %c1, %c64 : i64
+ return %r : i64
+}
+
// CHECK-LABEL: @foldShrs(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
@@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
return %r : i64
}
+// CHECK-LABEL: @nofoldShrs3(
+// CHECK: %[[res:.+]] = arith.shrsi
+// CHECK: return %[[res]]
+func.func @nofoldShrs3() -> i64 {
+ %c1 = arith.constant 8 : i64
+ %c64 = arith.constant 64 : i64
+ // Note that this should return Poison in the future.
+ %r = arith.shrsi %c1, %c64 : i64
+ return %r : i64
+}
+
// -----
// CHECK-LABEL: @test_negf(