summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVivian <zhyuhang88@gmail.com>2023-12-01 16:36:03 -0800
committerGitHub <noreply@github.com>2023-12-01 16:36:03 -0800
commitfc74db466b0d2b87d2013d5e24be137f0d8b6f0a (patch)
tree1b4eabc124775941a10cacffe7525bc58af5490a
parent005c83380a907becbf5a6b4522fc43652c9536cd (diff)
[mlir][Linalg] Fix foldFillPackIntoFillOp to work for general cases (#74148)
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp18
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir19
2 files changed, 21 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e..9a4d5e8845b2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -765,26 +765,12 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
return failure();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(fillOp);
-
Value packOpDest = packOp.getDest();
if (!packOpDest.hasOneUse())
return failure();
- if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
- packOpDest = tensor::PackOp::createDestinationTensor(
- rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
- packOp.getMixedTiles(), packOp.getInnerDimsPos(),
- packOp.getOuterDimsPerm());
- } else {
- DominanceInfo dom(fillOp);
- if (!dom.properlyDominates(packOpDest, fillOp))
- return failure();
- }
- Value fillDest = packOpDest;
- return clone(rewriter, fillOp, packOpDest.getType(),
- {fillOp.value(), fillDest});
+ return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
+ packOp.getDest());
}
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e875bae47309..052dc367ca67 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> {
// -----
+func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
+ %c0_i32 = arith.constant 0 : i32
+ %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32>
+ %9 = tensor.empty() : tensor<1x1x16x64xi32>
+ %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32>
+ %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32>
+ %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32>
+ %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32>
+ return %pack_18 : tensor<1x1x8x4x4x8xi32>
+}
+
+// CHECK-LABEL: func.func @fill_pack_general
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32>
+// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]]
+// CHECK: return %[[FILL]]
+
+// -----
+
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
%cst = arith.constant 0.000000e+00 : f32