diff options
author | Luke Lau <luke@igalia.com> | 2024-03-29 19:45:24 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-29 19:45:24 +0800 |
commit | 2a315d800bb352fe459a012006a42ac7cd63834e (patch) | |
tree | 398a7993caa76041bbc4ecb558dff5ad67a1619b | |
parent | 1403cf67a628712bddbe0055161ec68c7ebb468d (diff) |
[RISCV] Combine (or disjoint ext, ext) -> vwadd (#86929)
DAGCombiner (or InstCombine) will convert an add to an or if the bits
are disjoint, which can prevent what was originally an (add {s,z}ext,
{s,z}ext) from being selected as a vwadd.
This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as
an add.
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 26 | ||||
-rw-r--r-- | llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll | 18 |
2 files changed, 27 insertions, 17 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3cd9ecb9dd68..e48ca4a905ce 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13530,7 +13530,7 @@ struct CombineResult; enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: -/// add | add_vl -> vwadd(u) | vwadd(u)_w +/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w /// sub | sub_vl -> vwsub(u) | vwsub(u)_w /// mul | mul_vl -> vwmul(u) | vwmul_su /// fadd -> vfwadd | vfwadd_w @@ -13678,6 +13678,7 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case ISD::OR: return RISCVISD::VWADD_VL; case ISD::SUB: case RISCVISD::SUB_VL: @@ -13700,6 +13701,7 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case ISD::OR: return RISCVISD::VWADDU_VL; case ISD::SUB: case RISCVISD::SUB_VL: @@ -13745,6 +13747,7 @@ struct NodeExtensionHelper { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: + case ISD::OR: return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; case ISD::SUB: @@ -13865,6 +13868,10 @@ struct NodeExtensionHelper { case ISD::MUL: { return Root->getValueType(0).isScalableVector(); } + case ISD::OR: { + return Root->getValueType(0).isScalableVector() && + Root->getFlags().hasDisjoint(); + } // Vector Widening Integer Add/Sub/Mul Instructions case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: @@ -13945,7 +13952,8 @@ struct NodeExtensionHelper { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: - case ISD::MUL: { + case ISD::MUL: + case ISD::OR: { SDLoc DL(Root); MVT VT = Root->getSimpleValueType(0); return getDefaultScalableVLOps(VT, DL, DAG, Subtarget); @@ -13968,6 +13976,7 @@ struct NodeExtensionHelper { switch (N->getOpcode()) { case ISD::ADD: case ISD::MUL: + case ISD::OR: case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -14034,6 +14043,7 @@ struct CombineResult { case ISD::ADD: case ISD::SUB: case ISD::MUL: + case ISD::OR: Merge = DAG.getUNDEF(Root->getValueType(0)); break; } @@ -14184,6 +14194,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: + case ISD::OR: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: case RISCVISD::FADD_VL: @@ -14227,9 +14238,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { /// Combine a binary operation to its equivalent VW or VW_W form. /// The supported combines are: -/// add_vl -> vwadd(u) | vwadd(u)_w -/// sub_vl -> vwsub(u) | vwsub(u)_w -/// mul_vl -> vwmul(u) | vwmul_su +/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w +/// sub | sub_vl -> vwsub(u) | vwsub(u)_w +/// mul | mul_vl -> vwmul(u) | vwmul_su /// fadd_vl -> vfwadd | vfwadd_w /// fsub_vl -> vfwsub | vfwsub_w /// fmul_vl -> vfwmul @@ -15889,8 +15900,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } case ISD::AND: return performANDCombine(N, DCI, Subtarget); - case ISD::OR: + case ISD::OR: { + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; return performORCombine(N, DCI, Subtarget); + } case ISD::XOR: return performXORCombine(N, DAG, Subtarget); case ISD::MUL: diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll index ed12afdd9595..66e6883dd1d3 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll @@ -1401,11 +1401,9 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v ; CHECK: # %bb.0: ; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma ; CHECK-NEXT: vzext.vf2 v10, v8 -; CHECK-NEXT: vsll.vi v8, v10, 8 -; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma -; CHECK-NEXT: vzext.vf2 v10, v8 -; CHECK-NEXT: vzext.vf4 v8, v9 -; CHECK-NEXT: vor.vv v8, v10, v8 +; CHECK-NEXT: vsll.vi v10, v10, 8 +; CHECK-NEXT: vzext.vf2 v11, v9 +; CHECK-NEXT: vwaddu.vv v8, v10, v11 ; CHECK-NEXT: ret %x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16> %x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer) @@ -1450,9 +1448,8 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) { ; CHECK-LABEL: vwaddu_wv_disjoint_or: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma -; CHECK-NEXT: vzext.vf2 v10, v9 -; CHECK-NEXT: vor.vv v8, v8, v10 +; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma +; CHECK-NEXT: vwaddu.wv v8, v8, v9 ; CHECK-NEXT: ret %y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32> %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32 @@ -1462,9 +1459,8 @@ define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsc define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) { ; CHECK-LABEL: vwadd_wv_disjoint_or: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma -; CHECK-NEXT: vsext.vf2 v10, v9 -; CHECK-NEXT: vor.vv v8, v8, v10 +; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.wv v8, v8, v9 ; CHECK-NEXT: ret %y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32> %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32 |