summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp48
1 files changed, 8 insertions, 40 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a2d4e2166331..6f6b6dcdad20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1255,42 +1255,6 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
return result;
}
-/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
-/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
-/// two, it returns memref<512x16x16> type.
-static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
- MemRefType srcType,
- size_t dimsToDrop) {
- MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
- return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
- srcType.getElementType(), nullptr,
- srcType.getMemorySpace());
- }
- MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
- auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
- updatedLayout = StridedLayoutAttr::get(strided.getContext(),
- strided.getOffset(), strides);
- return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
- srcType.getElementType(), updatedLayout,
- srcType.getMemorySpace());
- }
-
- // Non-strided layout case.
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(builder.getAffineDimExpr(dim),
- builder.getAffineConstantExpr(0), map.getNumDims() - 1,
- numSymbols);
- }
- return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
- srcType.getElementType(), updatedLayout,
- srcType.getMemorySpace());
-}
-
/// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDimsTransferRead
: public OpRewritePattern<vector::TransferReadOp> {
@@ -1337,8 +1301,10 @@ class DropInnerMostUnitDimsTransferRead
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- MemRefType resultMemrefType =
- getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+ auto resultMemrefType =
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides));
ArrayAttr inBoundsAttr =
readOp.getInBounds()
? rewriter.getArrayAttr(
@@ -1421,8 +1387,10 @@ class DropInnerMostUnitDimsTransferWrite
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- MemRefType resultMemrefType =
- getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+ auto resultMemrefType =
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides));
ArrayAttr inBoundsAttr =
writeOp.getInBounds()
? rewriter.getArrayAttr(