diff options
author | Craig Topper <craig.topper@sifive.com> | 2023-12-21 14:34:49 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-21 14:34:49 -0800 |
commit | e64f5d6305c447b1ec3bc31128753b28f4e87f32 (patch) | |
tree | 38e7694f0ba45b1958d0a6bbd7354c432e8a5c0f | |
parent | 3ca9bcc6ccd0de4e05c7b7749c24f94e5f184b45 (diff) |
[RISCV] Replace RISCVISD::VP_MERGE_VL with a new node that has a separate passthru operand. (#75682)
ISD::VP_MERGE treats the false operand as the source for elements past
VL. The vmerge instruction encodes 3 registers and treats the vd
register as the source for the tail.
This patch adds a new ISD opcode that models the tail source explicitly.
During lowering we copy the false operand to this operand.
I think we can merge RISCVISD::VSELECT_VL with this new opcode by using
an UNDEF passthru, but I'll save that for another patch.
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 27 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.h | 6 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 125 |
3 files changed, 91 insertions, 67 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d6dedd669ffd..40518097fcce 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5530,7 +5530,7 @@ static unsigned getRISCVVLOp(SDValue Op) { case ISD::VP_SELECT: return RISCVISD::VSELECT_VL; case ISD::VP_MERGE: - return RISCVISD::VP_MERGE_VL; + return RISCVISD::VMERGE_VL; case ISD::VP_ASHR: return RISCVISD::SRA_VL; case ISD::VP_LSHR: @@ -5578,6 +5578,8 @@ static bool hasMergeOp(unsigned Opcode) { return true; if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL) return true; + if (Opcode == RISCVISD::VMERGE_VL) + return true; return false; } @@ -8242,8 +8244,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG, AVL); // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma. // It's fine because vmerge does not care mask policy. - return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff, - AVL); + return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff, + MaskedOff, AVL); } } @@ -10316,9 +10318,20 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const { for (const auto &OpIdx : enumerate(Op->ops())) { SDValue V = OpIdx.value(); assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!"); - // Add dummy merge value before the mask. - if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index()) - Ops.push_back(DAG.getUNDEF(ContainerVT)); + // Add dummy merge value before the mask. Or if there isn't a mask, before + // EVL. + if (HasMergeOp) { + auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode()); + if (MaskIdx) { + if (*MaskIdx == OpIdx.index()) + Ops.push_back(DAG.getUNDEF(ContainerVT)); + } else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == + OpIdx.index()) { + // For VP_MERGE, copy the false operand instead of an undef value. + assert(Op.getOpcode() == ISD::VP_MERGE); + Ops.push_back(Ops.back()); + } + } // Pass through operands which aren't fixed-length vectors. if (!V.getValueType().isFixedLengthVector()) { Ops.push_back(V); @@ -18658,7 +18671,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VNSRL_VL) NODE_NAME_CASE(SETCC_VL) NODE_NAME_CASE(VSELECT_VL) - NODE_NAME_CASE(VP_MERGE_VL) + NODE_NAME_CASE(VMERGE_VL) NODE_NAME_CASE(VMAND_VL) NODE_NAME_CASE(VMOR_VL) NODE_NAME_CASE(VMXOR_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 2d9f716cdf9a..58ed611efc83 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -332,10 +332,8 @@ enum NodeType : unsigned { // Vector select with an additional VL operand. This operation is unmasked. VSELECT_VL, - // Vector select with operand #2 (the value when the condition is false) tied - // to the destination and an additional VL operand. This operation is - // unmasked. - VP_MERGE_VL, + // General vmerge node with mask, true, false, passthru, and vl operands. + VMERGE_VL, // Mask binary operators. VMAND_VL, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index dc6b57fad321..33bdc3366aa3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -344,7 +344,14 @@ def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [ ]>; def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>; -def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>; + +def SDT_RISCVVMERGE_VL : SDTypeProfile<1, 5, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>, + SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>, + SDTCisVT<5, XLenVT> +]>; + +def riscv_vmerge_vl : SDNode<"RISCVISD::VMERGE_VL", SDT_RISCVVMERGE_VL>; def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>, SDTCisVT<1, XLenVT>]>; @@ -675,14 +682,14 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop, op2_reg_class:$rs2, GPR:$vl, sew, TAIL_AGNOSTIC)>; // Tail undisturbed - def : Pat<(riscv_vp_merge_vl true_mask, + def : Pat<(riscv_vmerge_vl true_mask, (result_type (vop result_reg_class:$rs1, (op2_type op2_reg_class:$rs2), srcvalue, true_mask, VLOpFrag)), - result_reg_class:$rs1, VLOpFrag), + result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag), (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED") result_reg_class:$rs1, op2_reg_class:$rs2, @@ -712,14 +719,14 @@ multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop, FRM_DYN, GPR:$vl, sew, TAIL_AGNOSTIC)>; // Tail undisturbed - def : Pat<(riscv_vp_merge_vl true_mask, + def : Pat<(riscv_vmerge_vl true_mask, (result_type (vop result_reg_class:$rs1, (op2_type op2_reg_class:$rs2), srcvalue, true_mask, VLOpFrag)), - result_reg_class:$rs1, VLOpFrag), + result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag), (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED") result_reg_class:$rs1, op2_reg_class:$rs2, @@ -1697,21 +1704,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> { foreach vti = AllIntegerVectors in { defvar suffix = vti.LMul.MX; let Predicates = GetVTypePredicates<vti>.Predicates in { - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (op vti.RegClass:$rd, (riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2, srcvalue, (vti.Mask true_mask), VLOpFrag), srcvalue, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK") vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>; - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (op vti.RegClass:$rd, (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2, srcvalue, (vti.Mask true_mask), VLOpFrag), srcvalue, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK") vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>; @@ -1840,17 +1847,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> { foreach vti = AllFloatVectors in { defvar suffix = vti.LMul.MX; let Predicates = GetVTypePredicates<vti>.Predicates in { - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2, vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK") vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>; - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2, vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK") vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>; @@ -1876,10 +1883,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> { foreach vti = AllFloatVectors in { defvar suffix = vti.LMul.MX; let Predicates = GetVTypePredicates<vti>.Predicates in { - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2, vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK") vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), @@ -1887,10 +1894,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> { // RISCVInsertReadWriteCSR FRM_DYN, GPR:$vl, vti.Log2SEW, TU_MU)>; - def : Pat<(riscv_vp_merge_vl (vti.Mask V0), + def : Pat<(riscv_vmerge_vl (vti.Mask V0), (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2, vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)), - vti.RegClass:$rd, VLOpFrag), + vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag), (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK") vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), @@ -2273,29 +2280,32 @@ foreach vti = AllIntegerVectors in { (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0), - vti.RegClass:$rs1, - vti.RegClass:$rs2, - VLOpFrag)), + def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0), + vti.RegClass:$rs1, + vti.RegClass:$rs2, + vti.RegClass:$merge, + VLOpFrag)), (!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX) - vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1, - (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; + vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0), - (SplatPat XLenVT:$rs1), - vti.RegClass:$rs2, - VLOpFrag)), + def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0), + (SplatPat XLenVT:$rs1), + vti.RegClass:$rs2, + vti.RegClass:$merge, + VLOpFrag)), (!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX) - vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1, - (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; - - def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0), - (SplatPat_simm5 simm5:$rs1), - vti.RegClass:$rs2, - VLOpFrag)), + vti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; + + def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0), + (SplatPat_simm5 simm5:$rs1), + vti.RegClass:$rs2, + vti.RegClass:$merge, + VLOpFrag)), (!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX) - vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1, - (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; + vti.RegClass:$merge, vti.RegClass:$rs2, simm5:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW)>; } } @@ -2493,21 +2503,23 @@ foreach fvti = AllFloatVectors in { (fvti.Vector (IMPLICIT_DEF)), fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>; - def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0), - fvti.RegClass:$rs1, - fvti.RegClass:$rs2, - VLOpFrag)), - (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX) - fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0), - GPR:$vl, fvti.Log2SEW)>; - - def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0), - (SplatFPOp (fvti.Scalar fpimm0)), - fvti.RegClass:$rs2, - VLOpFrag)), - (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX) - fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0), - GPR:$vl, fvti.Log2SEW)>; + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0), + fvti.RegClass:$rs1, + fvti.RegClass:$rs2, + fvti.RegClass:$merge, + VLOpFrag)), + (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX) + fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0), + GPR:$vl, fvti.Log2SEW)>; + + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0), + (SplatFPOp (fvti.Scalar fpimm0)), + fvti.RegClass:$rs2, + fvti.RegClass:$merge, + VLOpFrag)), + (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX) + fvti.RegClass:$merge, fvti.RegClass:$rs2, 0, (fvti.Mask V0), + GPR:$vl, fvti.Log2SEW)>; } let Predicates = GetVTypePredicates<fvti>.Predicates in { @@ -2521,12 +2533,13 @@ foreach fvti = AllFloatVectors in { (fvti.Scalar fvti.ScalarRegClass:$rs1), (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>; - def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0), - (SplatFPOp fvti.ScalarRegClass:$rs1), - fvti.RegClass:$rs2, - VLOpFrag)), + def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0), + (SplatFPOp fvti.ScalarRegClass:$rs1), + fvti.RegClass:$rs2, + fvti.RegClass:$merge, + VLOpFrag)), (!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX) - fvti.RegClass:$rs2, fvti.RegClass:$rs2, + fvti.RegClass:$merge, fvti.RegClass:$rs2, (fvti.Scalar fvti.ScalarRegClass:$rs1), (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>; |