//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// // // This pass reduces the VL where possible at the MI level, before VSETVLI // instructions are inserted. // // The purpose of this optimization is to make the VL argument, for instructions // that have a VL argument, as small as possible. This is implemented by // visiting each instruction in reverse order and checking that if it has a VL // argument, whether the VL can be reduced. // //===---------------------------------------------------------------------===// #include "RISCV.h" #include "RISCVMachineFunctionInfo.h" #include "RISCVSubtarget.h" #include "llvm/ADT/SetVector.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/InitializePasses.h" using namespace llvm; #define DEBUG_TYPE "riscv-vl-optimizer" #define PASS_NAME "RISC-V VL Optimizer" namespace { class RISCVVLOptimizer : public MachineFunctionPass { const MachineRegisterInfo *MRI; const MachineDominatorTree *MDT; public: static char ID; RISCVVLOptimizer() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } StringRef getPassName() const override { return PASS_NAME; } private: bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; } // end anonymous namespace char RISCVVLOptimizer::ID = 0; INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) FunctionPass *llvm::createRISCVVLOptimizerPass() { return new RISCVVLOptimizer(); } /// Return true if R is a physical or virtual vector register, false otherwise. static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { if (R.isPhysical()) return RISCV::VRRegClass.contains(R); const TargetRegisterClass *RC = MRI->getRegClass(R); return RISCVRI::isVRegClass(RC->TSFlags); } /// Represents the EMUL and EEW of a MachineOperand. struct OperandInfo { enum class State { Unknown, Known, } S; // Represent as 1,2,4,8, ... and fractional indicator. This is because // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. // For example, a mask operand can have an EMUL less than MF8. std::optional> EMUL; unsigned Log2EEW; OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) { } OperandInfo(std::pair EMUL, unsigned Log2EEW) : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} OperandInfo() : S(State::Unknown) {} bool isUnknown() const { return S == State::Unknown; } bool isKnown() const { return S == State::Known; } static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { assert(A.isKnown() && B.isKnown() && "Both operands must be known"); return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && A.EMUL->second == B.EMUL->second; } void print(raw_ostream &OS) const { if (isUnknown()) { OS << "Unknown"; return; } assert(EMUL && "Expected EMUL to have value"); OS << "EMUL: m"; if (EMUL->second) OS << "f"; OS << EMUL->first; OS << ", EEW: " << (1 << Log2EEW); } }; LLVM_ATTRIBUTE_UNUSED static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { OI.print(OS); return OS; } namespace llvm { namespace RISCVVType { /// Return the RISCVII::VLMUL that is two times VLMul. /// Precondition: VLMul is not LMUL_RESERVED or LMUL_8. static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) { switch (VLMul) { case RISCVII::VLMUL::LMUL_F8: return RISCVII::VLMUL::LMUL_F4; case RISCVII::VLMUL::LMUL_F4: return RISCVII::VLMUL::LMUL_F2; case RISCVII::VLMUL::LMUL_F2: return RISCVII::VLMUL::LMUL_1; case RISCVII::VLMUL::LMUL_1: return RISCVII::VLMUL::LMUL_2; case RISCVII::VLMUL::LMUL_2: return RISCVII::VLMUL::LMUL_4; case RISCVII::VLMUL::LMUL_4: return RISCVII::VLMUL::LMUL_8; case RISCVII::VLMUL::LMUL_8: default: llvm_unreachable("Could not multiply VLMul by 2"); } } /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and /// SEW are from the TSFlags of MI. static std::pair getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags); auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL); unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); unsigned MISEW = 1 << MILog2SEW; unsigned EEW = 1 << Log2EEW; // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD // to put fraction in simplest form. unsigned Num = EEW, Denom = MISEW; int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL) : std::gcd(Num * MILMUL, Denom); Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD; Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD; return std::make_pair(Num > Denom ? Num : Denom, Denom > Num); } } // end namespace RISCVVType } // end namespace llvm /// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). /// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, const MachineInstr &MI, const MachineOperand &MO) { RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) return OperandInfo(MIVLMul, MILog2SEW); unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), Log2EEW); } /// Check whether MO is a mask operand of MI. static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, const MachineRegisterInfo *MRI) { if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI)) return false; const MCInstrDesc &Desc = MI.getDesc(); return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } /// Return the OperandInfo for MO, which is an operand of MI. static OperandInfo getOperandInfo(const MachineInstr &MI, const MachineOperand &MO, const MachineRegisterInfo *MRI) { const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); // MI has a VLMUL and SEW associated with it. The RVV specification defines // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and // MI.SEW. RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()); // We bail out early for instructions that have passthru with non NoRegister, // which means they are using TU policy. We are not interested in these // since they must preserve the entire register content. if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && (MO.getReg() != RISCV::NoRegister)) return {}; bool IsMODef = MO.getOperandNo() == 0; // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL if (isMaskOperand(MI, MO, MRI)) return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); // switch against BaseInstr to reduce number of cases that need to be // considered. switch (RVV->BaseInstr) { // 6. Configuration-Setting Instructions // Configuration setting instructions do not read or write vector registers case RISCV::VSETIVLI: case RISCV::VSETVL: case RISCV::VSETVLI: llvm_unreachable("Configuration setting instructions do not read or write " "vector registers"); // Vector Integer Arithmetic Instructions // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: case RISCV::VADD_VX: case RISCV::VSUB_VV: case RISCV::VSUB_VX: case RISCV::VRSUB_VI: case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions // EEW=SEW. EMUL=LMUL. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: case RISCV::VOR_VI: case RISCV::VOR_VV: case RISCV::VOR_VX: case RISCV::VXOR_VI: case RISCV::VXOR_VV: case RISCV::VXOR_VX: case RISCV::VSLL_VI: case RISCV::VSLL_VV: case RISCV::VSLL_VX: case RISCV::VSRL_VI: case RISCV::VSRL_VV: case RISCV::VSRL_VX: case RISCV::VSRA_VI: case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions // EEW=SEW. EMUL=LMUL. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: case RISCV::VMIN_VX: case RISCV::VMAXU_VV: case RISCV::VMAXU_VX: case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions // Source and Dest EEW=SEW and EMUL=LMUL. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: case RISCV::VMULH_VX: case RISCV::VMULHU_VV: case RISCV::VMULHU_VX: case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions // EEW=SEW. EMUL=LMUL. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: case RISCV::VDIV_VX: case RISCV::VREMU_VV: case RISCV::VREMU_VX: case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions // EEW=SEW. EMUL=LMUL. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: case RISCV::VNMSAC_VX: case RISCV::VMADD_VV: case RISCV::VMADD_VX: case RISCV::VNMSUB_VV: case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= // (EEW/SEW)*LMUL. Mask operand is handled before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: // Vector Integer Move Instructions // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract // EEW=SEW. EMUL=LMUL. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: case RISCV::VSADDU_VI: case RISCV::VSADDU_VV: case RISCV::VSADDU_VX: case RISCV::VSADD_VI: case RISCV::VSADD_VV: case RISCV::VSADD_VX: case RISCV::VSSUBU_VV: case RISCV::VSSUBU_VX: case RISCV::VSSUB_VV: case RISCV::VSSUB_VX: case RISCV::VAADDU_VV: case RISCV::VAADDU_VX: case RISCV::VAADD_VV: case RISCV::VAADD_VX: case RISCV::VASUBU_VV: case RISCV::VASUBU_VX: case RISCV::VASUB_VV: case RISCV::VASUB_VX: // Vector Single-Width Scaling Shift Instructions // EEW=SEW. EMUL=LMUL. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: case RISCV::VSSRA_VI: case RISCV::VSSRA_VV: case RISCV::VSSRA_VX: // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions // EMUL=LMUL. EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions // EMUL=LMUL. EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: case RISCV::VSLIDEDOWN_VX: case RISCV::VSLIDE1UP_VX: case RISCV::VFSLIDE1UP_VF: case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction // EMUL=LMUL. EEW=SEW. case RISCV::VCOMPRESS_VM: return OperandInfo(MIVLMul, MILog2SEW); // Vector Widening Integer Add/Subtract // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: case RISCV::VWSUBU_VX: case RISCV::VWADD_VV: case RISCV::VWADD_VX: case RISCV::VWSUB_VV: case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: // Vector Widening Integer Multiply Instructions // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: case RISCV::VWMULSU_VX: case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: { unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; RISCVII::VLMUL EMUL = IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; return OperandInfo(EMUL, Log2EEW); } // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: case RISCV::VWSUBU_WX: case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: case RISCV::VWSUB_WX: // Vector Widening Integer Multiply-Add Instructions // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. // Even though the add is a 2*SEW addition, the operands of the add are the // Dest which is 2*SEW and the result of the multiply which is 2*SEW. case RISCV::VWMACCU_VV: case RISCV::VWMACCU_VX: case RISCV::VWMACC_VV: case RISCV::VWMACC_VX: case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: case RISCV::VWMACCUS_VX: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; RISCVII::VLMUL EMUL = TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; return OperandInfo(EMUL, Log2EEW); } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: return getIntegerExtensionOperandInfo(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: return getIntegerExtensionOperandInfo(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: return getIntegerExtensionOperandInfo(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has // EEW=SEW EMUL=LMUL. case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: case RISCV::VNSRA_WI: case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: case RISCV::VNCLIP_WI: case RISCV::VNCLIP_WV: case RISCV::VNCLIP_WX: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; RISCVII::VLMUL EMUL = TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; return OperandInfo(EMUL, Log2EEW); } default: return {}; } } /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL. static bool isSupportedInstr(const MachineInstr &MI) { const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); if (!RVV) return false; switch (RVV->BaseInstr) { // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: case RISCV::VADD_VX: case RISCV::VSUB_VV: case RISCV::VSUB_VX: case RISCV::VRSUB_VI: case RISCV::VRSUB_VX: // Vector Widening Integer Add/Subtract case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: case RISCV::VWSUBU_VX: case RISCV::VWADD_VV: case RISCV::VWADD_VX: case RISCV::VWSUB_VV: case RISCV::VWSUB_VX: case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: case RISCV::VWSUBU_WX: case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: case RISCV::VWSUB_WX: // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions // FIXME: Add support // Vector Bitwise Logical Instructions // FIXME: Add support // Vector Single-Width Shift Instructions // FIXME: Add support case RISCV::VSLL_VI: // Vector Narrowing Integer Right Shift Instructions // FIXME: Add support case RISCV::VNSRL_WI: // Vector Integer Compare Instructions // FIXME: Add support // Vector Integer Min/Max Instructions case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: case RISCV::VMIN_VX: case RISCV::VMAXU_VV: case RISCV::VMAXU_VX: case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: case RISCV::VMULH_VX: case RISCV::VMULHU_VV: case RISCV::VMULHU_VX: case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: case RISCV::VDIV_VX: case RISCV::VREMU_VV: case RISCV::VREMU_VX: case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Widening Integer Multiply Instructions case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: case RISCV::VWMULSU_VX: case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Single-Width Integer Multiply-Add Instructions // FIXME: Add support // Vector Widening Integer Multiply-Add Instructions // FIXME: Add support case RISCV::VWMACC_VX: case RISCV::VWMACCU_VX: // Vector Integer Merge Instructions // FIXME: Add support // Vector Integer Move Instructions // FIXME: Add support case RISCV::VMV_V_I: case RISCV::VMV_V_X: // Vector Crypto case RISCV::VWSLL_VI: return true; } return false; } /// Return true if MO is a vector operand but is used as a scalar operand. static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) { MachineInstr *MI = MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI->getOpcode()); if (!RVV) return false; switch (RVV->BaseInstr) { // Reductions only use vs1[0] of vs1 case RISCV::VREDAND_VS: case RISCV::VREDMAX_VS: case RISCV::VREDMAXU_VS: case RISCV::VREDMIN_VS: case RISCV::VREDMINU_VS: case RISCV::VREDOR_VS: case RISCV::VREDSUM_VS: case RISCV::VREDXOR_VS: case RISCV::VWREDSUM_VS: case RISCV::VWREDSUMU_VS: case RISCV::VFREDMAX_VS: case RISCV::VFREDMIN_VS: case RISCV::VFREDOSUM_VS: case RISCV::VFREDUSUM_VS: case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: return MO.getOperandNo() == 3; default: return false; } } /// Return true if MI may read elements past VL. static bool mayReadPastVL(const MachineInstr &MI) { const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); if (!RVV) return true; switch (RVV->BaseInstr) { // vslidedown instructions may read elements past VL. They are handled // according to current tail policy. case RISCV::VSLIDEDOWN_VI: case RISCV::VSLIDEDOWN_VX: case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // vrgather instructions may read the source vector at any index < VLMAX, // regardless of VL. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: case RISCV::VRGATHEREI16_VV: return true; default: return false; } } bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { const MCInstrDesc &Desc = MI.getDesc(); if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) return false; if (MI.getNumDefs() != 1) return false; // If we're not using VLMAX, then we need to be careful whether we are using // TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it // does not matter whether we are using TA/TU with a non-undef Passthru, since // there are no tail elements to be perserved. unsigned VLOpNum = RISCVII::getVLOpNum(Desc); const MachineOperand &VLOp = MI.getOperand(VLOpNum); if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) { // If MI has a non-undef passthru, we will not try to optimize it since // that requires us to preserve tail elements according to TA/TU. // Otherwise, The MI has an undef Passthru, so it doesn't matter whether we // are using TA/TU. bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc); unsigned PassthruOpIdx = MI.getNumExplicitDefs(); if (HasPassthru && MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) { LLVM_DEBUG( dbgs() << " Not a candidate because it uses non-undef passthru" " with non-VLMAX VL\n"); return false; } } // If the VL is 1, then there is no need to reduce it. This is an // optimization, not needed to preserve correctness. if (VLOp.isImm() && VLOp.getImm() == 1) { LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n"); return false; } // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some // instructions, such as reductions, may write lanes past VL to a scalar // register. Other instructions, such as some loads or stores, may write // lower lanes using data from higher lanes. There may be other complex // semantics not mentioned here that make it hard to determine whether // the VL can be optimized. As a result, a white-list of supported // instructions is used. Over time, more instructions cam be supported // upon careful examination of their semantics under the logic in this // optimization. // TODO: Use a better approach than a white-list, such as adding // properties to instructions using something like TSFlags. if (!isSupportedInstr(MI)) { LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); return false; } LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n"); return true; } bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure // along lines of an instcombine style worklist which integrates the outer // pass. bool CanReduceVL = true; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); // Instructions like reductions may use a vector register as a scalar // register. In this case, we should treat it like a scalar register which // does not impact the decision on whether to optimize VL. if (isVectorOpUsedAsScalarOp(UserOp)) { [[maybe_unused]] Register R = UserOp.getReg(); [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); assert(RISCV::VRRegClass.hasSubClassEq(RC) && "Expect LMUL 1 register class for vector as scalar operands!"); LLVM_DEBUG(dbgs() << " Use this operand as a scalar operand\n"); continue; } if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); CanReduceVL = false; break; } // Tied operands might pass through. if (UserOp.isTied()) { LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); CanReduceVL = false; break; } const MCInstrDesc &Desc = UserMI.getDesc(); if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" " use VLMAX\n"); CanReduceVL = false; break; } unsigned VLOpNum = RISCVII::getVLOpNum(Desc); const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); // Looking for an immediate or a register VL that isn't X0. assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && "Did not expect X0 VL"); if (!CommonVL) { CommonVL = &VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); } else if (!CommonVL->isIdenticalTo(VLOp)) { // FIXME: This check requires all users to have the same VL. We can relax // this and get the largest VL amongst all users. LLVM_DEBUG(dbgs() << " Abort because users have different VL\n"); CanReduceVL = false; break; } // The SEW and LMUL of destination and source registers need to match. // We know that MI DEF is a vector register, because that was the guard // to call this function. assert(isVectorRegClass(UserMI.getOperand(0).getReg(), MRI) && "Expected DEF and USE to be vector registers"); OperandInfo ConsumerInfo = getOperandInfo(UserMI, UserOp, MRI); OperandInfo ProducerInfo = getOperandInfo(MI, MI.getOperand(0), MRI); if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " "information for EMUL or EEW.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); CanReduceVL = false; break; } } return CanReduceVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { SetVector Worklist; Worklist.insert(&OrigMI); bool MadeChange = false; while (!Worklist.empty()) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); const MachineOperand *CommonVL = nullptr; bool CanReduceVL = true; if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) CanReduceVL = checkUsers(CommonVL, MI); if (!CanReduceVL || !CommonVL) continue; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && "Expected VL to be an Imm or virtual Reg"); unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); MachineOperand &VLOp = MI.getOperand(VLOpNum); if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); continue; } if (CommonVL->isImm()) { LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " << CommonVL->getImm() << " for " << MI << "\n"); VLOp.ChangeToImmediate(CommonVL->getImm()); } else { const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); if (!MDT->dominates(VLMI, &MI)) continue; LLVM_DEBUG( dbgs() << " Reduce VL from " << VLOp << " to " << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) << " for " << MI << "\n"); // All our checks passed. We can reduce VL. VLOp.ChangeToRegister(CommonVL->getReg(), false); } MadeChange = true; // Now add all inputs to this instruction to the worklist. for (auto &Op : MI.operands()) { if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual()) continue; if (!isVectorRegClass(Op.getReg(), MRI)) continue; MachineInstr *DefMI = MRI->getVRegDef(Op.getReg()); if (!isCandidate(*DefMI)) continue; Worklist.insert(DefMI); } } return MadeChange; } bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { if (skipFunction(MF.getFunction())) return false; MRI = &MF.getRegInfo(); MDT = &getAnalysis().getDomTree(); const RISCVSubtarget &ST = MF.getSubtarget(); if (!ST.hasVInstructions()) return false; bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { // Visit instructions in reverse order. for (auto &MI : make_range(MBB.rbegin(), MBB.rend())) { if (!isCandidate(MI)) continue; MadeChange |= tryReduceVL(MI); } } return MadeChange; }