summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorQuinn Dawkins <quinn.dawkins@gmail.com>2023-12-01 15:05:29 -0500
committerGitHub <noreply@github.com>2023-12-01 15:05:29 -0500
commitf310a5d2c13455f1d68f5654fa4258357bafeff6 (patch)
treefacec77faa031c5aad059607d2d9d46d899e1a8e
parent4c44dcffd5f1557bde2c21773221081437308895 (diff)
[mlir][tensor] Add a tensor.concat operation (#72779)
This adds an operation for concatenating ranked tensors along a static dimension, as well as a decomposition mirroring the existing lowering from TOSA to Tensor. This offers a convergence point for "input" like dialects that include various lowerings for concatenation operations, easing later analysis. In the future, this op can implement the necessary interfaces for tiling, as well as potentially add conversions to some kind of linalg and/or memref counterpart. This patch adds the op, the decomposition, and some basic folding/canonicalization. Replacing lowerings with the op (such as the TOSA lowering) will come as a follow up. See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
-rw-r--r--mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td64
-rw-r--r--mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td12
-rw-r--r--mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h7
-rw-r--r--mlir/include/mlir/Dialect/Utils/StaticValueUtils.h33
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp76
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp186
-rw-r--r--mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp93
-rw-r--r--mlir/test/Dialect/Tensor/canonicalize.mlir12
-rw-r--r--mlir/test/Dialect/Tensor/decompose-concat.mlir57
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir48
-rw-r--r--mlir/test/Dialect/Tensor/ops.mlir17
13 files changed, 554 insertions, 58 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7ae27407a952..f50e3464867b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -122,6 +122,70 @@ def Tensor_CastOp : Tensor_Op<"cast", [
}
//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ConcatOp : Tensor_Op<"concat",
+ [Pure,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ let summary = "tensor concatenation operation";
+ let description = [{
+ The "concat" operation constructs a tensor out of a variadic list of input
+ tensors, concatenated along a static dimension number. All inputs and the
+ result type must share the same rank.
+
+ `dim` specifies the dimension along which to concatenate. The size of the
+ concatenated dimension in the result must be equal to the sum of the sizes
+ of the inputs along that dimension. All other dimensions in both the inputs
+ and result must be the same size.
+
+ Example:
+
+ ```mlir
+ %0 = tensor.concat dim(0) %0, %1, %2 :
+ (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
+
+ // Dynamic + dynamic -> static
+ %0 = tensor.concat dim(1) %0, %1, %2 :
+ (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
+ ```
+ }];
+ let arguments = (ins I64Attr:$dim,
+ Variadic<AnyRankedTensor>:$inputs);
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ `dim` `(` $dim `)` $inputs attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let builders = [
+ // Builder with an inferred result type.
+ OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
+ ];
+
+ let extraClassDeclaration = [{
+ // Helper to infer the concatenated result type for the given list of input
+ // types, being concatenated along `dim`. Because concatenation can specify
+ // more static information than can automatically be inferred,
+ // InferTypeOpInterface is not used.
+ static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
+
+ RankedTensorType getResultType() {
+ return ::llvm::cast<RankedTensorType>(getResult().getType());
+ }
+
+ int64_t getRank() {
+ return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
+ }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 66c6021418b4..8556d9570fd1 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.tensor.decompose_concat",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that tensor.concat ops should be decomposed into a chain of
+ tensor.insert_slice operations inserting into a materialized destination.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 705b30e7ded4..44b8377bd6aa 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,13 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
bool foldSingleUseOnly = false);
+/// Populates `patterns` with patterns that decompose `tensor.concat` into
+/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
+/// of `tensor.insert_slice` operations on the inputs. This is intended to be
+/// used as a fallback tensor -> tensor lowering that decomposes concat such
+/// that it can be bufferized into a sequence of copies.
+void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index c2fbaea726ab..502ab93ddbfa 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);
+/// Idiomatic saturated operations on values like offsets, sizes, and strides.
+struct SaturatedInteger {
+ static SaturatedInteger wrap(int64_t v) {
+ return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
+ : SaturatedInteger{false, v};
+ }
+ int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
+ FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
+ if (saturated && !other.saturated)
+ return other;
+ if (!saturated && !other.saturated && v != other.v)
+ return failure();
+ return *this;
+ }
+ bool operator==(SaturatedInteger other) {
+ return (saturated && other.saturated) ||
+ (!saturated && !other.saturated && v == other.v);
+ }
+ bool operator!=(SaturatedInteger other) { return !(*this == other); }
+ SaturatedInteger operator+(SaturatedInteger other) {
+ if (saturated || other.saturated)
+ return SaturatedInteger{true, 0};
+ return SaturatedInteger{false, other.v + v};
+ }
+ SaturatedInteger operator*(SaturatedInteger other) {
+ if (saturated || other.saturated)
+ return SaturatedInteger{true, 0};
+ return SaturatedInteger{false, other.v * v};
+ }
+ bool saturated = true;
+ int64_t v = 0;
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07f..dce96cca016f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -26,43 +26,6 @@
using namespace mlir;
using namespace mlir::memref;
-namespace {
-/// Idiomatic saturated operations on offsets, sizes and strides.
-namespace saturated_arith {
-struct Wrapper {
- static Wrapper stride(int64_t v) {
- return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
- }
- static Wrapper offset(int64_t v) {
- return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
- }
- static Wrapper size(int64_t v) {
- return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
- }
- int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
- int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
- int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
- bool operator==(Wrapper other) {
- return (saturated && other.saturated) ||
- (!saturated && !other.saturated && v == other.v);
- }
- bool operator!=(Wrapper other) { return !(*this == other); }
- Wrapper operator+(Wrapper other) {
- if (saturated || other.saturated)
- return Wrapper{true, 0};
- return Wrapper{false, other.v + v};
- }
- Wrapper operator*(Wrapper other) {
- if (saturated || other.saturated)
- return Wrapper{true, 0};
- return Wrapper{false, other.v * v};
- }
- bool saturated;
- int64_t v;
-};
-} // namespace saturated_arith
-} // namespace
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
ReassociationIndices reassoc = std::get<0>(it);
int64_t currentStrideToExpand = std::get<1>(it);
for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
- using saturated_arith::Wrapper;
reverseResultStrides.push_back(currentStrideToExpand);
- currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
- Wrapper::size(resultShape[shapeIndex--]))
- .asStride();
+ currentStrideToExpand =
+ (SaturatedInteger::wrap(currentStrideToExpand) *
+ SaturatedInteger::wrap(resultShape[shapeIndex--]))
+ .asInteger();
}
}
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
unsigned resultStrideIndex = resultStrides.size() - 1;
for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
- using saturated_arith::Wrapper;
- auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
+ auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
- stride = stride * Wrapper::size(srcShape[idx]);
+ stride = stride * SaturatedInteger::wrap(srcShape[idx]);
// Both source and result stride must have the same static value. In that
// case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
// ops where obviously non-contiguous dims are collapsed, but accept ops
// where we cannot be sure statically. Such ops may fail at runtime. See
// the op documentation for details.
- auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
+ auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
if (strict && (stride.saturated || srcStride.saturated))
return failure();
@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
SmallVector<int64_t> resultShape;
resultShape.reserve(reassociation.size());
for (const ReassociationIndices &group : reassociation) {
- using saturated_arith::Wrapper;
- auto groupSize = Wrapper::size(1);
+ auto groupSize = SaturatedInteger::wrap(1);
for (int64_t srcDim : group)
- groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
- resultShape.push_back(groupSize.asSize());
+ groupSize =
+ groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
+ resultShape.push_back(groupSize.asInteger());
}
if (srcType.getLayout().isIdentity()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
int64_t targetOffset = sourceOffset;
for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
- using saturated_arith::Wrapper;
- targetOffset =
- (Wrapper::offset(targetOffset) +
- Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
- .asOffset();
+ targetOffset = (SaturatedInteger::wrap(targetOffset) +
+ SaturatedInteger::wrap(staticOffset) *
+ SaturatedInteger::wrap(targetStride))
+ .asInteger();
}
// Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
targetStrides.reserve(staticOffsets.size());
for (auto it : llvm::zip(sourceStrides, staticStrides)) {
auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
- using saturated_arith::Wrapper;
- targetStrides.push_back(
- (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
- .asStride());
+ targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
+ SaturatedInteger::wrap(staticStride))
+ .asInteger());
}
// The type is now known.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd9b82d2c553..02146e8257b3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -473,6 +473,192 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
+ assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
+ auto tensorTypes =
+ llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ }));
+ int64_t concatRank = tensorTypes[0].getRank();
+
+ // The concatenation dim must be in the range [0, rank).
+ assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
+
+ SmallVector<int64_t> sizes(concatRank);
+ for (int64_t i = 0, e = concatRank; i < e; ++i) {
+ if (i == dim)
+ continue;
+ SaturatedInteger size;
+ for (auto tensorType : tensorTypes)
+ size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+ sizes[i] = size.asInteger();
+ }
+ auto concatSize = SaturatedInteger::wrap(0);
+ for (auto tensorType : tensorTypes)
+ concatSize =
+ concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+ sizes[dim] = concatSize.asInteger();
+ return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+ ValueRange inputs) {
+ FailureOr<RankedTensorType> resultType =
+ inferResultType(dim, inputs.getTypes());
+ assert(succeeded(resultType) && "failed to infer concatenation result type");
+ build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+ if (getInputs().size() < 1)
+ return emitOpError("requires at least one input");
+
+ SmallVector<RankedTensorType> inputTypes;
+ for (auto input : getInputs())
+ inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+ RankedTensorType resultType = getResultType();
+ int64_t resultRank = getRank();
+ if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+ return type.getRank() != resultRank;
+ }))
+ return emitOpError("rank of concatenated inputs must match result rank");
+
+ Type resultElementType = resultType.getElementType();
+ if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+ return type.getElementType() != resultElementType;
+ }))
+ return emitOpError("inputs and result element type must match");
+
+ int64_t dim = getDim();
+ if (dim >= resultRank)
+ return emitOpError("concatenation dim must be less than the tensor rank");
+
+ SmallVector<int64_t> sizes(resultRank);
+ for (int64_t i = 0, e = resultRank; i < e; ++i) {
+ if (i == dim)
+ continue;
+ SaturatedInteger size;
+ for (auto tensorType : inputTypes) {
+ FailureOr<SaturatedInteger> maybeSize =
+ size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+ if (failed(maybeSize))
+ return emitOpError("static concatenation size mismatch along ")
+ << "non-concatenated dimension " << i;
+ size = *maybeSize;
+ }
+ sizes[i] = size.asInteger();
+ }
+ auto concatSize = SaturatedInteger::wrap(0);
+ for (auto tensorType : inputTypes)
+ concatSize =
+ concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+ sizes[dim] = concatSize.asInteger();
+ auto inferredResultType =
+ RankedTensorType::get(sizes, inputTypes[0].getElementType());
+
+ for (auto [inferredSize, actualSize] :
+ llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
+ bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+ ShapedType::isDynamic(actualSize);
+ if (!hasDynamic && inferredSize != actualSize)
+ return emitOpError("result type ")
+ << resultType << "does not match inferred shape "
+ << inferredResultType << " static sizes";
+ }
+
+ return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ ValueRange inputs = getInputs();
+ int64_t dim = getDim();
+ RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
+
+ Value init = inputs[0];
+ int64_t rank = getType().getRank();
+
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+ // Pre-populate the result sizes with as much static information as possible
+ // from the given result type, as well as the inferred result type, otherwise
+ // use the dim sizes from the first input.
+ for (int64_t i = 0; i < rank; ++i) {
+ if (i == dim)
+ continue;
+ if (!getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+ } else if (!inferredResultType.isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] =
+ builder.getIndexAttr(inferredResultType.getDimSize(i));
+ } else {
+ reifiedReturnShapes[0][i] =
+ builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
+ }
+ }
+
+ // Take the sum of the input sizes along the concatenated dim.
+ AffineExpr sum = builder.getAffineDimExpr(0);
+ SmallVector<OpFoldResult> sizes = {
+ builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
+ for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+ sum = sum + builder.getAffineDimExpr(idx + 1);
+ sizes.push_back(
+ builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+ }
+ reifiedReturnShapes[0][dim] =
+ affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+
+ // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
+ // returns a Value for dynamic dimensions.
+ for (int64_t i = 0; i < rank; ++i) {
+ if (getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
+ builder, getLoc(), reifiedReturnShapes[0][i]);
+ }
+ }
+ return success();
+}
+
+void ConcatOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "concat");
+}
+
+OpFoldResult ConcatOp::fold(FoldAdaptor) {
+ ValueRange inputs = getInputs();
+ if (inputs.size() == 1 && inputs[0].getType() == getResultType())
+ return inputs[0];
+ return {};
+}
+
+namespace {
+/// Fold a concat op with a single input to a cast.
+struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ if (concatOp.getInputs().size() != 1)
+ return failure();
+ rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+ concatOp.getInputs()[0]);
+ return success();
+ }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SingleInputConcatOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 3cec91389392..ed2742387047 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -83,6 +83,11 @@ void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
+void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ tensor::populateDecomposeTensorConcatPatterns(patterns);
+}
+
void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c5fd4e65bbf7..d233ab7a0e89 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ ConcatOpPatterns.cpp
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
FoldIntoPackAndUnpackPatterns.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRAffineTransforms
MLIRAffineUtils
MLIRArithDialect
+ MLIRArithUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
new file mode 100644
index 000000000000..2108fc591055
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -0,0 +1,93 @@
+//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
+///
+/// %concat = tensor.concat dim(1) %0, %1 :
+/// (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<2x7xf32>
+/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
+/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
+struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = concatOp.getLoc();
+ FailureOr<Value> dest =
+ tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+ if (failed(dest))
+ return failure();
+
+ auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+ if (!empty)
+ return failure();
+
+ int64_t dim = concatOp.getDim();
+ Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(dim));
+
+ int64_t rank = concatOp.getResultType().getRank();
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+
+ // Compute the partial sums for the slice offsets.
+ AffineExpr sum = rewriter.getAffineDimExpr(0);
+ SmallVector<AffineExpr> partialSums = {sum};
+ SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+ for (auto [idx, input] :
+ llvm::enumerate(concatOp.getInputs().drop_back())) {
+ sum = sum + rewriter.getAffineDimExpr(idx + 1);
+ partialSums.push_back(sum);
+ offsetStrides.push_back(
+ rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+ }
+ auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+ partialSums, rewriter.getContext());
+ SmallVector<OpFoldResult> dimOffsets =
+ affine::makeComposedFoldedMultiResultAffineApply(
+ rewriter, loc, partialSumMap, offsetStrides);
+
+ // Construct the chain of insert_slice ops into the destination.
+ Value result = *dest;
+ for (auto [input, offset] :
+ llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ offsets[dim] = offset;
+ result = rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, input, result, offsets, sizes, strides);
+ }
+
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ concatOp, concatOp.getResultType(), result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorConcatPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 580c1db60702..84c44a09aa3d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32>
// -----
+// CHECK-LABEL: fold_concat
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
+func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
+ %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
+ // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
+ %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
+ // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+ return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
new file mode 100644
index 000000000000..5712c77a743d
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.decompose_concat
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
+
+func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @decompose_dynamic_concat(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
+// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
+// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.decompose_concat
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
+
+func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
+ %arg1 : tensor<2xf32>,
+ %arg2 : tensor<3xf32>,
+ %arg3: tensor<4xf32>) -> tensor<10xf32> {
+ %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
+ : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
+ return %0 : tensor<10xf32>
+}
+// CHECK-LABEL: func @decompose_1d_concat
+// CHECK: tensor.empty() : tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
+// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
+// CHECK: return %[[CONCAT]] : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e675c0e..9b6c2327879c 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
// -----
+func.func @concat_empty() {
+ // expected-error@+1 {{requires at least one input}}
+ %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) {
+ // expected-error@+1 {{rank of concatenated inputs must match result rank}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+ return
+}
+
+// -----
+
+func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) {
+ // expected-error@+1 {{concatenation dim must be less than the tensor rank}}
+ %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) {
+ // expected-error@+1 {{inputs and result element type must match}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
+ // expected-error@+1 {{static concatenation size mismatch along non-concatenated dimension 1}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
+ return
+}
+
+// -----
+
+func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
+ // expected-error@+1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}}
+ %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32>
+ return
+}
+
+// -----
+
func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489b23f5..2282da38803a 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
// -----
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+ // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+ %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+ %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+ %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @empty(
// CHECK-SAME: %[[sz:.*]]: index
func.func @empty(%sz: index) -> tensor<5x?x6xf32> {