summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/VectorCombine.cpp71
1 files changed, 71 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index af5e7c9bc385..3738220b4f81 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -112,6 +112,7 @@ private:
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
+ bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
@@ -1432,6 +1433,75 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
return true;
}
+/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
+/// into "castop (shuffle)".
+bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
+ Value *V0, *V1;
+ ArrayRef<int> Mask;
+ if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
+ m_Mask(Mask))))
+ return false;
+
+ auto *C0 = dyn_cast<CastInst>(V0);
+ auto *C1 = dyn_cast<CastInst>(V1);
+ if (!C0 || !C1)
+ return false;
+
+ Instruction::CastOps Opcode = C0->getOpcode();
+ if (Opcode == Instruction::BitCast || C0->getSrcTy() != C1->getSrcTy())
+ return false;
+
+ // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
+ if (Opcode != C1->getOpcode()) {
+ if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
+ Opcode = Instruction::SExt;
+ else
+ return false;
+ }
+
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+ auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
+ auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
+ if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
+ return false;
+ assert(CastDstTy->getElementCount() == CastSrcTy->getElementCount() &&
+ "Unexpected src/dst element counts");
+
+ auto *NewShuffleDstTy =
+ FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size());
+
+ // Try to replace a castop with a shuffle if the shuffle is not costly.
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ InstructionCost OldCost =
+ TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
+ TTI::CastContextHint::None, CostKind) +
+ TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
+ TTI::CastContextHint::None, CostKind);
+ OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
+ CastDstTy, Mask, CostKind);
+
+ InstructionCost NewCost = TTI.getShuffleCost(
+ TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, Mask, CostKind);
+ NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
+ TTI::CastContextHint::None, CostKind);
+ if (NewCost > OldCost)
+ return false;
+
+ Value *Shuf =
+ Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask);
+ Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
+
+ // Intersect flags from the old casts.
+ if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
+ NewInst->copyIRFlags(C0);
+ NewInst->andIRFlags(C1);
+ }
+
+ replaceValue(I, *Cast);
+ return true;
+}
+
/// Given a commutative reduction, the order of the input lanes does not alter
/// the results. We can use this to remove certain shuffles feeding the
/// reduction, removing the need to shuffle at all.
@@ -1986,6 +2056,7 @@ bool VectorCombine::run() {
break;
case Instruction::ShuffleVector:
MadeChange |= foldShuffleOfBinops(I);
+ MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldSelectShuffle(I);
break;
case Instruction::BitCast: