From b8e1d4dbea8905e48d51a70bf75cb8fababa4a60 Mon Sep 17 00:00:00 2001 From: choikwa <5455710+choikwa@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:22:09 -0500 Subject: [AMDGPU] prevent shrinking udiv/urem if either operand is in (SignedMax,UnsignedMax] (#116733) Do this by using ComputeKnownBits and checking for !isNonNegative and isUnsigned. This rejects shrinking unsigned div/rem if operands exceed smax_bitwidth since we know NumSignBits will be always 0. --- llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp | 42 +++++++++++++++++-------- 1 file changed, 29 insertions(+), 13 deletions(-) (limited to 'llvm/lib') diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp index c49aab8..a6cef52 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp @@ -1193,19 +1193,35 @@ int AMDGPUCodeGenPrepareImpl::getDivNumBits(BinaryOperator &I, Value *Num, Value *Den, unsigned AtLeast, bool IsSigned) const { const DataLayout &DL = Mod->getDataLayout(); - unsigned LHSSignBits = ComputeNumSignBits(Num, DL, 0, AC, &I); - if (LHSSignBits < AtLeast) - return -1; - - unsigned RHSSignBits = ComputeNumSignBits(Den, DL, 0, AC, &I); - if (RHSSignBits < AtLeast) - return -1; - - unsigned SignBits = std::min(LHSSignBits, RHSSignBits); - unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits; - if (IsSigned) - ++DivBits; - return DivBits; + if (IsSigned) { + unsigned LHSSignBits = ComputeNumSignBits(Num, DL, 0, AC, &I); + if (LHSSignBits < AtLeast) + return -1; + + unsigned RHSSignBits = ComputeNumSignBits(Den, DL, 0, AC, &I); + if (RHSSignBits < AtLeast) + return -1; + + unsigned SignBits = std::min(LHSSignBits, RHSSignBits); + unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits; + return DivBits + 1; + } else { + KnownBits Known = computeKnownBits(Num, DL, 0, AC, &I); + // We know all bits are used for division for Num or Den in range + // (SignedMax, UnsignedMax] + if (Known.isNegative() || !Known.isNonNegative()) + return -1; + unsigned LHSSignBits = Known.countMinLeadingZeros(); + + Known = computeKnownBits(Den, DL, 0, AC, &I); + if (Known.isNegative() || !Known.isNonNegative()) + return -1; + unsigned RHSSignBits = Known.countMinLeadingZeros(); + + unsigned SignBits = std::min(LHSSignBits, RHSSignBits); + unsigned DivBits = Num->getType()->getScalarSizeInBits() - SignBits; + return DivBits; + } } // The fractional part of a float is enough to accurately represent up to -- cgit v1.1