diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 193 |
1 files changed, 96 insertions, 97 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index a64b93d..425420f 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. // - if (SM->getNumOperands() == 2) - if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) - if (MulLHS->getAPInt().isPowerOf2()) - if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { - int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - - MulLHS->getAPInt().logBase2(); - Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); - return getMulExpr( - getZeroExtendExpr(MulLHS, Ty), - getZeroExtendExpr( - getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), - SCEV::FlagNUW, Depth + 1); - } + const APInt *C; + const SCEV *TruncRHS; + if (match(SM, + m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) && + C->isPowerOf2()) { + int NewTruncBits = + getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2(); + Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); + return getMulExpr( + getZeroExtendExpr(SM->getOperand(0), Ty), + getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty), + SCEV::FlagNUW, Depth + 1); + } } // zext(umin(x, y)) -> umin(zext(x), zext(y)) @@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { if (Ops.size() == 2) { // C1*(C2+V) -> C1*C2 + C1*V - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) - // If any of Add's ops are Adds or Muls with a constant, apply this - // transformation as well. - // - // TODO: There are some cases where this transformation is not - // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of - // this transformation should be narrowed down. - if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { - const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1); - return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); - } + // If any of Add's ops are Adds or Muls with a constant, apply this + // transformation as well. + // + // TODO: There are some cases where this transformation is not + // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of + // this transformation should be narrowed down. + const SCEV *Op0, *Op1; + if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) && + containsConstantInAddMulChain(Ops[1])) { + const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1); + return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); + } if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the @@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C. - if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS); - AE && AE->getNumOperands() == 2) { - if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) { - const APInt &NegC = VC->getAPInt(); - if (NegC.isNegative() && !NegC.isMinSignedValue()) { - const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1)); - if (MME && MME->getNumOperands() == 2 && - isa<SCEVConstant>(MME->getOperand(0)) && - cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC && - MME->getOperand(1) == RHS) - return getZero(LHS->getType()); - } - } - } + const APInt *NegC, *C; + if (match(LHS, + m_scev_Add(m_scev_APInt(NegC), + m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) && + NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC) + return getZero(LHS->getType()); // TODO: Generalize to handle any common factors. // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b @@ -4623,17 +4614,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); - if (!Add || Add->getNumOperands() != 2 || - !Add->getOperand(0)->isAllOnesValue()) - return nullptr; - - const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2 || - !AddRHS->getOperand(0)->isAllOnesValue()) - return nullptr; - - return AddRHS->getOperand(1); + const SCEV *MulOp; + if (match(Expr, m_scev_Add(m_scev_AllOnes(), + m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp))))) + return MulOp; + return nullptr; } /// Return a SCEV corresponding to ~V = -1-V @@ -10797,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { } static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S); - if (!Add || Add->getNumOperands() != 2) + const SCEV *Op0, *Op1; + if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1)))) return false; - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(1); - RHS = ME->getOperand(1); + if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op1; return true; } - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(0); - RHS = ME->getOperand(1); + if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op0; return true; } return false; @@ -12172,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags) { - const auto *AE = dyn_cast<SCEVAddExpr>(Expr); - if (!AE || AE->getNumOperands() != 2) + if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R)))) return false; - L = AE->getOperand(0); - R = AE->getOperand(1); - Flags = AE->getNoWrapFlags(); + Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags(); return true; } @@ -12220,12 +12198,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { // Try to match a common constant multiply. auto MatchConstMul = [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> { - auto *M = dyn_cast<SCEVMulExpr>(S); - if (!M || M->getNumOperands() != 2 || - !isa<SCEVConstant>(M->getOperand(0))) - return std::nullopt; - return { - {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}}; + const APInt *C; + const SCEV *Op; + if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op)))) + return {{Op, *C}}; + return std::nullopt; }; if (auto MatchedMore = MatchConstMul(More)) { if (auto MatchedLess = MatchConstMul(Less)) { @@ -15557,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock( auto IsMinMaxSCEVWithNonNegativeConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, const SCEV *&RHS) { - if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) { - if (MinMax->getNumOperands() != 2) - return false; - if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) { - if (C->getAPInt().isNegative()) - return false; - SCTy = MinMax->getSCEVType(); - LHS = MinMax->getOperand(0); - RHS = MinMax->getOperand(1); - return true; - } - } - return false; + const APInt *C; + SCTy = Expr->getSCEVType(); + return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) && + match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; // Return a new SCEV that modifies \p Expr to the closest number divides by @@ -15772,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock( GetNextSCEVDividesByDivisor(One, DividesBy); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15918,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ + const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15979,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) + return SE.getUMaxExpr(S, SE.getOne(S->getType())); + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa<SCEVConstant>(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector<const SCEV *, 2> Operands; @@ -16021,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); |