From 678f32ab66508aea2068a5e4e07d53b71ce5cf31 Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Wed, 20 Mar 2024 22:08:22 -0500 Subject: [ValueTracking] Add more conditions in to `isTruePredicate` There is one notable "regression". This patch replaces the bespoke `or disjoint` logic we a direct match. This means we fail some simplification during `instsimplify`. All the cases we fail in `instsimplify` we do handle in `instcombine` as we add `disjoint` flags. Other than that, just some basic cases. See proofs: https://alive2.llvm.org/ce/z/_-g7C8 Closes #86083 --- llvm/lib/Analysis/ValueTracking.cpp | 89 +++++++++++++++++----------- llvm/test/Transforms/InstCombine/implies.ll | 36 ++++------- llvm/test/Transforms/InstSimplify/implies.ll | 16 ++++- 3 files changed, 77 insertions(+), 64 deletions(-) (limited to 'llvm') diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 33a6986..5ad4da4 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -8393,8 +8393,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P, /// Return true if "icmp Pred LHS RHS" is always true. static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, - const Value *RHS, const DataLayout &DL, - unsigned Depth) { + const Value *RHS) { if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS) return true; @@ -8406,8 +8405,26 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, const APInt *C; // LHS s<= LHS +_{nsw} C if C >= 0 - if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C)))) + // LHS s<= LHS | C if C >= 0 + if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))) || + match(RHS, m_Or(m_Specific(LHS), m_APInt(C)))) return !C->isNegative(); + + // LHS s<= smax(LHS, V) for any V + if (match(RHS, m_c_SMax(m_Specific(LHS), m_Value()))) + return true; + + // smin(RHS, V) s<= RHS for any V + if (match(LHS, m_c_SMin(m_Specific(RHS), m_Value()))) + return true; + + // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB) + const Value *X; + const APInt *CLHS, *CRHS; + if (match(LHS, m_NSWAddLike(m_Value(X), m_APInt(CLHS))) && + match(RHS, m_NSWAddLike(m_Specific(X), m_APInt(CRHS)))) + return CLHS->sle(*CRHS); + return false; } @@ -8417,34 +8434,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, cast(RHS)->hasNoUnsignedWrap()) return true; + // LHS u<= LHS | V for any V + if (match(RHS, m_c_Or(m_Specific(LHS), m_Value()))) + return true; + + // LHS u<= umax(LHS, V) for any V + if (match(RHS, m_c_UMax(m_Specific(LHS), m_Value()))) + return true; + // RHS >> V u<= RHS for any V if (match(LHS, m_LShr(m_Specific(RHS), m_Value()))) return true; - // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB) - auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B, - const Value *&X, - const APInt *&CA, const APInt *&CB) { - if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) && - match(B, m_NUWAdd(m_Specific(X), m_APInt(CB)))) - return true; + // RHS u/ C_ugt_1 u<= RHS + const APInt *C; + if (match(LHS, m_UDiv(m_Specific(RHS), m_APInt(C))) && C->ugt(1)) + return true; - // If X & C == 0 then (X | C) == X +_{nuw} C - if (match(A, m_Or(m_Value(X), m_APInt(CA))) && - match(B, m_Or(m_Specific(X), m_APInt(CB)))) { - KnownBits Known(CA->getBitWidth()); - computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr, - /*CxtI*/ nullptr, /*DT*/ nullptr); - if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero)) - return true; - } + // RHS & V u<= RHS for any V + if (match(LHS, m_c_And(m_Specific(RHS), m_Value()))) + return true; - return false; - }; + // umin(RHS, V) u<= RHS for any V + if (match(LHS, m_c_UMin(m_Specific(RHS), m_Value()))) + return true; + // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB) const Value *X; const APInt *CLHS, *CRHS; - if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS)) + if (match(LHS, m_NUWAddLike(m_Value(X), m_APInt(CLHS))) && + match(RHS, m_NUWAddLike(m_Specific(X), m_APInt(CRHS)))) return CLHS->ule(*CRHS); return false; @@ -8456,37 +8475,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, /// ALHS ARHS" is true. Otherwise, return std::nullopt. static std::optional isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, - const Value *ARHS, const Value *BLHS, const Value *BRHS, - const DataLayout &DL, unsigned Depth) { + const Value *ARHS, const Value *BLHS, const Value *BRHS) { switch (Pred) { default: return std::nullopt; case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) && - isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth)) + if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) && + isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS)) return true; return std::nullopt; case CmpInst::ICMP_SGT: case CmpInst::ICMP_SGE: - if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS, DL, Depth) && - isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS, DL, Depth)) + if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS) && + isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS)) return true; return std::nullopt; case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) && - isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth)) + if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) && + isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS)) return true; return std::nullopt; case CmpInst::ICMP_UGT: case CmpInst::ICMP_UGE: - if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS, DL, Depth) && - isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS, DL, Depth)) + if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS) && + isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS)) return true; return std::nullopt; } @@ -8530,7 +8548,7 @@ static std::optional isImpliedCondICmps(const ICmpInst *LHS, CmpInst::Predicate RPred, const Value *R0, const Value *R1, const DataLayout &DL, - bool LHSIsTrue, unsigned Depth) { + bool LHSIsTrue) { Value *L0 = LHS->getOperand(0); Value *L1 = LHS->getOperand(1); @@ -8577,7 +8595,7 @@ static std::optional isImpliedCondICmps(const ICmpInst *LHS, return LPred == RPred; if (LPred == RPred) - return isImpliedCondOperands(LPred, L0, L1, R0, R1, DL, Depth); + return isImpliedCondOperands(LPred, L0, L1, R0, R1); return std::nullopt; } @@ -8639,8 +8657,7 @@ llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred, // Both LHS and RHS are icmps. const ICmpInst *LHSCmp = dyn_cast(LHS); if (LHSCmp) - return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, - Depth); + return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue); /// The LHS should be an 'or', 'and', or a 'select' instruction. We expect /// the RHS to be an icmp. diff --git a/llvm/test/Transforms/InstCombine/implies.ll b/llvm/test/Transforms/InstCombine/implies.ll index 6741d59..c02d84d 100644 --- a/llvm/test/Transforms/InstCombine/implies.ll +++ b/llvm/test/Transforms/InstCombine/implies.ll @@ -7,8 +7,7 @@ define i1 @or_implies_sle(i8 %x, i8 %y, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp sgt i8 [[OR]], [[Y:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -49,9 +48,7 @@ define i1 @or_distjoint_implies_ule(i8 %x, i8 %y, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp ugt i8 [[X2]], [[Y:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[X1:%.*]] = or disjoint i8 [[X]], 23 -; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[X1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -121,9 +118,7 @@ define i1 @src_or_distjoint_implies_sle(i8 %x, i8 %y, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp sgt i8 [[X2]], [[Y:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[X1:%.*]] = or disjoint i8 [[X]], 23 -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -169,9 +164,7 @@ define i1 @src_addnsw_implies_sle(i8 %x, i8 %y, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp sgt i8 [[X2]], [[Y:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[X1:%.*]] = add nsw i8 [[X]], 23 -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -216,9 +209,7 @@ define i1 @src_and_implies_ult(i8 %x, i8 %y, i8 %z, i1 %other) { ; CHECK-NEXT: [[COND:%.*]] = icmp ult i8 [[X:%.*]], [[Z:%.*]] ; CHECK-NEXT: br i1 [[COND]], label [[T:%.*]], label [[F:%.*]] ; CHECK: T: -; CHECK-NEXT: [[AND:%.*]] = and i8 [[Z]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[AND]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -280,8 +271,7 @@ define i1 @src_or_implies_ule(i8 %x, i8 %y, i8 %z, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp ugt i8 [[OR]], [[Z:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[X]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -322,9 +312,7 @@ define i1 @src_udiv_implies_ult(i8 %x, i8 %z, i1 %other) { ; CHECK-NEXT: [[COND:%.*]] = icmp ugt i8 [[Z:%.*]], [[X:%.*]] ; CHECK-NEXT: br i1 [[COND]], label [[T:%.*]], label [[F:%.*]] ; CHECK: T: -; CHECK-NEXT: [[AND:%.*]] = udiv i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[AND]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -345,9 +333,7 @@ define i1 @src_udiv_implies_ult2(i8 %x, i8 %z, i1 %other) { ; CHECK: T: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; CHECK: F: -; CHECK-NEXT: [[AND:%.*]] = udiv i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[AND]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; %cond = icmp ule i8 %z, %x br i1 %cond, label %T, label %F @@ -403,8 +389,7 @@ define i1 @src_umax_implies_ule(i8 %x, i8 %y, i8 %z, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp ugt i8 [[UM]], [[Z:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[X]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; @@ -424,8 +409,7 @@ define i1 @src_smax_implies_sle(i8 %x, i8 %y, i8 %z, i1 %other) { ; CHECK-NEXT: [[COND_NOT:%.*]] = icmp sgt i8 [[UM]], [[Z:%.*]] ; CHECK-NEXT: br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]] ; CHECK: T: -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X]], [[Z]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; CHECK: F: ; CHECK-NEXT: ret i1 [[OTHER:%.*]] ; diff --git a/llvm/test/Transforms/InstSimplify/implies.ll b/llvm/test/Transforms/InstSimplify/implies.ll index 8a01190..7e3cb65 100644 --- a/llvm/test/Transforms/InstSimplify/implies.ll +++ b/llvm/test/Transforms/InstSimplify/implies.ll @@ -155,7 +155,13 @@ define i1 @test9(i32 %length.i, i32 %i) { define i1 @test10(i32 %length.i, i32 %x.full) { ; CHECK-LABEL: @test10( -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: [[X:%.*]] = and i32 [[X_FULL:%.*]], -65536 +; CHECK-NEXT: [[LARGE:%.*]] = or i32 [[X]], 100 +; CHECK-NEXT: [[SMALL:%.*]] = or i32 [[X]], 90 +; CHECK-NEXT: [[KNOWN:%.*]] = icmp ult i32 [[LARGE]], [[LENGTH_I:%.*]] +; CHECK-NEXT: [[TO_PROVE:%.*]] = icmp ult i32 [[SMALL]], [[LENGTH_I]] +; CHECK-NEXT: [[RES:%.*]] = icmp ule i1 [[KNOWN]], [[TO_PROVE]] +; CHECK-NEXT: ret i1 [[RES]] ; %x = and i32 %x.full, 4294901760 ;; 4294901760 == 0xffff0000 %large = or i32 %x, 100 @@ -229,7 +235,13 @@ define i1 @test13(i32 %length.i, i32 %x) { define i1 @test14(i32 %length.i, i32 %x.full) { ; CHECK-LABEL: @test14( -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: [[X:%.*]] = and i32 [[X_FULL:%.*]], -61681 +; CHECK-NEXT: [[LARGE:%.*]] = or i32 [[X]], 8224 +; CHECK-NEXT: [[SMALL:%.*]] = or i32 [[X]], 4112 +; CHECK-NEXT: [[KNOWN:%.*]] = icmp ult i32 [[LARGE]], [[LENGTH_I:%.*]] +; CHECK-NEXT: [[TO_PROVE:%.*]] = icmp ult i32 [[SMALL]], [[LENGTH_I]] +; CHECK-NEXT: [[RES:%.*]] = icmp ule i1 [[KNOWN]], [[TO_PROVE]] +; CHECK-NEXT: ret i1 [[RES]] ; %x = and i32 %x.full, 4294905615 ;; 4294905615 == 0xffff0f0f %large = or i32 %x, 8224 ;; == 0x2020 -- cgit v1.1