From dc2ed0043295a397d680db091c2033a51d21e32e Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Fri, 5 Sep 2025 08:25:48 +0800 Subject: [RISCV] Handle non uimm5 VL constants in isVLKnownLE (#156639) If a VL operand is > 31 then it will be materialized into an ADDI $x0, imm. We can reason about it by peeking at the virtual register definition which allows RISCVVectorPeephole and RISCVVLOptimizer to catch more cases. There's a separate issue with RISCVVLOptimizer where the materialized immediate may not always dominate the instruction we want to reduce the VL of, but this is left to another patch. --- llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 24 ++++++++++++++++++++---- llvm/lib/Target/RISCV/RISCVInstrInfo.h | 3 ++- llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 10 +++++----- llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp | 16 ++++++++-------- 4 files changed, 35 insertions(+), 18 deletions(-) (limited to 'llvm/lib') diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 7b4a1de..872f2cf 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -4796,8 +4796,22 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) { return Scaled; } -/// Given two VL operands, do we know that LHS <= RHS? -bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) { +static std::optional getEffectiveImm(const MachineOperand &MO, + const MachineRegisterInfo *MRI) { + assert(MO.isImm() || MO.getReg().isVirtual()); + if (MO.isImm()) + return MO.getImm(); + const MachineInstr *Def = MRI->getVRegDef(MO.getReg()); + int64_t Imm; + if (isLoadImm(Def, Imm)) + return Imm; + return std::nullopt; +} + +/// Given two VL operands, do we know that LHS <= RHS? Must be used in SSA form. +bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS, + const MachineRegisterInfo *MRI) { + assert(MRI->isSSA()); if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() && LHS.getReg() == RHS.getReg()) return true; @@ -4807,9 +4821,11 @@ bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) { return true; if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel) return false; - if (!LHS.isImm() || !RHS.isImm()) + std::optional LHSImm = getEffectiveImm(LHS, MRI), + RHSImm = getEffectiveImm(RHS, MRI); + if (!LHSImm || !RHSImm) return false; - return LHS.getImm() <= RHS.getImm(); + return LHSImm <= RHSImm; } namespace { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 785c835..0defb18 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -365,7 +365,8 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW); static constexpr int64_t VLMaxSentinel = -1LL; /// Given two VL operands, do we know that LHS <= RHS? -bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS); +bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS, + const MachineRegisterInfo *MRI); // Mask assignments for floating-point static constexpr unsigned FPMASK_Negative_Infinity = 0x001; diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 4d4f1db..dca86d7 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -1379,7 +1379,7 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() && RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc())); auto DemandedVL = DemandedVLs.lookup(&UserMI); - if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) { + if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp, MRI)) { LLVM_DEBUG(dbgs() << " Abort because user is passthru in " "instruction with demanded tail\n"); return std::nullopt; @@ -1397,7 +1397,7 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { // requires. if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) { assert(isCandidate(UserMI)); - if (RISCV::isVLKnownLE(*DemandedVL, VLOp)) + if (RISCV::isVLKnownLE(*DemandedVL, VLOp, MRI)) return DemandedVL; } @@ -1505,10 +1505,10 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. - if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { + if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp, MRI)) { CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); - } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { + } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL, MRI)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); return std::nullopt; } @@ -1570,7 +1570,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { CommonVL = VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc())); } - if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { + if (!RISCV::isVLKnownLE(*CommonVL, VLOp, MRI)) { LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); return false; } diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index 6265118..6ea010e 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -177,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const { MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc())); - if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL)) + if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL, MRI)) continue; if (!ensureDominates(VL, *Src)) @@ -440,7 +440,7 @@ bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) { const MachineOperand &MIVL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc())); const MachineOperand &TrueVL = True->getOperand(RISCVII::getVLOpNum(True->getDesc())); - if (!RISCV::isVLKnownLE(MIVL, TrueVL)) + if (!RISCV::isVLKnownLE(MIVL, TrueVL, MRI)) return false; // True's passthru needs to be equivalent to False @@ -611,7 +611,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) { MachineOperand &SrcPolicy = Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())); - if (RISCV::isVLKnownLE(MIVL, SrcVL)) + if (RISCV::isVLKnownLE(MIVL, SrcVL, MRI)) SrcPolicy.setImm(SrcPolicy.getImm() | RISCVVType::TAIL_AGNOSTIC); } @@ -663,7 +663,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { // so we don't need to handle a smaller source VL here. However, the // user's VL may be larger MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc())); - if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3))) + if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3), MRI)) return false; // If the new passthru doesn't dominate Src, try to move Src so it does. @@ -684,7 +684,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { // If MI was tail agnostic and the VL didn't increase, preserve it. int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED; if ((MI.getOperand(5).getImm() & RISCVVType::TAIL_AGNOSTIC) && - RISCV::isVLKnownLE(MI.getOperand(3), SrcVL)) + RISCV::isVLKnownLE(MI.getOperand(3), SrcVL, MRI)) Policy |= RISCVVType::TAIL_AGNOSTIC; Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy); } @@ -775,9 +775,9 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const { True.getOperand(RISCVII::getVLOpNum(True.getDesc())); MachineOperand MinVL = MachineOperand::CreateImm(0); - if (RISCV::isVLKnownLE(TrueVL, VMergeVL)) + if (RISCV::isVLKnownLE(TrueVL, VMergeVL, MRI)) MinVL = TrueVL; - else if (RISCV::isVLKnownLE(VMergeVL, TrueVL)) + else if (RISCV::isVLKnownLE(VMergeVL, TrueVL, MRI)) MinVL = VMergeVL; else return false; @@ -797,7 +797,7 @@ bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const { // to the tail. In that case we always need to use tail undisturbed to // preserve them. uint64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED; - if (!PassthruReg && RISCV::isVLKnownLE(VMergeVL, MinVL)) + if (!PassthruReg && RISCV::isVLKnownLE(VMergeVL, MinVL, MRI)) Policy |= RISCVVType::TAIL_AGNOSTIC; assert(RISCVII::hasVecPolicyOp(True.getDesc().TSFlags) && -- cgit v1.1