summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSemyon Khechnev <91785625+s-khechnev@users.noreply.github.com>2024-04-27 08:29:11 +0300
committerGitHub <noreply@github.com>2024-04-27 01:29:11 -0400
commit9145514fde484916971e6bb147c18f9235a9f2b5 (patch)
treedeeffe39f3fe762947993866af5e7a33f316dd01
parent85a9528aa1f2d54379bf972908e12ee2a6f07b4b (diff)
[mlir][arith] fix canonicalization of mulsi_extended for i1 (#90150)
There is the `MulSIExtendedRHSOne` canonicalization for arith.mulsi_extended that is defined as follows: `mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]`. In the implementation of this, there is a `IsScalarOrSplatOne` constraint for the second argument. However, this constraint does not correctly handle situation when multiplying i1 values. Therefore, an additional constraint has been added which checks the second argument for strict positivity. fix #88732
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td1
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir22
2 files changed, 23 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index caca2ff81964..02d05780a7ac 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -175,6 +175,7 @@ def MulSIExtendedToMulI :
def IsScalarOrSplatOne :
Constraint<And<[
CPred<"succeeded(getIntOrSplatIntValue($0))">,
+ CPred<"getIntOrSplatIntValue($0)->isStrictlyPositive()">,
CPred<"*getIntOrSplatIntValue($0) == 1">]>>;
// mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 79a318565e98..6c4193bc06ca 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1223,6 +1223,28 @@ func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vec
return %low, %high : vector<3xi32>, vector<3xi32>
}
+// CHECK-LABEL: @mulsiExtendedOneRhsI1
+// CHECK-SAME: (%[[ARG:.+]]: i1) -> (i1, i1)
+// CHECK-NEXT: %[[T:.+]] = arith.constant true
+// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[T]] : i1
+// CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i1, i1
+func.func @mulsiExtendedOneRhsI1(%arg0: i1) -> (i1, i1) {
+ %one = arith.constant true
+ %low, %high = arith.mulsi_extended %arg0, %one: i1
+ return %low, %high : i1, i1
+}
+
+// CHECK-LABEL: @mulsiExtendedOneRhsSplatI1
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>)
+// CHECK-NEXT: %[[TS:.+]] = arith.constant dense<true> : vector<3xi1>
+// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[TS]] : vector<3xi1>
+// CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<3xi1>, vector<3xi1>
+func.func @mulsiExtendedOneRhsSplatI1(%arg0: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+ %one = arith.constant dense<true> : vector<3xi1>
+ %low, %high = arith.mulsi_extended %arg0, %one: vector<3xi1>
+ return %low, %high : vector<3xi1>, vector<3xi1>
+}
+
// CHECK-LABEL: @mulsiExtendedUnusedHigh
// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32
// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32