diff options
author | Vivian <zhyuhang88@gmail.com> | 2023-12-01 16:36:03 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-01 16:36:03 -0800 |
commit | fc74db466b0d2b87d2013d5e24be137f0d8b6f0a (patch) | |
tree | 1b4eabc124775941a10cacffe7525bc58af5490a | |
parent | 005c83380a907becbf5a6b4522fc43652c9536cd (diff) |
[mlir][Linalg] Fix foldFillPackIntoFillOp to work for general cases (#74148)
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/canonicalize.mlir | 19 |
2 files changed, 21 insertions, 16 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 58af9995548e..9a4d5e8845b2 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -765,26 +765,12 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter, if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) return failure(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(fillOp); - Value packOpDest = packOp.getDest(); if (!packOpDest.hasOneUse()) return failure(); - if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { - packOpDest = tensor::PackOp::createDestinationTensor( - rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(), - packOp.getMixedTiles(), packOp.getInnerDimsPos(), - packOp.getOuterDimsPerm()); - } else { - DominanceInfo dom(fillOp); - if (!dom.properlyDominates(packOpDest, fillOp)) - return failure(); - } - Value fillDest = packOpDest; - return clone(rewriter, fillOp, packOpDest.getType(), - {fillOp.value(), fillDest}); + return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(), + packOp.getDest()); } /// Wrapper pattern that applies foldFillPackIntoFillOp method. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index e875bae47309..052dc367ca67 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> { // ----- +func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{ + %c0_i32 = arith.constant 0 : i32 + %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32> + %9 = tensor.empty() : tensor<1x1x16x64xi32> + %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32> + %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32> + %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32> + %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32> + return %pack_18 : tensor<1x1x8x4x4x8xi32> +} + +// CHECK-LABEL: func.func @fill_pack_general +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32> +// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]] +// CHECK: return %[[FILL]] + +// ----- + #map = affine_map<()[s0] -> (s0 ceildiv 16)> func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> { %cst = arith.constant 0.000000e+00 : f32 |