aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r--llvm/lib/Analysis/DependenceAnalysis.cpp19
-rw-r--r--llvm/lib/Analysis/RegionPrinter.cpp11
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp137
3 files changed, 159 insertions, 8 deletions
diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp
index 11d8294..e45d1f7 100644
--- a/llvm/lib/Analysis/DependenceAnalysis.cpp
+++ b/llvm/lib/Analysis/DependenceAnalysis.cpp
@@ -1587,6 +1587,15 @@ static const SCEV *minusSCEVNoSignedOverflow(const SCEV *A, const SCEV *B,
return nullptr;
}
+/// Returns \p A * \p B if it guaranteed not to signed wrap. Otherwise returns
+/// nullptr. \p A and \p B must have the same integer type.
+static const SCEV *mulSCEVNoSignedOverflow(const SCEV *A, const SCEV *B,
+ ScalarEvolution &SE) {
+ if (SE.willNotOverflow(Instruction::Mul, /*Signed=*/true, A, B))
+ return SE.getMulExpr(A, B);
+ return nullptr;
+}
+
/// Returns the absolute value of \p A. In the context of dependence analysis,
/// we need an absolute value in a mathematical sense. If \p A is the signed
/// minimum value, we cannot represent it unless extending the original type.
@@ -1686,7 +1695,11 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst,
assert(0 < Level && Level <= CommonLevels && "level out of range");
Level--;
- const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst);
+ const SCEV *Delta = minusSCEVNoSignedOverflow(SrcConst, DstConst, *SE);
+ if (!Delta) {
+ Result.Consistent = false;
+ return false;
+ }
LLVM_DEBUG(dbgs() << "\t Delta = " << *Delta);
LLVM_DEBUG(dbgs() << ", " << *Delta->getType() << "\n");
@@ -1702,7 +1715,9 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst,
const SCEV *AbsCoeff = absSCEVNoSignedOverflow(Coeff, *SE);
if (!AbsDelta || !AbsCoeff)
return false;
- const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff);
+ const SCEV *Product = mulSCEVNoSignedOverflow(UpperBound, AbsCoeff, *SE);
+ if (!Product)
+ return false;
return isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product);
}();
if (IsDeltaLarge) {
diff --git a/llvm/lib/Analysis/RegionPrinter.cpp b/llvm/lib/Analysis/RegionPrinter.cpp
index a83af4e..33e073b 100644
--- a/llvm/lib/Analysis/RegionPrinter.cpp
+++ b/llvm/lib/Analysis/RegionPrinter.cpp
@@ -29,10 +29,9 @@ onlySimpleRegions("only-simple-regions",
cl::Hidden,
cl::init(false));
-namespace llvm {
-
-std::string DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node,
- RegionNode *Graph) {
+std::string
+llvm::DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node,
+ RegionNode *Graph) {
if (!Node->isSubRegion()) {
BasicBlock *BB = Node->getNodeAs<BasicBlock>();
@@ -46,7 +45,8 @@ std::string DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node,
}
template <>
-struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> {
+struct llvm::DOTGraphTraits<RegionInfo *>
+ : public llvm::DOTGraphTraits<RegionNode *> {
DOTGraphTraits (bool isSimple = false)
: DOTGraphTraits<RegionNode*>(isSimple) {}
@@ -125,7 +125,6 @@ struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> {
printRegionCluster(*G->getTopLevelRegion(), GW, 4);
}
};
-} // end namespace llvm
namespace {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 789a983..41ff816 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
return V->getType()->getScalarSizeInBits() - SignBits + 1;
}
+/// Try to detect the lerp pattern: a * (b - c) + c * d
+/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
+///
+/// In that particular case, we can use the following chain of reasoning:
+///
+/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
+///
+/// Since that is true for arbitrary a, b, c and d within our constraints, we
+/// can conclude that:
+///
+/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
+///
+/// Considering that any result of the lerp would be less or equal to U, it
+/// would have at least the number of leading 0s as in U.
+///
+/// While being quite a specific situation, it is fairly common in computer
+/// graphics in the shape of alpha blending.
+///
+/// Modifies given KnownOut in-place with the inferred information.
+static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
+ const APInt &DemandedElts,
+ KnownBits &KnownOut,
+ const SimplifyQuery &Q,
+ unsigned Depth) {
+
+ Type *Ty = Op0->getType();
+ const unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // Only handle scalar types for now
+ if (Ty->isVectorTy())
+ return;
+
+ // Try to match: a * (b - c) + c * d.
+ // When a == 1 => A == nullptr, the same applies to d/D as well.
+ const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+ const Instruction *SubBC = nullptr;
+
+ const auto MatchSubBC = [&]() {
+ // (b - c) can have two forms that interest us:
+ //
+ // 1. sub nuw %b, %c
+ // 2. xor %c, %b
+ //
+ // For the first case, nuw flag guarantees our requirement b >= c.
+ //
+ // The second case might happen when the analysis can infer that b is a mask
+ // for c and we can transform sub operation into xor (that is usually true
+ // for constant b's). Even though xor is symmetrical, canonicalization
+ // ensures that the constant will be the RHS. We have additional checks
+ // later on to ensure that this xor operation is equivalent to subtraction.
+ return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
+ m_Xor(m_Value(C), m_Value(B))));
+ };
+
+ const auto MatchASubBC = [&]() {
+ // Cases:
+ // - a * (b - c)
+ // - (b - c) * a
+ // - (b - c) <- a implicitly equals 1
+ return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
+ };
+
+ const auto MatchCD = [&]() {
+ // Cases:
+ // - d * c
+ // - c * d
+ // - c <- d implicitly equals 1
+ return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
+ };
+
+ const auto Match = [&](const Value *LHS, const Value *RHS) {
+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
+ // has to evaluate first and return true.
+ //
+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
+ return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
+ };
+
+ if (!Match(Op0, Op1) && !Match(Op1, Op0))
+ return;
+
+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
+ // For some of the values we use the convention of leaving
+ // it nullptr to signify an implicit constant 1.
+ return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+ : KnownBits::makeConstant(APInt(BitWidth, 1));
+ };
+
+ // Check that all operands are non-negative
+ const KnownBits KnownA = ComputeKnownBitsOrOne(A);
+ if (!KnownA.isNonNegative())
+ return;
+
+ const KnownBits KnownD = ComputeKnownBitsOrOne(D);
+ if (!KnownD.isNonNegative())
+ return;
+
+ const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+ if (!KnownB.isNonNegative())
+ return;
+
+ const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+ if (!KnownC.isNonNegative())
+ return;
+
+ // If we matched subtraction as xor, we need to actually check that xor
+ // is semantically equivalent to subtraction.
+ //
+ // For that to be true, b has to be a mask for c or that b's known
+ // ones cover all known and possible ones of c.
+ if (SubBC->getOpcode() == Instruction::Xor &&
+ !KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue()))
+ return;
+
+ const APInt MaxA = KnownA.getMaxValue();
+ const APInt MaxD = KnownD.getMaxValue();
+ const APInt MaxAD = APIntOps::umax(MaxA, MaxD);
+ const APInt MaxB = KnownB.getMaxValue();
+
+ // We can't infer leading zeros info if the upper-bound estimate wraps.
+ bool Overflow;
+ const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
+
+ if (Overflow)
+ return;
+
+ // If we know that x <= y and both are positive than x has at least the same
+ // number of leading zeros as y.
+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
+ KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
+}
+
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
bool NSW, bool NUW,
const APInt &DemandedElts,
@@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
.value_or(false))
KnownOut.makeNonNegative();
+
+ if (Add)
+ // Try to match lerp pattern and combine results
+ computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
}
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,