summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHan-Chung Wang <hanhan0912@gmail.com>2024-04-03 17:00:56 -0700
committerGitHub <noreply@github.com>2024-04-03 17:00:56 -0700
commitef5a7109116c1615a9c99c8dba6577853beb6c73 (patch)
tree709a737538e30a9ac3f7836502b471f838d70d2e
parent622851a9059694487811a7f6078312fc2cce5486 (diff)
[mlir][vector] Skip 0D vectors in vector linearization. (#87577)
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp3
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir10
2 files changed, 13 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4fa5b8a4865b..b59e9062e5a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
if (trailingVecDimBitWidth >= targetBitWidth)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index f0e9b3a05c06..212541c79565 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -146,6 +146,16 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
// -----
+// ALL-LABEL: func.func @test_0d_vector
+func.func @test_0d_vector() -> vector<f32> {
+ // ALL: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
+ %0 = arith.constant dense<0.0> : vector<f32>
+ // ALL: return %[[CST]]
+ return %0 : vector<f32>
+}
+
+// -----
+
func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
// expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>