diff options
author | Han-Chung Wang <hanhan0912@gmail.com> | 2024-04-03 17:00:56 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-03 17:00:56 -0700 |
commit | ef5a7109116c1615a9c99c8dba6577853beb6c73 (patch) | |
tree | 709a737538e30a9ac3f7836502b471f838d70d2e | |
parent | 622851a9059694487811a7f6078312fc2cce5486 (diff) |
[mlir][vector] Skip 0D vectors in vector linearization. (#87577)
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/linearize.mlir | 10 |
2 files changed, 13 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 4fa5b8a4865b..b59e9062e5a0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { // Reject index since getElementTypeBitWidth will abort for Index types. if (!vecType || vecType.getElementType().isIndex()) return false; + // There are no dimension to fold if it is a 0-D vector. + if (vecType.getRank() == 0) + return false; unsigned trailingVecDimBitWidth = vecType.getShape().back() * vecType.getElementTypeBitWidth(); if (trailingVecDimBitWidth >= targetBitWidth) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index f0e9b3a05c06..212541c79565 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -146,6 +146,16 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x // ----- +// ALL-LABEL: func.func @test_0d_vector +func.func @test_0d_vector() -> vector<f32> { + // ALL: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<f32> + %0 = arith.constant dense<0.0> : vector<f32> + // ALL: return %[[CST]] + return %0 : vector<f32> +} + +// ----- + func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> { // expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32> |