diff options
| author | Florian Hahn <flo@fhahn.com> | 2026-02-10 21:16:21 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-10 21:16:21 +0000 |
| commit | a1fc5b4a48db440e37acd06972a65dd67a3b2b7c (patch) | |
| tree | 56ef37ecd8d2cf2f4650a4221108364db96d90ec /llvm/lib/Transforms/Vectorize | |
| parent | 7f9965c73de2a3290f78f12f9d26360407dd2fd6 (diff) | |
| download | llvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.tar.gz llvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.tar.bz2 llvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.zip | |
[VPlan] Reject partial reductions with invalid costs in getScaledReds. (#180438)
Check if costs for partial reductions are valid up-front in
getScaledReductions instead when transforming each link in the chain in
transformToPartialReduction. This ensures that we either transform all
entries in the chain together, or none via the existing invalidation
logic.
This fixes a crash when a link in the chain would have invalid cost, as
in the added test cases.
Fixes https://github.com/llvm/llvm-project/issues/180340.
PR: https://github.com/llvm/llvm-project/pull/180438
Diffstat (limited to 'llvm/lib/Transforms/Vectorize')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 168 |
1 files changed, 82 insertions, 86 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index a99641c472b9..19be84e9db09 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -5547,15 +5547,12 @@ struct VPPartialReductionChain { }; // Helper to transform a partial reduction chain into a partial reduction -// recipe. Returns true if transformation succeeded. Checks profitability and -// clamps VF range. -static bool transformToPartialReduction(const VPPartialReductionChain &Chain, - VFRange &Range, VPCostContext &CostCtx, - VPlan &Plan, +// recipe. Assumes profitability has been checked. +static void transformToPartialReduction(const VPPartialReductionChain &Chain, + VPTypeAnalysis &TypeInfo, VPlan &Plan, VPReductionPHIRecipe *RdxPhi, RecurKind RK) { VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp; - unsigned ScaleFactor = Chain.ScaleFactor; assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation"); VPValue *BinOp = WidenRecipe->getOperand(0); @@ -5565,67 +5562,6 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, if (isa_and_present<VPReductionPHIRecipe, VPReductionRecipe>(BinOp)) std::swap(BinOp, Accumulator); - // For chained reductions, only transform if accumulator is already a PHI or - // partial reduction. Otherwise, it needs to be transformed first. - auto *AccumRecipe = Accumulator->getDefiningRecipe(); - if (!isa_and_present<VPReductionPHIRecipe, VPReductionRecipe>(AccumRecipe)) - return false; - - // Check if the partial reduction is profitable for the VF range. - Type *PhiType = CostCtx.Types.inferScalarType(Accumulator); - - // Derive extend info from the stored extends. - auto GetExtInfo = [&CostCtx](VPWidenCastRecipe *Ext) - -> std::pair<Type *, TargetTransformInfo::PartialReductionExtendKind> { - if (!Ext) - return {nullptr, TargetTransformInfo::PR_None}; - Type *ExtOpType = CostCtx.Types.inferScalarType(Ext->getOperand(0)); - auto ExtKind = TargetTransformInfo::getPartialReductionExtendKind( - static_cast<Instruction::CastOps>(Ext->getOpcode())); - return {ExtOpType, ExtKind}; - }; - auto ExtInfoA = GetExtInfo(Chain.ExtendA); - auto ExtInfoB = GetExtInfo(Chain.ExtendB); - Type *ExtOpTypeA = ExtInfoA.first; - Type *ExtOpTypeB = ExtInfoB.first; - auto ExtKindA = ExtInfoA.second; - auto ExtKindB = ExtInfoB.second; - // If ExtendB is nullptr but there's a separate BinOp, the second operand - // was a constant that can use the same extend kind as the first. - if (!Chain.ExtendB && Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) { - // Validate that the constant can be extended to the narrow type. - const APInt *Const = nullptr; - for (VPValue *Op : Chain.BinOp->operands()) { - if (match(Op, m_APInt(Const))) - break; - } - if (!Const || !canConstantBeExtended(Const, ExtOpTypeA, ExtKindA)) - return false; - ExtOpTypeB = ExtOpTypeA; - ExtKindB = ExtKindA; - } - - // BinOpc is only set when there's a separate binary op (not when BinOp is - // the reduction itself). - std::optional<unsigned> BinOpc = - (Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) - ? std::make_optional(Chain.BinOp->getOpcode()) - : std::nullopt; - - if (!LoopVectorizationPlanner::getDecisionAndClampRange( - [&](ElementCount VF) { - return CostCtx.TTI - .getPartialReductionCost( - WidenRecipe->getOpcode(), ExtOpTypeA, ExtOpTypeB, PhiType, - VF, ExtKindA, ExtKindB, BinOpc, CostCtx.CostKind, - PhiType->isFloatingPointTy() - ? std::optional{WidenRecipe->getFastMathFlags()} - : std::nullopt) - .isValid(); - }, - Range)) - return false; - // Sub-reductions can be implemented in two ways: // (1) negate the operand in the vector loop (the default way). // (2) subtract the reduced value from the init value in the middle block. @@ -5642,7 +5578,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, // subtract in the middle block. if (WidenRecipe->getOpcode() == Instruction::Sub && RK != RecurKind::Sub) { VPBuilder Builder(WidenRecipe); - Type *ElemTy = CostCtx.Types.inferScalarType(BinOp); + Type *ElemTy = TypeInfo.inferScalarType(BinOp); auto *Zero = Plan.getConstantInt(ElemTy, 0); VPIRFlags Flags = WidenRecipe->getUnderlyingInstr() ? VPIRFlags(*WidenRecipe->getUnderlyingInstr()) @@ -5664,6 +5600,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, assert((!ExitValue || IsLastInChain) && "if we found ExitValue, it must match RdxPhi's backedge value"); + Type *PhiType = TypeInfo.inferScalarType(RdxPhi); RecurKind RdxKind = PhiType->isFloatingPointTy() ? RecurKind::FAdd : RecurKind::Add; auto *PartialRed = new VPReductionRecipe( @@ -5671,7 +5608,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, RdxKind == RecurKind::FAdd ? WidenRecipe->getFastMathFlags() : FastMathFlags(), WidenRecipe->getUnderlyingInstr(), Accumulator, BinOp, Cond, - RdxUnordered{/*VFScaleFactor=*/ScaleFactor}); + RdxUnordered{/*VFScaleFactor=*/Chain.ScaleFactor}); PartialRed->insertBefore(WidenRecipe); if (Cond) @@ -5681,22 +5618,22 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, // We only need to update the PHI node once, which is when we find the // last reduction in the chain. if (!IsLastInChain) - return true; + return; // Scale the PHI and ReductionStartVector by the VFScaleFactor assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set"); - RdxPhi->setVFScaleFactor(ScaleFactor); + RdxPhi->setVFScaleFactor(Chain.ScaleFactor); auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue()); assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector); - auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor); + auto *NewScaleFactor = Plan.getConstantInt(32, Chain.ScaleFactor); StartInst->setOperand(2, NewScaleFactor); // If this is the last value in a sub-reduction chain, then update the PHI // node to start at `0` and update the reduction-result to subtract from // the PHI's start value. if (RK != RecurKind::Sub) - return true; + return; VPValue *OldStartValue = StartInst->getOperand(0); StartInst->setOperand(0, StartInst->getOperand(1)); @@ -5713,8 +5650,61 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, RdxResult->replaceUsesWithIf( NewResult, [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; }); +} - return true; +/// Check if a partial reduction chain is is supported by the target (i.e. does +/// not have an invalid cost) for the given VF range. Clamps the range and +/// returns true if profitable for any VF. +static bool isValidPartialReduction(const VPPartialReductionChain &Chain, + Type *PhiType, VPCostContext &CostCtx, + VFRange &Range) { + auto GetExtInfo = [&CostCtx](VPWidenCastRecipe *Ext) + -> std::pair<Type *, TargetTransformInfo::PartialReductionExtendKind> { + if (!Ext) + return {nullptr, TargetTransformInfo::PR_None}; + Type *ExtOpType = CostCtx.Types.inferScalarType(Ext->getOperand(0)); + auto ExtKind = TargetTransformInfo::getPartialReductionExtendKind( + static_cast<Instruction::CastOps>(Ext->getOpcode())); + return {ExtOpType, ExtKind}; + }; + auto ExtInfoA = GetExtInfo(Chain.ExtendA); + auto ExtInfoB = GetExtInfo(Chain.ExtendB); + Type *ExtOpTypeA = ExtInfoA.first; + Type *ExtOpTypeB = ExtInfoB.first; + auto ExtKindA = ExtInfoA.second; + auto ExtKindB = ExtInfoB.second; + + // If ExtendB is nullptr but there's a separate BinOp, the second operand + // was a constant that can use the same extend kind as the first. + if (!Chain.ExtendB && Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) { + const APInt *Const = nullptr; + for (VPValue *Op : Chain.BinOp->operands()) { + if (match(Op, m_APInt(Const))) + break; + } + if (!Const || !canConstantBeExtended(Const, ExtOpTypeA, ExtKindA)) + return false; + ExtOpTypeB = ExtOpTypeA; + ExtKindB = ExtKindA; + } + + std::optional<unsigned> BinOpc = + (Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) + ? std::make_optional(Chain.BinOp->getOpcode()) + : std::nullopt; + VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp; + return LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) { + return CostCtx.TTI + .getPartialReductionCost( + WidenRecipe->getOpcode(), ExtOpTypeA, ExtOpTypeB, PhiType, VF, + ExtKindA, ExtKindB, BinOpc, CostCtx.CostKind, + PhiType->isFloatingPointTy() + ? std::optional{WidenRecipe->getFastMathFlags()} + : std::nullopt) + .isValid(); + }, + Range); } /// Examines reduction operations to see if the target can use a cheaper @@ -5724,7 +5714,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain, static bool getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, SmallVectorImpl<VPPartialReductionChain> &Chains, - VPTypeAnalysis &TypeInfo) { + VPCostContext &CostCtx, VFRange &Range) { auto *UpdateR = dyn_cast<VPWidenRecipe>(PrevValue); if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode())) return false; @@ -5753,7 +5743,7 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, // Try and get a scaled reduction from the first non-phi operand. // If one is found, we use the discovered reduction instruction in // place of the accumulator for costing. - if (getScaledReductions(RedPhiR, Op, Chains, TypeInfo)) { + if (getScaledReductions(RedPhiR, Op, Chains, CostCtx, Range)) { RedPhiR = Chains.rbegin()->ReductionBinOp; Op = UpdateR->getOperand(0); PhiOp = UpdateR->getOperand(1); @@ -5774,7 +5764,8 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, assert(Operands.size() <= 2 && "expected at most 2 operands"); for (const auto &[I, OpVal] : enumerate(Operands)) { - // Allow constant as second operand - validation happens in transform. + // Allow constant as second operand - validation happens in + // isValidPartialReduction. const APInt *Unused; if (I > 0 && CastRecipes[0] && match(OpVal, m_APInt(Unused))) continue; @@ -5824,16 +5815,21 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, return false; } - Type *PhiType = TypeInfo.inferScalarType(RedPhiR); + Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR); TypeSize PHISize = PhiType->getPrimitiveSizeInBits(); - Type *ExtOpType = TypeInfo.inferScalarType(CastRecipes[0]->getOperand(0)); + Type *ExtOpType = + CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0)); TypeSize ASize = ExtOpType->getPrimitiveSizeInBits(); if (!PHISize.hasKnownScalarFactor(ASize)) return false; - Chains.push_back( + VPPartialReductionChain Chain( {UpdateR, CastRecipes[0], CastRecipes[1], BinOp, static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize))}); + if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range)) + return false; + + Chains.push_back(Chain); return true; } } // namespace @@ -5841,9 +5837,9 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, void VPlanTransforms::createPartialReductions(VPlan &Plan, VPCostContext &CostCtx, VFRange &Range) { - // Find all possible partial reductions, grouping chains by their PHI. This - // grouping allows invalidating the whole chain, if any link is not a valid - // partial reduction. + // Find all possible valid partial reductions, grouping chains by their PHI. + // This grouping allows invalidating the whole chain, if any link is not a + // valid partial reduction. MapVector<VPReductionPHIRecipe *, SmallVector<VPPartialReductionChain>> ChainsByPhi; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); @@ -5859,8 +5855,8 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan, VPValue *ExitValue = RdxResult->getOperand(0); match(ExitValue, m_Select(m_VPValue(), m_VPValue(ExitValue), m_VPValue())); - getScaledReductions(RedPhiR, ExitValue, ChainsByPhi[RedPhiR], - CostCtx.Types); + getScaledReductions(RedPhiR, ExitValue, ChainsByPhi[RedPhiR], CostCtx, + Range); } } @@ -5930,6 +5926,6 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan, for (auto &[Phi, Chains] : ChainsByPhi) { RecurKind RK = cast<VPReductionPHIRecipe>(Phi)->getRecurrenceKind(); for (const VPPartialReductionChain &Chain : Chains) - transformToPartialReduction(Chain, Range, CostCtx, Plan, Phi, RK); + transformToPartialReduction(Chain, CostCtx.Types, Plan, Phi, RK); } } |
