aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/InstructionSimplify.cpp
diff options
context:
space:
mode:
authorNikita Popov <npopov@redhat.com>2024-12-02 09:53:10 +0100
committerGitHub <noreply@github.com>2024-12-02 09:53:10 +0100
commit8201926ec0a61ea182e3b25c23e3dbaae6036dbf (patch)
treed8cfa274451e96d599402361179ca0c3bb5dcc56 /llvm/lib/Analysis/InstructionSimplify.cpp
parenta545cf5c6da6decbde95287f95e1ffce40116d23 (diff)
downloadllvm-8201926ec0a61ea182e3b25c23e3dbaae6036dbf.zip
llvm-8201926ec0a61ea182e3b25c23e3dbaae6036dbf.tar.gz
llvm-8201926ec0a61ea182e3b25c23e3dbaae6036dbf.tar.bz2
[InstSimplify] Generalize simplification of icmps with monotonic operands (#69471)
InstSimplify currently folds patterns like `(x | y) uge x` and `(x & y) ule x` to true. However, it cannot handle combinations of such situations, such as `(x | y) uge (x & z)` etc. To support this, recursively collect operands of monotonic instructions (that preserve either a greater-or-equal or less-or-equal relationship) and then check whether any of them match. Fixes https://github.com/llvm/llvm-project/issues/69333.
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp123
1 files changed, 71 insertions, 52 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 01b0a08..1a5bbbc 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3070,6 +3070,69 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
return nullptr;
}
+enum class MonotonicType { GreaterEq, LowerEq };
+
+/// Get values V_i such that V uge V_i (GreaterEq) or V ule V_i (LowerEq).
+static void getUnsignedMonotonicValues(SmallPtrSetImpl<Value *> &Res, Value *V,
+ MonotonicType Type, unsigned Depth = 0) {
+ if (!Res.insert(V).second)
+ return;
+
+ // Can be increased if useful.
+ if (++Depth > 1)
+ return;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return;
+
+ Value *X, *Y;
+ if (Type == MonotonicType::GreaterEq) {
+ if (match(I, m_Or(m_Value(X), m_Value(Y))) ||
+ match(I, m_Intrinsic<Intrinsic::uadd_sat>(m_Value(X), m_Value(Y)))) {
+ getUnsignedMonotonicValues(Res, X, Type, Depth);
+ getUnsignedMonotonicValues(Res, Y, Type, Depth);
+ }
+ } else {
+ assert(Type == MonotonicType::LowerEq);
+ switch (I->getOpcode()) {
+ case Instruction::And:
+ getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+ getUnsignedMonotonicValues(Res, I->getOperand(1), Type, Depth);
+ break;
+ case Instruction::URem:
+ case Instruction::UDiv:
+ case Instruction::LShr:
+ getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+ break;
+ case Instruction::Call:
+ if (match(I, m_Intrinsic<Intrinsic::usub_sat>(m_Value(X))))
+ getUnsignedMonotonicValues(Res, X, Type, Depth);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred,
+ Value *LHS, Value *RHS) {
+ if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT)
+ return nullptr;
+
+ // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the
+ // GreaterValues and LowerValues are the same, it follows that LHS uge RHS.
+ SmallPtrSet<Value *, 4> GreaterValues;
+ SmallPtrSet<Value *, 4> LowerValues;
+ getUnsignedMonotonicValues(GreaterValues, LHS, MonotonicType::GreaterEq);
+ getUnsignedMonotonicValues(LowerValues, RHS, MonotonicType::LowerEq);
+ for (Value *GV : GreaterValues)
+ if (LowerValues.contains(GV))
+ return ConstantInt::getBool(getCompareTy(LHS),
+ Pred == ICmpInst::ICMP_UGE);
+ return nullptr;
+}
+
static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
BinaryOperator *LBO, Value *RHS,
const SimplifyQuery &Q,
@@ -3079,11 +3142,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
Value *Y = nullptr;
// icmp pred (or X, Y), X
if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_ULT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_UGE)
- return getTrue(ITy);
-
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
@@ -3094,14 +3152,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (and X, Y), X
- if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// icmp pred (urem X, Y), Y
if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
switch (Pred) {
@@ -3132,27 +3182,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (urem X, Y), X
- if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) {
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- }
-
- // x >>u y <=u x --> true.
- // x >>u y >u x --> false.
- // x udiv y <=u x --> true.
- // x udiv y >u x --> false.
- if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
- match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
- // icmp pred (X op Y), X
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// If x is nonzero:
// x >>u C <u x --> true for C != 0.
// x >>u C != x --> true for C != 0.
@@ -3172,14 +3201,12 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_UGE:
+ case ICmpInst::ICMP_UGT:
return getFalse(ITy);
case ICmpInst::ICMP_NE:
case ICmpInst::ICMP_ULT:
- return getTrue(ITy);
- case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_ULE:
- // UGT/ULE are handled by the more general case just above
- llvm_unreachable("Unexpected UGT/ULE, should have been handled");
+ return getTrue(ITy);
}
}
}
@@ -3702,13 +3729,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
switch (II->getIntrinsicID()) {
case Intrinsic::uadd_sat:
- // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y
- if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) {
- if (Pred == ICmpInst::ICMP_UGE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_ULT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
// uadd.sat(X, Y) uge X + Y
if (match(RHS, m_c_Add(m_Specific(II->getArgOperand(0)),
m_Specific(II->getArgOperand(1))))) {
@@ -3719,13 +3739,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
}
return nullptr;
case Intrinsic::usub_sat:
- // usub.sat(X, Y) ule X
- if (II->getArgOperand(0) == RHS) {
- if (Pred == ICmpInst::ICMP_ULE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_UGT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
// usub.sat(X, Y) ule X - Y
if (match(RHS, m_Sub(m_Specific(II->getArgOperand(0)),
m_Specific(II->getArgOperand(1))))) {
@@ -4030,6 +4043,12 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
return V;
+ if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS))
+ return V;
+ if (Value *V = simplifyICmpUsingMonotonicValues(
+ ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
+ return V;
+
if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q))
return V;