summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-02-23 15:52:26 +0100
committerGitHub <noreply@github.com>2024-02-23 15:52:26 +0100
commit3b232f066d40a3e91ac27e421a3baeaca0cd59ec (patch)
treefb67a6143b6bfd7114b64979c4167098356347d7
parentbe083dba95dfbbb0286d798cc06fbe021715bc03 (diff)
[mlir][linalg] `LinalgOp`: Disallow mixed tensor/buffer semantics (#80660)
Related discussion: https://github.com/llvm/llvm-project/pull/73908/files#r1414913030. This change fixes #73547.
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp5
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir55
-rw-r--r--mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir40
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir10
4 files changed, 29 insertions, 81 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7eed7928456d..3627ff6617ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
+ // Mixed tensor/buffer operands are not allowed.
+ if (!linalgOp.hasPureTensorSemantics() &&
+ !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
+ return op->emitOpError("expected to have pure tensor or buffer semantics");
+
// Before checking indexing maps, we need to make sure the attributes
// referenced by it are valid.
if (linalgOp.hasDynamicIndexingMaps())
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7adde3117dee..206d7e9f1ce8 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten
// -----
// CHECK-LABEL: func @linalg_effects(
-// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
-// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>
-// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32>
-func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
+func.func @linalg_effects(
+ %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
+ %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
// CHECK-NOT: %{{.*}} = linalg.matmul
- %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+ %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: linalg.matmul
- linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%b : memref<?x?xf32>)
+ linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%f : memref<?x?xf32>)
return
}
@@ -889,11 +888,11 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
// -----
#map = affine_map<(d0) -> (d0)>
-func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
+func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]
- } ins(%arg0 : tensor<?xf32>)
+ } ins(%arg0 : memref<?xf32>)
outs(%arg1 : memref<?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32):
linalg.yield %arg2 : f32
@@ -901,14 +900,13 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
return
}
-// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
-// For now, check generic remained unchanged.
-// CHECK-LABEL: func @identity_mixed
-// CHECK-SAME: (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
+// Do not erase ops with buffer semantics.
+// CHECK-LABEL: func @identity_buffer
+// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#map, #map],
// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>)
+// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>)
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
// -----
@@ -916,12 +914,12 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
// Just make sure that we don't crash.
// CHECK-LABEL: func @dedeplicate_regression_test
-func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
+func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
%36 = linalg.generic
{indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
- ins(%1, %1 : memref<4xf32>, memref<4xf32>)
+ ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
outs(%0 : tensor<4xf32>) {
^bb0(%in: f32, %in_24: f32, %out: f32):
linalg.yield %in : f32
@@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
// -----
-#map = affine_map<(d0) -> (d0)>
-func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
- %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
- linalg.generic {
- indexing_maps = [#map, #map],
- iterator_types = ["parallel"]
- } ins(%0 : tensor<?xf32>)
- outs(%arg1 : memref<?xf32>) {
- ^bb0(%arg2 : f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- }
- return
-}
-
-// We need a mixed linalg as a bridge between tensor and memref worlds.
-// CHECK-LABEL: func @cast_producer_mixed
-// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
-// CHECK: linalg.generic {
-// CHECK-SAME: indexing_maps = [#map, #map],
-// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
-// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
-
-// -----
-
// CHECK-LABEL: dead_softmax
func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
%0 = tensor.empty() : tensor<16x64x256xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421cbab49..15a4f6cdd3bb 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1110,43 +1110,3 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]
-
-// -----
-
-// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @mixed_fusion
-func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
-{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
- %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%2 : tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %4 = arith.addf %arg3, %arg4 : f32
- linalg.yield %4 : f32
- } -> tensor<?x?xf32>
- // CHECK: linalg.generic {
- // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg8 : memref<?x?xf32>) {
- // CHECK: ^{{[a-zA-Z0-9_]*}}
- // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
- // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
- // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
- ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
- // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
- // CHECK-NOT: linalg.yield
- // CHECK: arith.mulf [[T1]], [[ARG2]]
- // CHECK: linalg.yield
- %5 = arith.mulf %arg5, %arg6 : f32
- linalg.yield %5 : f32
- }
- return
-}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 916c04f33e9c..44c81c31ace0 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
-> tensor<8x8xf32>
return %res : tensor<8x8xf32>
}
+
+// -----
+
+func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) {
+ // expected-error @+1 {{expected to have pure tensor or buffer semantics}}
+ linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
+ return
+}
+