diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 383 |
1 files changed, 231 insertions, 152 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 6e7dcb9..bdd26ac 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -306,10 +306,7 @@ static bool isCommutative(Instruction *I) { return Cmp->isCommutative(); if (auto *BO = dyn_cast<BinaryOperator>(I)) return BO->isCommutative(); - // TODO: This should check for generic Instruction::isCommutative(), but - // we need to confirm that the caller code correctly handles Intrinsics - // for example (does not have 2 operands). - return false; + return I->isCommutative(); } /// \returns inserting index of InsertElement or InsertValue instruction, @@ -658,6 +655,29 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, unsigned AltOpcode = Opcode; unsigned AltIndex = BaseIndex; + bool SwappedPredsCompatible = [&]() { + if (!IsCmpOp) + return false; + SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds; + UniquePreds.insert(BasePred); + UniqueNonSwappedPreds.insert(BasePred); + for (Value *V : VL) { + auto *I = dyn_cast<CmpInst>(V); + if (!I) + return false; + CmpInst::Predicate CurrentPred = I->getPredicate(); + CmpInst::Predicate SwappedCurrentPred = + CmpInst::getSwappedPredicate(CurrentPred); + UniqueNonSwappedPreds.insert(CurrentPred); + if (!UniquePreds.contains(CurrentPred) && + !UniquePreds.contains(SwappedCurrentPred)) + UniquePreds.insert(CurrentPred); + } + // Total number of predicates > 2, but if consider swapped predicates + // compatible only 2, consider swappable predicates as compatible opcodes, + // not alternate. + return UniqueNonSwappedPreds.size() > 2 && UniquePreds.size() == 2; + }(); // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). auto *IBase = cast<Instruction>(VL[BaseIndex]); @@ -710,7 +730,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, CmpInst::Predicate SwappedCurrentPred = CmpInst::getSwappedPredicate(CurrentPred); - if (E == 2 && + if ((E == 2 || SwappedPredsCompatible) && (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) continue; @@ -1087,7 +1107,7 @@ public: MinBWs.clear(); ReductionBitWidth = 0; CastMaxMinBWSizes.reset(); - TruncNodes.clear(); + ExtraBitWidthNodes.clear(); InstrElementSize.clear(); UserIgnoreList = nullptr; PostponedGathers.clear(); @@ -1952,6 +1972,9 @@ public: "Expected same number of lanes"); assert(isa<Instruction>(VL[0]) && "Expected instruction"); unsigned NumOperands = cast<Instruction>(VL[0])->getNumOperands(); + constexpr unsigned IntrinsicNumOperands = 2; + if (isa<IntrinsicInst>(VL[0])) + NumOperands = IntrinsicNumOperands; OpsVec.resize(NumOperands); unsigned NumLanes = VL.size(); for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { @@ -3397,10 +3420,11 @@ private: // immediates do not affect scheduler behavior this is considered // okay. auto *In = BundleMember->Inst; - assert(In && - (isa<ExtractValueInst, ExtractElementInst>(In) || - In->getNumOperands() == TE->getNumOperands()) && - "Missed TreeEntry operands?"); + assert( + In && + (isa<ExtractValueInst, ExtractElementInst, IntrinsicInst>(In) || + In->getNumOperands() == TE->getNumOperands()) && + "Missed TreeEntry operands?"); (void)In; // fake use to avoid build failure when assertions disabled for (unsigned OpIdx = 0, NumOperands = TE->getNumOperands(); @@ -3659,8 +3683,9 @@ private: /// type sizes, used in the tree. std::optional<std::pair<unsigned, unsigned>> CastMaxMinBWSizes; - /// Indices of the vectorized trunc nodes. - DenseSet<unsigned> TruncNodes; + /// Indices of the vectorized nodes, which supposed to be the roots of the new + /// bitwidth analysis attempt, like trunc, IToFP or ICmp. + DenseSet<unsigned> ExtraBitWidthNodes; }; } // end namespace slpvectorizer @@ -6588,7 +6613,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, PrevMaxBW), std::min<unsigned>(DL->getTypeSizeInBits(VL0->getType()), PrevMinBW)); - TruncNodes.insert(VectorizableTree.size()); + ExtraBitWidthNodes.insert(VectorizableTree.size() + 1); + } else if (ShuffleOrOp == Instruction::SIToFP || + ShuffleOrOp == Instruction::UIToFP) { + unsigned NumSignBits = + ComputeNumSignBits(VL0->getOperand(0), *DL, 0, AC, nullptr, DT); + if (auto *OpI = dyn_cast<Instruction>(VL0->getOperand(0))) { + APInt Mask = DB->getDemandedBits(OpI); + NumSignBits = std::max(NumSignBits, Mask.countl_zero()); + } + if (NumSignBits * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(0)->getType())) + ExtraBitWidthNodes.insert(VectorizableTree.size() + 1); } TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); @@ -6636,6 +6672,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); buildTree_rec(Right, Depth + 1, {TE, 1}); + if (ShuffleOrOp == Instruction::ICmp) { + unsigned NumSignBits0 = + ComputeNumSignBits(VL0->getOperand(0), *DL, 0, AC, nullptr, DT); + if (NumSignBits0 * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(0)->getType())) + ExtraBitWidthNodes.insert(getOperandEntry(TE, 0)->Idx); + unsigned NumSignBits1 = + ComputeNumSignBits(VL0->getOperand(1), *DL, 0, AC, nullptr, DT); + if (NumSignBits1 * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(1)->getType())) + ExtraBitWidthNodes.insert(getOperandEntry(TE, 1)->Idx); + } return; } case Instruction::Select: @@ -6775,6 +6823,33 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); + // Sort operands of the instructions so that each side is more likely to + // have the same opcode. + if (isCommutative(VL0)) { + ValueList Left, Right; + reorderInputsAccordingToOpcode(VL, Left, Right, *this); + TE->setOperand(0, Left); + TE->setOperand(1, Right); + SmallVector<ValueList> Operands; + for (unsigned I : seq<unsigned>(2, CI->arg_size())) { + Operands.emplace_back(); + if (isVectorIntrinsicWithScalarOpAtArg(ID, I)) + continue; + for (Value *V : VL) { + auto *CI2 = cast<CallInst>(V); + Operands.back().push_back(CI2->getArgOperand(I)); + } + TE->setOperand(I, Operands.back()); + } + buildTree_rec(Left, Depth + 1, {TE, 0}); + buildTree_rec(Right, Depth + 1, {TE, 1}); + for (unsigned I : seq<unsigned>(2, CI->arg_size())) { + if (Operands[I - 2].empty()) + continue; + buildTree_rec(Operands[I - 2], Depth + 1, {TE, I}); + } + return; + } TE->setOperandsInOrder(); for (unsigned I : seq<unsigned>(0, CI->arg_size())) { // For scalar operands no need to create an entry since no need to @@ -8447,7 +8522,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) ScalarTy = IE->getOperand(1)->getType(); } - if (!FixedVectorType::isValidElementType(ScalarTy)) + if (!isValidElementType(ScalarTy)) return InstructionCost::getInvalid(); auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; @@ -9063,25 +9138,35 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind, E->getAltOp()); } else { - Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType(); - Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType(); - auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size()); - auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size()); - if (It != MinBWs.end()) { - if (!MinBWs.contains(getOperandEntry(E, 0))) - VecCost = - TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, Src0Ty, - TTI::CastContextHint::None, CostKind); - LLVM_DEBUG({ - dbgs() << "SLP: alternate extension, which should be truncated.\n"; - E->dump(); - }); - return VecCost; + Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType(); + auto *SrcTy = FixedVectorType::get(SrcSclTy, VL.size()); + if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) { + auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + unsigned SrcBWSz = + DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType()); + if (SrcIt != MinBWs.end()) { + SrcBWSz = SrcIt->second.first; + SrcSclTy = IntegerType::get(SrcSclTy->getContext(), SrcBWSz); + SrcTy = FixedVectorType::get(SrcSclTy, VL.size()); + } + if (BWSz <= SrcBWSz) { + if (BWSz < SrcBWSz) + VecCost = + TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, SrcTy, + TTI::CastContextHint::None, CostKind); + LLVM_DEBUG({ + dbgs() + << "SLP: alternate extension, which should be truncated.\n"; + E->dump(); + }); + return VecCost; + } } - VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, Src0Ty, + VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, SrcTy, TTI::CastContextHint::None, CostKind); VecCost += - TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, + TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, SrcTy, TTI::CastContextHint::None, CostKind); } SmallVector<int> Mask; @@ -12591,15 +12676,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { CmpInst::Predicate AltPred = AltCI->getPredicate(); V1 = Builder.CreateCmp(AltPred, LHS, RHS); } else { - if (It != MinBWs.end()) { - if (!MinBWs.contains(getOperandEntry(E, 0))) - LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first); - assert(LHS->getType() == VecTy && "Expected same type as operand."); - if (auto *I = dyn_cast<Instruction>(LHS)) - LHS = propagateMetadata(I, E->Scalars); - E->VectorizedValue = LHS; - ++NumVectorInstructions; - return LHS; + if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) { + unsigned SrcBWSz = DL->getTypeSizeInBits( + cast<VectorType>(LHS->getType())->getElementType()); + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + if (BWSz <= SrcBWSz) { + if (BWSz < SrcBWSz) + LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first); + assert(LHS->getType() == VecTy && "Expected same type as operand."); + if (auto *I = dyn_cast<Instruction>(LHS)) + LHS = propagateMetadata(I, E->Scalars); + E->VectorizedValue = LHS; + ++NumVectorInstructions; + return LHS; + } } V0 = Builder.CreateCast( static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy); @@ -14051,6 +14141,16 @@ bool BoUpSLP::collectValuesToDemote( })) return FinalAnalysis(); + if (!all_of(I->users(), + [=](User *U) { + return getTreeEntry(U) || + (UserIgnoreList && UserIgnoreList->contains(U)) || + (U->getType()->isSized() && + DL->getTypeSizeInBits(U->getType()) <= BitWidth); + }) && + !IsPotentiallyTruncated(I, BitWidth)) + return false; + unsigned Start = 0; unsigned End = I->getNumOperands(); @@ -14097,25 +14197,52 @@ bool BoUpSLP::collectValuesToDemote( } return false; }; - bool NeedToExit = false; + auto TryProcessInstruction = + [&](Instruction *I, const TreeEntry &ITE, unsigned &BitWidth, + ArrayRef<Value *> Operands = std::nullopt, + function_ref<bool(unsigned, unsigned)> Checker = {}) { + if (Operands.empty()) { + if (!IsTruncRoot) + MaxDepthLevel = 1; + (void)IsPotentiallyTruncated(V, BitWidth); + } else { + // Several vectorized uses? Check if we can truncate it, otherwise - + // exit. + if (ITE.UserTreeIndices.size() > 1 && + !IsPotentiallyTruncated(I, BitWidth)) + return false; + bool NeedToExit = false; + if (Checker && !AttemptCheckBitwidth(Checker, NeedToExit)) + return false; + if (NeedToExit) + return true; + if (!ProcessOperands(Operands, NeedToExit)) + return false; + if (NeedToExit) + return true; + } + + ++MaxDepthLevel; + // Gather demoted constant operands. + for (unsigned Idx : seq<unsigned>(Start, End)) + if (isa<Constant>(I->getOperand(Idx))) + DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx); + // Record the value that we can demote. + ToDemote.push_back(V); + return IsProfitableToDemote; + }; switch (I->getOpcode()) { // We can always demote truncations and extensions. Since truncations can // seed additional demotion, we save the truncated value. case Instruction::Trunc: - if (!IsTruncRoot) - MaxDepthLevel = 1; if (IsProfitableToDemoteRoot) IsProfitableToDemote = true; - (void)IsPotentiallyTruncated(V, BitWidth); - break; + return TryProcessInstruction(I, *ITE, BitWidth); case Instruction::ZExt: case Instruction::SExt: - if (!IsTruncRoot) - MaxDepthLevel = 1; IsProfitableToDemote = true; - (void)IsPotentiallyTruncated(V, BitWidth); - break; + return TryProcessInstruction(I, *ITE, BitWidth); // We can demote certain binary operations if we can demote both of their // operands. @@ -14125,140 +14252,83 @@ bool BoUpSLP::collectValuesToDemote( case Instruction::And: case Instruction::Or: case Instruction::Xor: { - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; - if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) - return false; - break; + return TryProcessInstruction(I, *ITE, BitWidth, + {I->getOperand(0), I->getOperand(1)}); } case Instruction::Shl: { - // Several vectorized uses? Check if we can truncate it, otherwise - exit. - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; // If we are truncating the result of this SHL, and if it's a shift of an // inrange amount, we can always perform a SHL in a smaller type. - if (!AttemptCheckBitwidth( - [&](unsigned BitWidth, unsigned) { - KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); - return AmtKnownBits.getMaxValue().ult(BitWidth); - }, - NeedToExit)) - return false; - if (NeedToExit) - return true; - if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) - return false; - break; + auto ShlChecker = [&](unsigned BitWidth, unsigned) { + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + return AmtKnownBits.getMaxValue().ult(BitWidth); + }; + return TryProcessInstruction( + I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, ShlChecker); } case Instruction::LShr: { - // Several vectorized uses? Check if we can truncate it, otherwise - exit. - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - if (!AttemptCheckBitwidth( - [&](unsigned BitWidth, unsigned OrigBitWidth) { - KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); - APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); - return AmtKnownBits.getMaxValue().ult(BitWidth) && - MaskedValueIsZero(I->getOperand(0), ShiftedBits, - SimplifyQuery(*DL)); - }, - NeedToExit)) - return false; - if (NeedToExit) - return true; - if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) - return false; - break; + auto LShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return AmtKnownBits.getMaxValue().ult(BitWidth) && + MaskedValueIsZero(I->getOperand(0), ShiftedBits, + SimplifyQuery(*DL)); + }; + return TryProcessInstruction( + I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, LShrChecker); } case Instruction::AShr: { - // Several vectorized uses? Check if we can truncate it, otherwise - exit. - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; // If this is a truncate of an arithmetic shr, we can truncate it to a // smaller ashr iff we know that all the bits from the sign bit of the // original type and the sign bit of the truncate type are similar. - if (!AttemptCheckBitwidth( - [&](unsigned BitWidth, unsigned OrigBitWidth) { - KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); - unsigned ShiftedBits = OrigBitWidth - BitWidth; - return AmtKnownBits.getMaxValue().ult(BitWidth) && - ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0, - AC, nullptr, DT); - }, - NeedToExit)) - return false; - if (NeedToExit) - return true; - if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) - return false; - break; + auto AShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + unsigned ShiftedBits = OrigBitWidth - BitWidth; + return AmtKnownBits.getMaxValue().ult(BitWidth) && + ShiftedBits < + ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT); + }; + return TryProcessInstruction( + I, *ITE, BitWidth, {I->getOperand(0), I->getOperand(1)}, AShrChecker); } case Instruction::UDiv: case Instruction::URem: { - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; // UDiv and URem can be truncated if all the truncated bits are zero. - if (!AttemptCheckBitwidth( - [&](unsigned BitWidth, unsigned OrigBitWidth) { - assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); - APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); - return MaskedValueIsZero(I->getOperand(0), Mask, - SimplifyQuery(*DL)) && - MaskedValueIsZero(I->getOperand(1), Mask, - SimplifyQuery(*DL)); - }, - NeedToExit)) - return false; - if (NeedToExit) - return true; - if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) - return false; - break; + auto Checker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) && + MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)); + }; + return TryProcessInstruction(I, *ITE, BitWidth, + {I->getOperand(0), I->getOperand(1)}, Checker); } // We can demote selects if we can demote their true and false values. case Instruction::Select: { - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; Start = 1; auto *SI = cast<SelectInst>(I); - if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit)) - return false; - break; + return TryProcessInstruction(I, *ITE, BitWidth, + {SI->getTrueValue(), SI->getFalseValue()}); } // We can demote phis if we can demote all their incoming operands. Note that // we don't need to worry about cycles since we ensure single use above. case Instruction::PHI: { PHINode *PN = cast<PHINode>(I); - if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) - return false; SmallVector<Value *> Ops(PN->incoming_values().begin(), PN->incoming_values().end()); - if (!ProcessOperands(Ops, NeedToExit)) - return false; - break; + return TryProcessInstruction(I, *ITE, BitWidth, Ops); } // Otherwise, conservatively give up. default: - MaxDepthLevel = 1; - return FinalAnalysis(); + break; } - if (NeedToExit) - return true; - - ++MaxDepthLevel; - // Gather demoted constant operands. - for (unsigned Idx : seq<unsigned>(Start, End)) - if (isa<Constant>(I->getOperand(Idx))) - DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx); - // Record the value that we can demote. - ToDemote.push_back(V); - return IsProfitableToDemote; + MaxDepthLevel = 1; + return FinalAnalysis(); } void BoUpSLP::computeMinimumValueSizes() { @@ -14266,7 +14336,8 @@ void BoUpSLP::computeMinimumValueSizes() { bool IsStoreOrInsertElt = VectorizableTree.front()->getOpcode() == Instruction::Store || VectorizableTree.front()->getOpcode() == Instruction::InsertElement; - if ((IsStoreOrInsertElt || UserIgnoreList) && TruncNodes.size() <= 1 && + if ((IsStoreOrInsertElt || UserIgnoreList) && + ExtraBitWidthNodes.size() <= 1 && (!CastMaxMinBWSizes || CastMaxMinBWSizes->second == 0 || CastMaxMinBWSizes->first / CastMaxMinBWSizes->second <= 2)) return; @@ -14309,7 +14380,8 @@ void BoUpSLP::computeMinimumValueSizes() { DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts; auto ComputeMaxBitWidth = [&](ArrayRef<Value *> TreeRoot, unsigned VF, bool IsTopRoot, bool IsProfitableToDemoteRoot, - unsigned Opcode, unsigned Limit, bool IsTruncRoot) { + unsigned Opcode, unsigned Limit, + bool IsTruncRoot) { ToDemote.clear(); auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType()); if (!TreeRootIT || !Opcode) @@ -14469,16 +14541,23 @@ void BoUpSLP::computeMinimumValueSizes() { IsTopRoot = false; IsProfitableToDemoteRoot = true; - if (TruncNodes.empty()) { + if (ExtraBitWidthNodes.empty()) { NodeIdx = VectorizableTree.size(); } else { unsigned NewIdx = 0; do { - NewIdx = *TruncNodes.begin() + 1; - TruncNodes.erase(TruncNodes.begin()); - } while (NewIdx <= NodeIdx && !TruncNodes.empty()); + NewIdx = *ExtraBitWidthNodes.begin(); + ExtraBitWidthNodes.erase(ExtraBitWidthNodes.begin()); + } while (NewIdx <= NodeIdx && !ExtraBitWidthNodes.empty()); NodeIdx = NewIdx; - IsTruncRoot = true; + IsTruncRoot = + NodeIdx < VectorizableTree.size() && + any_of(VectorizableTree[NodeIdx]->UserTreeIndices, + [](const EdgeInfo &EI) { + return EI.EdgeIdx == 0 && + EI.UserTE->getOpcode() == Instruction::Trunc && + !EI.UserTE->isAltShuffle(); + }); } // If the maximum bit width we compute is less than the with of the roots' |