diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2023-12-21 17:46:12 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-21 17:46:12 +0000 |
commit | a4e15416b41459b6f69086a22088520ee826f244 (patch) | |
tree | e51ebd4fa6784fb91f00996e855e88cbe71264c6 | |
parent | 88151dd4285cdd9feeb24ebb1be9cf5252ab0883 (diff) |
[mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (#76168)
Also, for consistency make the ZeroOp lowering switch on the ArmSMETileType,
rather than the element bit width.
-rw-r--r-- | mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 227 |
1 files changed, 108 insertions, 119 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index f9d6f04a811f..0c6e2e80b88a 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -32,6 +32,95 @@ using namespace mlir; namespace { +/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. +static Operation *createLoadTileSliceIntrinsic( + RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, + arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, + IntegerAttr tileId, Value tileSliceI32) { + if (layout == arm_sme::TileSliceLayout::Horizontal) { + switch (type) { + case arm_sme::ArmSMETileType::ZAB: + return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAH: + return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAS: + return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAD: + return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAQ: + return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + } + } else { + switch (type) { + case arm_sme::ArmSMETileType::ZAB: + return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAH: + return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAS: + return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAD: + return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAQ: + return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + break; + } + } +} + +/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic. +static Operation *createStoreTileSliceIntrinsic( + RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, + arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, + IntegerAttr tileId, Value tileSliceI32) { + if (layout == arm_sme::TileSliceLayout::Horizontal) { + switch (type) { + case arm_sme::ArmSMETileType::ZAB: + return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAH: + return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAS: + return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAD: + return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAQ: + return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( + loc, maskOp, ptr, tileId, tileSliceI32); + } + } else { + switch (type) { + case arm_sme::ArmSMETileType::ZAB: + return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAH: + return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAS: + return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAD: + return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + case arm_sme::ArmSMETileType::ZAQ: + return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( + loc, maskOp, ptr, tileId, tileSliceI32); + } + } +} + IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { auto tileId = op.getTileId(); if (!tileId) @@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> { ConversionPatternRewriter &rewriter) const override { auto loc = zero.getLoc(); - unsigned tileElementWidth = - zero.getVectorType().getElementType().getIntOrFloatBitWidth(); - auto tileId = getTileIdOrError(zero); if (!tileId) return failure(); @@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> { // The base mask is just the mask to zero the first tile (of a size). // These masks are derived from: // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- + arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType(); auto baseMaskForSize = [&] { - switch (tileElementWidth) { - case 8: + switch (tileType) { + case arm_sme::ArmSMETileType::ZAB: // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight // 64-bit element tiles named ZA0.D to ZA7.D. return 0b1111'1111; - case 16: - // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element - // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. - // Shift this left once for ZA1.H. + case arm_sme::ArmSMETileType::ZAH: + // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit + // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left + // once for ZA1.H. return 0b0101'0101; - case 32: + case arm_sme::ArmSMETileType::ZAS: // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit // element tiles named ZA0.D and ZA4.D. // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S. return 0b0001'0001; - case 64: + case arm_sme::ArmSMETileType::ZAD: // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires // setting the bit for that tile. return 0b0000'0001; @@ -172,63 +259,13 @@ struct LoadTileSliceConversion // Create all active predicate mask. auto maskOp = loadTileSliceOp.getMask(); - auto tileType = loadTileSliceOp.getVectorType(); - auto tileElementType = tileType.getElementType(); - unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + auto tileVectorType = loadTileSliceOp.getVectorType(); + arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice. - if (layout == arm_sme::TileSliceLayout::Horizontal) { - switch (tileElementWidth) { - default: - llvm_unreachable("unexpected element type!"); - case 8: - rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 16: - rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 32: - rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 64: - rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 128: - rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - } - } else { - switch (tileElementWidth) { - default: - llvm_unreachable("unexpected element type!"); - case 8: - rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 16: - rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 32: - rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 64: - rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - case 128: - rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr, - tileId, tileSliceI32); - break; - } - } + createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr, + tileId, tileSliceI32); // The load intrinsics have no result, replace 'arm_sme.tile_load' with // the input tile to preserve dataflow. @@ -249,9 +286,7 @@ struct StoreTileSliceConversion arm_sme::StoreTileSliceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = storeTileSliceOp.getLoc(); - auto tileType = storeTileSliceOp.getVectorType(); - auto tileElementType = tileType.getElementType(); - unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + auto tileVectorType = storeTileSliceOp.getVectorType(); auto tileId = getTileIdOrError(storeTileSliceOp); if (!tileId) @@ -271,58 +306,12 @@ struct StoreTileSliceConversion auto maskOp = storeTileSliceOp.getMask(); arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); + arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); - if (layout == arm_sme::TileSliceLayout::Horizontal) { - switch (tileElementWidth) { - default: - llvm_unreachable("unexpected element type!"); - case 8: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 16: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 32: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 64: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 128: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - } - } else { - switch (tileElementWidth) { - default: - llvm_unreachable("unexpected element type!"); - case 8: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 16: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 32: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 64: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - case 128: - rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>( - storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32); - break; - } - } + rewriter.replaceOp(storeTileSliceOp, + createStoreTileSliceIntrinsic(rewriter, loc, tileType, + layout, maskOp, ptr, + tileId, tileSliceI32)); return success(); } |