aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize
diff options
context:
space:
mode:
authorFlorian Hahn <flo@fhahn.com>2026-02-10 21:16:21 +0000
committerGitHub <noreply@github.com>2026-02-10 21:16:21 +0000
commita1fc5b4a48db440e37acd06972a65dd67a3b2b7c (patch)
tree56ef37ecd8d2cf2f4650a4221108364db96d90ec /llvm/lib/Transforms/Vectorize
parent7f9965c73de2a3290f78f12f9d26360407dd2fd6 (diff)
downloadllvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.tar.gz
llvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.tar.bz2
llvm-a1fc5b4a48db440e37acd06972a65dd67a3b2b7c.zip
[VPlan] Reject partial reductions with invalid costs in getScaledReds. (#180438)
Check if costs for partial reductions are valid up-front in getScaledReductions instead when transforming each link in the chain in transformToPartialReduction. This ensures that we either transform all entries in the chain together, or none via the existing invalidation logic. This fixes a crash when a link in the chain would have invalid cost, as in the added test cases. Fixes https://github.com/llvm/llvm-project/issues/180340. PR: https://github.com/llvm/llvm-project/pull/180438
Diffstat (limited to 'llvm/lib/Transforms/Vectorize')
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp168
1 files changed, 82 insertions, 86 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a99641c472b9..19be84e9db09 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5547,15 +5547,12 @@ struct VPPartialReductionChain {
};
// Helper to transform a partial reduction chain into a partial reduction
-// recipe. Returns true if transformation succeeded. Checks profitability and
-// clamps VF range.
-static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
- VFRange &Range, VPCostContext &CostCtx,
- VPlan &Plan,
+// recipe. Assumes profitability has been checked.
+static void transformToPartialReduction(const VPPartialReductionChain &Chain,
+ VPTypeAnalysis &TypeInfo, VPlan &Plan,
VPReductionPHIRecipe *RdxPhi,
RecurKind RK) {
VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp;
- unsigned ScaleFactor = Chain.ScaleFactor;
assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation");
VPValue *BinOp = WidenRecipe->getOperand(0);
@@ -5565,67 +5562,6 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
if (isa_and_present<VPReductionPHIRecipe, VPReductionRecipe>(BinOp))
std::swap(BinOp, Accumulator);
- // For chained reductions, only transform if accumulator is already a PHI or
- // partial reduction. Otherwise, it needs to be transformed first.
- auto *AccumRecipe = Accumulator->getDefiningRecipe();
- if (!isa_and_present<VPReductionPHIRecipe, VPReductionRecipe>(AccumRecipe))
- return false;
-
- // Check if the partial reduction is profitable for the VF range.
- Type *PhiType = CostCtx.Types.inferScalarType(Accumulator);
-
- // Derive extend info from the stored extends.
- auto GetExtInfo = [&CostCtx](VPWidenCastRecipe *Ext)
- -> std::pair<Type *, TargetTransformInfo::PartialReductionExtendKind> {
- if (!Ext)
- return {nullptr, TargetTransformInfo::PR_None};
- Type *ExtOpType = CostCtx.Types.inferScalarType(Ext->getOperand(0));
- auto ExtKind = TargetTransformInfo::getPartialReductionExtendKind(
- static_cast<Instruction::CastOps>(Ext->getOpcode()));
- return {ExtOpType, ExtKind};
- };
- auto ExtInfoA = GetExtInfo(Chain.ExtendA);
- auto ExtInfoB = GetExtInfo(Chain.ExtendB);
- Type *ExtOpTypeA = ExtInfoA.first;
- Type *ExtOpTypeB = ExtInfoB.first;
- auto ExtKindA = ExtInfoA.second;
- auto ExtKindB = ExtInfoB.second;
- // If ExtendB is nullptr but there's a separate BinOp, the second operand
- // was a constant that can use the same extend kind as the first.
- if (!Chain.ExtendB && Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) {
- // Validate that the constant can be extended to the narrow type.
- const APInt *Const = nullptr;
- for (VPValue *Op : Chain.BinOp->operands()) {
- if (match(Op, m_APInt(Const)))
- break;
- }
- if (!Const || !canConstantBeExtended(Const, ExtOpTypeA, ExtKindA))
- return false;
- ExtOpTypeB = ExtOpTypeA;
- ExtKindB = ExtKindA;
- }
-
- // BinOpc is only set when there's a separate binary op (not when BinOp is
- // the reduction itself).
- std::optional<unsigned> BinOpc =
- (Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp)
- ? std::make_optional(Chain.BinOp->getOpcode())
- : std::nullopt;
-
- if (!LoopVectorizationPlanner::getDecisionAndClampRange(
- [&](ElementCount VF) {
- return CostCtx.TTI
- .getPartialReductionCost(
- WidenRecipe->getOpcode(), ExtOpTypeA, ExtOpTypeB, PhiType,
- VF, ExtKindA, ExtKindB, BinOpc, CostCtx.CostKind,
- PhiType->isFloatingPointTy()
- ? std::optional{WidenRecipe->getFastMathFlags()}
- : std::nullopt)
- .isValid();
- },
- Range))
- return false;
-
// Sub-reductions can be implemented in two ways:
// (1) negate the operand in the vector loop (the default way).
// (2) subtract the reduced value from the init value in the middle block.
@@ -5642,7 +5578,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
// subtract in the middle block.
if (WidenRecipe->getOpcode() == Instruction::Sub && RK != RecurKind::Sub) {
VPBuilder Builder(WidenRecipe);
- Type *ElemTy = CostCtx.Types.inferScalarType(BinOp);
+ Type *ElemTy = TypeInfo.inferScalarType(BinOp);
auto *Zero = Plan.getConstantInt(ElemTy, 0);
VPIRFlags Flags = WidenRecipe->getUnderlyingInstr()
? VPIRFlags(*WidenRecipe->getUnderlyingInstr())
@@ -5664,6 +5600,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
assert((!ExitValue || IsLastInChain) &&
"if we found ExitValue, it must match RdxPhi's backedge value");
+ Type *PhiType = TypeInfo.inferScalarType(RdxPhi);
RecurKind RdxKind =
PhiType->isFloatingPointTy() ? RecurKind::FAdd : RecurKind::Add;
auto *PartialRed = new VPReductionRecipe(
@@ -5671,7 +5608,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
RdxKind == RecurKind::FAdd ? WidenRecipe->getFastMathFlags()
: FastMathFlags(),
WidenRecipe->getUnderlyingInstr(), Accumulator, BinOp, Cond,
- RdxUnordered{/*VFScaleFactor=*/ScaleFactor});
+ RdxUnordered{/*VFScaleFactor=*/Chain.ScaleFactor});
PartialRed->insertBefore(WidenRecipe);
if (Cond)
@@ -5681,22 +5618,22 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
// We only need to update the PHI node once, which is when we find the
// last reduction in the chain.
if (!IsLastInChain)
- return true;
+ return;
// Scale the PHI and ReductionStartVector by the VFScaleFactor
assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
- RdxPhi->setVFScaleFactor(ScaleFactor);
+ RdxPhi->setVFScaleFactor(Chain.ScaleFactor);
auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
- auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+ auto *NewScaleFactor = Plan.getConstantInt(32, Chain.ScaleFactor);
StartInst->setOperand(2, NewScaleFactor);
// If this is the last value in a sub-reduction chain, then update the PHI
// node to start at `0` and update the reduction-result to subtract from
// the PHI's start value.
if (RK != RecurKind::Sub)
- return true;
+ return;
VPValue *OldStartValue = StartInst->getOperand(0);
StartInst->setOperand(0, StartInst->getOperand(1));
@@ -5713,8 +5650,61 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
RdxResult->replaceUsesWithIf(
NewResult,
[&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
+}
- return true;
+/// Check if a partial reduction chain is is supported by the target (i.e. does
+/// not have an invalid cost) for the given VF range. Clamps the range and
+/// returns true if profitable for any VF.
+static bool isValidPartialReduction(const VPPartialReductionChain &Chain,
+ Type *PhiType, VPCostContext &CostCtx,
+ VFRange &Range) {
+ auto GetExtInfo = [&CostCtx](VPWidenCastRecipe *Ext)
+ -> std::pair<Type *, TargetTransformInfo::PartialReductionExtendKind> {
+ if (!Ext)
+ return {nullptr, TargetTransformInfo::PR_None};
+ Type *ExtOpType = CostCtx.Types.inferScalarType(Ext->getOperand(0));
+ auto ExtKind = TargetTransformInfo::getPartialReductionExtendKind(
+ static_cast<Instruction::CastOps>(Ext->getOpcode()));
+ return {ExtOpType, ExtKind};
+ };
+ auto ExtInfoA = GetExtInfo(Chain.ExtendA);
+ auto ExtInfoB = GetExtInfo(Chain.ExtendB);
+ Type *ExtOpTypeA = ExtInfoA.first;
+ Type *ExtOpTypeB = ExtInfoB.first;
+ auto ExtKindA = ExtInfoA.second;
+ auto ExtKindB = ExtInfoB.second;
+
+ // If ExtendB is nullptr but there's a separate BinOp, the second operand
+ // was a constant that can use the same extend kind as the first.
+ if (!Chain.ExtendB && Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp) {
+ const APInt *Const = nullptr;
+ for (VPValue *Op : Chain.BinOp->operands()) {
+ if (match(Op, m_APInt(Const)))
+ break;
+ }
+ if (!Const || !canConstantBeExtended(Const, ExtOpTypeA, ExtKindA))
+ return false;
+ ExtOpTypeB = ExtOpTypeA;
+ ExtKindB = ExtKindA;
+ }
+
+ std::optional<unsigned> BinOpc =
+ (Chain.BinOp && Chain.BinOp != Chain.ReductionBinOp)
+ ? std::make_optional(Chain.BinOp->getOpcode())
+ : std::nullopt;
+ VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp;
+ return LoopVectorizationPlanner::getDecisionAndClampRange(
+ [&](ElementCount VF) {
+ return CostCtx.TTI
+ .getPartialReductionCost(
+ WidenRecipe->getOpcode(), ExtOpTypeA, ExtOpTypeB, PhiType, VF,
+ ExtKindA, ExtKindB, BinOpc, CostCtx.CostKind,
+ PhiType->isFloatingPointTy()
+ ? std::optional{WidenRecipe->getFastMathFlags()}
+ : std::nullopt)
+ .isValid();
+ },
+ Range);
}
/// Examines reduction operations to see if the target can use a cheaper
@@ -5724,7 +5714,7 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
static bool
getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue,
SmallVectorImpl<VPPartialReductionChain> &Chains,
- VPTypeAnalysis &TypeInfo) {
+ VPCostContext &CostCtx, VFRange &Range) {
auto *UpdateR = dyn_cast<VPWidenRecipe>(PrevValue);
if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
return false;
@@ -5753,7 +5743,7 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue,
// Try and get a scaled reduction from the first non-phi operand.
// If one is found, we use the discovered reduction instruction in
// place of the accumulator for costing.
- if (getScaledReductions(RedPhiR, Op, Chains, TypeInfo)) {
+ if (getScaledReductions(RedPhiR, Op, Chains, CostCtx, Range)) {
RedPhiR = Chains.rbegin()->ReductionBinOp;
Op = UpdateR->getOperand(0);
PhiOp = UpdateR->getOperand(1);
@@ -5774,7 +5764,8 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue,
assert(Operands.size() <= 2 && "expected at most 2 operands");
for (const auto &[I, OpVal] : enumerate(Operands)) {
- // Allow constant as second operand - validation happens in transform.
+ // Allow constant as second operand - validation happens in
+ // isValidPartialReduction.
const APInt *Unused;
if (I > 0 && CastRecipes[0] && match(OpVal, m_APInt(Unused)))
continue;
@@ -5824,16 +5815,21 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue,
return false;
}
- Type *PhiType = TypeInfo.inferScalarType(RedPhiR);
+ Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
- Type *ExtOpType = TypeInfo.inferScalarType(CastRecipes[0]->getOperand(0));
+ Type *ExtOpType =
+ CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0));
TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
if (!PHISize.hasKnownScalarFactor(ASize))
return false;
- Chains.push_back(
+ VPPartialReductionChain Chain(
{UpdateR, CastRecipes[0], CastRecipes[1], BinOp,
static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize))});
+ if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
+ return false;
+
+ Chains.push_back(Chain);
return true;
}
} // namespace
@@ -5841,9 +5837,9 @@ getScaledReductions(VPSingleDefRecipe *RedPhiR, VPValue *PrevValue,
void VPlanTransforms::createPartialReductions(VPlan &Plan,
VPCostContext &CostCtx,
VFRange &Range) {
- // Find all possible partial reductions, grouping chains by their PHI. This
- // grouping allows invalidating the whole chain, if any link is not a valid
- // partial reduction.
+ // Find all possible valid partial reductions, grouping chains by their PHI.
+ // This grouping allows invalidating the whole chain, if any link is not a
+ // valid partial reduction.
MapVector<VPReductionPHIRecipe *, SmallVector<VPPartialReductionChain>>
ChainsByPhi;
VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
@@ -5859,8 +5855,8 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan,
VPValue *ExitValue = RdxResult->getOperand(0);
match(ExitValue,
m_Select(m_VPValue(), m_VPValue(ExitValue), m_VPValue()));
- getScaledReductions(RedPhiR, ExitValue, ChainsByPhi[RedPhiR],
- CostCtx.Types);
+ getScaledReductions(RedPhiR, ExitValue, ChainsByPhi[RedPhiR], CostCtx,
+ Range);
}
}
@@ -5930,6 +5926,6 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan,
for (auto &[Phi, Chains] : ChainsByPhi) {
RecurKind RK = cast<VPReductionPHIRecipe>(Phi)->getRecurrenceKind();
for (const VPPartialReductionChain &Chain : Chains)
- transformToPartialReduction(Chain, Range, CostCtx, Plan, Phi, RK);
+ transformToPartialReduction(Chain, CostCtx.Types, Plan, Phi, RK);
}
}