summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2023-12-21 17:46:12 +0000
committerGitHub <noreply@github.com>2023-12-21 17:46:12 +0000
commita4e15416b41459b6f69086a22088520ee826f244 (patch)
treee51ebd4fa6784fb91f00996e855e88cbe71264c6
parent88151dd4285cdd9feeb24ebb1be9cf5252ab0883 (diff)
[mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (#76168)
Also, for consistency make the ZeroOp lowering switch on the ArmSMETileType, rather than the element bit width.
-rw-r--r--mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp227
1 files changed, 108 insertions, 119 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index f9d6f04a811f..0c6e2e80b88a 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,95 @@ using namespace mlir;
namespace {
+/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+static Operation *createLoadTileSliceIntrinsic(
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+ IntegerAttr tileId, Value tileSliceI32) {
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ } else {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ break;
+ }
+ }
+}
+
+/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+static Operation *createStoreTileSliceIntrinsic(
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+ IntegerAttr tileId, Value tileSliceI32) {
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ } else {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ }
+}
+
IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
auto tileId = op.getTileId();
if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
- unsigned tileElementWidth =
- zero.getVectorType().getElementType().getIntOrFloatBitWidth();
-
auto tileId = getTileIdOrError(zero);
if (!tileId)
return failure();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
// The base mask is just the mask to zero the first tile (of a size).
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
+ arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
auto baseMaskForSize = [&] {
- switch (tileElementWidth) {
- case 8:
+ switch (tileType) {
+ case arm_sme::ArmSMETileType::ZAB:
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
// 64-bit element tiles named ZA0.D to ZA7.D.
return 0b1111'1111;
- case 16:
- // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
- // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
- // Shift this left once for ZA1.H.
+ case arm_sme::ArmSMETileType::ZAH:
+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
+ // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
+ // once for ZA1.H.
return 0b0101'0101;
- case 32:
+ case arm_sme::ArmSMETileType::ZAS:
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
// element tiles named ZA0.D and ZA4.D.
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
return 0b0001'0001;
- case 64:
+ case arm_sme::ArmSMETileType::ZAD:
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
// setting the bit for that tile.
return 0b0000'0001;
@@ -172,63 +259,13 @@ struct LoadTileSliceConversion
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
- auto tileType = loadTileSliceOp.getVectorType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+ auto tileVectorType = loadTileSliceOp.getVectorType();
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- }
- } else {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- }
- }
+ createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
+ tileId, tileSliceI32);
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
// the input tile to preserve dataflow.
@@ -249,9 +286,7 @@ struct StoreTileSliceConversion
arm_sme::StoreTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = storeTileSliceOp.getLoc();
- auto tileType = storeTileSliceOp.getVectorType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+ auto tileVectorType = storeTileSliceOp.getVectorType();
auto tileId = getTileIdOrError(storeTileSliceOp);
if (!tileId)
@@ -271,58 +306,12 @@ struct StoreTileSliceConversion
auto maskOp = storeTileSliceOp.getMask();
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- }
- } else {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- }
- }
+ rewriter.replaceOp(storeTileSliceOp,
+ createStoreTileSliceIntrinsic(rewriter, loc, tileType,
+ layout, maskOp, ptr,
+ tileId, tileSliceI32));
return success();
}