diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-04-16 12:54:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-16 12:54:01 +0100 |
commit | dadcaf82274805456b7d85131cf94f921b5398b7 (patch) | |
tree | d5f2ec2b8f3c80bd882b310bcc24da6fa2821f1d | |
parent | 71b9f6648222771470473431bc8ef2a2c25e872c (diff) |
[mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)
This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.
-rw-r--r-- | mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 32 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSME/vector-legalization.mlir | 11 |
2 files changed, 42 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 31500c62c0d6..b595c6dd8a68 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) { return (vectorRows * vectorCols) / (minNumElts * minNumElts); } +/// Legalize `arith.constant dense<value>` splat operations to fit within SME +/// tiles by decomposing them into tile-sized operations. +struct LegalizeArithConstantOpsByDecomposition + : public OneToNOpConversionPattern<arith::ConstantOp> { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto vectorType = dyn_cast<VectorType>(constantOp.getType()); + auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); + if (!vectorType || !denseAttr || !denseAttr.isSplat()) + return failure(); + + if (!isMultipleOfSMETileVectorType(vectorType)) + return rewriter.notifyMatchFailure(constantOp, + kMatchFailureNotSMETileTypeMultiple); + + auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); + auto tileCount = getNumberOfSMETilesForVectorType(vectorType); + auto tileSplat = rewriter.create<arith::ConstantOp>( + constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); + rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat), + adaptor.getResultMapping()); + + return success(); + } +}; + /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition @@ -637,7 +666,8 @@ struct VectorLegalizationPass // Note: High benefit to ensure masked outer products are lowered first. patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>( converter, context, 1024); - patterns.add<LegalizeVectorOuterProductOpsByDecomposition, + patterns.add<LegalizeArithConstantOpsByDecomposition, + LegalizeVectorOuterProductOpsByDecomposition, LegalizeTransferReadOpsByDecomposition, LegalizeTransferWriteOpsByDecomposition>(converter, context); populateFuncTypeConversionPatterns(converter, patterns); diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index f8be697548c1..f43ef1cce787 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -433,3 +433,14 @@ func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: m %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32> return %cast : vector<[4]xf32> } + +// ----- + +// CHECK-LABEL: @multi_tile_splat +func.func @multi_tile_splat() -> vector<[8]x[8]xi32> +{ + // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32> + // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32> + %0 = arith.constant dense<42> : vector<[8]x[8]xi32> + return %0 : vector<[8]x[8]xi32> +} |