aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp16
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp6
-rw-r--r--llvm/lib/IR/ConstantFPRange.cpp70
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp3
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.cpp86
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp19
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp12
7 files changed, 135 insertions, 77 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b23b190..b1accdd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17086,11 +17086,6 @@ static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
N->getFlags().hasAllowContract();
}
-// Returns true if `N` can assume no infinities involved in its computation.
-static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
- return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
-}
-
/// Try to perform FMA combining on a given FADD node.
template <class MatchContextClass>
SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
@@ -17666,7 +17661,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
// The transforms below are incorrect when x == 0 and y == inf, because the
// intermediate multiplication produces a nan.
SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
- if (!hasNoInfs(Options, FAdd))
+ if (!FAdd->getFlags().hasNoInfs())
return SDValue();
// Floating-point multiply-add without intermediate rounding.
@@ -18343,7 +18338,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
}
- if ((Options.NoNaNsFPMath && Options.NoInfsFPMath) ||
+ if ((Options.NoNaNsFPMath && N->getFlags().hasNoInfs()) ||
(N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs())) {
if (N->getFlags().hasNoSignedZeros() ||
(N2CFP && !N2CFP->isExactlyValue(-0.0))) {
@@ -18533,7 +18528,6 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
SDLoc DL(N);
- const TargetOptions &Options = DAG.getTarget().Options;
SDNodeFlags Flags = N->getFlags();
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
@@ -18644,7 +18638,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
}
// Fold into a reciprocal estimate and multiply instead of a real divide.
- if (Options.NoInfsFPMath || Flags.hasNoInfs())
+ if (Flags.hasNoInfs())
if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
return RV;
}
@@ -18721,12 +18715,10 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {
SDValue DAGCombiner::visitFSQRT(SDNode *N) {
SDNodeFlags Flags = N->getFlags();
- const TargetOptions &Options = DAG.getTarget().Options;
// Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
// sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
- if (!Flags.hasApproximateFuncs() ||
- (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
+ if (!Flags.hasApproximateFuncs() || !Flags.hasNoInfs())
return SDValue();
SDValue N0 = N->getOperand(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6ea2e27..08af74c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5767,11 +5767,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
// even if the nonan flag is dropped somewhere.
unsigned CCOp = Opcode == ISD::SETCC ? 2 : 4;
ISD::CondCode CCCode = cast<CondCodeSDNode>(Op.getOperand(CCOp))->get();
- if (((unsigned)CCCode & 0x10U))
- return true;
-
- const TargetOptions &Options = getTarget().Options;
- return Options.NoNaNsFPMath || Options.NoInfsFPMath;
+ return (unsigned)CCCode & 0x10U;
}
case ISD::OR:
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 070e833..51d2e21 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -414,15 +414,31 @@ ConstantFPRange ConstantFPRange::negate() const {
return ConstantFPRange(-Upper, -Lower, MayBeQNaN, MayBeSNaN);
}
+/// Return true if the finite part is not empty after removing infinities.
+static bool removeInf(APFloat &Lower, APFloat &Upper, bool &HasPosInf,
+ bool &HasNegInf) {
+ assert(strictCompare(Lower, Upper) != APFloat::cmpGreaterThan &&
+ "Non-NaN part is empty.");
+ auto &Sem = Lower.getSemantics();
+ if (Lower.isNegInfinity()) {
+ Lower = APFloat::getLargest(Sem, /*Negative=*/true);
+ HasNegInf = true;
+ }
+ if (Upper.isPosInfinity()) {
+ Upper = APFloat::getLargest(Sem, /*Negative=*/false);
+ HasPosInf = true;
+ }
+ return strictCompare(Lower, Upper) != APFloat::cmpGreaterThan;
+}
+
ConstantFPRange ConstantFPRange::getWithoutInf() const {
if (isNaNOnly())
return *this;
APFloat NewLower = Lower;
APFloat NewUpper = Upper;
- if (Lower.isNegInfinity())
- NewLower = APFloat::getLargest(getSemantics(), /*Negative=*/true);
- if (Upper.isPosInfinity())
- NewUpper = APFloat::getLargest(getSemantics(), /*Negative=*/false);
+ bool UnusedFlag;
+ removeInf(NewLower, NewUpper, /*HasPosInf=*/UnusedFlag,
+ /*HasNegInf=*/UnusedFlag);
canonicalizeRange(NewLower, NewUpper);
return ConstantFPRange(std::move(NewLower), std::move(NewUpper), MayBeQNaN,
MayBeSNaN);
@@ -444,3 +460,49 @@ ConstantFPRange ConstantFPRange::cast(const fltSemantics &DstSem,
/*MayBeQNaNVal=*/MayBeQNaN || MayBeSNaN,
/*MayBeSNaNVal=*/false);
}
+
+ConstantFPRange ConstantFPRange::add(const ConstantFPRange &Other) const {
+ bool ResMayBeQNaN = ((MayBeQNaN || MayBeSNaN) && !Other.isEmptySet()) ||
+ ((Other.MayBeQNaN || Other.MayBeSNaN) && !isEmptySet());
+ if (isNaNOnly() || Other.isNaNOnly())
+ return getNaNOnly(getSemantics(), /*MayBeQNaN=*/ResMayBeQNaN,
+ /*MayBeSNaN=*/false);
+ bool LHSHasNegInf = false, LHSHasPosInf = false;
+ APFloat LHSLower = Lower, LHSUpper = Upper;
+ bool LHSFiniteIsNonEmpty =
+ removeInf(LHSLower, LHSUpper, LHSHasPosInf, LHSHasNegInf);
+ bool RHSHasNegInf = false, RHSHasPosInf = false;
+ APFloat RHSLower = Other.Lower, RHSUpper = Other.Upper;
+ bool RHSFiniteIsNonEmpty =
+ removeInf(RHSLower, RHSUpper, RHSHasPosInf, RHSHasNegInf);
+ // -inf + +inf = QNaN
+ ResMayBeQNaN |=
+ (LHSHasNegInf && RHSHasPosInf) || (LHSHasPosInf && RHSHasNegInf);
+ // +inf + finite/+inf = +inf, -inf + finite/-inf = -inf
+ bool HasNegInf = (LHSHasNegInf && (RHSFiniteIsNonEmpty || RHSHasNegInf)) ||
+ (RHSHasNegInf && (LHSFiniteIsNonEmpty || LHSHasNegInf));
+ bool HasPosInf = (LHSHasPosInf && (RHSFiniteIsNonEmpty || RHSHasPosInf)) ||
+ (RHSHasPosInf && (LHSFiniteIsNonEmpty || LHSHasPosInf));
+ if (LHSFiniteIsNonEmpty && RHSFiniteIsNonEmpty) {
+ APFloat NewLower =
+ HasNegInf ? APFloat::getInf(LHSLower.getSemantics(), /*Negative=*/true)
+ : LHSLower + RHSLower;
+ APFloat NewUpper =
+ HasPosInf ? APFloat::getInf(LHSUpper.getSemantics(), /*Negative=*/false)
+ : LHSUpper + RHSUpper;
+ return ConstantFPRange(NewLower, NewUpper, ResMayBeQNaN,
+ /*MayBeSNaN=*/false);
+ }
+ // If both HasNegInf and HasPosInf are false, the non-NaN part is empty.
+ // We just return the canonical form [+inf, -inf] for the empty non-NaN set.
+ return ConstantFPRange(
+ APFloat::getInf(Lower.getSemantics(), /*Negative=*/HasNegInf),
+ APFloat::getInf(Upper.getSemantics(), /*Negative=*/!HasPosInf),
+ ResMayBeQNaN,
+ /*MayBeSNaN=*/false);
+}
+
+ConstantFPRange ConstantFPRange::sub(const ConstantFPRange &Other) const {
+ // fsub X, Y = fadd X, (fneg Y)
+ return add(Other.negate());
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbce3b0..6965116 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19093,7 +19093,8 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
SDValue Ext1 = Op1.getOperand(0);
if (Ext0.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
Ext1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
- Ext0.getOperand(0) != Ext1.getOperand(0))
+ Ext0.getOperand(0) != Ext1.getOperand(0) ||
+ Ext0.getOperand(0).getValueType().isScalableVector())
return SDValue();
// Check that the type is twice the add types, and the extract are from
// upper/lower parts of the same source.
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index b8761d97..30dfcf2b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -5064,17 +5064,15 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool RenamableSrc) const {
if (AArch64::GPR32spRegClass.contains(DestReg) &&
(AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
-
if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) {
// If either operand is WSP, expand to ADD #0.
if (Subtarget.hasZeroCycleRegMoveGPR64() &&
!Subtarget.hasZeroCycleRegMoveGPR32()) {
// Cyclone recognizes "ADD Xd, Xn, #0" as a zero-cycle register move.
- MCRegister DestRegX = TRI->getMatchingSuperReg(
- DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
- MCRegister SrcRegX = TRI->getMatchingSuperReg(
- SrcReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
+ MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
// This instruction is reading and writing X registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegX, but a proper
@@ -5097,14 +5095,14 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
} else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
!Subtarget.hasZeroCycleRegMoveGPR32()) {
// Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
- MCRegister DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
+ MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
assert(DestRegX.isValid() && "Destination super-reg not valid");
MCRegister SrcRegX =
SrcReg == AArch64::WZR
? AArch64::XZR
- : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
+ : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
assert(SrcRegX.isValid() && "Source super-reg not valid");
// This instruction is reading and writing X registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
@@ -5334,11 +5332,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (Subtarget.hasZeroCycleRegMoveFPR128() &&
!Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::dsub,
- &AArch64::FPR128RegClass);
- MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::dsub,
- &AArch64::FPR128RegClass);
+ MCRegister DestRegQ = RI.getMatchingSuperReg(DestReg, AArch64::dsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = RI.getMatchingSuperReg(SrcReg, AArch64::dsub,
+ &AArch64::FPR128RegClass);
// This instruction is reading and writing Q registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegQ, but a proper
@@ -5359,11 +5356,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (Subtarget.hasZeroCycleRegMoveFPR128() &&
!Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::ssub,
- &AArch64::FPR128RegClass);
- MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::ssub,
- &AArch64::FPR128RegClass);
+ MCRegister DestRegQ = RI.getMatchingSuperReg(DestReg, AArch64::ssub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = RI.getMatchingSuperReg(SrcReg, AArch64::ssub,
+ &AArch64::FPR128RegClass);
// This instruction is reading and writing Q registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegQ, but a proper
@@ -5374,11 +5370,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
} else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::ssub,
- &AArch64::FPR64RegClass);
- MCRegister SrcRegD = TRI->getMatchingSuperReg(SrcReg, AArch64::ssub,
- &AArch64::FPR64RegClass);
+ MCRegister DestRegD = RI.getMatchingSuperReg(DestReg, AArch64::ssub,
+ &AArch64::FPR64RegClass);
+ MCRegister SrcRegD = RI.getMatchingSuperReg(SrcReg, AArch64::ssub,
+ &AArch64::FPR64RegClass);
// This instruction is reading and writing D registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegD, but a proper
@@ -5398,11 +5393,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (Subtarget.hasZeroCycleRegMoveFPR128() &&
!Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::hsub,
- &AArch64::FPR128RegClass);
- MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::hsub,
- &AArch64::FPR128RegClass);
+ MCRegister DestRegQ = RI.getMatchingSuperReg(DestReg, AArch64::hsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = RI.getMatchingSuperReg(SrcReg, AArch64::hsub,
+ &AArch64::FPR128RegClass);
// This instruction is reading and writing Q registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegQ, but a proper
@@ -5413,11 +5407,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
} else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::hsub,
- &AArch64::FPR64RegClass);
- MCRegister SrcRegD = TRI->getMatchingSuperReg(SrcReg, AArch64::hsub,
- &AArch64::FPR64RegClass);
+ MCRegister DestRegD = RI.getMatchingSuperReg(DestReg, AArch64::hsub,
+ &AArch64::FPR64RegClass);
+ MCRegister SrcRegD = RI.getMatchingSuperReg(SrcReg, AArch64::hsub,
+ &AArch64::FPR64RegClass);
// This instruction is reading and writing D registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegD, but a proper
@@ -5441,11 +5434,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (Subtarget.hasZeroCycleRegMoveFPR128() &&
!Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR64() && Subtarget.isNeonAvailable()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::bsub,
- &AArch64::FPR128RegClass);
- MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::bsub,
- &AArch64::FPR128RegClass);
+ MCRegister DestRegQ = RI.getMatchingSuperReg(DestReg, AArch64::bsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = RI.getMatchingSuperReg(SrcReg, AArch64::bsub,
+ &AArch64::FPR128RegClass);
// This instruction is reading and writing Q registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegQ, but a proper
@@ -5456,11 +5448,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
} else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
!Subtarget.hasZeroCycleRegMoveFPR32()) {
- const TargetRegisterInfo *TRI = &getRegisterInfo();
- MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::bsub,
- &AArch64::FPR64RegClass);
- MCRegister SrcRegD = TRI->getMatchingSuperReg(SrcReg, AArch64::bsub,
- &AArch64::FPR64RegClass);
+ MCRegister DestRegD = RI.getMatchingSuperReg(DestReg, AArch64::bsub,
+ &AArch64::FPR64RegClass);
+ MCRegister SrcRegD = RI.getMatchingSuperReg(SrcReg, AArch64::bsub,
+ &AArch64::FPR64RegClass);
// This instruction is reading and writing D registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegD, but a proper
@@ -5532,9 +5523,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
}
#ifndef NDEBUG
- const TargetRegisterInfo &TRI = getRegisterInfo();
- errs() << TRI.getRegAsmName(DestReg) << " = COPY "
- << TRI.getRegAsmName(SrcReg) << "\n";
+ errs() << RI.getRegAsmName(DestReg) << " = COPY " << RI.getRegAsmName(SrcReg)
+ << "\n";
#endif
llvm_unreachable("unimplemented reg-to-reg copy");
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e62d57e..50136a8 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9348,13 +9348,12 @@ static SmallVector<Instruction *> preparePlanForEpilogueVectorLoop(
VPBasicBlock *Header = VectorLoop->getEntryBasicBlock();
Header->setName("vec.epilog.vector.body");
- // Ensure that the start values for all header phi recipes are updated before
- // vectorizing the epilogue loop.
VPCanonicalIVPHIRecipe *IV = Plan.getCanonicalIV();
- // When vectorizing the epilogue loop, the canonical induction start
- // value needs to be changed from zero to the value after the main
- // vector loop. Find the resume value created during execution of the main
- // VPlan. It must be the first phi in the loop preheader.
+ // When vectorizing the epilogue loop, the canonical induction needs to be
+ // adjusted by the value after the main vector loop. Find the resume value
+ // created during execution of the main VPlan. It must be the first phi in the
+ // loop preheader. Use the value to increment the canonical IV, and update all
+ // users in the loop region to use the adjusted value.
// FIXME: Improve modeling for canonical IV start values in the epilogue
// loop.
using namespace llvm::PatternMatch;
@@ -9389,10 +9388,16 @@ static SmallVector<Instruction *> preparePlanForEpilogueVectorLoop(
}) &&
"the canonical IV should only be used by its increment or "
"ScalarIVSteps when resetting the start value");
- IV->setOperand(0, VPV);
+ VPBuilder Builder(Header, Header->getFirstNonPhi());
+ VPInstruction *Add = Builder.createNaryOp(Instruction::Add, {IV, VPV});
+ IV->replaceAllUsesWith(Add);
+ Add->setOperand(0, IV);
DenseMap<Value *, Value *> ToFrozen;
SmallVector<Instruction *> InstsToMove;
+ // Ensure that the start values for all header phi recipes are updated before
+ // vectorizing the epilogue loop. Skip the canonical IV, which has been
+ // handled above.
for (VPRecipeBase &R : drop_begin(Header->phis())) {
Value *ResumeV = nullptr;
// TODO: Move setting of resume values to prepareToExecute.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c8a2d84..7563cd7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1234,6 +1234,18 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
if (!Plan->isUnrolled())
return;
+ // Hoist an invariant increment Y of a phi X, by having X start at Y.
+ if (match(Def, m_c_Add(m_VPValue(X), m_VPValue(Y))) && Y->isLiveIn() &&
+ isa<VPPhi>(X)) {
+ auto *Phi = cast<VPPhi>(X);
+ if (Phi->getOperand(1) != Def && match(Phi->getOperand(0), m_ZeroInt()) &&
+ Phi->getNumUsers() == 1 && (*Phi->user_begin() == &R)) {
+ Phi->setOperand(0, Y);
+ Def->replaceAllUsesWith(Phi);
+ return;
+ }
+ }
+
// VPVectorPointer for part 0 can be replaced by their start pointer.
if (auto *VecPtr = dyn_cast<VPVectorPointerRecipe>(&R)) {
if (VecPtr->isFirstPart()) {