diff options
author | Alexey Bataev <a.bataev@outlook.com> | 2024-04-05 14:29:26 -0400 |
---|---|---|
committer | Alexey Bataev <a.bataev@outlook.com> | 2024-04-08 08:32:35 -0700 |
commit | 4a1c53f9fabbc18d436dcd4d5b572b82656fbbf9 (patch) | |
tree | 8cf50d2764bf11e014e20dfa5a6146cf2b26b577 | |
parent | ed1b24bf8b76871ab00f365d8dc066b3f7f76202 (diff) |
[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics.
https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations.
Reviewers: RKSimon
Reviewed By: RKSimon
Pull Request: https://github.com/llvm/llvm-project/pull/86135
4 files changed, 117 insertions, 27 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 6a662b2791bd..b41229e25a34 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7056,19 +7056,16 @@ bool BoUpSLP::areAllUsersVectorized( static std::pair<InstructionCost, InstructionCost> getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, - TargetTransformInfo *TTI, TargetLibraryInfo *TLI) { + TargetTransformInfo *TTI, TargetLibraryInfo *TLI, + ArrayRef<Type *> ArgTys) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. - SmallVector<Type *, 4> VecTys; - for (Use &Arg : CI->args()) - VecTys.push_back( - FixedVectorType::get(Arg->getType(), VecTy->getNumElements())); FastMathFlags FMF; if (auto *FPCI = dyn_cast<FPMathOperator>(CI)) FMF = FPCI->getFastMathFlags(); SmallVector<const Value *> Arguments(CI->args()); - IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF, + IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, ArgTys, FMF, dyn_cast<IntrinsicInst>(CI)); auto IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput); @@ -7081,8 +7078,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, if (!CI->isNoBuiltin() && VecFunc) { // Calculate the cost of the vector library call. // If the corresponding vector call is cheaper, return its cost. - LibCost = TTI->getCallInstrCost(nullptr, VecTy, VecTys, - TTI::TCK_RecipThroughput); + LibCost = + TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput); } return {IntrinsicCost, LibCost}; } @@ -8508,6 +8505,30 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const { return TTI::CastContextHint::None; } +/// Builds the arguments types vector for the given call instruction with the +/// given \p ID for the specified vector factor. +static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI, + const Intrinsic::ID ID, + const unsigned VF, + unsigned MinBW) { + SmallVector<Type *> ArgTys; + for (auto [Idx, Arg] : enumerate(CI->args())) { + if (ID != Intrinsic::not_intrinsic) { + if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) { + ArgTys.push_back(Arg->getType()); + continue; + } + if (MinBW > 0) { + ArgTys.push_back(FixedVectorType::get( + IntegerType::get(CI->getContext(), MinBW), VF)); + continue; + } + } + ArgTys.push_back(FixedVectorType::get(Arg->getType(), VF)); + } + return ArgTys; +} + InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, SmallPtrSetImpl<Value *> &CheckedExtracts) { @@ -9074,7 +9095,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, }; auto GetVectorCost = [=](InstructionCost CommonCost) { auto *CI = cast<CallInst>(VL0); - auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + SmallVector<Type *> ArgTys = + buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(), + It != MinBWs.end() ? It->second.first : 0); + auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys); return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); @@ -12548,7 +12573,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); + SmallVector<Type *> ArgTys = + buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(), + It != MinBWs.end() ? It->second.first : 0); + auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys); bool UseIntrinsic = ID != Intrinsic::not_intrinsic && VecCallCosts.first <= VecCallCosts.second; @@ -12557,8 +12585,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { SmallVector<Type *, 2> TysForDecl; // Add return type if intrinsic is overloaded on it. if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) - TysForDecl.push_back( - FixedVectorType::get(CI->getType(), E->Scalars.size())); + TysForDecl.push_back(VecTy); auto *CEI = cast<CallInst>(VL0); for (unsigned I : seq<unsigned>(0, CI->arg_size())) { ValueList OpVL; @@ -12566,7 +12593,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { // vectorized. if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) { ScalarArg = CEI->getArgOperand(I); - OpVecs.push_back(CEI->getArgOperand(I)); + // if decided to reduce bitwidth of abs intrinsic, it second argument + // must be set false (do not return poison, if value issigned min). + if (ID == Intrinsic::abs && It != MinBWs.end() && + It->second.first < DL->getTypeSizeInBits(CEI->getType())) + ScalarArg = Builder.getFalse(); + OpVecs.push_back(ScalarArg); if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) TysForDecl.push_back(ScalarArg->getType()); continue; @@ -12579,10 +12611,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { } ScalarArg = CEI->getArgOperand(I); if (cast<VectorType>(OpVec->getType())->getElementType() != - ScalarArg->getType()) { + ScalarArg->getType() && + It == MinBWs.end()) { auto *CastTy = FixedVectorType::get(ScalarArg->getType(), VecTy->getNumElements()); OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I)); + } else if (It != MinBWs.end()) { + OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I)); } LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); @@ -14326,6 +14361,62 @@ bool BoUpSLP::collectValuesToDemote( return TryProcessInstruction(I, *ITE, BitWidth, Ops); } + case Instruction::Call: { + auto *IC = dyn_cast<IntrinsicInst>(I); + if (!IC) + break; + Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI); + if (ID != Intrinsic::abs && ID != Intrinsic::smin && + ID != Intrinsic::smax && ID != Intrinsic::umin && ID != Intrinsic::umax) + break; + SmallVector<Value *> Operands(1, I->getOperand(0)); + function_ref<bool(unsigned, unsigned)> CallChecker; + auto CompChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + if (ID == Intrinsic::umin || ID == Intrinsic::umax) { + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) && + MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)); + } + assert((ID == Intrinsic::smin || ID == Intrinsic::smax) && + "Expected min/max intrinsics only."); + unsigned SignBits = OrigBitWidth - BitWidth; + return SignBits <= ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, + nullptr, DT) && + SignBits <= + ComputeNumSignBits(I->getOperand(1), *DL, 0, AC, nullptr, DT); + }; + End = 1; + if (ID != Intrinsic::abs) { + Operands.push_back(I->getOperand(1)); + End = 2; + CallChecker = CompChecker; + } + InstructionCost BestCost = + std::numeric_limits<InstructionCost::CostType>::max(); + unsigned BestBitWidth = BitWidth; + unsigned VF = ITE->Scalars.size(); + // Choose the best bitwidth based on cost estimations. + auto Checker = [&](unsigned BitWidth, unsigned) { + unsigned MinBW = PowerOf2Ceil(BitWidth); + SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(IC, ID, VF, MinBW); + auto VecCallCosts = getVectorCallCosts( + IC, + FixedVectorType::get(IntegerType::get(IC->getContext(), MinBW), VF), + TTI, TLI, ArgTys); + InstructionCost Cost = std::min(VecCallCosts.first, VecCallCosts.second); + if (Cost < BestCost) { + BestCost = Cost; + BestBitWidth = BitWidth; + } + return false; + }; + [[maybe_unused]] bool NeedToExit; + (void)AttemptCheckBitwidth(Checker, NeedToExit); + BitWidth = BestBitWidth; + return TryProcessInstruction(I, *ITE, BitWidth, Operands, CallChecker); + } + // Otherwise, conservatively give up. default: break; diff --git a/llvm/test/Transforms/SLPVectorizer/X86/call-arg-reduced-by-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/call-arg-reduced-by-minbitwidth.ll index 27c9655f94d3..82966124d3ba 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/call-arg-reduced-by-minbitwidth.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/call-arg-reduced-by-minbitwidth.ll @@ -11,9 +11,7 @@ define void @test(ptr %0, i8 %1, i1 %cmp12.i) { ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <8 x i8> [[TMP4]], <8 x i8> poison, <8 x i32> zeroinitializer ; CHECK-NEXT: br label [[PRE:%.*]] ; CHECK: pre: -; CHECK-NEXT: [[TMP6:%.*]] = zext <8 x i8> [[TMP5]] to <8 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = call <8 x i32> @llvm.umax.v8i32(<8 x i32> [[TMP6]], <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>) -; CHECK-NEXT: [[TMP8:%.*]] = trunc <8 x i32> [[TMP7]] to <8 x i8> +; CHECK-NEXT: [[TMP8:%.*]] = call <8 x i8> @llvm.umax.v8i8(<8 x i8> [[TMP5]], <8 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>) ; CHECK-NEXT: [[TMP9:%.*]] = add <8 x i8> [[TMP8]], <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1> ; CHECK-NEXT: [[TMP10:%.*]] = select <8 x i1> [[TMP3]], <8 x i8> [[TMP9]], <8 x i8> [[TMP5]] ; CHECK-NEXT: store <8 x i8> [[TMP10]], ptr [[TMP0]], align 1 diff --git a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll index a05d4fdd6315..9fa88084aaa0 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll @@ -5,12 +5,14 @@ define void @test() { ; CHECK-LABEL: define void @test( ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer) -; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]] -; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer -; CHECK-NEXT: [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer) +; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1 +; CHECK-NEXT: [[ADD:%.*]] = zext i2 [[TMP3]] to i32 ; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[ADD]], 0 -; CHECK-NEXT: [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0 +; CHECK-NEXT: [[ADD45:%.*]] = zext i2 [[TMP5]] to i32 ; CHECK-NEXT: [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]] ; CHECK-NEXT: [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64 ; CHECK-NEXT: [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]] diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll index e8b854b7cea6..df7312e3d2b5 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll @@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) { ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison> ; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison> ; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5> -; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16> ; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison> ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5> -; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32> -; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]] -; CHECK-NEXT: [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true) -; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16> +; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16> +; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]] +; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 false) ; CHECK-NEXT: store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2 ; CHECK-NEXT: ret i32 undef ; |