diff options
author | Felix Schneider <fx.schn@gmail.com> | 2024-02-18 10:17:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-18 10:17:03 +0100 |
commit | 1a8c6130f60fe517fb722ab4309997ce7b638234 (patch) | |
tree | 798e322f183da2efe1399a3c87209ea63572cdfe | |
parent | 833fea40d22ff7265a8331c88a7dc3b32a84c6a8 (diff) |
[mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133)
This patch aligns the shift Ops in `arith` with respective LLVM instructions.
Specifically, shifting by an amount equal to the bitwidth of the operand
is now defined to return poison.
Relevant discussion:
https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10
Relevant issue: https://github.com/llvm/llvm-project/issues/80960
-rw-r--r-- | mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 12 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/canonicalize.mlir | 33 |
3 files changed, 44 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 4babbe80e285..c9df50d0395d 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> { The `shli` operation shifts the integer value of the first operand to the left by the integer value of the second operand. The second operand is interpreted as unsigned. The low order bits are filled with zeros. If the value of the second - operand is greater than the bitwidth of the first operand, then the + operand is greater or equal than the bitwidth of the first operand, then the operation returns poison. This op supports `nuw`/`nsw` overflow flags which stands stand for @@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { The `shrui` operation shifts an integer value of the first operand to the right by the value of the second operand. The first operand is interpreted as unsigned, and the second operand is interpreted as unsigned. The high order bits are always - filled with zeros. If the value of the second operand is greater than the bitwidth - of the first operand, then the operation returns poison. + filled with zeros. If the value of the second operand is greater or equal than the + bitwidth of the first operand, then the operation returns poison. Example: @@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { and the second operand is interpreter as unsigned. The high order bits in the output are filled with copies of the most-significant bit of the shifted value (which means that the sign of the value is preserved). If the value of the second - operand is greater than bitwidth of the first operand, then the operation returns - poison. + operand is greater or equal than bitwidth of the first operand, then the operation + returns poison. Example: diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 275c2debe9a6..0f71c19c23b6 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { // shli(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp<IntegerAttr>( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.shl(b); }); return bounded ? result : Attribute(); @@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { // shrui(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp<IntegerAttr>( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.lshr(b); }); return bounded ? result : Attribute(); @@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { // shrsi(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp<IntegerAttr>( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.ashr(b); }); return bounded ? result : Attribute(); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f128b13e9f73..cb98a10048a3 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShl3( +// CHECK: %[[res:.+]] = arith.shli +// CHECK: return %[[res]] +func.func @nofoldShl3() -> i64 { + %c1 = arith.constant 1 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shli %c1, %c64 : i64 + return %r : i64 +} + // CHECK-LABEL: @foldShru( // CHECK: %[[res:.+]] = arith.constant 2 : i64 // CHECK: return %[[res]] @@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShru3( +// CHECK: %[[res:.+]] = arith.shrui +// CHECK: return %[[res]] +func.func @nofoldShru3() -> i64 { + %c1 = arith.constant 8 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shrui %c1, %c64 : i64 + return %r : i64 +} + // CHECK-LABEL: @foldShrs( // CHECK: %[[res:.+]] = arith.constant 2 : i64 // CHECK: return %[[res]] @@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShrs3( +// CHECK: %[[res:.+]] = arith.shrsi +// CHECK: return %[[res]] +func.func @nofoldShrs3() -> i64 { + %c1 = arith.constant 8 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shrsi %c1, %c64 : i64 + return %r : i64 +} + // ----- // CHECK-LABEL: @test_negf( |