diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index af5e7c9bc385..3738220b4f81 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -112,6 +112,7 @@ private: bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); bool foldShuffleOfBinops(Instruction &I); + bool foldShuffleOfCastops(Instruction &I); bool foldShuffleFromReductions(Instruction &I); bool foldTruncFromReductions(Instruction &I); bool foldSelectShuffle(Instruction &I, bool FromReduction = false); @@ -1432,6 +1433,75 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { return true; } +/// Try to convert "shuffle (castop), (castop)" with a shared castop operand +/// into "castop (shuffle)". +bool VectorCombine::foldShuffleOfCastops(Instruction &I) { + Value *V0, *V1; + ArrayRef<int> Mask; + if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)), + m_Mask(Mask)))) + return false; + + auto *C0 = dyn_cast<CastInst>(V0); + auto *C1 = dyn_cast<CastInst>(V1); + if (!C0 || !C1) + return false; + + Instruction::CastOps Opcode = C0->getOpcode(); + if (Opcode == Instruction::BitCast || C0->getSrcTy() != C1->getSrcTy()) + return false; + + // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds. + if (Opcode != C1->getOpcode()) { + if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value()))) + Opcode = Instruction::SExt; + else + return false; + } + + auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); + auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy()); + auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy()); + if (!ShuffleDstTy || !CastDstTy || !CastSrcTy) + return false; + assert(CastDstTy->getElementCount() == CastSrcTy->getElementCount() && + "Unexpected src/dst element counts"); + + auto *NewShuffleDstTy = + FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size()); + + // Try to replace a castop with a shuffle if the shuffle is not costly. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + InstructionCost OldCost = + TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy, + TTI::CastContextHint::None, CostKind) + + TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy, + TTI::CastContextHint::None, CostKind); + OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + CastDstTy, Mask, CostKind); + + InstructionCost NewCost = TTI.getShuffleCost( + TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, Mask, CostKind); + NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy, + TTI::CastContextHint::None, CostKind); + if (NewCost > OldCost) + return false; + + Value *Shuf = + Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask); + Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy); + + // Intersect flags from the old casts. + if (auto *NewInst = dyn_cast<Instruction>(Cast)) { + NewInst->copyIRFlags(C0); + NewInst->andIRFlags(C1); + } + + replaceValue(I, *Cast); + return true; +} + /// Given a commutative reduction, the order of the input lanes does not alter /// the results. We can use this to remove certain shuffles feeding the /// reduction, removing the need to shuffle at all. @@ -1986,6 +2056,7 @@ bool VectorCombine::run() { break; case Instruction::ShuffleVector: MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldShuffleOfCastops(I); MadeChange |= foldSelectShuffle(I); break; case Instruction::BitCast: |