diff options
| author | Alexey Bataev <a.bataev@outlook.com> | 2026-02-10 08:59:44 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-10 08:59:44 -0500 |
| commit | 70aebae2a13114f4e3d5e2460c052d8f3de295be (patch) | |
| tree | 7e40120d8552cd782a37aa65561adf283f424167 /llvm/lib/Transforms/Vectorize | |
| parent | f8d5a003faa7567c2bac0b064ea6616f2e892467 (diff) | |
| download | llvm-70aebae2a13114f4e3d5e2460c052d8f3de295be.tar.gz llvm-70aebae2a13114f4e3d5e2460c052d8f3de295be.tar.bz2 llvm-70aebae2a13114f4e3d5e2460c052d8f3de295be.zip | |
[SLP]Support for zext i1 %x modeling as select %x, 1, 0
Model zext i1 %x to in as select i1 %x, in 1, in 0 in case, if there are
other select instructions, which can be combined into a bundle.
Fixes #178403
Reviewers: hiraditya, RKSimon
Pull Request: https://github.com/llvm/llvm-project/pull/180635
Diffstat (limited to 'llvm/lib/Transforms/Vectorize')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 89 |
1 files changed, 82 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index f89c22fafcf0..f18b6fc4dd95 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5761,12 +5761,15 @@ private: // moment are extracts where their second (immediate) operand is // not added. Since immediates do not affect scheduler behavior // this is considered okay. - assert(In && - (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) || - In->getNumOperands() == - Bundle->getTreeEntry()->getNumOperands() || - Bundle->getTreeEntry()->isCopyableElement(In)) && - "Missed TreeEntry operands?"); + assert( + In && + (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) || + In->getNumOperands() == + Bundle->getTreeEntry()->getNumOperands() || + (isa<ZExtInst>(In) && Bundle->getTreeEntry()->getOpcode() == + Instruction::Select) || + Bundle->getTreeEntry()->isCopyableElement(In)) && + "Missed TreeEntry operands?"); // Count the number of unique phi nodes, which are the parent for // parent entry, and exit, if all the unique phis are processed. @@ -11345,7 +11348,6 @@ class InstructionsCompatibilityAnalysis { case Instruction::BitCast: case Instruction::ICmp: case Instruction::FCmp: - case Instruction::Select: case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: @@ -11381,6 +11383,30 @@ class InstructionsCompatibilityAnalysis { Ops[Idx] = ConvertedOps[OpIdx]; } return; + case Instruction::Select: + Operands.assign(VL0->getNumOperands(), {VL.size(), nullptr}); + for (auto [Idx, V] : enumerate(VL)) { + auto *I = dyn_cast<Instruction>(V); + if (!I) { + for (auto [OpIdx, Ops] : enumerate(Operands)) + Ops[Idx] = PoisonValue::get(VL0->getOperand(OpIdx)->getType()); + continue; + } + if (isa<ZExtInst>(I)) { + // Special case for select + zext i1 to avoid explosion of different + // types. We want to keep the condition as i1 to be able to match + // different selects together and reuse the vectorized condition + // rather than trying to gather it. + Operands[0][Idx] = I->getOperand(0); + Operands[1][Idx] = ConstantInt::get(I->getType(), 1); + Operands[2][Idx] = ConstantInt::getNullValue(I->getType()); + continue; + } + auto [Op, ConvertedOps] = convertTo(I, S); + for (auto [OpIdx, Ops] : enumerate(Operands)) + Ops[Idx] = ConvertedOps[OpIdx]; + } + return; case Instruction::GetElementPtr: { Operands.assign(2, {VL.size(), nullptr}); // Need to cast all indices to the same type before vectorization to @@ -11453,6 +11479,22 @@ public: : getSameOpcode(VL, TLI); if (S) return S; + // Check if series of selects + zext i1 %x to in can be combined into + // selects + select %x, i32 1, i32 0. + Instruction *SelectOp = nullptr; + if (allSameBlock(VL) && all_of(VL, [&](Value *V) { + if (match(V, m_Select(m_Value(), m_Value(), m_Value()))) { + if (!SelectOp) + SelectOp = cast<Instruction>(V); + return true; + } + auto *ZExt = dyn_cast<ZExtInst>(V); + return (ZExt && ZExt->getSrcTy()->isIntegerTy(1)) || + isa<PoisonValue>(V); + })) { + if (SelectOp) + return InstructionsState(SelectOp, SelectOp); + } if (!VectorizeCopyableElements || !TryCopyableElementsVectorization) return S; findAndSetMainInstruction(VL, R); @@ -15481,6 +15523,10 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, if (isa<PoisonValue>(UniqueValues[Idx])) return InstructionCost(TTI::TCC_Free); + if (!isa<SelectInst>(UniqueValues[Idx])) + return TTI->getInstructionCost(cast<Instruction>(UniqueValues[Idx]), + CostKind); + auto *VI = cast<Instruction>(UniqueValues[Idx]); CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy() ? CmpInst::BAD_FCMP_PREDICATE @@ -25243,6 +25289,32 @@ private: (I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode())); } + /// Optimizes original placement of the reduced values for the reduction tree. + /// For example, if there is a zext i1 + selects, we can merge select + /// into zext and improve emission of the reductions. + void optimizeReducedVals() { + SmallDenseMap<unsigned, unsigned> UsedReductionOpIds; + for (const auto [Idx, Vals] : enumerate(ReducedVals)) { + if (auto *I = dyn_cast<Instruction>(Vals.front())) + UsedReductionOpIds.try_emplace(I->getOpcode(), Idx); + } + // Check if zext i1 can be merged with select. + auto ZExtIt = UsedReductionOpIds.find(Instruction::ZExt); + auto SelectIt = UsedReductionOpIds.find(Instruction::Select); + if (ZExtIt != UsedReductionOpIds.end() && + SelectIt != UsedReductionOpIds.end()) { + unsigned ZExtIdx = ZExtIt->second; + unsigned SelectIdx = SelectIt->second; + auto *ZExt = cast<ZExtInst>(ReducedVals[ZExtIdx].front()); + // ZExt is compatible with Select? Merge select to zext, if so. + if (ZExt->getSrcTy()->isIntegerTy(1) && + ZExt->getType() == ReducedVals[SelectIdx].front()->getType()) { + ReducedVals[ZExtIdx].append(ReducedVals[SelectIdx]); + ReducedVals.erase(std::next(ReducedVals.begin(), SelectIdx)); + } + } + } + public: HorizontalReduction() = default; HorizontalReduction(Instruction *I, ArrayRef<Value *> Ops) @@ -25418,6 +25490,9 @@ public: ReducedVals.back().append(Data.rbegin(), Data.rend()); } } + // Post optimize reduced values to get better reduction sequences and sort + // them by size. + optimizeReducedVals(); // Sort the reduced values by number of same/alternate opcode and/or pointer // operand. stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) { |
