//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// // // This pass reduces the VL where possible at the MI level, before VSETVLI // instructions are inserted. // // The purpose of this optimization is to make the VL argument, for instructions // that have a VL argument, as small as possible. This is implemented by // visiting each instruction in reverse order and checking that if it has a VL // argument, whether the VL can be reduced. // //===---------------------------------------------------------------------===// #include "RISCV.h" #include "RISCVSubtarget.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/InitializePasses.h" using namespace llvm; #define DEBUG_TYPE "riscv-vl-optimizer" #define PASS_NAME "RISC-V VL Optimizer" namespace { class RISCVVLOptimizer : public MachineFunctionPass { const MachineRegisterInfo *MRI; const MachineDominatorTree *MDT; public: static char ID; RISCVVLOptimizer() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } StringRef getPassName() const override { return PASS_NAME; } private: std::optional getMinimumVLForUser(const MachineOperand &UserOp) const; /// Returns the largest common VL MachineOperand that may be used to optimize /// MI. Returns std::nullopt if it failed to find a suitable VL. std::optional checkUsers(const MachineInstr &MI) const; bool tryReduceVL(MachineInstr &MI) const; bool isCandidate(const MachineInstr &MI) const; /// For a given instruction, records what elements of it are demanded by /// downstream users. DenseMap> DemandedVLs; }; /// Represents the EMUL and EEW of a MachineOperand. struct OperandInfo { // Represent as 1,2,4,8, ... and fractional indicator. This is because // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly. // For example, a mask operand can have an EMUL less than MF8. std::optional> EMUL; unsigned Log2EEW; OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW) : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} OperandInfo(std::pair EMUL, unsigned Log2EEW) : EMUL(EMUL), Log2EEW(Log2EEW) {} OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} OperandInfo() = delete; static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { return A.Log2EEW == B.Log2EEW && A.EMUL == B.EMUL; } static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { return A.Log2EEW == B.Log2EEW; } void print(raw_ostream &OS) const { if (EMUL) { OS << "EMUL: m"; if (EMUL->second) OS << "f"; OS << EMUL->first; } else OS << "EMUL: unknown\n"; OS << ", EEW: " << (1 << Log2EEW); } }; } // end anonymous namespace char RISCVVLOptimizer::ID = 0; INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) FunctionPass *llvm::createRISCVVLOptimizerPass() { return new RISCVVLOptimizer(); } /// Return true if R is a physical or virtual vector register, false otherwise. static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { if (R.isPhysical()) return RISCV::VRRegClass.contains(R); const TargetRegisterClass *RC = MRI->getRegClass(R); return RISCVRI::isVRegClass(RC->TSFlags); } LLVM_ATTRIBUTE_UNUSED static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { OI.print(OS); return OS; } LLVM_ATTRIBUTE_UNUSED static raw_ostream &operator<<(raw_ostream &OS, const std::optional &OI) { if (OI) OI->print(OS); else OS << "nullopt"; return OS; } /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and /// SEW are from the TSFlags of MI. static std::pair getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags); auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL); unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); // Mask instructions will have 0 as the SEW operand. But the LMUL of these // instructions is calculated is as if the SEW operand was 3 (e8). if (MILog2SEW == 0) MILog2SEW = 3; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = 1 << Log2EEW; // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD // to put fraction in simplest form. unsigned Num = EEW, Denom = MISEW; int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL) : std::gcd(Num * MILMUL, Denom); Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD; Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD; return std::make_pair(Num > Denom ? Num : Denom, Denom > Num); } /// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). /// SEW comes from TSFlags of MI. static unsigned getIntegerExtensionOperandEEW(unsigned Factor, const MachineInstr &MI, const MachineOperand &MO) { unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) return MILog2SEW; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); return Log2EEW; } /// Check whether MO is a mask operand of MI. static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, const MachineRegisterInfo *MRI) { if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI)) return false; const MCInstrDesc &Desc = MI.getDesc(); return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } static std::optional getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); // MI has a SEW associated with it. The RVV specification defines // the EEW of each operand and definition in relation to MI.SEW. unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()); const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags); bool IsMODef = MO.getOperandNo() == 0 || (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs()); // All mask operands have EEW=1 if (isMaskOperand(MI, MO, MRI)) return 0; // switch against BaseInstr to reduce number of cases that need to be // considered. switch (RVV->BaseInstr) { // 6. Configuration-Setting Instructions // Configuration setting instructions do not read or write vector registers case RISCV::VSETIVLI: case RISCV::VSETVL: case RISCV::VSETVLI: llvm_unreachable("Configuration setting instructions do not read or write " "vector registers"); // Vector Loads and Stores // Vector Unit-Stride Instructions // Vector Strided Instructions /// Dest EEW encoded in the instruction case RISCV::VLM_V: case RISCV::VSM_V: return 0; case RISCV::VLE8_V: case RISCV::VSE8_V: case RISCV::VLSE8_V: case RISCV::VSSE8_V: return 3; case RISCV::VLE16_V: case RISCV::VSE16_V: case RISCV::VLSE16_V: case RISCV::VSSE16_V: return 4; case RISCV::VLE32_V: case RISCV::VSE32_V: case RISCV::VLSE32_V: case RISCV::VSSE32_V: return 5; case RISCV::VLE64_V: case RISCV::VSE64_V: case RISCV::VLSE64_V: case RISCV::VSSE64_V: return 6; // Vector Indexed Instructions // vs(o|u)xei.v // Dest/Data (operand 0) EEW=SEW. Source EEW=. case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: case RISCV::VSOXEI8_V: { if (MO.getOperandNo() == 0) return MILog2SEW; return 3; } case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: case RISCV::VSOXEI16_V: { if (MO.getOperandNo() == 0) return MILog2SEW; return 4; } case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: case RISCV::VSOXEI32_V: { if (MO.getOperandNo() == 0) return MILog2SEW; return 5; } case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: case RISCV::VSOXEI64_V: { if (MO.getOperandNo() == 0) return MILog2SEW; return 6; } // Vector Integer Arithmetic Instructions // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: case RISCV::VADD_VX: case RISCV::VSUB_VV: case RISCV::VSUB_VX: case RISCV::VRSUB_VI: case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions // EEW=SEW. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: case RISCV::VOR_VI: case RISCV::VOR_VV: case RISCV::VOR_VX: case RISCV::VXOR_VI: case RISCV::VXOR_VV: case RISCV::VXOR_VX: case RISCV::VSLL_VI: case RISCV::VSLL_VV: case RISCV::VSLL_VX: case RISCV::VSRL_VI: case RISCV::VSRL_VV: case RISCV::VSRL_VX: case RISCV::VSRA_VI: case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions // EEW=SEW. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: case RISCV::VMIN_VX: case RISCV::VMAXU_VV: case RISCV::VMAXU_VX: case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions // Source and Dest EEW=SEW. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: case RISCV::VMULH_VX: case RISCV::VMULHU_VV: case RISCV::VMULHU_VX: case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions // EEW=SEW. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: case RISCV::VDIV_VX: case RISCV::VREMU_VV: case RISCV::VREMU_VX: case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions // EEW=SEW. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: case RISCV::VNMSAC_VX: case RISCV::VMADD_VV: case RISCV::VMADD_VX: case RISCV::VNMSUB_VV: case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled // before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: case RISCV::VADC_VIM: case RISCV::VADC_VVM: case RISCV::VADC_VXM: case RISCV::VSBC_VVM: case RISCV::VSBC_VXM: // Vector Integer Move Instructions // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract // EEW=SEW. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: case RISCV::VSADDU_VI: case RISCV::VSADDU_VV: case RISCV::VSADDU_VX: case RISCV::VSADD_VI: case RISCV::VSADD_VV: case RISCV::VSADD_VX: case RISCV::VSSUBU_VV: case RISCV::VSSUBU_VX: case RISCV::VSSUB_VV: case RISCV::VSSUB_VX: case RISCV::VAADDU_VV: case RISCV::VAADDU_VX: case RISCV::VAADD_VV: case RISCV::VAADD_VX: case RISCV::VASUBU_VV: case RISCV::VASUBU_VX: case RISCV::VASUB_VV: case RISCV::VASUB_VX: // Vector Single-Width Fractional Multiply with Rounding and Saturation // EEW=SEW. The instruction produces 2*SEW product internally but // saturates to fit into SEW bits. case RISCV::VSMUL_VV: case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions // EEW=SEW. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: case RISCV::VSSRA_VI: case RISCV::VSSRA_VV: case RISCV::VSSRA_VX: // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions // EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions // EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: case RISCV::VSLIDEDOWN_VX: case RISCV::VSLIDE1UP_VX: case RISCV::VFSLIDE1UP_VF: case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions // EEW=SEW. For mask operand, EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction // EEW=SEW. case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: // Vector Single-Width Floating-Point Add/Subtract Instructions case RISCV::VFADD_VF: case RISCV::VFADD_VV: case RISCV::VFSUB_VF: case RISCV::VFSUB_VV: case RISCV::VFRSUB_VF: // Vector Single-Width Floating-Point Multiply/Divide Instructions case RISCV::VFMUL_VF: case RISCV::VFMUL_VV: case RISCV::VFDIV_VF: case RISCV::VFDIV_VV: case RISCV::VFRDIV_VF: // Vector Single-Width Floating-Point Fused Multiply-Add Instructions case RISCV::VFMACC_VV: case RISCV::VFMACC_VF: case RISCV::VFNMACC_VV: case RISCV::VFNMACC_VF: case RISCV::VFMSAC_VV: case RISCV::VFMSAC_VF: case RISCV::VFNMSAC_VV: case RISCV::VFNMSAC_VF: case RISCV::VFMADD_VV: case RISCV::VFMADD_VF: case RISCV::VFNMADD_VV: case RISCV::VFNMADD_VF: case RISCV::VFMSUB_VV: case RISCV::VFMSUB_VF: case RISCV::VFNMSUB_VV: case RISCV::VFNMSUB_VF: // Vector Floating-Point Square-Root Instruction case RISCV::VFSQRT_V: // Vector Floating-Point Reciprocal Square-Root Estimate Instruction case RISCV::VFRSQRT7_V: // Vector Floating-Point Reciprocal Estimate Instruction case RISCV::VFREC7_V: // Vector Floating-Point MIN/MAX Instructions case RISCV::VFMIN_VF: case RISCV::VFMIN_VV: case RISCV::VFMAX_VF: case RISCV::VFMAX_VV: // Vector Floating-Point Sign-Injection Instructions case RISCV::VFSGNJ_VF: case RISCV::VFSGNJ_VV: case RISCV::VFSGNJN_VV: case RISCV::VFSGNJN_VF: case RISCV::VFSGNJX_VF: case RISCV::VFSGNJX_VV: // Vector Floating-Point Classify Instruction case RISCV::VFCLASS_V: // Vector Floating-Point Move Instruction case RISCV::VFMV_V_F: // Single-Width Floating-Point/Integer Type-Convert Instructions case RISCV::VFCVT_XU_F_V: case RISCV::VFCVT_X_F_V: case RISCV::VFCVT_RTZ_XU_F_V: case RISCV::VFCVT_RTZ_X_F_V: case RISCV::VFCVT_F_XU_V: case RISCV::VFCVT_F_X_V: // Vector Floating-Point Merge Instruction case RISCV::VFMERGE_VFM: // Vector count population in mask vcpop.m // vfirst find-first-set mask bit case RISCV::VCPOP_M: case RISCV::VFIRST_M: return MILog2SEW; // Vector Widening Integer Add/Subtract // Def uses EEW=2*SEW . Operands use EEW=SEW. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: case RISCV::VWSUBU_VX: case RISCV::VWADD_VV: case RISCV::VWADD_VX: case RISCV::VWSUB_VV: case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: case RISCV::VWSLL_VX: case RISCV::VWSLL_VV: // Vector Widening Integer Multiply Instructions // Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: case RISCV::VWMULSU_VX: case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Widening Integer Multiply-Add Instructions // Destination EEW=2*SEW. Source EEW=SEW. // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which // is then added to the 2*SEW-bit Dest. These instructions never have a // passthru operand. case RISCV::VWMACCU_VV: case RISCV::VWMACCU_VX: case RISCV::VWMACC_VV: case RISCV::VWMACC_VX: case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: case RISCV::VWMACCUS_VX: // Vector Widening Floating-Point Fused Multiply-Add Instructions case RISCV::VFWMACC_VF: case RISCV::VFWMACC_VV: case RISCV::VFWNMACC_VF: case RISCV::VFWNMACC_VV: case RISCV::VFWMSAC_VF: case RISCV::VFWMSAC_VV: case RISCV::VFWNMSAC_VF: case RISCV::VFWNMSAC_VV: case RISCV::VFWMACCBF16_VV: case RISCV::VFWMACCBF16_VF: // Vector Widening Floating-Point Add/Subtract Instructions // Dest EEW=2*SEW. Source EEW=SEW. case RISCV::VFWADD_VV: case RISCV::VFWADD_VF: case RISCV::VFWSUB_VV: case RISCV::VFWSUB_VF: // Vector Widening Floating-Point Multiply case RISCV::VFWMUL_VF: case RISCV::VFWMUL_VV: // Widening Floating-Point/Integer Type-Convert Instructions case RISCV::VFWCVT_XU_F_V: case RISCV::VFWCVT_X_F_V: case RISCV::VFWCVT_RTZ_XU_F_V: case RISCV::VFWCVT_RTZ_X_F_V: case RISCV::VFWCVT_F_XU_V: case RISCV::VFWCVT_F_X_V: case RISCV::VFWCVT_F_F_V: case RISCV::VFWCVTBF16_F_F_V: return IsMODef ? MILog2SEW + 1 : MILog2SEW; // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: case RISCV::VWSUBU_WX: case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: case RISCV::VWSUB_WX: // Vector Widening Floating-Point Add/Subtract Instructions case RISCV::VFWADD_WF: case RISCV::VFWADD_WV: case RISCV::VFWSUB_WF: case RISCV::VFWSUB_WV: { bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; return TwoTimes ? MILog2SEW + 1 : MILog2SEW; } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: return getIntegerExtensionOperandEEW(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: return getIntegerExtensionOperandEEW(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: return getIntegerExtensionOperandEEW(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: case RISCV::VNSRA_WI: case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: case RISCV::VNCLIP_WI: case RISCV::VNCLIP_WV: case RISCV::VNCLIP_WX: // Narrowing Floating-Point/Integer Type-Convert Instructions case RISCV::VFNCVT_XU_F_W: case RISCV::VFNCVT_X_F_W: case RISCV::VFNCVT_RTZ_XU_F_W: case RISCV::VFNCVT_RTZ_X_F_W: case RISCV::VFNCVT_F_XU_W: case RISCV::VFNCVT_F_X_W: case RISCV::VFNCVT_F_F_W: case RISCV::VFNCVT_ROD_F_F_W: case RISCV::VFNCVTBF16_F_F_W: { assert(!IsTied); bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; return TwoTimes ? MILog2SEW + 1 : MILog2SEW; } // Vector Mask Instructions // Vector Mask-Register Logical Instructions // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit // EEW=1 // We handle the cases when operand is a v0 mask operand above the switch, // but these instructions may use non-v0 mask operands and need to be handled // specifically. case RISCV::VMAND_MM: case RISCV::VMNAND_MM: case RISCV::VMANDN_MM: case RISCV::VMXOR_MM: case RISCV::VMOR_MM: case RISCV::VMNOR_MM: case RISCV::VMORN_MM: case RISCV::VMXNOR_MM: case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: { return MILog2SEW; } // Vector Iota Instruction // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled // before this switch. case RISCV::VIOTA_M: { if (IsMODef || MO.getOperandNo() == 1) return MILog2SEW; return 0; } // Vector Integer Compare Instructions // Dest EEW=1. Source EEW=SEW. case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: case RISCV::VMSNE_VI: case RISCV::VMSNE_VV: case RISCV::VMSNE_VX: case RISCV::VMSLTU_VV: case RISCV::VMSLTU_VX: case RISCV::VMSLT_VV: case RISCV::VMSLT_VX: case RISCV::VMSLEU_VV: case RISCV::VMSLEU_VI: case RISCV::VMSLEU_VX: case RISCV::VMSLE_VV: case RISCV::VMSLE_VI: case RISCV::VMSLE_VX: case RISCV::VMSGTU_VI: case RISCV::VMSGTU_VX: case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: // Dest EEW=1. Source EEW=SEW. case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: case RISCV::VMSBC_VX: // 13.13. Vector Floating-Point Compare Instructions // Dest EEW=1. Source EEW=SEW case RISCV::VMFEQ_VF: case RISCV::VMFEQ_VV: case RISCV::VMFNE_VF: case RISCV::VMFNE_VV: case RISCV::VMFLT_VF: case RISCV::VMFLT_VV: case RISCV::VMFLE_VF: case RISCV::VMFLE_VV: case RISCV::VMFGT_VF: case RISCV::VMFGE_VF: { if (IsMODef) return 0; return MILog2SEW; } // Vector Reduction Operations // Vector Single-Width Integer Reduction Instructions case RISCV::VREDAND_VS: case RISCV::VREDMAX_VS: case RISCV::VREDMAXU_VS: case RISCV::VREDMIN_VS: case RISCV::VREDMINU_VS: case RISCV::VREDOR_VS: case RISCV::VREDSUM_VS: case RISCV::VREDXOR_VS: // Vector Single-Width Floating-Point Reduction Instructions case RISCV::VFREDMAX_VS: case RISCV::VFREDMIN_VS: case RISCV::VFREDOSUM_VS: case RISCV::VFREDUSUM_VS: { return MILog2SEW; } // Vector Widening Integer Reduction Instructions // The Dest and VS1 read only element 0 for the vector register. Return // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL. case RISCV::VWREDSUM_VS: case RISCV::VWREDSUMU_VS: // Vector Widening Floating-Point Reduction Instructions case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: { bool TwoTimes = IsMODef || MO.getOperandNo() == 3; return TwoTimes ? MILog2SEW + 1 : MILog2SEW; } default: return std::nullopt; } } static std::optional getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); std::optional Log2EEW = getOperandLog2EEW(MO, MRI); if (!Log2EEW) return std::nullopt; switch (RVV->BaseInstr) { // Vector Reduction Operations // Vector Single-Width Integer Reduction Instructions // Vector Widening Integer Reduction Instructions // Vector Widening Floating-Point Reduction Instructions // The Dest and VS1 only read element 0 of the vector register. Return just // the EEW for these. case RISCV::VREDAND_VS: case RISCV::VREDMAX_VS: case RISCV::VREDMAXU_VS: case RISCV::VREDMIN_VS: case RISCV::VREDMINU_VS: case RISCV::VREDOR_VS: case RISCV::VREDSUM_VS: case RISCV::VREDXOR_VS: case RISCV::VWREDSUM_VS: case RISCV::VWREDSUMU_VS: case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: if (MO.getOperandNo() != 2) return OperandInfo(*Log2EEW); break; }; // All others have EMUL=EEW/SEW*LMUL return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW); } /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL. static bool isSupportedInstr(const MachineInstr &MI) { const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); if (!RVV) return false; switch (RVV->BaseInstr) { // Vector Unit-Stride Instructions // Vector Strided Instructions case RISCV::VLM_V: case RISCV::VLE8_V: case RISCV::VLSE8_V: case RISCV::VLE16_V: case RISCV::VLSE16_V: case RISCV::VLE32_V: case RISCV::VLSE32_V: case RISCV::VLE64_V: case RISCV::VLSE64_V: // Vector Indexed Instructions case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: { for (const MachineMemOperand *MMO : MI.memoperands()) if (MMO->isVolatile()) return false; return true; } // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: case RISCV::VADD_VX: case RISCV::VSUB_VV: case RISCV::VSUB_VX: case RISCV::VRSUB_VI: case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: case RISCV::VOR_VI: case RISCV::VOR_VV: case RISCV::VOR_VX: case RISCV::VXOR_VI: case RISCV::VXOR_VV: case RISCV::VXOR_VX: case RISCV::VSLL_VI: case RISCV::VSLL_VV: case RISCV::VSLL_VX: case RISCV::VSRL_VI: case RISCV::VSRL_VV: case RISCV::VSRL_VX: case RISCV::VSRA_VI: case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Widening Integer Add/Subtract case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: case RISCV::VWSUBU_VX: case RISCV::VWADD_VV: case RISCV::VWADD_VX: case RISCV::VWSUB_VV: case RISCV::VWSUB_VX: case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: case RISCV::VWSUBU_WX: case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: case RISCV::VWSUB_WX: // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions // FIXME: Add support case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: case RISCV::VMSBC_VX: // Vector Narrowing Integer Right Shift Instructions case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: case RISCV::VNSRA_WI: case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Integer Compare Instructions case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: case RISCV::VMSNE_VI: case RISCV::VMSNE_VV: case RISCV::VMSNE_VX: case RISCV::VMSLTU_VV: case RISCV::VMSLTU_VX: case RISCV::VMSLT_VV: case RISCV::VMSLT_VX: case RISCV::VMSLEU_VV: case RISCV::VMSLEU_VI: case RISCV::VMSLEU_VX: case RISCV::VMSLE_VV: case RISCV::VMSLE_VI: case RISCV::VMSLE_VX: case RISCV::VMSGTU_VI: case RISCV::VMSGTU_VX: case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Min/Max Instructions case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: case RISCV::VMIN_VX: case RISCV::VMAXU_VV: case RISCV::VMAXU_VX: case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: case RISCV::VMULH_VX: case RISCV::VMULHU_VV: case RISCV::VMULHU_VX: case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: case RISCV::VDIV_VX: case RISCV::VREMU_VV: case RISCV::VREMU_VX: case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Widening Integer Multiply Instructions case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: case RISCV::VWMULSU_VX: case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Single-Width Integer Multiply-Add Instructions case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: case RISCV::VNMSAC_VX: case RISCV::VMADD_VV: case RISCV::VMADD_VX: case RISCV::VNMSUB_VV: case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions case RISCV::VADC_VIM: case RISCV::VADC_VVM: case RISCV::VADC_VXM: case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VSBC_VVM: case RISCV::VSBC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: // Vector Widening Integer Multiply-Add Instructions case RISCV::VWMACCU_VV: case RISCV::VWMACCU_VX: case RISCV::VWMACC_VV: case RISCV::VWMACC_VX: case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: case RISCV::VWMACCUS_VX: // Vector Integer Merge Instructions // FIXME: Add support // Vector Integer Move Instructions // FIXME: Add support case RISCV::VMV_V_I: case RISCV::VMV_V_X: case RISCV::VMV_V_V: // Vector Single-Width Saturating Add and Subtract case RISCV::VSADDU_VV: case RISCV::VSADDU_VX: case RISCV::VSADDU_VI: case RISCV::VSADD_VV: case RISCV::VSADD_VX: case RISCV::VSADD_VI: case RISCV::VSSUBU_VV: case RISCV::VSSUBU_VX: case RISCV::VSSUB_VV: case RISCV::VSSUB_VX: // Vector Single-Width Averaging Add and Subtract case RISCV::VAADDU_VV: case RISCV::VAADDU_VX: case RISCV::VAADD_VV: case RISCV::VAADD_VX: case RISCV::VASUBU_VV: case RISCV::VASUBU_VX: case RISCV::VASUB_VV: case RISCV::VASUB_VX: // Vector Single-Width Fractional Multiply with Rounding and Saturation case RISCV::VSMUL_VV: case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: case RISCV::VSSRL_VI: case RISCV::VSSRA_VV: case RISCV::VSSRA_VX: case RISCV::VSSRA_VI: // Vector Narrowing Fixed-Point Clip Instructions case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: case RISCV::VNCLIPU_WI: case RISCV::VNCLIP_WV: case RISCV::VNCLIP_WX: case RISCV::VNCLIP_WI: // Vector Crypto case RISCV::VWSLL_VI: case RISCV::VWSLL_VX: case RISCV::VWSLL_VV: // Vector Mask Instructions // Vector Mask-Register Logical Instructions // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit // Vector Iota Instruction // Vector Element Index Instruction case RISCV::VMAND_MM: case RISCV::VMNAND_MM: case RISCV::VMANDN_MM: case RISCV::VMXOR_MM: case RISCV::VMOR_MM: case RISCV::VMNOR_MM: case RISCV::VMORN_MM: case RISCV::VMXNOR_MM: case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: case RISCV::VIOTA_M: case RISCV::VID_V: // Vector Slide Instructions case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEDOWN_VX: case RISCV::VSLIDEDOWN_VI: case RISCV::VSLIDE1UP_VX: case RISCV::VFSLIDE1UP_VF: // Vector Single-Width Floating-Point Add/Subtract Instructions case RISCV::VFADD_VF: case RISCV::VFADD_VV: case RISCV::VFSUB_VF: case RISCV::VFSUB_VV: case RISCV::VFRSUB_VF: // Vector Widening Floating-Point Add/Subtract Instructions case RISCV::VFWADD_VV: case RISCV::VFWADD_VF: case RISCV::VFWSUB_VV: case RISCV::VFWSUB_VF: case RISCV::VFWADD_WF: case RISCV::VFWADD_WV: case RISCV::VFWSUB_WF: case RISCV::VFWSUB_WV: // Vector Single-Width Floating-Point Multiply/Divide Instructions case RISCV::VFMUL_VF: case RISCV::VFMUL_VV: case RISCV::VFDIV_VF: case RISCV::VFDIV_VV: case RISCV::VFRDIV_VF: // Vector Widening Floating-Point Multiply case RISCV::VFWMUL_VF: case RISCV::VFWMUL_VV: // Vector Single-Width Floating-Point Fused Multiply-Add Instructions case RISCV::VFMACC_VV: case RISCV::VFMACC_VF: case RISCV::VFNMACC_VV: case RISCV::VFNMACC_VF: case RISCV::VFMSAC_VV: case RISCV::VFMSAC_VF: case RISCV::VFNMSAC_VV: case RISCV::VFNMSAC_VF: case RISCV::VFMADD_VV: case RISCV::VFMADD_VF: case RISCV::VFNMADD_VV: case RISCV::VFNMADD_VF: case RISCV::VFMSUB_VV: case RISCV::VFMSUB_VF: case RISCV::VFNMSUB_VV: case RISCV::VFNMSUB_VF: // Vector Widening Floating-Point Fused Multiply-Add Instructions case RISCV::VFWMACC_VV: case RISCV::VFWMACC_VF: case RISCV::VFWNMACC_VV: case RISCV::VFWNMACC_VF: case RISCV::VFWMSAC_VV: case RISCV::VFWMSAC_VF: case RISCV::VFWNMSAC_VV: case RISCV::VFWNMSAC_VF: case RISCV::VFWMACCBF16_VV: case RISCV::VFWMACCBF16_VF: // Vector Floating-Point Square-Root Instruction case RISCV::VFSQRT_V: // Vector Floating-Point Reciprocal Square-Root Estimate Instruction case RISCV::VFRSQRT7_V: // Vector Floating-Point Reciprocal Estimate Instruction case RISCV::VFREC7_V: // Vector Floating-Point MIN/MAX Instructions case RISCV::VFMIN_VF: case RISCV::VFMIN_VV: case RISCV::VFMAX_VF: case RISCV::VFMAX_VV: // Vector Floating-Point Sign-Injection Instructions case RISCV::VFSGNJ_VF: case RISCV::VFSGNJ_VV: case RISCV::VFSGNJN_VV: case RISCV::VFSGNJN_VF: case RISCV::VFSGNJX_VF: case RISCV::VFSGNJX_VV: // Vector Floating-Point Compare Instructions case RISCV::VMFEQ_VF: case RISCV::VMFEQ_VV: case RISCV::VMFNE_VF: case RISCV::VMFNE_VV: case RISCV::VMFLT_VF: case RISCV::VMFLT_VV: case RISCV::VMFLE_VF: case RISCV::VMFLE_VV: case RISCV::VMFGT_VF: case RISCV::VMFGE_VF: // Vector Floating-Point Classify Instruction case RISCV::VFCLASS_V: // Vector Floating-Point Merge Instruction case RISCV::VFMERGE_VFM: // Vector Floating-Point Move Instruction case RISCV::VFMV_V_F: // Single-Width Floating-Point/Integer Type-Convert Instructions case RISCV::VFCVT_XU_F_V: case RISCV::VFCVT_X_F_V: case RISCV::VFCVT_RTZ_XU_F_V: case RISCV::VFCVT_RTZ_X_F_V: case RISCV::VFCVT_F_XU_V: case RISCV::VFCVT_F_X_V: // Widening Floating-Point/Integer Type-Convert Instructions case RISCV::VFWCVT_XU_F_V: case RISCV::VFWCVT_X_F_V: case RISCV::VFWCVT_RTZ_XU_F_V: case RISCV::VFWCVT_RTZ_X_F_V: case RISCV::VFWCVT_F_XU_V: case RISCV::VFWCVT_F_X_V: case RISCV::VFWCVT_F_F_V: case RISCV::VFWCVTBF16_F_F_V: // Narrowing Floating-Point/Integer Type-Convert Instructions case RISCV::VFNCVT_XU_F_W: case RISCV::VFNCVT_X_F_W: case RISCV::VFNCVT_RTZ_XU_F_W: case RISCV::VFNCVT_RTZ_X_F_W: case RISCV::VFNCVT_F_XU_W: case RISCV::VFNCVT_F_X_W: case RISCV::VFNCVT_F_F_W: case RISCV::VFNCVT_ROD_F_F_W: case RISCV::VFNCVTBF16_F_F_W: return true; } return false; } /// Return true if MO is a vector operand but is used as a scalar operand. static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) { const MachineInstr *MI = MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI->getOpcode()); if (!RVV) return false; switch (RVV->BaseInstr) { // Reductions only use vs1[0] of vs1 case RISCV::VREDAND_VS: case RISCV::VREDMAX_VS: case RISCV::VREDMAXU_VS: case RISCV::VREDMIN_VS: case RISCV::VREDMINU_VS: case RISCV::VREDOR_VS: case RISCV::VREDSUM_VS: case RISCV::VREDXOR_VS: case RISCV::VWREDSUM_VS: case RISCV::VWREDSUMU_VS: case RISCV::VFREDMAX_VS: case RISCV::VFREDMIN_VS: case RISCV::VFREDOSUM_VS: case RISCV::VFREDUSUM_VS: case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: return MO.getOperandNo() == 3; case RISCV::VMV_X_S: case RISCV::VFMV_F_S: return MO.getOperandNo() == 1; default: return false; } } /// Return true if MI may read elements past VL. static bool mayReadPastVL(const MachineInstr &MI) { const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); if (!RVV) return true; switch (RVV->BaseInstr) { // vslidedown instructions may read elements past VL. They are handled // according to current tail policy. case RISCV::VSLIDEDOWN_VI: case RISCV::VSLIDEDOWN_VX: case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // vrgather instructions may read the source vector at any index < VLMAX, // regardless of VL. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: case RISCV::VRGATHEREI16_VV: return true; default: return false; } } bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { const MCInstrDesc &Desc = MI.getDesc(); if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) return false; if (MI.getNumExplicitDefs() != 1) return false; // Some instructions have implicit defs e.g. $vxsat. If they might be read // later then we can't reduce VL. if (!MI.allImplicitDefsAreDead()) { LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n"); return false; } if (MI.mayRaiseFPException()) { LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); return false; } // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some // instructions, such as reductions, may write lanes past VL to a scalar // register. Other instructions, such as some loads or stores, may write // lower lanes using data from higher lanes. There may be other complex // semantics not mentioned here that make it hard to determine whether // the VL can be optimized. As a result, a white-list of supported // instructions is used. Over time, more instructions can be supported // upon careful examination of their semantics under the logic in this // optimization. // TODO: Use a better approach than a white-list, such as adding // properties to instructions using something like TSFlags. if (!isSupportedInstr(MI)) { LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); return false; } assert(!RISCVII::elementsDependOnVL(RISCV::getRVVMCOpcode(MI.getOpcode())) && "Instruction shouldn't be supported if elements depend on VL"); assert(MI.getOperand(0).isReg() && isVectorRegClass(MI.getOperand(0).getReg(), MRI) && "All supported instructions produce a vector register result"); LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n"); return true; } std::optional RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { const MachineInstr &UserMI = *UserOp.getParent(); const MCInstrDesc &Desc = UserMI.getDesc(); if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" " use VLMAX\n"); return std::nullopt; } if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); return std::nullopt; } unsigned VLOpNum = RISCVII::getVLOpNum(Desc); const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); // Looking for an immediate or a register VL that isn't X0. assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && "Did not expect X0 VL"); // If the user is a passthru it will read the elements past VL, so // abort if any of the elements past VL are demanded. if (UserOp.isTied()) { assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() && RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc())); auto DemandedVL = DemandedVLs.lookup(&UserMI); if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) { LLVM_DEBUG(dbgs() << " Abort because user is passthru in " "instruction with demanded tail\n"); return std::nullopt; } } // Instructions like reductions may use a vector register as a scalar // register. In this case, we should treat it as only reading the first lane. if (isVectorOpUsedAsScalarOp(UserOp)) { LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); return MachineOperand::CreateImm(1); } // If we know the demanded VL of UserMI, then we can reduce the VL it // requires. if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) { assert(isCandidate(UserMI)); if (RISCV::isVLKnownLE(*DemandedVL, VLOp)) return DemandedVL; } return VLOp; } std::optional RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { std::optional CommonVL; SmallSetVector Worklist; SmallPtrSet PHISeen; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) Worklist.insert(&UserOp); while (!Worklist.empty()) { MachineOperand &UserOp = *Worklist.pop_back_val(); const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); if (UserMI.isFullCopy() && UserMI.getOperand(0).getReg().isVirtual()) { LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n"); Worklist.insert_range(llvm::make_pointer_range( MRI->use_operands(UserMI.getOperand(0).getReg()))); continue; } if (UserMI.isPHI()) { // Don't follow PHI cycles if (!PHISeen.insert(&UserMI).second) continue; LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n"); Worklist.insert_range(llvm::make_pointer_range( MRI->use_operands(UserMI.getOperand(0).getReg()))); continue; } auto VLOp = getMinimumVLForUser(UserOp); if (!VLOp) return std::nullopt; // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); return std::nullopt; } if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); return std::nullopt; } std::optional ConsumerInfo = getOperandInfo(UserOp, MRI); std::optional ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); if (!ConsumerInfo || !ProducerInfo) { LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); return std::nullopt; } // If the operand is used as a scalar operand, then the EEW must be // compatible. Otherwise, the EMUL *and* EEW must be compatible. bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); if ((IsVectorOpUsedAsScalarOp && !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || (!IsVectorOpUsedAsScalarOp && !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { LLVM_DEBUG( dbgs() << " Abort due to incompatible information for EMUL or EEW.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); return std::nullopt; } } return CommonVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); MachineOperand &VLOp = MI.getOperand(VLOpNum); // If the VL is 1, then there is no need to reduce it. This is an // optimization, not needed to preserve correctness. if (VLOp.isImm() && VLOp.getImm() == 1) { LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n"); return false; } auto CommonVL = DemandedVLs.lookup(&MI); if (!CommonVL) return false; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && "Expected VL to be an Imm or virtual Reg"); if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); return false; } if (CommonVL->isIdenticalTo(VLOp)) { LLVM_DEBUG( dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n"); return false; } if (CommonVL->isImm()) { LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " << CommonVL->getImm() << " for " << MI << "\n"); VLOp.ChangeToImmediate(CommonVL->getImm()); return true; } const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); if (!MDT->dominates(VLMI, &MI)) return false; LLVM_DEBUG( dbgs() << " Reduce VL from " << VLOp << " to " << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) << " for " << MI << "\n"); // All our checks passed. We can reduce VL. VLOp.ChangeToRegister(CommonVL->getReg(), false); return true; } bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { assert(DemandedVLs.size() == 0); if (skipFunction(MF.getFunction())) return false; MRI = &MF.getRegInfo(); MDT = &getAnalysis().getDomTree(); const RISCVSubtarget &ST = MF.getSubtarget(); if (!ST.hasVInstructions()) return false; // For each instruction that defines a vector, compute what VL its // downstream users demand. for (MachineBasicBlock *MBB : post_order(&MF)) { assert(MDT->isReachableFromEntry(MBB)); for (MachineInstr &MI : reverse(*MBB)) { if (!isCandidate(MI)) continue; DemandedVLs.insert({&MI, checkUsers(MI)}); } } // Then go through and see if we can reduce the VL of any instructions to // only what's demanded. bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { // Avoid unreachable blocks as they have degenerate dominance if (!MDT->isReachableFromEntry(&MBB)) continue; for (auto &MI : reverse(MBB)) { if (!isCandidate(MI)) continue; if (!tryReduceVL(MI)) continue; MadeChange = true; } } DemandedVLs.clear(); return MadeChange; }