summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKojo Acquah <KoolJBlack@users.noreply.github.com>2024-04-03 16:24:18 -0700
committerGitHub <noreply@github.com>2024-04-03 19:24:18 -0400
commitc511c90680eecae2e4adb87f442f41d465feb0f2 (patch)
treeffe39303a4582cae2e4f1c87688a8d6fe4f250b6
parentbe57c90feff81d067c83be1ab8927fb345c761cc (diff)
[mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (#86005)
Updates smmla unrolling patterns to handle vecmat contracts where `dimM=1`. This includes explicit vecmats in the form: `<1x8xi8> x <8x8xi8> --> <1x8xi32>` or implied with the leading dim folded: `<8xi8> x <8x8xi8> --> <8xi32>` Since the smmla operates on two `<2x8xi8>` input vectors to produce `<2x2xi8>` accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp98
-rw-r--r--mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir124
2 files changed, 191 insertions, 31 deletions
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b..13740225749e 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -49,32 +50,35 @@ public:
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract. All iterators except inner-most
+ // dimension must be parallel.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
+ return failure();
+ }
+ if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
+ [](vector::IteratorType iteratorType) {
+ return iteratorType != vector::IteratorType::parallel;
+ })) {
return failure();
}
@@ -120,11 +124,14 @@ public:
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape{2, 2, 8};
- SmallVector<int64_t> loopOrder{0, 1, 2};
+ SmallVector<int64_t> smmlaShape{2, 8};
+ SmallVector<int64_t> loopOrder{0, 1};
+ if (unrolledSize.size() == 3) {
+ smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+ loopOrder.push_back(2);
+ }
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +157,40 @@ public:
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ auto inputElementType =
+ tiledLhs.getType().cast<ShapedType>().getElementType();
+ auto accElementType =
+ tiledAcc.getType().cast<ShapedType>().getElementType();
+ auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Expand their shapes to match the smmla op.
+ if (isVecmat) {
+ auto expandForSMMLA = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
+ auto emptyOperand = rewriter.create<arith::ConstantOp>(
+ loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+ SmallVector<int64_t> offsets(
+ emptyOperand.getType().cast<ShapedType>().getRank(), 0);
+ SmallVector<int64_t> strides(
+ tiledOperand.getType().cast<ShapedType>().getRank(), 1);
+ return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledOperand, emptyOperand, offsets, strides);
+ };
+ tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
+ tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- tiledLhs.getType().cast<ShapedType>().getNumElements(),
- tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
@@ -172,6 +203,11 @@ public:
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
+ if (isVecmat) {
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+ }
+
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6..46c4026d13b6 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,127 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
return %res : vector<4x4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: return %[[VAL_49]] : vector<8xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32>
+// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8>
+// CHECK: %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: return %[[VAL_49]] : vector<1x8xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+ %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
+ return %res : vector<1x8xi32>
+}