diff options
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 48 |
1 files changed, 8 insertions, 40 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index a2d4e2166331..6f6b6dcdad20 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1255,42 +1255,6 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { return result; } -/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from -/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is -/// two, it returns memref<512x16x16> type. -static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder, - MemRefType srcType, - size_t dimsToDrop) { - MemRefLayoutAttrInterface layout = srcType.getLayout(); - if (isa<AffineMapAttr>(layout) && layout.isIdentity()) { - return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), - srcType.getElementType(), nullptr, - srcType.getMemorySpace()); - } - MemRefLayoutAttrInterface updatedLayout; - if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) { - auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); - updatedLayout = StridedLayoutAttr::get(strided.getContext(), - strided.getOffset(), strides); - return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), - srcType.getElementType(), updatedLayout, - srcType.getMemorySpace()); - } - - // Non-strided layout case. - AffineMap map = srcType.getLayout().getAffineMap(); - int numSymbols = map.getNumSymbols(); - for (size_t i = 0; i < dimsToDrop; ++i) { - int dim = srcType.getRank() - i - 1; - map = map.replace(builder.getAffineDimExpr(dim), - builder.getAffineConstantExpr(0), map.getNumDims() - 1, - numSymbols); - } - return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), - srcType.getElementType(), updatedLayout, - srcType.getMemorySpace()); -} - /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { @@ -1337,8 +1301,10 @@ class DropInnerMostUnitDimsTransferRead rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(srcType.getRank(), rewriter.getIndexAttr(1)); - MemRefType resultMemrefType = - getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); + auto resultMemrefType = + cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( + srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, + strides)); ArrayAttr inBoundsAttr = readOp.getInBounds() ? rewriter.getArrayAttr( @@ -1421,8 +1387,10 @@ class DropInnerMostUnitDimsTransferWrite rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(srcType.getRank(), rewriter.getIndexAttr(1)); - MemRefType resultMemrefType = - getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); + auto resultMemrefType = + cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( + srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, + strides)); ArrayAttr inBoundsAttr = writeOp.getInBounds() ? rewriter.getArrayAttr( |