aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/IR/Value.cpp46
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp160
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp164
3 files changed, 288 insertions, 82 deletions
diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp
index 129ca4a..5928c89 100644
--- a/llvm/lib/IR/Value.cpp
+++ b/llvm/lib/IR/Value.cpp
@@ -747,34 +747,28 @@ const Value *Value::stripAndAccumulateConstantOffsets(
// means when we construct GEPOffset, we need to use the size
// of GEP's pointer type rather than the size of the original
// pointer type.
- unsigned CurBitWidth = DL.getIndexTypeSizeInBits(V->getType());
- if (CurBitWidth == BitWidth) {
- if (!GEP->accumulateConstantOffset(DL, Offset, ExternalAnalysis))
- return V;
- } else {
- APInt GEPOffset(CurBitWidth, 0);
- if (!GEP->accumulateConstantOffset(DL, GEPOffset, ExternalAnalysis))
- return V;
+ APInt GEPOffset(DL.getIndexTypeSizeInBits(V->getType()), 0);
+ if (!GEP->accumulateConstantOffset(DL, GEPOffset, ExternalAnalysis))
+ return V;
- // Stop traversal if the pointer offset wouldn't fit in the bit-width
- // provided by the Offset argument. This can happen due to AddrSpaceCast
- // stripping.
- if (GEPOffset.getSignificantBits() > BitWidth)
- return V;
+ // Stop traversal if the pointer offset wouldn't fit in the bit-width
+ // provided by the Offset argument. This can happen due to AddrSpaceCast
+ // stripping.
+ if (GEPOffset.getSignificantBits() > BitWidth)
+ return V;
- // External Analysis can return a result higher/lower than the value
- // represents. We need to detect overflow/underflow.
- APInt GEPOffsetST = GEPOffset.sextOrTrunc(BitWidth);
- if (!ExternalAnalysis) {
- Offset += GEPOffsetST;
- } else {
- bool Overflow = false;
- APInt OldOffset = Offset;
- Offset = Offset.sadd_ov(GEPOffsetST, Overflow);
- if (Overflow) {
- Offset = OldOffset;
- return V;
- }
+ // External Analysis can return a result higher/lower than the value
+ // represents. We need to detect overflow/underflow.
+ APInt GEPOffsetST = GEPOffset.sextOrTrunc(BitWidth);
+ if (!ExternalAnalysis) {
+ Offset += GEPOffsetST;
+ } else {
+ bool Overflow = false;
+ APInt OldOffset = Offset;
+ Offset = Offset.sadd_ov(GEPOffsetST, Overflow);
+ if (Overflow) {
+ Offset = OldOffset;
+ return V;
}
}
V = GEP->getPointerOperand();
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e4aa8b8..e63b937 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1844,6 +1844,17 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
/*IsStore*/ true,
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
+ case Intrinsic::riscv_sseg2_store_mask:
+ case Intrinsic::riscv_sseg3_store_mask:
+ case Intrinsic::riscv_sseg4_store_mask:
+ case Intrinsic::riscv_sseg5_store_mask:
+ case Intrinsic::riscv_sseg6_store_mask:
+ case Intrinsic::riscv_sseg7_store_mask:
+ case Intrinsic::riscv_sseg8_store_mask:
+ // Operands are (vec, ..., vec, ptr, offset, mask, vl)
+ return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 4,
+ /*IsStore*/ true,
+ /*IsUnitStrided*/ false, /*UsePtrVal*/ true);
case Intrinsic::riscv_vlm:
return SetRVVLoadStoreInfo(/*PtrOp*/ 0,
/*IsStore*/ false,
@@ -11084,69 +11095,118 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
}
-SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
- SelectionDAG &DAG) const {
- unsigned IntNo = Op.getConstantOperandVal(1);
+static SDValue
+lowerFixedVectorSegStoreIntrinsics(unsigned IntNo, SDValue Op,
+ const RISCVSubtarget &Subtarget,
+ SelectionDAG &DAG) {
+ bool IsStrided;
switch (IntNo) {
- default:
- break;
case Intrinsic::riscv_seg2_store_mask:
case Intrinsic::riscv_seg3_store_mask:
case Intrinsic::riscv_seg4_store_mask:
case Intrinsic::riscv_seg5_store_mask:
case Intrinsic::riscv_seg6_store_mask:
case Intrinsic::riscv_seg7_store_mask:
- case Intrinsic::riscv_seg8_store_mask: {
- SDLoc DL(Op);
- static const Intrinsic::ID VssegInts[] = {
- Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
- Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
- Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
- Intrinsic::riscv_vsseg8_mask};
+ case Intrinsic::riscv_seg8_store_mask:
+ IsStrided = false;
+ break;
+ case Intrinsic::riscv_sseg2_store_mask:
+ case Intrinsic::riscv_sseg3_store_mask:
+ case Intrinsic::riscv_sseg4_store_mask:
+ case Intrinsic::riscv_sseg5_store_mask:
+ case Intrinsic::riscv_sseg6_store_mask:
+ case Intrinsic::riscv_sseg7_store_mask:
+ case Intrinsic::riscv_sseg8_store_mask:
+ IsStrided = true;
+ break;
+ default:
+ llvm_unreachable("unexpected intrinsic ID");
+ }
- // Operands: (chain, int_id, vec*, ptr, mask, vl)
- unsigned NF = Op->getNumOperands() - 5;
- assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
- MVT XLenVT = Subtarget.getXLenVT();
- MVT VT = Op->getOperand(2).getSimpleValueType();
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
- unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
- ContainerVT.getScalarSizeInBits();
- EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
+ SDLoc DL(Op);
+ static const Intrinsic::ID VssegInts[] = {
+ Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
+ Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
+ Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
+ Intrinsic::riscv_vsseg8_mask};
+ static const Intrinsic::ID VsssegInts[] = {
+ Intrinsic::riscv_vssseg2_mask, Intrinsic::riscv_vssseg3_mask,
+ Intrinsic::riscv_vssseg4_mask, Intrinsic::riscv_vssseg5_mask,
+ Intrinsic::riscv_vssseg6_mask, Intrinsic::riscv_vssseg7_mask,
+ Intrinsic::riscv_vssseg8_mask};
+
+ // Operands: (chain, int_id, vec*, ptr, mask, vl) or
+ // (chain, int_id, vec*, ptr, stride, mask, vl)
+ unsigned NF = Op->getNumOperands() - (IsStrided ? 6 : 5);
+ assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
+ MVT XLenVT = Subtarget.getXLenVT();
+ MVT VT = Op->getOperand(2).getSimpleValueType();
+ MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
+ ContainerVT.getScalarSizeInBits();
+ EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
- SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
- SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
- MVT MaskVT = Mask.getSimpleValueType();
- MVT MaskContainerVT =
- ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
- Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
+ SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
+ SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
+ MVT MaskVT = Mask.getSimpleValueType();
+ MVT MaskContainerVT =
+ ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
+ Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
- SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
- SDValue Ptr = Op->getOperand(NF + 2);
+ SDValue IntID = DAG.getTargetConstant(
+ IsStrided ? VsssegInts[NF - 2] : VssegInts[NF - 2], DL, XLenVT);
+ SDValue Ptr = Op->getOperand(NF + 2);
- auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
+ auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
- SDValue StoredVal = DAG.getUNDEF(VecTupTy);
- for (unsigned i = 0; i < NF; i++)
- StoredVal = DAG.getNode(
- RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
- convertToScalableVector(
- ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget),
- DAG.getTargetConstant(i, DL, MVT::i32));
+ SDValue StoredVal = DAG.getUNDEF(VecTupTy);
+ for (unsigned i = 0; i < NF; i++)
+ StoredVal = DAG.getNode(
+ RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
+ convertToScalableVector(ContainerVT, FixedIntrinsic->getOperand(2 + i),
+ DAG, Subtarget),
+ DAG.getTargetConstant(i, DL, MVT::i32));
+
+ SmallVector<SDValue, 10> Ops = {
+ FixedIntrinsic->getChain(),
+ IntID,
+ StoredVal,
+ Ptr,
+ Mask,
+ VL,
+ DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
+ // Insert the stride operand.
+ if (IsStrided)
+ Ops.insert(std::next(Ops.begin(), 4),
+ Op.getOperand(Op.getNumOperands() - 3));
+
+ return DAG.getMemIntrinsicNode(
+ ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
+ FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
+}
+
+SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
+ SelectionDAG &DAG) const {
+ unsigned IntNo = Op.getConstantOperandVal(1);
+ switch (IntNo) {
+ default:
+ break;
+ case Intrinsic::riscv_seg2_store_mask:
+ case Intrinsic::riscv_seg3_store_mask:
+ case Intrinsic::riscv_seg4_store_mask:
+ case Intrinsic::riscv_seg5_store_mask:
+ case Intrinsic::riscv_seg6_store_mask:
+ case Intrinsic::riscv_seg7_store_mask:
+ case Intrinsic::riscv_seg8_store_mask:
+ case Intrinsic::riscv_sseg2_store_mask:
+ case Intrinsic::riscv_sseg3_store_mask:
+ case Intrinsic::riscv_sseg4_store_mask:
+ case Intrinsic::riscv_sseg5_store_mask:
+ case Intrinsic::riscv_sseg6_store_mask:
+ case Intrinsic::riscv_sseg7_store_mask:
+ case Intrinsic::riscv_sseg8_store_mask:
+ return lowerFixedVectorSegStoreIntrinsics(IntNo, Op, Subtarget, DAG);
- SDValue Ops[] = {
- FixedIntrinsic->getChain(),
- IntID,
- StoredVal,
- Ptr,
- Mask,
- VL,
- DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
-
- return DAG.getMemIntrinsicNode(
- ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
- FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
- }
case Intrinsic::riscv_sf_vc_xv_se:
return getVCIXISDNodeVOID(Op, DAG, RISCVISD::SF_VC_XV_SE);
case Intrinsic::riscv_sf_vc_iv_se:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 5d0e2f9..ec06a21 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3883,6 +3883,7 @@ private:
enum CombinedOpcode {
NotCombinedOp = -1,
MinMax = Instruction::OtherOpsEnd + 1,
+ FMulAdd,
};
CombinedOpcode CombinedOp = NotCombinedOp;
@@ -4033,6 +4034,9 @@ private:
/// Returns true if any scalar in the list is a copyable element.
bool hasCopyableElements() const { return !CopyableElements.empty(); }
+ /// Returns the state of the operations.
+ const InstructionsState &getOperations() const { return S; }
+
/// When ReuseReorderShuffleIndices is empty it just returns position of \p
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,82 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
}
}
+static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
+ const InstructionsState &S,
+ DominatorTree &DT, const DataLayout &DL,
+ TargetTransformInfo &TTI,
+ const TargetLibraryInfo &TLI) {
+ assert(all_of(VL,
+ [](Value *V) {
+ return V->getType()->getScalarType()->isFloatingPointTy();
+ }) &&
+ "Can only convert to FMA for floating point types");
+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
+
+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
+ FastMathFlags FMF;
+ FMF.set();
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ continue;
+ // TODO: support for copyable elements.
+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
+ continue;
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ }
+ return FMF.allowContract();
+ };
+ if (!CheckForContractable(VL))
+ return InstructionCost::getInvalid();
+ // fmul also should be contractable
+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
+
+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
+ if (!OpS.valid())
+ return InstructionCost::getInvalid();
+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
+ return InstructionCost::getInvalid();
+ if (!CheckForContractable(Operands.front()))
+ return InstructionCost::getInvalid();
+ // Compare the costs.
+ InstructionCost FMulPlusFAddCost = 0;
+ InstructionCost FMACost = 0;
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ FastMathFlags FMF;
+ FMF.set();
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ continue;
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
+ }
+ unsigned NumOps = 0;
+ for (auto [V, Op] : zip(VL, Operands.front())) {
+ auto *I = dyn_cast<Instruction>(Op);
+ if (!I || !I->hasOneUse()) {
+ if (auto *OpI = dyn_cast<Instruction>(V))
+ FMACost += TTI.getInstructionCost(OpI, CostKind);
+ if (I)
+ FMACost += TTI.getInstructionCost(I, CostKind);
+ continue;
+ }
+ ++NumOps;
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
+ FMF &= FPCI->getFastMathFlags();
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
+ }
+ Type *Ty = VL.front()->getType();
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
+ FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
+ return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
+}
+
void BoUpSLP::transformNodes() {
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12435,25 @@ void BoUpSLP::transformNodes() {
}
break;
}
+ case Instruction::FSub:
+ case Instruction::FAdd: {
+ // Check if possible to convert (a*b)+c to fma.
+ if (E.State != TreeEntry::Vectorize ||
+ !E.getOperations().isAddSubLikeOp())
+ break;
+ if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
+ .isValid())
+ break;
+ // This node is a fmuladd node.
+ E.CombinedOp = TreeEntry::FMulAdd;
+ TreeEntry *FMulEntry = getOperandEntry(&E, 0);
+ if (FMulEntry->UserTreeIndex &&
+ FMulEntry->State == TreeEntry::Vectorize) {
+ // The FMul node is part of the combined fmuladd node.
+ FMulEntry->State = TreeEntry::CombinedVectorize;
+ }
+ break;
+ }
default:
break;
}
@@ -13587,6 +13686,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
return IntrinsicCost;
};
+ auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
+ Instruction *VI) {
+ InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
+ return Cost;
+ };
switch (ShuffleOrOp) {
case Instruction::PHI: {
// Count reused scalars.
@@ -13927,6 +14031,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
+ case TreeEntry::FMulAdd: {
+ auto GetScalarCost = [&](unsigned Idx) {
+ if (isa<PoisonValue>(UniqueValues[Idx]))
+ return InstructionCost(TTI::TCC_Free);
+ return GetFMulAddCost(E->getOperations(),
+ cast<Instruction>(UniqueValues[Idx]));
+ };
+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
+ FastMathFlags FMF;
+ FMF.set();
+ for (Value *V : E->Scalars) {
+ if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
+ FMF &= FPCI->getFastMathFlags();
+ if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
+ FMF &= FPCIOp->getFastMathFlags();
+ }
+ }
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
+ {VecTy, VecTy, VecTy}, FMF);
+ InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
+ return VecCost + CommonCost;
+ };
+ return GetCostDiff(GetScalarCost, GetVectorCost);
+ }
case Instruction::FNeg:
case Instruction::Add:
case Instruction::FAdd:
@@ -13964,8 +14092,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
- return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
- Op1Info, Op2Info, Operands);
+ InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
+ ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
+ if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
+ I && (ShuffleOrOp == Instruction::FAdd ||
+ ShuffleOrOp == Instruction::FSub)) {
+ InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
+ if (IntrinsicCost.isValid())
+ ScalarCost = IntrinsicCost;
+ }
+ return ScalarCost;
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -22594,11 +22730,21 @@ public:
/// Try to find a reduction tree.
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
ScalarEvolution &SE, const DataLayout &DL,
- const TargetLibraryInfo &TLI) {
+ const TargetLibraryInfo &TLI,
+ DominatorTree &DT, TargetTransformInfo &TTI) {
RdxKind = HorizontalReduction::getRdxKind(Root);
if (!isVectorizable(RdxKind, Root))
return false;
+ // FMA reduction root - skip.
+ auto CheckForFMA = [&](Instruction *I) {
+ return RdxKind == RecurKind::FAdd &&
+ canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
+ .isValid();
+ };
+ if (CheckForFMA(Root))
+ return false;
+
// Analyze "regular" integer/FP types for reductions - no target-specific
// types or pointers.
Type *Ty = Root->getType();
@@ -22636,7 +22782,7 @@ public:
// Also, do not try to reduce const values, if the operation is not
// foldable.
if (!EdgeInst || Level > RecursionMaxDepth ||
- getRdxKind(EdgeInst) != RdxKind ||
+ getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
!isVectorizable(RdxKind, EdgeInst) ||
@@ -24205,13 +24351,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
Stack.emplace(SelectRoot(), 0);
SmallPtrSet<Value *, 8> VisitedInstrs;
bool Res = false;
- auto &&TryToReduce = [this, &R](Instruction *Inst) -> Value * {
+ auto TryToReduce = [this, &R, TTI = TTI](Instruction *Inst) -> Value * {
if (R.isAnalyzedReductionRoot(Inst))
return nullptr;
if (!isReductionCandidate(Inst))
return nullptr;
HorizontalReduction HorRdx;
- if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
+ if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
return nullptr;
return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
};
@@ -24277,6 +24423,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
return false;
+ // Skip potential FMA candidates.
+ if ((I->getOpcode() == Instruction::FAdd ||
+ I->getOpcode() == Instruction::FSub) &&
+ canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
+ .isValid())
+ return false;
Value *P = I->getParent();