summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlexey Bataev <a.bataev@outlook.com>2024-04-03 09:37:23 -0700
committerAlexey Bataev <a.bataev@outlook.com>2024-04-03 10:00:03 -0700
commit07a566793b2f94d0de6b95b7e6d1146b0d7ffe49 (patch)
treec3a5572f16367d6dbf21ea9912d5bc6f416128e0
parent33992eabc7834e32094e7187dc10225f1a3773a5 (diff)
[SLP]Fix PR87477: fix alternate node cast cost/codegen.
Have to compare actual type size to pick up proper cast operation opcode.
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp65
-rw-r--r--llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll34
2 files changed, 74 insertions, 25 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index cb55992051eb..7928d29d6dfa 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9063,25 +9063,35 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
E->getAltOp());
} else {
- Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
- Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
- auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size());
- auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size());
- if (It != MinBWs.end()) {
- if (!MinBWs.contains(getOperandEntry(E, 0)))
- VecCost =
- TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, Src0Ty,
- TTI::CastContextHint::None, CostKind);
- LLVM_DEBUG({
- dbgs() << "SLP: alternate extension, which should be truncated.\n";
- E->dump();
- });
- return VecCost;
+ Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType();
+ auto *SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
+ if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) {
+ auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ unsigned SrcBWSz =
+ DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType());
+ if (SrcIt != MinBWs.end()) {
+ SrcBWSz = SrcIt->second.first;
+ SrcSclTy = IntegerType::get(SrcSclTy->getContext(), SrcBWSz);
+ SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
+ }
+ if (BWSz <= SrcBWSz) {
+ if (BWSz < SrcBWSz)
+ VecCost =
+ TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, SrcTy,
+ TTI::CastContextHint::None, CostKind);
+ LLVM_DEBUG({
+ dbgs()
+ << "SLP: alternate extension, which should be truncated.\n";
+ E->dump();
+ });
+ return VecCost;
+ }
}
- VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, Src0Ty,
+ VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, SrcTy,
TTI::CastContextHint::None, CostKind);
VecCost +=
- TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
+ TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, SrcTy,
TTI::CastContextHint::None, CostKind);
}
SmallVector<int> Mask;
@@ -12591,15 +12601,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
CmpInst::Predicate AltPred = AltCI->getPredicate();
V1 = Builder.CreateCmp(AltPred, LHS, RHS);
} else {
- if (It != MinBWs.end()) {
- if (!MinBWs.contains(getOperandEntry(E, 0)))
- LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
- assert(LHS->getType() == VecTy && "Expected same type as operand.");
- if (auto *I = dyn_cast<Instruction>(LHS))
- LHS = propagateMetadata(I, E->Scalars);
- E->VectorizedValue = LHS;
- ++NumVectorInstructions;
- return LHS;
+ if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) {
+ unsigned SrcBWSz = DL->getTypeSizeInBits(
+ cast<VectorType>(LHS->getType())->getElementType());
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ if (BWSz <= SrcBWSz) {
+ if (BWSz < SrcBWSz)
+ LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
+ assert(LHS->getType() == VecTy && "Expected same type as operand.");
+ if (auto *I = dyn_cast<Instruction>(LHS))
+ LHS = propagateMetadata(I, E->Scalars);
+ E->VectorizedValue = LHS;
+ ++NumVectorInstructions;
+ return LHS;
+ }
}
V0 = Builder.CreateCast(
static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
diff --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll
new file mode 100644
index 000000000000..979d0ea66bac
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S --passes=slp-vectorizer -mtriple=systemz-unknown -mcpu=z15 < %s -slp-threshold=-10 | FileCheck %s
+
+define i32 @test(ptr %0, ptr %1) {
+; CHECK-LABEL: define i32 @test(
+; CHECK-SAME: ptr [[TMP0:%.*]], ptr [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr inttoptr (i64 32 to ptr), align 32
+; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP1]], align 8
+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i64 32
+; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 8
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0
+; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x i64> [[TMP7]], i64 [[TMP3]], i32 1
+; CHECK-NEXT: [[TMP9:%.*]] = icmp ne <2 x i64> [[TMP14]], zeroinitializer
+; CHECK-NEXT: [[TMP16:%.*]] = sext <2 x i1> [[TMP9]] to <2 x i8>
+; CHECK-NEXT: [[TMP11:%.*]] = zext <2 x i1> [[TMP9]] to <2 x i8>
+; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i8> [[TMP16]], <2 x i8> [[TMP11]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x i8> [[TMP12]], i32 0
+; CHECK-NEXT: [[DOTNEG:%.*]] = sext i8 [[TMP13]] to i32
+; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x i8> [[TMP12]], i32 1
+; CHECK-NEXT: [[TMP8:%.*]] = sext i8 [[TMP15]] to i32
+; CHECK-NEXT: [[TMP10:%.*]] = add nsw i32 [[DOTNEG]], [[TMP8]]
+; CHECK-NEXT: ret i32 [[TMP10]]
+;
+ %3 = load i64, ptr inttoptr (i64 32 to ptr), align 32
+ %4 = load ptr, ptr %1, align 8
+ %5 = getelementptr inbounds i8, ptr %4, i64 32
+ %6 = load i64, ptr %5, align 8
+ %7 = icmp ne i64 %3, 0
+ %8 = zext i1 %7 to i32
+ %9 = icmp ne i64 %6, 0
+ %.neg = sext i1 %9 to i32
+ %10 = add nsw i32 %.neg, %8
+ ret i32 %10
+}