aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorchoikwa <5455710+choikwa@users.noreply.github.com>2024-11-20 11:22:09 -0500
committerGitHub <noreply@github.com>2024-11-20 11:22:09 -0500
commitb8e1d4dbea8905e48d51a70bf75cb8fababa4a60 (patch)
treef0d93ab996dfb4a9718c7f1877982a9b9cfbaffc /llvm/lib
parent9d5b3c80175da59728d13c779051eaf5311c64f7 (diff)
downloadllvm-b8e1d4dbea8905e48d51a70bf75cb8fababa4a60.zip
llvm-b8e1d4dbea8905e48d51a70bf75cb8fababa4a60.tar.gz
llvm-b8e1d4dbea8905e48d51a70bf75cb8fababa4a60.tar.bz2
[AMDGPU] prevent shrinking udiv/urem if either operand is in (SignedMax,UnsignedMax] (#116733)
Do this by using ComputeKnownBits and checking for !isNonNegative and isUnsigned. This rejects shrinking unsigned div/rem if operands exceed smax_bitwidth since we know NumSignBits will be always 0.
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp42
1 files changed, 29 insertions, 13 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
index c49aab8..a6cef52 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
@@ -1193,19 +1193,35 @@ int AMDGPUCodeGenPrepareImpl::getDivNumBits(BinaryOperator &I, Value *Num,
Value *Den, unsigned AtLeast,
bool IsSigned) const {
const DataLayout &DL = Mod->getDataLayout();
- unsigned LHSSignBits = ComputeNumSignBits(Num, DL, 0, AC, &I);
- if (LHSSignBits < AtLeast)
- return -1;
-
- unsigned RHSSignBits = ComputeNumSignBits(Den, DL, 0, AC, &I);
- if (RHSSignBits < AtLeast)
- return -1;
-
- unsigned SignBits = std::min(LHSSignBits, RHSSignBits);
- unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits;
- if (IsSigned)
- ++DivBits;
- return DivBits;
+ if (IsSigned) {
+ unsigned LHSSignBits = ComputeNumSignBits(Num, DL, 0, AC, &I);
+ if (LHSSignBits < AtLeast)
+ return -1;
+
+ unsigned RHSSignBits = ComputeNumSignBits(Den, DL, 0, AC, &I);
+ if (RHSSignBits < AtLeast)
+ return -1;
+
+ unsigned SignBits = std::min(LHSSignBits, RHSSignBits);
+ unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits;
+ return DivBits + 1;
+ } else {
+ KnownBits Known = computeKnownBits(Num, DL, 0, AC, &I);
+ // We know all bits are used for division for Num or Den in range
+ // (SignedMax, UnsignedMax]
+ if (Known.isNegative() || !Known.isNonNegative())
+ return -1;
+ unsigned LHSSignBits = Known.countMinLeadingZeros();
+
+ Known = computeKnownBits(Den, DL, 0, AC, &I);
+ if (Known.isNegative() || !Known.isNonNegative())
+ return -1;
+ unsigned RHSSignBits = Known.countMinLeadingZeros();
+
+ unsigned SignBits = std::min(LHSSignBits, RHSSignBits);
+ unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits;
+ return DivBits;
+ }
}
// The fractional part of a float is enough to accurately represent up to