diff options
author | Kojo Acquah <KoolJBlack@users.noreply.github.com> | 2024-04-03 16:27:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-03 19:27:01 -0400 |
commit | 66fed33db014bd705433e4b4f1ea766a8d71cadf (patch) | |
tree | f25d9ff5165be7d333dfe30917bc292ece8ddb29 | |
parent | c511c90680eecae2e4adb87f442f41d465feb0f2 (diff) |
[mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (#85694)
Updates `castAwayContractionLeadingOneDim` to check for leading unit
dimensions before inserting `vector.transpose` ops.
Currently `castAwayContractionLeadingOneDim` removes all leading unit
dims based on the accumulator and transpose any subsequent operands to
match the accumulator indexing. This does not take into account if the
transpose is strictly necessary, for instance when given this
vector-matrix contract:
```mlir
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
```
Passing this through `castAwayContractionLeadingOneDim` pattern produces
the following:
```mlir
%0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32>
%1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32>
%2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32>
%3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
%4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32>
```
The `vector.transpose` introduced does not affect the underlying data
layout (effectively a no op), but it cannot be folded automatically.
This change avoids inserting transposes when only leading unit
dimensions are involved.
Fixes #85691
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp | 21 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir | 12 |
2 files changed, 30 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 593c1e53557a..8d733c5a8849 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, transposeResults.push_back(targetExpr); } } + + // Checks if only the outer, unit dimensions (of size 1) are permuted. + // Such transposes do not materially effect the underlying vector and can + // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> + bool transposeNonOuterUnitDims = false; + auto operandShape = operands[it.index()].getType().cast<ShapedType>(); + for (auto [index, dim] : + llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) { + if (dim != static_cast<int64_t>(index) && + operandShape.getDimSize(index) != 1) { + transposeNonOuterUnitDims = true; + break; + } + } + // Do the tranpose now if needed so that we can drop the // correct dim using extract later. if (tranposeNeeded) { map = AffineMap::get(map.getNumDims(), 0, transposeResults, contractOp.getContext()); - operands[it.index()] = rewriter.create<vector::TransposeOp>( - loc, operands[it.index()], perm); + if (transposeNonOuterUnitDims) { + operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>( + loc, operands[it.index()], perm); + } } } // We have taken care to have the dim to be dropped be diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index 3a120a56056c..252aeb0c15cb 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -238,6 +238,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra return %0: vector<1x1x2x16xf32> } +// ----- + +// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims +// CHECK-NOT vector.transpose +// CHECK: vector.contract +func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>, + %rhs: vector<1x8x8xi32>, + %acc: vector<1x8xi32>) -> vector<1x8xi32> { + %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32> + return %result : vector<1x8xi32> +} // ----- // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims @@ -663,4 +674,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1> return %sel : vector<1x16xi1> } - |