aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/ConstantRange.cpp
diff options
context:
space:
mode:
authorYingwei Zheng <dtcxzyw2333@gmail.com>2024-08-07 02:00:33 +0800
committerGitHub <noreply@github.com>2024-08-07 02:00:33 +0800
commit07b29fc808ca0842d02cf4e973381b974bfdf19f (patch)
treee424bcb408c8ab530ca84705747fee69227c0035 /llvm/lib/IR/ConstantRange.cpp
parent4dee6411e0d993fd17099bd7564276474412383e (diff)
downloadllvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.zip
llvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.tar.gz
llvm-07b29fc808ca0842d02cf4e973381b974bfdf19f.tar.bz2
[ConstantRange] Improve `shlWithNoWrap` (#101800)
Closes https://github.com/dtcxzyw/llvm-tools/issues/22.
Diffstat (limited to 'llvm/lib/IR/ConstantRange.cpp')
-rw-r--r--llvm/lib/IR/ConstantRange.cpp104
1 files changed, 95 insertions, 9 deletions
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a..c389d72 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1617,21 +1617,107 @@ ConstantRange::shl(const ConstantRange &Other) const {
return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
}
+static ConstantRange computeShlNUW(const ConstantRange &LHS,
+ const ConstantRange &RHS) {
+ unsigned BitWidth = LHS.getBitWidth();
+ bool Overflow;
+ APInt LHSMin = LHS.getUnsignedMin();
+ unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+ APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt LHSMax = LHS.getUnsignedMax();
+ unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+ APInt MaxShl = MinShl;
+ unsigned MaxShAmt = LHSMax.countLeadingZeros();
+ if (RHSMin <= MaxShAmt)
+ MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+ if (RHSMin <= RHSMax)
+ MaxShl = APIntOps::umax(MaxShl,
+ APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
+ const APInt &LHSMax,
+ unsigned RHSMin,
+ unsigned RHSMax) {
+ unsigned BitWidth = LHSMin.getBitWidth();
+ bool Overflow;
+ APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt MaxShl = MinShl;
+ unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
+ if (RHSMin <= MaxShAmt)
+ MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
+ if (RHSMin <= RHSMax)
+ MaxShl = APIntOps::umax(MaxShl,
+ APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
+ const APInt &LHSMax,
+ unsigned RHSMin, unsigned RHSMax) {
+ unsigned BitWidth = LHSMin.getBitWidth();
+ bool Overflow;
+ APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt MinShl = MaxShl;
+ unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
+ if (RHSMin <= MaxShAmt)
+ MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
+ if (RHSMin <= RHSMax)
+ MinShl = APInt::getSignMask(BitWidth);
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSW(const ConstantRange &LHS,
+ const ConstantRange &RHS) {
+ unsigned BitWidth = LHS.getBitWidth();
+ unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+ unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+ APInt LHSMin = LHS.getSignedMin();
+ APInt LHSMax = LHS.getSignedMax();
+ if (LHSMin.isNonNegative())
+ return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+ else if (LHSMax.isNegative())
+ return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+ return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
+ RHSMax)
+ .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
+ RHSMin, RHSMax),
+ ConstantRange::Signed);
+}
+
ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
unsigned NoWrapKind,
PreferredRangeType RangeType) const {
if (isEmptySet() || Other.isEmptySet())
return getEmpty();
- ConstantRange Result = shl(Other);
-
- if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
- Result = Result.intersectWith(sshl_sat(Other), RangeType);
-
- if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
- Result = Result.intersectWith(ushl_sat(Other), RangeType);
-
- return Result;
+ switch (NoWrapKind) {
+ case 0:
+ return shl(Other);
+ case OverflowingBinaryOperator::NoSignedWrap:
+ return computeShlNSW(*this, Other);
+ case OverflowingBinaryOperator::NoUnsignedWrap:
+ return computeShlNUW(*this, Other);
+ case OverflowingBinaryOperator::NoSignedWrap |
+ OverflowingBinaryOperator::NoUnsignedWrap:
+ return computeShlNSW(*this, Other)
+ .intersectWith(computeShlNUW(*this, Other), RangeType);
+ default:
+ llvm_unreachable("Invalid NoWrapKind");
+ }
}
ConstantRange