summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDiego Caballero <diegocaballero@google.com>2024-01-31 17:26:50 -0800
committerGitHub <noreply@github.com>2024-01-31 17:26:50 -0800
commit8ba018d72a67050a9c37065ea2668814ebf513a9 (patch)
tree4c4a220fa50caa6cd02166bee7c0c408a249154a
parent0e8eb445db0cc2552d9d077b527a43c779785cb9 (diff)
[mlir][Vector] Add support for sub-byte transpose emulation (#80110)
This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h4
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp52
-rw-r--r--mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir21
5 files changed, 80 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb93..ce88360aa52e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_masked_transfers",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Apply opt-in patterns that lower vector.mask operations surrounding
+ Apply opt-in patterns that lower vector.mask operations surrounding
side-effecting ops:
- MaskedTransferReadOpPattern
- MaskedTransferWriteOpPattern
@@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
- ReorderCastOpsOnBroadcast
- ReorderElementwiseOpsOnTranspose
- These patterns have the effect of rewriting a vector.multi_reduce into a
+ These patterns have the effect of rewriting a vector.multi_reduce into a
vector.contract.
}];
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 49b74c0c466d..f5941d32e683 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Appends patterns for emulating a sub-byte vector transpose.
+void populateVectorTransposeNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e..19922c4295fe 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
+ populateVectorTransposeNarrowTypeRewritePatterns(patterns);
}
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0110a8df89ae..36fb66708407 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1052,6 +1052,53 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite a sub-byte vector transpose into a sequence of instructions that
+/// perform the transpose on wider (byte) element types.
+/// For example:
+/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+///
+/// is rewritten as:
+///
+/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
+/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
+///
+struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ // Precondition: sub-byte integer transpose.
+ constexpr unsigned minNativeBitwidth = 8;
+ VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
+ if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
+ srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
+ return rewriter.notifyMatchFailure(transposeOp,
+ "not a sub-byte transpose");
+ }
+
+ // Perform the rewrite.
+ Location loc = transposeOp.getLoc();
+ // Signed/unsigned interpretation shouldn't matter here as we are just
+ // transposing the elements and truncating them back to the original size.
+ // TODO: Use unsigned extension (more efficient) when emulation or backend
+ // support is available.
+ auto srcNativeVecType = srcSubByteVecType.cloneWith(
+ std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
+ Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
+ transposeOp.getVector());
+ Value newTranspose = rewriter.create<vector::TransposeOp>(
+ loc, extOp, transposeOp.getPermutation());
+ VectorType dstSubByteVecType = transposeOp.getResultVectorType();
+ rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
+ newTranspose);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1080,3 +1127,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
patterns.getContext(), benefit.getBenefit() + 1);
}
+
+void vector::populateVectorTransposeNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index c4fbb4c219b9..02063a81664b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @i4_transpose(
+// CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
+ // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
+ // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+ // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+ %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+ return %0 : vector<16x8xi4>
+}
+
+// CHECK-LABEL: func.func @i7_transpose(
+// CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
+ // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
+ // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+ // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+ %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
+ return %0 : vector<16x8xi7>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+