summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRafael Ubal <rubal@mathworks.com>2024-04-03 13:49:55 -0400
committerGitHub <noreply@github.com>2024-04-03 13:49:55 -0400
commitfbcd0c65f7b2f65e0ee58e5448b88af39faf10f1 (patch)
tree8687e944da4a0c0aa4ac632c4e8e819070ed6835
parentcd29126b6333c28cc4df7b932ed0d6d6c13983d1 (diff)
Updates to 'tosa.reshape' verifier (#87416)
This addition catches common cases of malformed `tosa.reshape` ops. This prevents the `--tosa-to-tensor` pass from asserting when fed invalid operations, as these will be caught ahead of time by the verifier. Closes #87396
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp17
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir58
2 files changed, 58 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6e6e84350738..e06ac9a27ae4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -955,25 +955,34 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
mlir::LogicalResult tosa::ReshapeOp::verify() {
- ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
- ShapedType outputType = llvm::cast<ShapedType>(getType());
+ TensorType inputType = getInput1().getType();
+ RankedTensorType outputType = getType();
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";
+ if ((int64_t) getNewShape().size() != outputType.getRank())
+ return emitOpError() << "new shape does not match result rank";
+
+ for (auto [newShapeDim, outputShapeDim] :
+ zip(getNewShape(), outputType.getShape()))
+ if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
+ newShapeDim != outputShapeDim)
+ return emitOpError() << "new shape is inconsistent with result shape";
+
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
int64_t outputElementsNum = outputType.getNumElements();
if (inputElementsNum != outputElementsNum) {
- return emitOpError() << "Cannot reshape " << inputElementsNum
+ return emitOpError() << "cannot reshape " << inputElementsNum
<< " elements into " << outputElementsNum;
}
}
int missingDims = llvm::count(getNewShape(), -1);
if (missingDims > 1)
- return emitOpError() << "At most one target dimension can be -1";
+ return emitOpError() << "expected at most one target dimension to be -1";
return mlir::success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 38ba48f365ea..730ac41dd7a8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -243,38 +243,70 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// -----
-func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
- // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
- %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
return
}
// -----
-func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
- // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
- %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
- return %0 : tensor<100x100xf32>
+func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
+ return
}
// -----
-func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
- // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
+func.func @test_reshape_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op new shape does not match result rank}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4>} : (tensor<?xf32>) -> tensor<?xf32>
return
}
// -----
-func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
- // expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
+func.func @test_reshape_inconsistent_result_type(%arg0 : tensor<?xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op new shape is inconsistent with result shape}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4, -1>} : (tensor<?xf32>) -> tensor<?x3x5xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 15}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 3, 5>} : (tensor<2x4xf32>) -> tensor<3x5xf32>
+ return
+}
+
+// -----
+
+func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
+ // expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
return
}
// -----
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+ // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
+func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
+ // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
+ %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
+ return %0 : tensor<100x100xf32>
+}
+
+// -----
+
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
// expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}