diff options
author | Diego Caballero <diegocaballero@google.com> | 2024-01-31 17:26:50 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-31 17:26:50 -0800 |
commit | 8ba018d72a67050a9c37065ea2668814ebf513a9 (patch) | |
tree | 4c4a220fa50caa6cd02166bee7c0c408a249154a | |
parent | 0e8eb445db0cc2552d9d077b527a43c779785cb9 (diff) |
[mlir][Vector] Add support for sub-byte transpose emulation (#80110)
This PR adds patterns to convert a sub-byte vector transpose into a
sequence of instructions that perform the transpose on i8 vector
elements. Whereas this rewrite may not lead to the absolute peak
performance, it should ensure correctness when dealing with sub-byte
transposes.
5 files changed, 80 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 3ac6f28dcb93..ce88360aa52e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect, "apply_patterns.vector.lower_masked_transfers", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { let description = [{ - Apply opt-in patterns that lower vector.mask operations surrounding + Apply opt-in patterns that lower vector.mask operations surrounding side-effecting ops: - MaskedTransferReadOpPattern - MaskedTransferWriteOpPattern @@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect, - ReorderCastOpsOnBroadcast - ReorderElementwiseOpsOnTranspose - These patterns have the effect of rewriting a vector.multi_reduce into a + These patterns have the effect of rewriting a vector.multi_reduce into a vector.contract. }]; diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 49b74c0c466d..f5941d32e683 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Appends patterns for emulating a sub-byte vector transpose. +void populateVectorTransposeNarrowTypeRewritePatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 37127ea70f1e..19922c4295fe 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns( void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorNarrowTypeRewritePatterns(patterns); + populateVectorTransposeNarrowTypeRewritePatterns(patterns); } void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 0110a8df89ae..36fb66708407 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1052,6 +1052,53 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> { } }; +/// Rewrite a sub-byte vector transpose into a sequence of instructions that +/// perform the transpose on wider (byte) element types. +/// For example: +/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> +/// +/// is rewritten as: +/// +/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8> +/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8> +/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> +/// +struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { + using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; + + RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) + : OpRewritePattern<vector::TransposeOp>(context, benefit) {} + + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Precondition: sub-byte integer transpose. + constexpr unsigned minNativeBitwidth = 8; + VectorType srcSubByteVecType = transposeOp.getSourceVectorType(); + if (!srcSubByteVecType.getElementType().isSignlessInteger() || + srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) { + return rewriter.notifyMatchFailure(transposeOp, + "not a sub-byte transpose"); + } + + // Perform the rewrite. + Location loc = transposeOp.getLoc(); + // Signed/unsigned interpretation shouldn't matter here as we are just + // transposing the elements and truncating them back to the original size. + // TODO: Use unsigned extension (more efficient) when emulation or backend + // support is available. + auto srcNativeVecType = srcSubByteVecType.cloneWith( + std::nullopt, rewriter.getIntegerType(minNativeBitwidth)); + Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType, + transposeOp.getVector()); + Value newTranspose = rewriter.create<vector::TransposeOp>( + loc, extOp, transposeOp.getPermutation()); + VectorType dstSubByteVecType = transposeOp.getResultVectorType(); + rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType, + newTranspose); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1080,3 +1127,8 @@ void vector::populateVectorNarrowTypeRewritePatterns( RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>( patterns.getContext(), benefit.getBenefit() + 1); } + +void vector::populateVectorTransposeNarrowTypeRewritePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index c4fbb4c219b9..02063a81664b 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { return %0 : vector<8xf32> } +// CHECK-LABEL: func.func @i4_transpose( +// CHECK-SAME: %[[A:[0-9a-z]*]] +func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> { + // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8> + // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4> + %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> + return %0 : vector<16x8xi4> +} + +// CHECK-LABEL: func.func @i7_transpose( +// CHECK-SAME: %[[A:[0-9a-z]*]] +func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { + // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8> + // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7> + %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7> + return %0 : vector<16x8xi7> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op @@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} { transform.yield } } + |