aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Support/KnownBits.cpp
diff options
context:
space:
mode:
authorNoah Goldstein <goldstein.w.n@gmail.com>2024-03-05 22:03:44 -0600
committerNoah Goldstein <goldstein.w.n@gmail.com>2024-03-11 15:51:07 -0500
commitd81db0e5f5b1404ff4813af3050d671528ad45cc (patch)
tree4fa74264813641b7cde900b7a803b7fb64bfa538 /llvm/lib/Support/KnownBits.cpp
parenta9d913ebcd567ad14ffdc8c8684c4f0611e1e2da (diff)
downloadllvm-d81db0e5f5b1404ff4813af3050d671528ad45cc.zip
llvm-d81db0e5f5b1404ff4813af3050d671528ad45cc.tar.gz
llvm-d81db0e5f5b1404ff4813af3050d671528ad45cc.tar.bz2
[KnownBits] Implement knownbits `lshr`/`ashr` with exact flag
The exact flag basically allows us to set an upper bound on shift amount when we have a known 1 in `LHS`. Typically we deduce exact using knownbits (on non-exact incoming shifts), so this is particularly impactful, but may be useful in some circumstances. Closes #84254
Diffstat (limited to 'llvm/lib/Support/KnownBits.cpp')
-rw-r--r--llvm/lib/Support/KnownBits.cpp28
1 files changed, 26 insertions, 2 deletions
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52..c33c368 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();