summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKojo Acquah <KoolJBlack@users.noreply.github.com>2024-04-03 16:27:01 -0700
committerGitHub <noreply@github.com>2024-04-03 19:27:01 -0400
commit66fed33db014bd705433e4b4f1ea766a8d71cadf (patch)
treef25d9ff5165be7d333dfe30917bc292ece8ddb29
parentc511c90680eecae2e4adb87f442f41d465feb0f2 (diff)
[mlir][vector] Update `castAwayContractionLeadingOneDim` to omit transposes solely on leading unit dims. (#85694)
Updates `castAwayContractionLeadingOneDim` to check for leading unit dimensions before inserting `vector.transpose` ops. Currently `castAwayContractionLeadingOneDim` removes all leading unit dims based on the accumulator and transpose any subsequent operands to match the accumulator indexing. This does not take into account if the transpose is strictly necessary, for instance when given this vector-matrix contract: ```mlir %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32> ``` Passing this through `castAwayContractionLeadingOneDim` pattern produces the following: ```mlir %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32> %1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32> %2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32> %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32> %4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32> ``` The `vector.transpose` introduced does not affect the underlying data layout (effectively a no op), but it cannot be folded automatically. This change avoids inserting transposes when only leading unit dimensions are involved. Fixes #85691
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp21
-rw-r--r--mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir12
2 files changed, 30 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 593c1e53557a..8d733c5a8849 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Checks if only the outer, unit dimensions (of size 1) are permuted.
+ // Such transposes do not materially effect the underlying vector and can
+ // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+ bool transposeNonOuterUnitDims = false;
+ auto operandShape = operands[it.index()].getType().cast<ShapedType>();
+ for (auto [index, dim] :
+ llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+ if (dim != static_cast<int64_t>(index) &&
+ operandShape.getDimSize(index) != 1) {
+ transposeNonOuterUnitDims = true;
+ break;
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- loc, operands[it.index()], perm);
+ if (transposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ loc, operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 3a120a56056c..252aeb0c15cb 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -238,6 +238,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
+// CHECK-NOT vector.transpose
+// CHECK: vector.contract
+func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
+ %rhs: vector<1x8x8xi32>,
+ %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+ return %result : vector<1x8xi32>
+}
// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
@@ -663,4 +674,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
%sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
return %sel : vector<1x16xi1>
}
-