aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/VectorCombine.cpp143
1 files changed, 102 insertions, 41 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 5b9fe1c..7fa1b433 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1703,9 +1703,44 @@ generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
return NItem;
}
+/// Detect concat of multiple values into a vector
+static bool isFreeConcat(ArrayRef<InstLane> Item,
+ const TargetTransformInfo &TTI) {
+ auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
+ unsigned NumElts = Ty->getNumElements();
+ if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
+ return false;
+
+ // Check that the concat is free, usually meaning that the type will be split
+ // during legalization.
+ SmallVector<int, 16> ConcatMask(NumElts * 2);
+ std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
+ if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask,
+ TTI::TCK_RecipThroughput) != 0)
+ return false;
+
+ unsigned NumSlices = Item.size() / NumElts;
+ // Currently we generate a tree of shuffles for the concats, which limits us
+ // to a power2.
+ if (!isPowerOf2_32(NumSlices))
+ return false;
+ for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
+ Use *SliceV = Item[Slice * NumElts].first;
+ if (!SliceV || SliceV->get()->getType() != Ty)
+ return false;
+ for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
+ auto [V, Lane] = Item[Slice * NumElts + Elt];
+ if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
+ return false;
+ }
+ }
+ return true;
+}
+
static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
const SmallPtrSet<Use *, 4> &IdentityLeafs,
const SmallPtrSet<Use *, 4> &SplatLeafs,
+ const SmallPtrSet<Use *, 4> &ConcatLeafs,
IRBuilder<> &Builder) {
auto [FrontU, FrontLane] = Item.front();
@@ -1713,13 +1748,28 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
return FrontU->get();
}
if (SplatLeafs.contains(FrontU)) {
- if (auto *ILI = dyn_cast<Instruction>(FrontU))
- Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
- else if (auto *Arg = dyn_cast<Argument>(FrontU))
- Builder.SetInsertPointPastAllocas(Arg->getParent());
SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
return Builder.CreateShuffleVector(FrontU->get(), Mask);
}
+ if (ConcatLeafs.contains(FrontU)) {
+ unsigned NumElts =
+ cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
+ SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
+ for (unsigned S = 0; S < Values.size(); ++S)
+ Values[S] = Item[S * NumElts].first->get();
+
+ while (Values.size() > 1) {
+ NumElts *= 2;
+ SmallVector<int, 16> Mask(NumElts, 0);
+ std::iota(Mask.begin(), Mask.end(), 0);
+ SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
+ for (unsigned S = 0; S < NewValues.size(); ++S)
+ NewValues[S] =
+ Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
+ Values = NewValues;
+ }
+ return Values[0];
+ }
auto *I = cast<Instruction>(FrontU->get());
auto *II = dyn_cast<IntrinsicInst>(I);
@@ -1730,8 +1780,9 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
Ops[Idx] = II->getOperand(Idx);
continue;
}
- Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
- Ty, IdentityLeafs, SplatLeafs, Builder);
+ Ops[Idx] =
+ generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty,
+ IdentityLeafs, SplatLeafs, ConcatLeafs, Builder);
}
SmallVector<Value *, 8> ValueList;
@@ -1739,7 +1790,6 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
if (Lane.first)
ValueList.push_back(Lane.first->get());
- Builder.SetInsertPoint(I);
Type *DstTy =
FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
if (auto *BI = dyn_cast<BinaryOperator>(I)) {
@@ -1790,7 +1840,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
SmallVector<SmallVector<InstLane>> Worklist;
Worklist.push_back(Start);
- SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs;
+ SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
unsigned NumVisited = 0;
while (!Worklist.empty()) {
@@ -1839,7 +1889,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
// We need each element to be the same type of value, and check that each
// element has a single use.
- if (!all_of(drop_begin(Item), [Item](InstLane IL) {
+ if (all_of(drop_begin(Item), [Item](InstLane IL) {
Value *FrontV = Item.front().first->get();
if (!IL.first)
return true;
@@ -1860,40 +1910,49 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
return !II || (isa<IntrinsicInst>(FrontV) &&
II->getIntrinsicID() ==
cast<IntrinsicInst>(FrontV)->getIntrinsicID());
- }))
- return false;
-
- // Check the operator is one that we support. We exclude div/rem in case
- // they hit UB from poison lanes.
- if ((isa<BinaryOperator>(FrontU) &&
- !cast<BinaryOperator>(FrontU)->isIntDivRem()) ||
- isa<CmpInst>(FrontU)) {
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
- } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) {
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
- } else if (isa<SelectInst>(FrontU)) {
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
- } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
- II && isTriviallyVectorizable(II->getIntrinsicID())) {
- for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
- if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
- if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
- Value *FrontV = Item.front().first->get();
- Use *U = IL.first;
- return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
- cast<Instruction>(FrontV)->getOperand(Op));
- }))
- return false;
- continue;
+ })) {
+ // Check the operator is one that we support.
+ if (isa<BinaryOperator, CmpInst>(FrontU)) {
+ // We exclude div/rem in case they hit UB from poison lanes.
+ if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
+ BO && BO->isIntDivRem())
+ return false;
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+ continue;
+ } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) {
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+ continue;
+ } else if (isa<SelectInst>(FrontU)) {
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
+ continue;
+ } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
+ II && isTriviallyVectorizable(II->getIntrinsicID())) {
+ for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
+ if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
+ if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
+ Value *FrontV = Item.front().first->get();
+ Use *U = IL.first;
+ return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
+ cast<Instruction>(FrontV)->getOperand(Op));
+ }))
+ return false;
+ continue;
+ }
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
}
- Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
+ continue;
}
- } else {
- return false;
}
+
+ if (isFreeConcat(Item, TTI)) {
+ ConcatLeafs.insert(FrontU);
+ continue;
+ }
+
+ return false;
}
if (NumVisited <= 1)
@@ -1901,7 +1960,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
// If we got this far, we know the shuffles are superfluous and can be
// removed. Scan through again and generate the new tree of instructions.
- Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
+ Builder.SetInsertPoint(&I);
+ Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
+ ConcatLeafs, Builder);
replaceValue(I, *V);
return true;
}