summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLuke Lau <luke@igalia.com>2024-03-29 19:45:24 +0800
committerGitHub <noreply@github.com>2024-03-29 19:45:24 +0800
commit2a315d800bb352fe459a012006a42ac7cd63834e (patch)
tree398a7993caa76041bbc4ecb558dff5ad67a1619b
parent1403cf67a628712bddbe0055161ec68c7ebb468d (diff)
[RISCV] Combine (or disjoint ext, ext) -> vwadd (#86929)
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd. This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp26
-rw-r--r--llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll18
2 files changed, 27 insertions, 17 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3cd9ecb9dd68..e48ca4a905ce 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13530,7 +13530,7 @@ struct CombineResult;
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd -> vfwadd | vfwadd_w
@@ -13678,6 +13678,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADD_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13700,6 +13701,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADDU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13745,6 +13747,7 @@ struct NodeExtensionHelper {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
+ case ISD::OR:
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
: RISCVISD::VWADDU_W_VL;
case ISD::SUB:
@@ -13865,6 +13868,10 @@ struct NodeExtensionHelper {
case ISD::MUL: {
return Root->getValueType(0).isScalableVector();
}
+ case ISD::OR: {
+ return Root->getValueType(0).isScalableVector() &&
+ Root->getFlags().hasDisjoint();
+ }
// Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
@@ -13945,7 +13952,8 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::OR: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13968,6 +13976,7 @@ struct NodeExtensionHelper {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::MUL:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -14034,6 +14043,7 @@ struct CombineResult {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::OR:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -14184,6 +14194,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
case RISCVISD::FADD_VL:
@@ -14227,9 +14238,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// Combine a binary operation to its equivalent VW or VW_W form.
/// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd_vl -> vfwadd | vfwadd_w
/// fsub_vl -> vfwsub | vfwsub_w
/// fmul_vl -> vfwmul
@@ -15889,8 +15900,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
- case ISD::OR:
+ case ISD::OR: {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
return performORCombine(N, DCI, Subtarget);
+ }
case ISD::XOR:
return performXORCombine(N, DAG, Subtarget);
case ISD::MUL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index ed12afdd9595..66e6883dd1d3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1401,11 +1401,9 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vsll.vi v8, v10, 8
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf4 v8, v9
-; CHECK-NEXT: vor.vv v8, v10, v8
+; CHECK-NEXT: vsll.vi v10, v10, 8
+; CHECK-NEXT: vzext.vf2 v11, v9
+; CHECK-NEXT: vwaddu.vv v8, v10, v11
; CHECK-NEXT: ret
%x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
%x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)
@@ -1450,9 +1448,8 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca
define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwaddu_wv_disjoint_or:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v10, v9
-; CHECK-NEXT: vor.vv v8, v8, v10
+; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwaddu.wv v8, v8, v9
; CHECK-NEXT: ret
%y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
@@ -1462,9 +1459,8 @@ define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsc
define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwadd_wv_disjoint_or:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v9
-; CHECK-NEXT: vor.vv v8, v8, v10
+; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.wv v8, v8, v9
; CHECK-NEXT: ret
%y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32