summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMin-Yih Hsu <min.hsu@sifive.com>2024-04-23 11:10:37 -0700
committerGitHub <noreply@github.com>2024-04-23 11:10:37 -0700
commit5fe93b0a4d91f1beb801e3a7588e1fa43955af15 (patch)
tree47f4473c85f4e259645a38fc5f9be71bb2df03b2
parent033453a9ad2a62914358747b5beb347482d3fdbd (diff)
[CodeGen][TII] Allow reassociation on custom operand indices (#88306)
This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list. This is effectively a NFC.
-rw-r--r--llvm/include/llvm/CodeGen/TargetInstrInfo.h10
-rw-r--r--llvm/lib/CodeGen/TargetInstrInfo.cpp143
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfo.cpp8
3 files changed, 114 insertions, 47 deletions
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index d4a83e3753d9..d5b1df2114e9 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -31,6 +31,7 @@
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/ErrorHandling.h"
+#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -1271,11 +1272,20 @@ public:
return true;
}
+ /// The returned array encodes the operand index for each parameter because
+ /// the operands may be commuted; the operand indices for associative
+ /// operations might also be target-specific. Each element specifies the index
+ /// of {Prev, A, B, X, Y}.
+ virtual void
+ getReassociateOperandIndices(const MachineInstr &Root, unsigned Pattern,
+ std::array<unsigned, 5> &OperandIndices) const;
+
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
/// reduce critical path length.
void reassociateOps(MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
/// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 14b2e4268eb0..e01e7b388891 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1055,12 +1055,34 @@ static std::pair<bool, bool> mustSwapOperands(unsigned Pattern) {
}
}
+void TargetInstrInfo::getReassociateOperandIndices(
+ const MachineInstr &Root, unsigned Pattern,
+ std::array<unsigned, 5> &OperandIndices) const {
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ OperandIndices = {1, 1, 1, 2, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ OperandIndices = {2, 1, 2, 2, 1};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ OperandIndices = {1, 2, 1, 1, 2};
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ OperandIndices = {2, 2, 2, 1, 1};
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+}
+
/// Attempt the reassociation transformation to reduce critical path length.
/// See the above comments before getMachineCombinerPatterns().
void TargetInstrInfo::reassociateOps(
MachineInstr &Root, MachineInstr &Prev, unsigned Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
+ ArrayRef<unsigned> OperandIndices,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
MachineFunction *MF = Root.getMF();
MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1068,29 +1090,10 @@ void TargetInstrInfo::reassociateOps(
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
- // This array encodes the operand index for each parameter because the
- // operands may be commuted. Each row corresponds to a pattern value,
- // and each column specifies the index of A, B, X, Y.
- unsigned OpIdx[4][4] = {
- { 1, 1, 2, 2 },
- { 1, 2, 2, 1 },
- { 2, 1, 1, 2 },
- { 2, 2, 1, 1 }
- };
-
- int Row;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
- case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
- case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
- case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
- default: llvm_unreachable("unexpected MachineCombinerPattern");
- }
-
- MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
- MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
- MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
- MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+ MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+ MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+ MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+ MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
MachineOperand &OpC = Root.getOperand(0);
Register RegA = OpA.getReg();
@@ -1129,11 +1132,62 @@ void TargetInstrInfo::reassociateOps(
std::swap(KillX, KillY);
}
+ unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+ unsigned RootFirstOpIdx, RootSecondOpIdx;
+ switch (Pattern) {
+ case MachineCombinerPattern::REASSOC_AX_BY:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_AX_YB:
+ PrevFirstOpIdx = OperandIndices[1];
+ PrevSecondOpIdx = OperandIndices[3];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_BY:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[2];
+ RootSecondOpIdx = OperandIndices[4];
+ break;
+ case MachineCombinerPattern::REASSOC_XA_YB:
+ PrevFirstOpIdx = OperandIndices[3];
+ PrevSecondOpIdx = OperandIndices[1];
+ RootFirstOpIdx = OperandIndices[4];
+ RootSecondOpIdx = OperandIndices[2];
+ break;
+ default:
+ llvm_unreachable("unexpected MachineCombinerPattern");
+ }
+
+ // Basically BuildMI but doesn't add implicit operands by default.
+ auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+ const MCInstrDesc &MCID, Register DestReg) {
+ return MachineInstrBuilder(
+ MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+ .setPCSections(MIMD.getPCSections())
+ .addReg(DestReg, RegState::Define);
+ };
+
// Create new instructions for insertion.
MachineInstrBuilder MIB1 =
- BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
- .addReg(RegX, getKillRegState(KillX))
- .addReg(RegY, getKillRegState(KillY));
+ buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+ for (const auto &MO : Prev.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand we'd already added.
+ if (Idx == 0)
+ continue;
+ if (Idx == PrevFirstOpIdx)
+ MIB1.addReg(RegX, getKillRegState(KillX));
+ else if (Idx == PrevSecondOpIdx)
+ MIB1.addReg(RegY, getKillRegState(KillY));
+ else
+ MIB1.add(MO);
+ }
+ MIB1.copyImplicitOps(Prev);
if (SwapRootOperands) {
std::swap(RegA, NewVR);
@@ -1141,9 +1195,20 @@ void TargetInstrInfo::reassociateOps(
}
MachineInstrBuilder MIB2 =
- BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
- .addReg(RegA, getKillRegState(KillA))
- .addReg(NewVR, getKillRegState(KillNewVR));
+ buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+ for (const auto &MO : Root.explicit_operands()) {
+ unsigned Idx = MO.getOperandNo();
+ // Skip the result operand.
+ if (Idx == 0)
+ continue;
+ if (Idx == RootFirstOpIdx)
+ MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+ else if (Idx == RootSecondOpIdx)
+ MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+ else
+ MIB2 = MIB2.add(MO);
+ }
+ MIB2.copyImplicitOps(Root);
// Propagate FP flags from the original instructions.
// But clear poison-generating flags because those may not be valid now.
@@ -1187,25 +1252,17 @@ void TargetInstrInfo::genAlternativeCodeSequence(
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
// Select the previous instruction in the sequence based on the input pattern.
- MachineInstr *Prev = nullptr;
- switch (Pattern) {
- case MachineCombinerPattern::REASSOC_AX_BY:
- case MachineCombinerPattern::REASSOC_XA_BY:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
- break;
- case MachineCombinerPattern::REASSOC_AX_YB:
- case MachineCombinerPattern::REASSOC_XA_YB:
- Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
- break;
- default:
- llvm_unreachable("Unknown pattern for machine combiner");
- }
+ std::array<unsigned, 5> OperandIndices;
+ getReassociateOperandIndices(Root, Pattern, OperandIndices);
+ MachineInstr *Prev =
+ MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
// Don't reassociate if Prev and Root are in different blocks.
if (Prev->getParent() != Root.getParent())
return;
- reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+ reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
+ InstIdxForVirtReg);
}
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 8331fc0b8c30..70ac1f8a592e 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1582,10 +1582,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
MachineFunction &MF = *Root.getMF();
for (auto *NewMI : InsInstrs) {
- assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
- NewMI->getOpcode(), RISCV::OpName::frm)) ==
- NewMI->getNumOperands() &&
- "Instruction has unexpected number of operands");
+ // We'd already added the FRM operand.
+ if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+ NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+ continue;
MachineInstrBuilder MIB(MF, NewMI);
MIB.add(FRM);
if (FRM.getImm() == RISCVFPRndMode::DYN)