diff options
Diffstat (limited to 'llvm/lib/Analysis/ValueTracking.cpp')
| -rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 789a983..41ff816 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, return V->getType()->getScalarSizeInBits() - SignBits + 1; } +/// Try to detect the lerp pattern: a * (b - c) + c * d +/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c. +/// +/// In that particular case, we can use the following chain of reasoning: +/// +/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) +/// +/// Since that is true for arbitrary a, b, c and d within our constraints, we +/// can conclude that: +/// +/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U +/// +/// Considering that any result of the lerp would be less or equal to U, it +/// would have at least the number of leading 0s as in U. +/// +/// While being quite a specific situation, it is fairly common in computer +/// graphics in the shape of alpha blending. +/// +/// Modifies given KnownOut in-place with the inferred information. +static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1, + const APInt &DemandedElts, + KnownBits &KnownOut, + const SimplifyQuery &Q, + unsigned Depth) { + + Type *Ty = Op0->getType(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + + // Only handle scalar types for now + if (Ty->isVectorTy()) + return; + + // Try to match: a * (b - c) + c * d. + // When a == 1 => A == nullptr, the same applies to d/D as well. + const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + const Instruction *SubBC = nullptr; + + const auto MatchSubBC = [&]() { + // (b - c) can have two forms that interest us: + // + // 1. sub nuw %b, %c + // 2. xor %c, %b + // + // For the first case, nuw flag guarantees our requirement b >= c. + // + // The second case might happen when the analysis can infer that b is a mask + // for c and we can transform sub operation into xor (that is usually true + // for constant b's). Even though xor is symmetrical, canonicalization + // ensures that the constant will be the RHS. We have additional checks + // later on to ensure that this xor operation is equivalent to subtraction. + return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), + m_Xor(m_Value(C), m_Value(B)))); + }; + + const auto MatchASubBC = [&]() { + // Cases: + // - a * (b - c) + // - (b - c) * a + // - (b - c) <- a implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC()); + }; + + const auto MatchCD = [&]() { + // Cases: + // - d * c + // - c * d + // - c <- d implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C)); + }; + + const auto Match = [&](const Value *LHS, const Value *RHS) { + // We do use m_Specific(C) in MatchCD, so we have to make sure that + // it's bound to anything and match(LHS, MatchASubBC()) absolutely + // has to evaluate first and return true. + // + // If Match returns true, it is guaranteed that B != nullptr, C != nullptr. + return match(LHS, MatchASubBC()) && match(RHS, MatchCD()); + }; + + if (!Match(Op0, Op1) && !Match(Op1, Op0)) + return; + + const auto ComputeKnownBitsOrOne = [&](const Value *V) { + // For some of the values we use the convention of leaving + // it nullptr to signify an implicit constant 1. + return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1) + : KnownBits::makeConstant(APInt(BitWidth, 1)); + }; + + // Check that all operands are non-negative + const KnownBits KnownA = ComputeKnownBitsOrOne(A); + if (!KnownA.isNonNegative()) + return; + + const KnownBits KnownD = ComputeKnownBitsOrOne(D); + if (!KnownD.isNonNegative()) + return; + + const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1); + if (!KnownB.isNonNegative()) + return; + + const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1); + if (!KnownC.isNonNegative()) + return; + + // If we matched subtraction as xor, we need to actually check that xor + // is semantically equivalent to subtraction. + // + // For that to be true, b has to be a mask for c or that b's known + // ones cover all known and possible ones of c. + if (SubBC->getOpcode() == Instruction::Xor && + !KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue())) + return; + + const APInt MaxA = KnownA.getMaxValue(); + const APInt MaxD = KnownD.getMaxValue(); + const APInt MaxAD = APIntOps::umax(MaxA, MaxD); + const APInt MaxB = KnownB.getMaxValue(); + + // We can't infer leading zeros info if the upper-bound estimate wraps. + bool Overflow; + const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow); + + if (Overflow) + return; + + // If we know that x <= y and both are positive than x has at least the same + // number of leading zeros as y. + const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero(); + KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros); +} + static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, bool NUW, const APInt &DemandedElts, @@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL) .value_or(false)) KnownOut.makeNonNegative(); + + if (Add) + // Try to match lerp pattern and combine results + computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth); } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, |
