diff options
author | Yingwei Zheng <dtcxzyw2333@gmail.com> | 2024-08-07 02:00:33 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-07 02:00:33 +0800 |
commit | 07b29fc808ca0842d02cf4e973381b974bfdf19f (patch) | |
tree | e424bcb408c8ab530ca84705747fee69227c0035 /llvm/lib/IR/ConstantRange.cpp | |
parent | 4dee6411e0d993fd17099bd7564276474412383e (diff) | |
download | llvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.zip llvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.tar.gz llvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.tar.bz2 |
[ConstantRange] Improve `shlWithNoWrap` (#101800)
Closes https://github.com/dtcxzyw/llvm-tools/issues/22.
Diffstat (limited to 'llvm/lib/IR/ConstantRange.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantRange.cpp | 104 |
1 files changed, 95 insertions, 9 deletions
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index 50b211a..c389d72 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -1617,21 +1617,107 @@ ConstantRange::shl(const ConstantRange &Other) const { return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1); } +static ConstantRange computeShlNUW(const ConstantRange &LHS, + const ConstantRange &RHS) { + unsigned BitWidth = LHS.getBitWidth(); + bool Overflow; + APInt LHSMin = LHS.getUnsignedMin(); + unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth); + APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow); + if (Overflow) + return ConstantRange::getEmpty(BitWidth); + APInt LHSMax = LHS.getUnsignedMax(); + unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth); + APInt MaxShl = MinShl; + unsigned MaxShAmt = LHSMax.countLeadingZeros(); + if (RHSMin <= MaxShAmt) + MaxShl = LHSMax << std::min(RHSMax, MaxShAmt); + RHSMin = std::max(RHSMin, MaxShAmt + 1); + RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros()); + if (RHSMin <= RHSMax) + MaxShl = APIntOps::umax(MaxShl, + APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin)); + return ConstantRange::getNonEmpty(MinShl, MaxShl + 1); +} + +static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin, + const APInt &LHSMax, + unsigned RHSMin, + unsigned RHSMax) { + unsigned BitWidth = LHSMin.getBitWidth(); + bool Overflow; + APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow); + if (Overflow) + return ConstantRange::getEmpty(BitWidth); + APInt MaxShl = MinShl; + unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1; + if (RHSMin <= MaxShAmt) + MaxShl = LHSMax << std::min(RHSMax, MaxShAmt); + RHSMin = std::max(RHSMin, MaxShAmt + 1); + RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1); + if (RHSMin <= RHSMax) + MaxShl = APIntOps::umax(MaxShl, + APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1)); + return ConstantRange::getNonEmpty(MinShl, MaxShl + 1); +} + +static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin, + const APInt &LHSMax, + unsigned RHSMin, unsigned RHSMax) { + unsigned BitWidth = LHSMin.getBitWidth(); + bool Overflow; + APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow); + if (Overflow) + return ConstantRange::getEmpty(BitWidth); + APInt MinShl = MaxShl; + unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1; + if (RHSMin <= MaxShAmt) + MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt)); + RHSMin = std::max(RHSMin, MaxShAmt + 1); + RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1); + if (RHSMin <= RHSMax) + MinShl = APInt::getSignMask(BitWidth); + return ConstantRange::getNonEmpty(MinShl, MaxShl + 1); +} + +static ConstantRange computeShlNSW(const ConstantRange &LHS, + const ConstantRange &RHS) { + unsigned BitWidth = LHS.getBitWidth(); + unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth); + unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth); + APInt LHSMin = LHS.getSignedMin(); + APInt LHSMax = LHS.getSignedMax(); + if (LHSMin.isNonNegative()) + return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax); + else if (LHSMax.isNegative()) + return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax); + return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin, + RHSMax) + .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth), + RHSMin, RHSMax), + ConstantRange::Signed); +} + ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other, unsigned NoWrapKind, PreferredRangeType RangeType) const { if (isEmptySet() || Other.isEmptySet()) return getEmpty(); - ConstantRange Result = shl(Other); - - if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) - Result = Result.intersectWith(sshl_sat(Other), RangeType); - - if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) - Result = Result.intersectWith(ushl_sat(Other), RangeType); - - return Result; + switch (NoWrapKind) { + case 0: + return shl(Other); + case OverflowingBinaryOperator::NoSignedWrap: + return computeShlNSW(*this, Other); + case OverflowingBinaryOperator::NoUnsignedWrap: + return computeShlNUW(*this, Other); + case OverflowingBinaryOperator::NoSignedWrap | + OverflowingBinaryOperator::NoUnsignedWrap: + return computeShlNSW(*this, Other) + .intersectWith(computeShlNUW(*this, Other), RangeType); + default: + llvm_unreachable("Invalid NoWrapKind"); + } } ConstantRange |