diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 24adfa3..477e477 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2682,6 +2682,20 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, return getAddExpr(NewOps, PreservedFlags); } } + + // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext + // (B), if trunc (A) + -A + B does not unsigned-wrap. + const SCEVAddExpr *InnerAdd; + if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) { + const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType()); + if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) && + getZeroExtendExpr(NarrowA, B->getType()) == A && + hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd}, + SCEV::FlagAnyWrap), + SCEV::FlagNUW)) { + return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType()); + } + } } // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y) @@ -11418,8 +11432,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred, XNonConstOp = X; XFlagsPresent = ExpectedFlags; } - if (!isa<SCEVConstant>(XConstOp) || - (XFlagsPresent & ExpectedFlags) != ExpectedFlags) + if (!isa<SCEVConstant>(XConstOp)) return false; if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) { @@ -11428,13 +11441,21 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred, YFlagsPresent = ExpectedFlags; } - if (!isa<SCEVConstant>(YConstOp) || - (YFlagsPresent & ExpectedFlags) != ExpectedFlags) + if (YNonConstOp != XNonConstOp) return false; - if (YNonConstOp != XNonConstOp) + if (!isa<SCEVConstant>(YConstOp)) return false; + // When matching ADDs with NUW flags (and unsigned predicates), only the + // second ADD (with the larger constant) requires NUW. + if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags) + return false; + if (ExpectedFlags != SCEV::FlagNUW && + (XFlagsPresent & ExpectedFlags) != ExpectedFlags) { + return false; + } + OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt(); OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt(); @@ -11472,7 +11493,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred, std::swap(LHS, RHS); [[fallthrough]]; case ICmpInst::ICMP_ULE: - // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2. + // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2. if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2)) return true; @@ -11482,7 +11503,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred, std::swap(LHS, RHS); [[fallthrough]]; case ICmpInst::ICMP_ULT: - // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2. + // (X + C1) u< (X + C2)<nuw> if C1 u< C2. if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2)) return true; break; |