diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 143 |
1 files changed, 102 insertions, 41 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 5b9fe1c..7fa1b433 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1703,9 +1703,44 @@ generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { return NItem; } +/// Detect concat of multiple values into a vector +static bool isFreeConcat(ArrayRef<InstLane> Item, + const TargetTransformInfo &TTI) { + auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); + unsigned NumElts = Ty->getNumElements(); + if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) + return false; + + // Check that the concat is free, usually meaning that the type will be split + // during legalization. + SmallVector<int, 16> ConcatMask(NumElts * 2); + std::iota(ConcatMask.begin(), ConcatMask.end(), 0); + if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, + TTI::TCK_RecipThroughput) != 0) + return false; + + unsigned NumSlices = Item.size() / NumElts; + // Currently we generate a tree of shuffles for the concats, which limits us + // to a power2. + if (!isPowerOf2_32(NumSlices)) + return false; + for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { + Use *SliceV = Item[Slice * NumElts].first; + if (!SliceV || SliceV->get()->getType() != Ty) + return false; + for (unsigned Elt = 0; Elt < NumElts; ++Elt) { + auto [V, Lane] = Item[Slice * NumElts + Elt]; + if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) + return false; + } + } + return true; +} + static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, const SmallPtrSet<Use *, 4> &IdentityLeafs, const SmallPtrSet<Use *, 4> &SplatLeafs, + const SmallPtrSet<Use *, 4> &ConcatLeafs, IRBuilder<> &Builder) { auto [FrontU, FrontLane] = Item.front(); @@ -1713,13 +1748,28 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, return FrontU->get(); } if (SplatLeafs.contains(FrontU)) { - if (auto *ILI = dyn_cast<Instruction>(FrontU)) - Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef()); - else if (auto *Arg = dyn_cast<Argument>(FrontU)) - Builder.SetInsertPointPastAllocas(Arg->getParent()); SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); return Builder.CreateShuffleVector(FrontU->get(), Mask); } + if (ConcatLeafs.contains(FrontU)) { + unsigned NumElts = + cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); + SmallVector<Value *> Values(Item.size() / NumElts, nullptr); + for (unsigned S = 0; S < Values.size(); ++S) + Values[S] = Item[S * NumElts].first->get(); + + while (Values.size() > 1) { + NumElts *= 2; + SmallVector<int, 16> Mask(NumElts, 0); + std::iota(Mask.begin(), Mask.end(), 0); + SmallVector<Value *> NewValues(Values.size() / 2, nullptr); + for (unsigned S = 0; S < NewValues.size(); ++S) + NewValues[S] = + Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); + Values = NewValues; + } + return Values[0]; + } auto *I = cast<Instruction>(FrontU->get()); auto *II = dyn_cast<IntrinsicInst>(I); @@ -1730,8 +1780,9 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, Ops[Idx] = II->getOperand(Idx); continue; } - Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), - Ty, IdentityLeafs, SplatLeafs, Builder); + Ops[Idx] = + generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty, + IdentityLeafs, SplatLeafs, ConcatLeafs, Builder); } SmallVector<Value *, 8> ValueList; @@ -1739,7 +1790,6 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, if (Lane.first) ValueList.push_back(Lane.first->get()); - Builder.SetInsertPoint(I); Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); if (auto *BI = dyn_cast<BinaryOperator>(I)) { @@ -1790,7 +1840,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { SmallVector<SmallVector<InstLane>> Worklist; Worklist.push_back(Start); - SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs; + SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; unsigned NumVisited = 0; while (!Worklist.empty()) { @@ -1839,7 +1889,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { // We need each element to be the same type of value, and check that each // element has a single use. - if (!all_of(drop_begin(Item), [Item](InstLane IL) { + if (all_of(drop_begin(Item), [Item](InstLane IL) { Value *FrontV = Item.front().first->get(); if (!IL.first) return true; @@ -1860,40 +1910,49 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { return !II || (isa<IntrinsicInst>(FrontV) && II->getIntrinsicID() == cast<IntrinsicInst>(FrontV)->getIntrinsicID()); - })) - return false; - - // Check the operator is one that we support. We exclude div/rem in case - // they hit UB from poison lanes. - if ((isa<BinaryOperator>(FrontU) && - !cast<BinaryOperator>(FrontU)->isIntDivRem()) || - isa<CmpInst>(FrontU)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); - } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - } else if (isa<SelectInst>(FrontU)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); - } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); - II && isTriviallyVectorizable(II->getIntrinsicID())) { - for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { - if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) { - if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { - Value *FrontV = Item.front().first->get(); - Use *U = IL.first; - return !U || (cast<Instruction>(U->get())->getOperand(Op) == - cast<Instruction>(FrontV)->getOperand(Op)); - })) - return false; - continue; + })) { + // Check the operator is one that we support. + if (isa<BinaryOperator, CmpInst>(FrontU)) { + // We exclude div/rem in case they hit UB from poison lanes. + if (auto *BO = dyn_cast<BinaryOperator>(FrontU); + BO && BO->isIntDivRem()) + return false; + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + continue; + } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + continue; + } else if (isa<SelectInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); + continue; + } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); + II && isTriviallyVectorizable(II->getIntrinsicID())) { + for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { + if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) { + if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { + Value *FrontV = Item.front().first->get(); + Use *U = IL.first; + return !U || (cast<Instruction>(U->get())->getOperand(Op) == + cast<Instruction>(FrontV)->getOperand(Op)); + })) + return false; + continue; + } + Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); } - Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); + continue; } - } else { - return false; } + + if (isFreeConcat(Item, TTI)) { + ConcatLeafs.insert(FrontU); + continue; + } + + return false; } if (NumVisited <= 1) @@ -1901,7 +1960,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { // If we got this far, we know the shuffles are superfluous and can be // removed. Scan through again and generate the new tree of instructions. - Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder); + Builder.SetInsertPoint(&I); + Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, + ConcatLeafs, Builder); replaceValue(I, *V); return true; } |