diff options
author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 17:50:40 +0900 |
---|---|---|
committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 17:50:40 +0900 |
commit | fea7da1b00cc97d742faede2df96c7d327950f49 (patch) | |
tree | 4de1d6b4ddc69f4f32daabb11ad5c71ab0cf895e /llvm/lib/Analysis/ValueTracking.cpp | |
parent | 9b99dde0d47102625d93c5d1cbbc04951025a6c9 (diff) | |
parent | 0aa930a41f2d1ebf1fa90ec42da8f96d15a4dcbb (diff) | |
download | llvm-users/chapuni/cov/single/nextcount.zip llvm-users/chapuni/cov/single/nextcount.tar.gz llvm-users/chapuni/cov/single/nextcount.tar.bz2 |
Merge branch 'users/chapuni/cov/single/nextcount-base' into users/chapuni/cov/single/nextcountusers/chapuni/cov/single/nextcount
Diffstat (limited to 'llvm/lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 183 |
1 files changed, 134 insertions, 49 deletions
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 14d7c2d..0eb43dd 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -1065,6 +1065,64 @@ void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond, Known = CondRes; } +// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow). +// Returns the input and lower/upper bounds. +static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, + const APInt *&CLow, const APInt *&CHigh) { + assert(isa<Operator>(Select) && + cast<Operator>(Select)->getOpcode() == Instruction::Select && + "Input should be a Select!"); + + const Value *LHS = nullptr, *RHS = nullptr; + SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor; + if (SPF != SPF_SMAX && SPF != SPF_SMIN) + return false; + + if (!match(RHS, m_APInt(CLow))) + return false; + + const Value *LHS2 = nullptr, *RHS2 = nullptr; + SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor; + if (getInverseMinMaxFlavor(SPF) != SPF2) + return false; + + if (!match(RHS2, m_APInt(CHigh))) + return false; + + if (SPF == SPF_SMIN) + std::swap(CLow, CHigh); + + In = LHS2; + return CLow->sle(*CHigh); +} + +static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II, + const APInt *&CLow, + const APInt *&CHigh) { + assert((II->getIntrinsicID() == Intrinsic::smin || + II->getIntrinsicID() == Intrinsic::smax) && + "Must be smin/smax"); + + Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + if (!InnerII || InnerII->getIntrinsicID() != InverseID || + !match(II->getArgOperand(1), m_APInt(CLow)) || + !match(InnerII->getArgOperand(1), m_APInt(CHigh))) + return false; + + if (II->getIntrinsicID() == Intrinsic::smin) + std::swap(CLow, CHigh); + return CLow->sle(*CHigh); +} + +static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II, + KnownBits &Known) { + const APInt *CLow, *CHigh; + if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh)) + Known = Known.unionWith( + ConstantRange::getNonEmpty(*CLow, *CHigh + 1).toKnownBits()); +} + static void computeKnownBitsFromOperator(const Operator *I, const APInt &DemandedElts, KnownBits &Known, unsigned Depth, @@ -1804,11 +1862,13 @@ static void computeKnownBitsFromOperator(const Operator *I, computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q); computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q); Known = KnownBits::smin(Known, Known2); + unionWithMinMaxIntrinsicClamp(II, Known); break; case Intrinsic::smax: computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q); computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q); Known = KnownBits::smax(Known, Known2); + unionWithMinMaxIntrinsicClamp(II, Known); break; case Intrinsic::ptrmask: { computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q); @@ -3751,55 +3811,6 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, return false; } -// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow). -// Returns the input and lower/upper bounds. -static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, - const APInt *&CLow, const APInt *&CHigh) { - assert(isa<Operator>(Select) && - cast<Operator>(Select)->getOpcode() == Instruction::Select && - "Input should be a Select!"); - - const Value *LHS = nullptr, *RHS = nullptr; - SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor; - if (SPF != SPF_SMAX && SPF != SPF_SMIN) - return false; - - if (!match(RHS, m_APInt(CLow))) - return false; - - const Value *LHS2 = nullptr, *RHS2 = nullptr; - SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor; - if (getInverseMinMaxFlavor(SPF) != SPF2) - return false; - - if (!match(RHS2, m_APInt(CHigh))) - return false; - - if (SPF == SPF_SMIN) - std::swap(CLow, CHigh); - - In = LHS2; - return CLow->sle(*CHigh); -} - -static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II, - const APInt *&CLow, - const APInt *&CHigh) { - assert((II->getIntrinsicID() == Intrinsic::smin || - II->getIntrinsicID() == Intrinsic::smax) && "Must be smin/smax"); - - Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); - if (!InnerII || InnerII->getIntrinsicID() != InverseID || - !match(II->getArgOperand(1), m_APInt(CLow)) || - !match(InnerII->getArgOperand(1), m_APInt(CHigh))) - return false; - - if (II->getIntrinsicID() == Intrinsic::smin) - std::swap(CLow, CHigh); - return CLow->sle(*CHigh); -} - /// For vector constants, loop over the elements and find the constant with the /// minimum number of sign bits. Return 0 if the value is not a vector constant /// or if any element was not analyzed; otherwise, return the count for the @@ -8630,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred, } } +std::optional<std::pair<CmpPredicate, Constant *>> +llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); + if (isa<UndefValue>(C)) + return std::nullopt; + + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; + + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + Constant *SafeReplacementConstant = nullptr; + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return std::nullopt; + } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { + unsigned NumElts = FVTy->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return std::nullopt; + + if (isa<UndefValue>(Elt)) + continue; + + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; + } + } else if (isa<VectorType>(C->getType())) { + // Handle scalable splat + Value *SplatC = C->getSplatValue(); + auto *CI = dyn_cast_or_null<ConstantInt>(SplatC); + // Bail out if the constant can't be safely incremented/decremented. + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + } else { + // ConstantExpr? + return std::nullopt; + } + + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, |