diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 569 |
1 files changed, 389 insertions, 180 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 4e3212c..ad61a77 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -50,7 +50,10 @@ public: StringRef getPassName() const override { return PASS_NAME; } private: - bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); + std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp); + /// Returns the largest common VL MachineOperand that may be used to optimize + /// MI. Returns std::nullopt if it failed to find a suitable VL. + std::optional<MachineOperand> checkUsers(MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; @@ -76,11 +79,6 @@ static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { /// Represents the EMUL and EEW of a MachineOperand. struct OperandInfo { - enum class State { - Unknown, - Known, - } S; - // Represent as 1,2,4,8, ... and fractional indicator. This is because // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. // For example, a mask operand can have an EMUL less than MF8. @@ -89,34 +87,32 @@ struct OperandInfo { unsigned Log2EEW; OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) { - } + : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) - : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} + : EMUL(EMUL), Log2EEW(Log2EEW) {} - OperandInfo() : S(State::Unknown) {} + OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} - bool isUnknown() const { return S == State::Unknown; } - bool isKnown() const { return S == State::Known; } + OperandInfo() = delete; static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { - assert(A.isKnown() && B.isKnown() && "Both operands must be known"); - return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && A.EMUL->second == B.EMUL->second; } + static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { + return A.Log2EEW == B.Log2EEW; + } + void print(raw_ostream &OS) const { - if (isUnknown()) { - OS << "Unknown"; - return; - } - assert(EMUL && "Expected EMUL to have value"); - OS << "EMUL: m"; - if (EMUL->second) - OS << "f"; - OS << EMUL->first; + if (EMUL) { + OS << "EMUL: m"; + if (EMUL->second) + OS << "f"; + OS << EMUL->first; + } else + OS << "EMUL: unknown\n"; OS << ", EEW: " << (1 << Log2EEW); } }; @@ -127,30 +123,18 @@ static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { return OS; } -namespace llvm { -namespace RISCVVType { -/// Return the RISCVII::VLMUL that is two times VLMul. -/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8. -static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) { - switch (VLMul) { - case RISCVII::VLMUL::LMUL_F8: - return RISCVII::VLMUL::LMUL_F4; - case RISCVII::VLMUL::LMUL_F4: - return RISCVII::VLMUL::LMUL_F2; - case RISCVII::VLMUL::LMUL_F2: - return RISCVII::VLMUL::LMUL_1; - case RISCVII::VLMUL::LMUL_1: - return RISCVII::VLMUL::LMUL_2; - case RISCVII::VLMUL::LMUL_2: - return RISCVII::VLMUL::LMUL_4; - case RISCVII::VLMUL::LMUL_4: - return RISCVII::VLMUL::LMUL_8; - case RISCVII::VLMUL::LMUL_8: - default: - llvm_unreachable("Could not multiply VLMul by 2"); - } +LLVM_ATTRIBUTE_UNUSED +static raw_ostream &operator<<(raw_ostream &OS, + const std::optional<OperandInfo> &OI) { + if (OI) + OI->print(OS); + else + OS << "nullopt"; + return OS; } +namespace llvm { +namespace RISCVVType { /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and /// SEW are from the TSFlags of MI. static std::pair<unsigned, bool> @@ -180,24 +164,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { } // end namespace RISCVVType } // end namespace llvm -/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). -/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. -static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, - const MachineInstr &MI, - const MachineOperand &MO) { - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); +/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). +/// SEW comes from TSFlags of MI. +static unsigned getIntegerExtensionOperandEEW(unsigned Factor, + const MachineInstr &MI, + const MachineOperand &MO) { unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } /// Check whether MO is a mask operand of MI. @@ -211,18 +193,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } -/// Return the OperandInfo for MO. -static OperandInfo getOperandInfo(const MachineOperand &MO, - const MachineRegisterInfo *MRI) { +static std::optional<unsigned> +getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); - // MI has a VLMUL and SEW associated with it. The RVV specification defines - // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and - // MI.SEW. - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); + // MI has a SEW associated with it. The RVV specification defines + // the EEW of each operand and definition in relation to MI.SEW. unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); @@ -233,13 +212,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // since they must preserve the entire register content. if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && (MO.getReg() != RISCV::NoRegister)) - return {}; + return std::nullopt; bool IsMODef = MO.getOperandNo() == 0; - // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL + // All mask operands have EEW=1 if (isMaskOperand(MI, MO, MRI)) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; // switch against BaseInstr to reduce number of cases that need to be // considered. @@ -256,55 +235,65 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Loads and Stores // Vector Unit-Stride Instructions // Vector Strided Instructions - /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL + /// Dest EEW encoded in the instruction + case RISCV::VLM_V: + case RISCV::VSM_V: + return 0; + case RISCV::VLE8_V: case RISCV::VSE8_V: + case RISCV::VLSE8_V: case RISCV::VSSE8_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return 3; + case RISCV::VLE16_V: case RISCV::VSE16_V: + case RISCV::VLSE16_V: case RISCV::VSSE16_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return 4; + case RISCV::VLE32_V: case RISCV::VSE32_V: + case RISCV::VLSE32_V: case RISCV::VSSE32_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return 5; + case RISCV::VLE64_V: case RISCV::VSE64_V: + case RISCV::VLSE64_V: case RISCV::VSSE64_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return 6; // Vector Indexed Instructions // vs(o|u)xei<eew>.v - // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW=<eew> and - // EMUL=(EEW/SEW)*LMUL. + // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>. case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: case RISCV::VSOXEI8_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return MILog2SEW; + return 3; } case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: case RISCV::VSOXEI16_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return MILog2SEW; + return 4; } case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: case RISCV::VSOXEI32_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return MILog2SEW; + return 5; } case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: case RISCV::VSOXEI64_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return MILog2SEW; + return 6; } // Vector Integer Arithmetic Instructions @@ -318,7 +307,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: @@ -338,7 +327,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: @@ -348,7 +337,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions - // Source and Dest EEW=SEW and EMUL=LMUL. + // Source and Dest EEW=SEW. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: @@ -358,7 +347,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: @@ -368,7 +357,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: @@ -379,8 +368,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled + // before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: @@ -393,7 +382,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: @@ -415,8 +404,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VASUBU_VX: case RISCV::VASUB_VV: case RISCV::VASUB_VX: + // Vector Single-Width Fractional Multiply with Rounding and Saturation + // EEW=SEW. The instruction produces 2*SEW product internally but + // saturates to fit into SEW bits. + case RISCV::VSMUL_VV: + case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: @@ -426,13 +420,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: @@ -442,19 +436,62 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions - // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. + // EEW=SEW. For mask operand, EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: - return OperandInfo(MIVLMul, MILog2SEW); + // Vector Single-Width Floating-Point Add/Subtract Instructions + case RISCV::VFADD_VF: + case RISCV::VFADD_VV: + case RISCV::VFSUB_VF: + case RISCV::VFSUB_VV: + case RISCV::VFRSUB_VF: + // Vector Single-Width Floating-Point Multiply/Divide Instructions + case RISCV::VFMUL_VF: + case RISCV::VFMUL_VV: + case RISCV::VFDIV_VF: + case RISCV::VFDIV_VV: + case RISCV::VFRDIV_VF: + // Vector Floating-Point Square-Root Instruction + case RISCV::VFSQRT_V: + // Vector Floating-Point Reciprocal Square-Root Estimate Instruction + case RISCV::VFRSQRT7_V: + // Vector Floating-Point Reciprocal Estimate Instruction + case RISCV::VFREC7_V: + // Vector Floating-Point MIN/MAX Instructions + case RISCV::VFMIN_VF: + case RISCV::VFMIN_VV: + case RISCV::VFMAX_VF: + case RISCV::VFMAX_VV: + // Vector Floating-Point Sign-Injection Instructions + case RISCV::VFSGNJ_VF: + case RISCV::VFSGNJ_VV: + case RISCV::VFSGNJN_VV: + case RISCV::VFSGNJN_VF: + case RISCV::VFSGNJX_VF: + case RISCV::VFSGNJX_VV: + // Vector Floating-Point Classify Instruction + case RISCV::VFCLASS_V: + // Vector Floating-Point Move Instruction + case RISCV::VFMV_V_F: + // Single-Width Floating-Point/Integer Type-Convert Instructions + case RISCV::VFCVT_XU_F_V: + case RISCV::VFCVT_X_F_V: + case RISCV::VFCVT_RTZ_XU_F_V: + case RISCV::VFCVT_RTZ_X_F_V: + case RISCV::VFCVT_F_XU_V: + case RISCV::VFCVT_F_X_V: + // Vector Floating-Point Merge Instruction + case RISCV::VFMERGE_VFM: + return MILog2SEW; // Vector Widening Integer Add/Subtract - // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. + // Def uses EEW=2*SEW . Operands use EEW=SEW. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: @@ -465,7 +502,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: // Vector Widening Integer Multiply Instructions - // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. + // Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: @@ -473,7 +510,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Widening Integer Multiply-Add Instructions - // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Destination EEW=2*SEW. Source EEW=SEW. // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which // is then added to the 2*SEW-bit Dest. These instructions never have a // passthru operand. @@ -483,14 +520,38 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMACC_VX: case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: - case RISCV::VWMACCUS_VX: { + case RISCV::VWMACCUS_VX: + // Vector Widening Floating-Point Fused Multiply-Add Instructions + case RISCV::VFWMACC_VF: + case RISCV::VFWMACC_VV: + case RISCV::VFWNMACC_VF: + case RISCV::VFWNMACC_VV: + case RISCV::VFWMSAC_VF: + case RISCV::VFWMSAC_VV: + case RISCV::VFWNMSAC_VF: + case RISCV::VFWNMSAC_VV: + // Vector Widening Floating-Point Add/Subtract Instructions + // Dest EEW=2*SEW. Source EEW=SEW. + case RISCV::VFWADD_VV: + case RISCV::VFWADD_VF: + case RISCV::VFWSUB_VV: + case RISCV::VFWSUB_VF: + // Vector Widening Floating-Point Multiply + case RISCV::VFWMUL_VF: + case RISCV::VFWMUL_VV: + // Widening Floating-Point/Integer Type-Convert Instructions + case RISCV::VFWCVT_XU_F_V: + case RISCV::VFWCVT_X_F_V: + case RISCV::VFWCVT_RTZ_XU_F_V: + case RISCV::VFWCVT_RTZ_X_F_V: + case RISCV::VFWCVT_F_XU_V: + case RISCV::VFWCVT_F_X_V: + case RISCV::VFWCVT_F_F_V: { unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } - // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL + // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: @@ -498,29 +559,31 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWADD_WV: case RISCV::VWADD_WX: case RISCV::VWSUB_WV: - case RISCV::VWSUB_WX: { + case RISCV::VWSUB_WX: + // Vector Widening Floating-Point Add/Subtract Instructions + case RISCV::VFWADD_WF: + case RISCV::VFWADD_WV: + case RISCV::VFWSUB_WF: + case RISCV::VFWSUB_WV: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: - return getIntegerExtensionOperandInfo(2, MI, MO); + return getIntegerExtensionOperandEEW(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: - return getIntegerExtensionOperandInfo(4, MI, MO); + return getIntegerExtensionOperandEEW(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: - return getIntegerExtensionOperandInfo(8, MI, MO); + return getIntegerExtensionOperandEEW(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions - // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has - // EEW=SEW EMUL=LMUL. + // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: @@ -528,19 +591,26 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions - // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL + // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: case RISCV::VNCLIP_WI: case RISCV::VNCLIP_WV: - case RISCV::VNCLIP_WX: { + case RISCV::VNCLIP_WX: + // Narrowing Floating-Point/Integer Type-Convert Instructions + case RISCV::VFNCVT_XU_F_W: + case RISCV::VFNCVT_X_F_W: + case RISCV::VFNCVT_RTZ_XU_F_W: + case RISCV::VFNCVT_RTZ_X_F_W: + case RISCV::VFNCVT_F_XU_W: + case RISCV::VFNCVT_F_X_W: + case RISCV::VFNCVT_F_F_W: + case RISCV::VFNCVT_ROD_F_F_W: { bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - RISCVII::VLMUL EMUL = - TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; - return OperandInfo(EMUL, Log2EEW); + return Log2EEW; } // Vector Mask Instructions @@ -548,7 +618,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit - // EEW=1 and EMUL=(EEW/SEW)*LMUL + // EEW=1 // We handle the cases when operand is a v0 mask operand above the switch, // but these instructions may use non-v0 mask operands and need to be handled // specifically. @@ -563,20 +633,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: { - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; } // Vector Iota Instruction - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is not handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled + // before this switch. case RISCV::VIOTA_M: { if (IsMODef || MO.getOperandNo() == 1) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return MILog2SEW; + return 0; } // Vector Integer Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: @@ -598,29 +668,87 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask - // source operand handled above this switch. + // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: - case RISCV::VMSBC_VX: { + case RISCV::VMSBC_VX: + // 13.13. Vector Floating-Point Compare Instructions + // Dest EEW=1. Source EEW=SEW + case RISCV::VMFEQ_VF: + case RISCV::VMFEQ_VV: + case RISCV::VMFNE_VF: + case RISCV::VMFNE_VV: + case RISCV::VMFLT_VF: + case RISCV::VMFLT_VV: + case RISCV::VMFLE_VF: + case RISCV::VMFLE_VV: + case RISCV::VMFGT_VF: + case RISCV::VMFGE_VF: { if (IsMODef) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); - return OperandInfo(MIVLMul, MILog2SEW); + return 0; + return MILog2SEW; + } + + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: { + return MILog2SEW; } default: - return {}; + return std::nullopt; } } +static std::optional<OperandInfo> +getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { + const MachineInstr &MI = *MO.getParent(); + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); + assert(RVV && "Could not find MI in PseudoTable"); + + std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); + if (!Log2EEW) + return std::nullopt; + + switch (RVV->BaseInstr) { + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + // The Dest and VS1 only read element 0 of the vector register. Return just + // the EEW for these. + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: + if (MO.getOperandNo() != 2) + return OperandInfo(*Log2EEW); + break; + }; + + // All others have EMUL=EEW/SEW*LMUL + return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), + *Log2EEW); +} + /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL. @@ -632,6 +760,32 @@ static bool isSupportedInstr(const MachineInstr &MI) { return false; switch (RVV->BaseInstr) { + // Vector Unit-Stride Instructions + // Vector Strided Instructions + case RISCV::VLM_V: + case RISCV::VLE8_V: + case RISCV::VLSE8_V: + case RISCV::VLE16_V: + case RISCV::VLSE16_V: + case RISCV::VLE32_V: + case RISCV::VLSE32_V: + case RISCV::VLE64_V: + case RISCV::VLSE64_V: + // Vector Indexed Instructions + case RISCV::VLUXEI8_V: + case RISCV::VLOXEI8_V: + case RISCV::VLUXEI16_V: + case RISCV::VLOXEI16_V: + case RISCV::VLUXEI32_V: + case RISCV::VLOXEI32_V: + case RISCV::VLUXEI64_V: + case RISCV::VLOXEI64_V: { + for (const MachineMemOperand *MMO : MI.memoperands()) + if (MMO->isVolatile()) + return false; + return true; + } + // Vector Single-Width Integer Add and Subtract case RISCV::VADD_VI: case RISCV::VADD_VV: @@ -801,6 +955,30 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VMSOF_M: case RISCV::VIOTA_M: case RISCV::VID_V: + // Single-Width Floating-Point/Integer Type-Convert Instructions + case RISCV::VFCVT_XU_F_V: + case RISCV::VFCVT_X_F_V: + case RISCV::VFCVT_RTZ_XU_F_V: + case RISCV::VFCVT_RTZ_X_F_V: + case RISCV::VFCVT_F_XU_V: + case RISCV::VFCVT_F_X_V: + // Widening Floating-Point/Integer Type-Convert Instructions + case RISCV::VFWCVT_XU_F_V: + case RISCV::VFWCVT_X_F_V: + case RISCV::VFWCVT_RTZ_XU_F_V: + case RISCV::VFWCVT_RTZ_X_F_V: + case RISCV::VFWCVT_F_XU_V: + case RISCV::VFWCVT_F_X_V: + case RISCV::VFWCVT_F_F_V: + // Narrowing Floating-Point/Integer Type-Convert Instructions + case RISCV::VFNCVT_XU_F_W: + case RISCV::VFNCVT_X_F_W: + case RISCV::VFNCVT_RTZ_XU_F_W: + case RISCV::VFNCVT_RTZ_X_F_W: + case RISCV::VFNCVT_F_XU_W: + case RISCV::VFNCVT_F_X_W: + case RISCV::VFNCVT_F_F_W: + case RISCV::VFNCVT_ROD_F_F_W: return true; } @@ -835,6 +1013,9 @@ static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) { case RISCV::VFWREDOSUM_VS: case RISCV::VFWREDUSUM_VS: return MO.getOperandNo() == 3; + case RISCV::VMV_X_S: + case RISCV::VFMV_F_S: + return MO.getOperandNo() == 1; default: return false; } @@ -904,6 +1085,11 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return false; } + if (MI.mayRaiseFPException()) { + LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); + return false; + } + // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some // instructions, such as reductions, may write lanes past VL to a scalar @@ -925,79 +1111,103 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return true; } -bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, - MachineInstr &MI) { +std::optional<MachineOperand> +RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { + const MachineInstr &UserMI = *UserOp.getParent(); + const MCInstrDesc &Desc = UserMI.getDesc(); + + if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" + " use VLMAX\n"); + return std::nullopt; + } + + // Instructions like reductions may use a vector register as a scalar + // register. In this case, we should treat it as only reading the first lane. + if (isVectorOpUsedAsScalarOp(UserOp)) { + [[maybe_unused]] Register R = UserOp.getReg(); + [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); + assert(RISCV::VRRegClass.hasSubClassEq(RC) && + "Expect LMUL 1 register class for vector as scalar operands!"); + LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); + + return MachineOperand::CreateImm(1); + } + + unsigned VLOpNum = RISCVII::getVLOpNum(Desc); + const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); + // Looking for an immediate or a register VL that isn't X0. + assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && + "Did not expect X0 VL"); + return VLOp; +} + +std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure // along lines of an instcombine style worklist which integrates the outer // pass. - bool CanReduceVL = true; + std::optional<MachineOperand> CommonVL; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); - - // Instructions like reductions may use a vector register as a scalar - // register. In this case, we should treat it like a scalar register which - // does not impact the decision on whether to optimize VL. - // TODO: Treat it like a scalar register instead of bailing out. - if (isVectorOpUsedAsScalarOp(UserOp)) { - CanReduceVL = false; - break; - } - if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); - CanReduceVL = false; - break; + return std::nullopt; } // Tied operands might pass through. if (UserOp.isTied()) { LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); - CanReduceVL = false; - break; - } - - const MCInstrDesc &Desc = UserMI.getDesc(); - if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" - " use VLMAX\n"); - CanReduceVL = false; - break; + return std::nullopt; } - unsigned VLOpNum = RISCVII::getVLOpNum(Desc); - const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); - - // Looking for an immediate or a register VL that isn't X0. - assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && - "Did not expect X0 VL"); + auto VLOp = getMinimumVLForUser(UserOp); + if (!VLOp) + return std::nullopt; // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. - if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) { - CommonVL = &VLOp; + if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { + CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); - } else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) { + } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); - CanReduceVL = false; - break; + return std::nullopt; + } + + if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); + return std::nullopt; + } + + std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); + std::optional<OperandInfo> ProducerInfo = + getOperandInfo(MI.getOperand(0), MRI); + if (!ConsumerInfo || !ProducerInfo) { + LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); + LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); + LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); + return std::nullopt; } - // The SEW and LMUL of destination and source registers need to match. - OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI); - OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); - if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || - !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { - LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " - "information for EMUL or EEW.\n"); + // If the operand is used as a scalar operand, then the EEW must be + // compatible. Otherwise, the EMUL *and* EEW must be compatible. + bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); + if ((IsVectorOpUsedAsScalarOp && + !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || + (!IsVectorOpUsedAsScalarOp && + !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { + LLVM_DEBUG( + dbgs() + << " Abort due to incompatible information for EMUL or EEW.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); - CanReduceVL = false; - break; + return std::nullopt; } } - return CanReduceVL; + + return CommonVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { @@ -1009,12 +1219,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - const MachineOperand *CommonVL = nullptr; - bool CanReduceVL = true; - if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) - CanReduceVL = checkUsers(CommonVL, MI); + if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) + continue; - if (!CanReduceVL || !CommonVL) + auto CommonVL = checkUsers(MI); + if (!CommonVL) continue; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && |