diff options
author | Diego Caballero <diegocaballero@google.com> | 2024-02-14 11:38:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-14 11:38:52 -0800 |
commit | d592c8ec8f7138dcbde6f0890d048e59cba95041 (patch) | |
tree | c1637af86795b5156d08beeeacda8f029da5778a | |
parent | 9b80ab4332bbe336ab8b9f2082eadf6b8d223150 (diff) |
Reapply "[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes." (#80712) (#81778)
This reverts commit b4c7152eb4f7971c111e3e2f60b55892def58d5d.
Downstream regression due to another issue that this PR exposes. We have identified the work-items to fix the new issue here: https://github.com/openxla/iree/issues/16406
Co-authored-by: Han-Chung Wang <hanchung@google.com>
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 14 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir | 19 |
2 files changed, 27 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 53ae138d1e43..74dd1dfaca0d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1237,7 +1237,7 @@ class DropInnerMostUnitDimsTransferRead return failure(); auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType()); - if (!srcType || !srcType.hasStaticShape()) + if (!srcType) return failure(); if (!readOp.getPermutationMap().isMinorIdentity()) @@ -1261,19 +1261,21 @@ class DropInnerMostUnitDimsTransferRead targetType.getElementType()); auto loc = readOp.getLoc(); + SmallVector<OpFoldResult> sizes = + memref::getMixedSizes(rewriter, loc, readOp.getSource()); + SmallVector<OpFoldResult> offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector<OpFoldResult> strides(srcType.getRank(), + rewriter.getIndexAttr(1)); MemRefType resultMemrefType = getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); - SmallVector<int64_t> offsets(srcType.getRank(), 0); - SmallVector<int64_t> strides(srcType.getRank(), 1); - ArrayAttr inBoundsAttr = readOp.getInBounds() ? rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); Value rankedReducedView = rewriter.create<memref::SubViewOp>( - loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), - strides); + loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); Value result = rewriter.create<vector::TransferReadOp>( diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index 750879df129b..3984f17f9e8c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, // ----- +func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> + return %0 : vector<1x8x1xf32> +} +// CHECK: func @contiguous_outer_dyn_inner_most_view( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]] +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1] +// CHECK-SAME: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>> +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] +// CHECK-SAME: memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 |