diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 906fa2f..b7224a3 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7933,6 +7933,26 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) { (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))) ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second); } + + // Check that all partial reductions in a chain are only used by other + // partial reductions with the same scale factor. Otherwise we end up creating + // users of scaled reductions where the types of the other operands don't + // match. + for (const auto &[Chain, Scale] : PartialReductionChains) { + auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) { + auto *UI = cast<Instruction>(U); + if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) { + return all_of(UI->users(), [ScaleVal, this](const User *U) { + auto *UI = cast<Instruction>(U); + return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal; + }); + } + return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal || + !OrigLoop->contains(UI->getParent()); + }; + if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) + ScaledReductionMap.erase(Chain.Reduction); + } } bool VPRecipeBuilder::getScaledReductions( @@ -8116,11 +8136,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R, if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) return tryToWidenMemory(Instr, Operands, Range); - if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr)) { - if (auto PartialRed = - tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value())) - return PartialRed; - } + if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr)) + return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value()); if (!shouldWiden(Instr, Range)) return nullptr; @@ -8154,9 +8171,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction, isa<VPPartialReductionRecipe>(BinOpRecipe)) std::swap(BinOp, Accumulator); - if (ScaleFactor != - vputils::getVFScaleFactor(Accumulator->getDefiningRecipe())) - return nullptr; + assert(ScaleFactor == + vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) && + "all accumulators in chain must have same scale factor"); unsigned ReductionOpcode = Reduction->getOpcode(); if (ReductionOpcode == Instruction::Sub) { |
