diff options
Diffstat (limited to 'llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp')
-rw-r--r-- | llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp index 1286af8..974fc40 100644 --- a/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp +++ b/llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp @@ -1884,6 +1884,14 @@ unsigned GISelValueTracking::computeNumSignBits(Register R, } break; } + case TargetOpcode::G_ASHR: { + Register Src1 = MI.getOperand(1).getReg(); + Register Src2 = MI.getOperand(2).getReg(); + FirstAnswer = computeNumSignBits(Src1, DemandedElts, Depth + 1); + if (auto C = getValidMinimumShiftAmount(Src2, DemandedElts, Depth + 1)) + FirstAnswer = std::min<uint64_t>(FirstAnswer + *C, TyBits); + break; + } case TargetOpcode::G_TRUNC: { Register Src = MI.getOperand(1).getReg(); LLT SrcTy = MRI.getType(Src); @@ -2053,6 +2061,64 @@ unsigned GISelValueTracking::computeNumSignBits(Register R, unsigned Depth) { return computeNumSignBits(R, DemandedElts, Depth); } +std::optional<ConstantRange> GISelValueTracking::getValidShiftAmountRange( + Register R, const APInt &DemandedElts, unsigned Depth) { + // Shifting more than the bitwidth is not valid. + MachineInstr &MI = *MRI.getVRegDef(R); + unsigned Opcode = MI.getOpcode(); + + LLT Ty = MRI.getType(R); + unsigned BitWidth = Ty.getScalarSizeInBits(); + + if (Opcode == TargetOpcode::G_CONSTANT) { + const APInt &ShAmt = MI.getOperand(1).getCImm()->getValue(); + if (ShAmt.uge(BitWidth)) + return std::nullopt; + return ConstantRange(ShAmt); + } + + if (Opcode == TargetOpcode::G_BUILD_VECTOR) { + const APInt *MinAmt = nullptr, *MaxAmt = nullptr; + for (unsigned I = 0, E = MI.getNumOperands() - 1; I != E; ++I) { + if (!DemandedElts[I]) + continue; + MachineInstr *Op = MRI.getVRegDef(MI.getOperand(I + 1).getReg()); + if (Op->getOpcode() != TargetOpcode::G_CONSTANT) { + MinAmt = MaxAmt = nullptr; + break; + } + + const APInt &ShAmt = Op->getOperand(1).getCImm()->getValue(); + if (ShAmt.uge(BitWidth)) + return std::nullopt; + if (!MinAmt || MinAmt->ugt(ShAmt)) + MinAmt = &ShAmt; + if (!MaxAmt || MaxAmt->ult(ShAmt)) + MaxAmt = &ShAmt; + } + assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) && + "Failed to find matching min/max shift amounts"); + if (MinAmt && MaxAmt) + return ConstantRange(*MinAmt, *MaxAmt + 1); + } + + // Use computeKnownBits to find a hidden constant/knownbits (usually type + // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc. + KnownBits KnownAmt = getKnownBits(R, DemandedElts, Depth); + if (KnownAmt.getMaxValue().ult(BitWidth)) + return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false); + + return std::nullopt; +} + +std::optional<uint64_t> GISelValueTracking::getValidMinimumShiftAmount( + Register R, const APInt &DemandedElts, unsigned Depth) { + if (std::optional<ConstantRange> AmtRange = + getValidShiftAmountRange(R, DemandedElts, Depth)) + return AmtRange->getUnsignedMin().getZExtValue(); + return std::nullopt; +} + void GISelValueTrackingAnalysisLegacy::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesAll(); |