diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-02-17 08:47:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-17 08:47:10 +0000 |
commit | 9478bf0ce625a5845139b0c9e3bb41ef88d2f005 (patch) | |
tree | 3971fdbf65d07160fe7c98df7325f9802d8942e1 | |
parent | 44436a9c6b8d92500f0d94bb2e0df9029d735de6 (diff) |
[mlir] Introduce `trailingNDimsContiguous` for MemRefs (#78247)
Extracts logic from `vector::isContiguousSlice` to check whether
the trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.
Follow-up for https://github.com/llvm/llvm-project/pull/76848.
-rw-r--r-- | mlir/include/mlir/IR/BuiltinTypes.h | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 33 | ||||
-rw-r--r-- | mlir/lib/IR/BuiltinTypes.cpp | 32 |
3 files changed, 46 insertions, 29 deletions
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 92ce053ad5c8..2361cf137123 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -518,6 +518,16 @@ bool isStrided(MemRefType t); /// stride. Also return "true" for types with no strides. bool isLastMemrefDimUnitStride(MemRefType type); +/// Return "true" if the last N dimensions of the given type are contiguous. +/// +/// Examples: +/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when +/// considering both _all_ and _only_ the trailing 3 dims, +/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when +/// considering the trailing 3 dims. +/// +bool trailingNDimsContiguous(MemRefType type, int64_t n); + } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 377f3d8c5574..cfa4a6e93a4a 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { ArrayRef<int64_t> vectorShape = vectorType.getShape(); auto vecRank = vectorType.getRank(); - // Extract the trailing dims and strides of the input memref - auto memrefShape = memrefType.getShape().take_back(vecRank); - int64_t offset; - SmallVector<int64_t> stridesFull; - if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset))) - return false; - auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank); - memrefType.getLayout().isIdentity(); - - // TODO: Add support for memref with trailing dynamic shapes. Memrefs - // with leading dynamic dimensions are already supported. - if (ShapedType::isDynamicShape(memrefShape)) + if (!trailingNDimsContiguous(memrefType, vecRank)) return false; - // Cond 1: Check whether `memrefType` is contiguous. - if (!strides.empty()) { - // Cond 1.1: A contiguous memref will always have a unit trailing stride. - if (strides.back() != 1) - return false; - - // Cond 1.2: Strides of a contiguous memref have to match the flattened - // dims. - strides = strides.drop_back(1); - SmallVector<int64_t> flattenedDims; - for (size_t i = 1; i < memrefShape.size(); i++) - flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i))); - - if (!llvm::equal(strides, llvm::reverse(flattenedDims))) - return false; - } + // Extract the trailing dims and strides of the input memref + auto memrefShape = memrefType.getShape().take_back(vecRank); - // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse). + // Compare the dims of `vectorType` against `memrefType` (in reverse). // In the most basic case, all dims will match. auto firstNonMatchingDim = std::mismatch(vectorShape.rbegin(), vectorShape.rend(), diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1794b38478a7..a2738946de41 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -967,3 +967,35 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) { auto successStrides = getStridesAndOffset(type, strides, offset); return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } + +bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) { + if (!isLastMemrefDimUnitStride(type)) + return false; + + auto memrefShape = type.getShape().take_back(n); + if (ShapedType::isDynamicShape(memrefShape)) + return false; + + if (type.getLayout().isIdentity()) + return true; + + int64_t offset; + SmallVector<int64_t> stridesFull; + if (!succeeded(getStridesAndOffset(type, stridesFull, offset))) + return false; + auto strides = ArrayRef<int64_t>(stridesFull).take_back(n); + + if (strides.empty()) + return true; + + // Check whether strides match "flattened" dims. + SmallVector<int64_t> flattenedDims; + auto dimProduct = 1; + for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { + dimProduct *= dim; + flattenedDims.push_back(dimProduct); + } + + strides = strides.drop_back(1); + return llvm::equal(strides, llvm::reverse(flattenedDims)); +} |