diff options
Diffstat (limited to 'llvm/lib/Analysis')
| -rw-r--r-- | llvm/lib/Analysis/DependenceAnalysis.cpp | 19 | ||||
| -rw-r--r-- | llvm/lib/Analysis/RegionPrinter.cpp | 11 | ||||
| -rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 137 |
3 files changed, 159 insertions, 8 deletions
diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index 11d8294..e45d1f7 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -1587,6 +1587,15 @@ static const SCEV *minusSCEVNoSignedOverflow(const SCEV *A, const SCEV *B, return nullptr; } +/// Returns \p A * \p B if it guaranteed not to signed wrap. Otherwise returns +/// nullptr. \p A and \p B must have the same integer type. +static const SCEV *mulSCEVNoSignedOverflow(const SCEV *A, const SCEV *B, + ScalarEvolution &SE) { + if (SE.willNotOverflow(Instruction::Mul, /*Signed=*/true, A, B)) + return SE.getMulExpr(A, B); + return nullptr; +} + /// Returns the absolute value of \p A. In the context of dependence analysis, /// we need an absolute value in a mathematical sense. If \p A is the signed /// minimum value, we cannot represent it unless extending the original type. @@ -1686,7 +1695,11 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, assert(0 < Level && Level <= CommonLevels && "level out of range"); Level--; - const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst); + const SCEV *Delta = minusSCEVNoSignedOverflow(SrcConst, DstConst, *SE); + if (!Delta) { + Result.Consistent = false; + return false; + } LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta); LLVM_DEBUG(dbgs() << ", " << *Delta->getType() << "\n"); @@ -1702,7 +1715,9 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, const SCEV *AbsCoeff = absSCEVNoSignedOverflow(Coeff, *SE); if (!AbsDelta || !AbsCoeff) return false; - const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff); + const SCEV *Product = mulSCEVNoSignedOverflow(UpperBound, AbsCoeff, *SE); + if (!Product) + return false; return isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product); }(); if (IsDeltaLarge) { diff --git a/llvm/lib/Analysis/RegionPrinter.cpp b/llvm/lib/Analysis/RegionPrinter.cpp index a83af4e..33e073b 100644 --- a/llvm/lib/Analysis/RegionPrinter.cpp +++ b/llvm/lib/Analysis/RegionPrinter.cpp @@ -29,10 +29,9 @@ onlySimpleRegions("only-simple-regions", cl::Hidden, cl::init(false)); -namespace llvm { - -std::string DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node, - RegionNode *Graph) { +std::string +llvm::DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node, + RegionNode *Graph) { if (!Node->isSubRegion()) { BasicBlock *BB = Node->getNodeAs<BasicBlock>(); @@ -46,7 +45,8 @@ std::string DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node, } template <> -struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { +struct llvm::DOTGraphTraits<RegionInfo *> + : public llvm::DOTGraphTraits<RegionNode *> { DOTGraphTraits (bool isSimple = false) : DOTGraphTraits<RegionNode*>(isSimple) {} @@ -125,7 +125,6 @@ struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { printRegionCluster(*G->getTopLevelRegion(), GW, 4); } }; -} // end namespace llvm namespace { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 789a983..41ff816 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, return V->getType()->getScalarSizeInBits() - SignBits + 1; } +/// Try to detect the lerp pattern: a * (b - c) + c * d +/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c. +/// +/// In that particular case, we can use the following chain of reasoning: +/// +/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) +/// +/// Since that is true for arbitrary a, b, c and d within our constraints, we +/// can conclude that: +/// +/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U +/// +/// Considering that any result of the lerp would be less or equal to U, it +/// would have at least the number of leading 0s as in U. +/// +/// While being quite a specific situation, it is fairly common in computer +/// graphics in the shape of alpha blending. +/// +/// Modifies given KnownOut in-place with the inferred information. +static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1, + const APInt &DemandedElts, + KnownBits &KnownOut, + const SimplifyQuery &Q, + unsigned Depth) { + + Type *Ty = Op0->getType(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + + // Only handle scalar types for now + if (Ty->isVectorTy()) + return; + + // Try to match: a * (b - c) + c * d. + // When a == 1 => A == nullptr, the same applies to d/D as well. + const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + const Instruction *SubBC = nullptr; + + const auto MatchSubBC = [&]() { + // (b - c) can have two forms that interest us: + // + // 1. sub nuw %b, %c + // 2. xor %c, %b + // + // For the first case, nuw flag guarantees our requirement b >= c. + // + // The second case might happen when the analysis can infer that b is a mask + // for c and we can transform sub operation into xor (that is usually true + // for constant b's). Even though xor is symmetrical, canonicalization + // ensures that the constant will be the RHS. We have additional checks + // later on to ensure that this xor operation is equivalent to subtraction. + return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), + m_Xor(m_Value(C), m_Value(B)))); + }; + + const auto MatchASubBC = [&]() { + // Cases: + // - a * (b - c) + // - (b - c) * a + // - (b - c) <- a implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC()); + }; + + const auto MatchCD = [&]() { + // Cases: + // - d * c + // - c * d + // - c <- d implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C)); + }; + + const auto Match = [&](const Value *LHS, const Value *RHS) { + // We do use m_Specific(C) in MatchCD, so we have to make sure that + // it's bound to anything and match(LHS, MatchASubBC()) absolutely + // has to evaluate first and return true. + // + // If Match returns true, it is guaranteed that B != nullptr, C != nullptr. + return match(LHS, MatchASubBC()) && match(RHS, MatchCD()); + }; + + if (!Match(Op0, Op1) && !Match(Op1, Op0)) + return; + + const auto ComputeKnownBitsOrOne = [&](const Value *V) { + // For some of the values we use the convention of leaving + // it nullptr to signify an implicit constant 1. + return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1) + : KnownBits::makeConstant(APInt(BitWidth, 1)); + }; + + // Check that all operands are non-negative + const KnownBits KnownA = ComputeKnownBitsOrOne(A); + if (!KnownA.isNonNegative()) + return; + + const KnownBits KnownD = ComputeKnownBitsOrOne(D); + if (!KnownD.isNonNegative()) + return; + + const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1); + if (!KnownB.isNonNegative()) + return; + + const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1); + if (!KnownC.isNonNegative()) + return; + + // If we matched subtraction as xor, we need to actually check that xor + // is semantically equivalent to subtraction. + // + // For that to be true, b has to be a mask for c or that b's known + // ones cover all known and possible ones of c. + if (SubBC->getOpcode() == Instruction::Xor && + !KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue())) + return; + + const APInt MaxA = KnownA.getMaxValue(); + const APInt MaxD = KnownD.getMaxValue(); + const APInt MaxAD = APIntOps::umax(MaxA, MaxD); + const APInt MaxB = KnownB.getMaxValue(); + + // We can't infer leading zeros info if the upper-bound estimate wraps. + bool Overflow; + const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow); + + if (Overflow) + return; + + // If we know that x <= y and both are positive than x has at least the same + // number of leading zeros as y. + const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero(); + KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros); +} + static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, bool NUW, const APInt &DemandedElts, @@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL) .value_or(false)) KnownOut.makeNonNegative(); + + if (Add) + // Try to match lerp pattern and combine results + computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth); } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, |
