aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
commite2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch)
treeae0b02a8491b969a1cee94ea16ffe42c559143c5 /llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
parentfa04eb4af95c1ca7377279728cb004bcd2324d01 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
downloadllvm-users/chapuni/cov/single/switch.zip
llvm-users/chapuni/cov/single/switch.tar.gz
llvm-users/chapuni/cov/single/switch.tar.bz2
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp569
1 files changed, 389 insertions, 180 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 4e3212c..ad61a77 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -50,7 +50,10 @@ public:
StringRef getPassName() const override { return PASS_NAME; }
private:
- bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
+ std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
+ /// Returns the largest common VL MachineOperand that may be used to optimize
+ /// MI. Returns std::nullopt if it failed to find a suitable VL.
+ std::optional<MachineOperand> checkUsers(MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
@@ -76,11 +79,6 @@ static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
/// Represents the EMUL and EEW of a MachineOperand.
struct OperandInfo {
- enum class State {
- Unknown,
- Known,
- } S;
-
// Represent as 1,2,4,8, ... and fractional indicator. This is because
// EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
// For example, a mask operand can have an EMUL less than MF8.
@@ -89,34 +87,32 @@ struct OperandInfo {
unsigned Log2EEW;
OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
- : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {
- }
+ : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
- : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}
+ : EMUL(EMUL), Log2EEW(Log2EEW) {}
- OperandInfo() : S(State::Unknown) {}
+ OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
- bool isUnknown() const { return S == State::Unknown; }
- bool isKnown() const { return S == State::Known; }
+ OperandInfo() = delete;
static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
- assert(A.isKnown() && B.isKnown() && "Both operands must be known");
-
return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first &&
A.EMUL->second == B.EMUL->second;
}
+ static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
+ return A.Log2EEW == B.Log2EEW;
+ }
+
void print(raw_ostream &OS) const {
- if (isUnknown()) {
- OS << "Unknown";
- return;
- }
- assert(EMUL && "Expected EMUL to have value");
- OS << "EMUL: m";
- if (EMUL->second)
- OS << "f";
- OS << EMUL->first;
+ if (EMUL) {
+ OS << "EMUL: m";
+ if (EMUL->second)
+ OS << "f";
+ OS << EMUL->first;
+ } else
+ OS << "EMUL: unknown\n";
OS << ", EEW: " << (1 << Log2EEW);
}
};
@@ -127,30 +123,18 @@ static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
return OS;
}
-namespace llvm {
-namespace RISCVVType {
-/// Return the RISCVII::VLMUL that is two times VLMul.
-/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8.
-static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) {
- switch (VLMul) {
- case RISCVII::VLMUL::LMUL_F8:
- return RISCVII::VLMUL::LMUL_F4;
- case RISCVII::VLMUL::LMUL_F4:
- return RISCVII::VLMUL::LMUL_F2;
- case RISCVII::VLMUL::LMUL_F2:
- return RISCVII::VLMUL::LMUL_1;
- case RISCVII::VLMUL::LMUL_1:
- return RISCVII::VLMUL::LMUL_2;
- case RISCVII::VLMUL::LMUL_2:
- return RISCVII::VLMUL::LMUL_4;
- case RISCVII::VLMUL::LMUL_4:
- return RISCVII::VLMUL::LMUL_8;
- case RISCVII::VLMUL::LMUL_8:
- default:
- llvm_unreachable("Could not multiply VLMul by 2");
- }
+LLVM_ATTRIBUTE_UNUSED
+static raw_ostream &operator<<(raw_ostream &OS,
+ const std::optional<OperandInfo> &OI) {
+ if (OI)
+ OI->print(OS);
+ else
+ OS << "nullopt";
+ return OS;
}
+namespace llvm {
+namespace RISCVVType {
/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
/// SEW are from the TSFlags of MI.
static std::pair<unsigned, bool>
@@ -180,24 +164,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
} // end namespace RISCVVType
} // end namespace llvm
-/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
-/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
-static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
- const MachineInstr &MI,
- const MachineOperand &MO) {
- RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
+/// SEW comes from TSFlags of MI.
+static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
+ const MachineInstr &MI,
+ const MachineOperand &MO) {
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
if (MO.getOperandNo() == 0)
- return OperandInfo(MIVLMul, MILog2SEW);
+ return MILog2SEW;
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = MISEW / Factor;
unsigned Log2EEW = Log2_32(EEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
- Log2EEW);
+ return Log2EEW;
}
/// Check whether MO is a mask operand of MI.
@@ -211,18 +193,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
}
-/// Return the OperandInfo for MO.
-static OperandInfo getOperandInfo(const MachineOperand &MO,
- const MachineRegisterInfo *MRI) {
+static std::optional<unsigned>
+getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
const MachineInstr &MI = *MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
- // MI has a VLMUL and SEW associated with it. The RVV specification defines
- // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and
- // MI.SEW.
- RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+ // MI has a SEW associated with it. The RVV specification defines
+ // the EEW of each operand and definition in relation to MI.SEW.
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
@@ -233,13 +212,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
// since they must preserve the entire register content.
if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() &&
(MO.getReg() != RISCV::NoRegister))
- return {};
+ return std::nullopt;
bool IsMODef = MO.getOperandNo() == 0;
- // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL
+ // All mask operands have EEW=1
if (isMaskOperand(MI, MO, MRI))
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+ return 0;
// switch against BaseInstr to reduce number of cases that need to be
// considered.
@@ -256,55 +235,65 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
// Vector Loads and Stores
// Vector Unit-Stride Instructions
// Vector Strided Instructions
- /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL
+ /// Dest EEW encoded in the instruction
+ case RISCV::VLM_V:
+ case RISCV::VSM_V:
+ return 0;
+ case RISCV::VLE8_V:
case RISCV::VSE8_V:
+ case RISCV::VLSE8_V:
case RISCV::VSSE8_V:
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3);
+ return 3;
+ case RISCV::VLE16_V:
case RISCV::VSE16_V:
+ case RISCV::VLSE16_V:
case RISCV::VSSE16_V:
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4);
+ return 4;
+ case RISCV::VLE32_V:
case RISCV::VSE32_V:
+ case RISCV::VLSE32_V:
case RISCV::VSSE32_V:
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5);
+ return 5;
+ case RISCV::VLE64_V:
case RISCV::VSE64_V:
+ case RISCV::VLSE64_V:
case RISCV::VSSE64_V:
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6);
+ return 6;
// Vector Indexed Instructions
// vs(o|u)xei<eew>.v
- // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW=<eew> and
- // EMUL=(EEW/SEW)*LMUL.
+ // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VSUXEI8_V:
case RISCV::VSOXEI8_V: {
if (MO.getOperandNo() == 0)
- return OperandInfo(MIVLMul, MILog2SEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3);
+ return MILog2SEW;
+ return 3;
}
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VSUXEI16_V:
case RISCV::VSOXEI16_V: {
if (MO.getOperandNo() == 0)
- return OperandInfo(MIVLMul, MILog2SEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4);
+ return MILog2SEW;
+ return 4;
}
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VSUXEI32_V:
case RISCV::VSOXEI32_V: {
if (MO.getOperandNo() == 0)
- return OperandInfo(MIVLMul, MILog2SEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5);
+ return MILog2SEW;
+ return 5;
}
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V:
case RISCV::VSUXEI64_V:
case RISCV::VSOXEI64_V: {
if (MO.getOperandNo() == 0)
- return OperandInfo(MIVLMul, MILog2SEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6);
+ return MILog2SEW;
+ return 6;
}
// Vector Integer Arithmetic Instructions
@@ -318,7 +307,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
@@ -338,7 +327,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Integer Min/Max Instructions
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
@@ -348,7 +337,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
- // Source and Dest EEW=SEW and EMUL=LMUL.
+ // Source and Dest EEW=SEW.
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
@@ -358,7 +347,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
@@ -368,7 +357,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Single-Width Integer Multiply-Add Instructions
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
@@ -379,8 +368,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
- // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
- // (EEW/SEW)*LMUL. Mask operand is handled before this switch.
+ // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
+ // before this switch.
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
@@ -393,7 +382,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
// Vector Fixed-Point Arithmetic Instructions
// Vector Single-Width Saturating Add and Subtract
// Vector Single-Width Averaging Add and Subtract
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VMV_V_I:
case RISCV::VMV_V_V:
case RISCV::VMV_V_X:
@@ -415,8 +404,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
+ // Vector Single-Width Fractional Multiply with Rounding and Saturation
+ // EEW=SEW. The instruction produces 2*SEW product internally but
+ // saturates to fit into SEW bits.
+ case RISCV::VSMUL_VV:
+ case RISCV::VSMUL_VX:
// Vector Single-Width Scaling Shift Instructions
- // EEW=SEW. EMUL=LMUL.
+ // EEW=SEW.
case RISCV::VSSRL_VI:
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
@@ -426,13 +420,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
// Vector Permutation Instructions
// Integer Scalar Move Instructions
// Floating-Point Scalar Move Instructions
- // EMUL=LMUL. EEW=SEW.
+ // EEW=SEW.
case RISCV::VMV_X_S:
case RISCV::VMV_S_X:
case RISCV::VFMV_F_S:
case RISCV::VFMV_S_F:
// Vector Slide Instructions
- // EMUL=LMUL. EEW=SEW.
+ // EEW=SEW.
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEDOWN_VI:
@@ -442,19 +436,62 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VSLIDE1DOWN_VX:
case RISCV::VFSLIDE1DOWN_VF:
// Vector Register Gather Instructions
- // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1.
+ // EEW=SEW. For mask operand, EEW=1.
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
// Vector Compress Instruction
- // EMUL=LMUL. EEW=SEW.
+ // EEW=SEW.
case RISCV::VCOMPRESS_VM:
// Vector Element Index Instruction
case RISCV::VID_V:
- return OperandInfo(MIVLMul, MILog2SEW);
+ // Vector Single-Width Floating-Point Add/Subtract Instructions
+ case RISCV::VFADD_VF:
+ case RISCV::VFADD_VV:
+ case RISCV::VFSUB_VF:
+ case RISCV::VFSUB_VV:
+ case RISCV::VFRSUB_VF:
+ // Vector Single-Width Floating-Point Multiply/Divide Instructions
+ case RISCV::VFMUL_VF:
+ case RISCV::VFMUL_VV:
+ case RISCV::VFDIV_VF:
+ case RISCV::VFDIV_VV:
+ case RISCV::VFRDIV_VF:
+ // Vector Floating-Point Square-Root Instruction
+ case RISCV::VFSQRT_V:
+ // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
+ case RISCV::VFRSQRT7_V:
+ // Vector Floating-Point Reciprocal Estimate Instruction
+ case RISCV::VFREC7_V:
+ // Vector Floating-Point MIN/MAX Instructions
+ case RISCV::VFMIN_VF:
+ case RISCV::VFMIN_VV:
+ case RISCV::VFMAX_VF:
+ case RISCV::VFMAX_VV:
+ // Vector Floating-Point Sign-Injection Instructions
+ case RISCV::VFSGNJ_VF:
+ case RISCV::VFSGNJ_VV:
+ case RISCV::VFSGNJN_VV:
+ case RISCV::VFSGNJN_VF:
+ case RISCV::VFSGNJX_VF:
+ case RISCV::VFSGNJX_VV:
+ // Vector Floating-Point Classify Instruction
+ case RISCV::VFCLASS_V:
+ // Vector Floating-Point Move Instruction
+ case RISCV::VFMV_V_F:
+ // Single-Width Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFCVT_XU_F_V:
+ case RISCV::VFCVT_X_F_V:
+ case RISCV::VFCVT_RTZ_XU_F_V:
+ case RISCV::VFCVT_RTZ_X_F_V:
+ case RISCV::VFCVT_F_XU_V:
+ case RISCV::VFCVT_F_X_V:
+ // Vector Floating-Point Merge Instruction
+ case RISCV::VFMERGE_VFM:
+ return MILog2SEW;
// Vector Widening Integer Add/Subtract
- // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL.
+ // Def uses EEW=2*SEW . Operands use EEW=SEW.
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
@@ -465,7 +502,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VWSUB_VX:
case RISCV::VWSLL_VI:
// Vector Widening Integer Multiply Instructions
- // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW.
+ // Destination EEW=2*SEW. Source EEW=SEW.
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
@@ -473,7 +510,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Widening Integer Multiply-Add Instructions
- // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL.
+ // Destination EEW=2*SEW. Source EEW=SEW.
// A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
// is then added to the 2*SEW-bit Dest. These instructions never have a
// passthru operand.
@@ -483,14 +520,38 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
- case RISCV::VWMACCUS_VX: {
+ case RISCV::VWMACCUS_VX:
+ // Vector Widening Floating-Point Fused Multiply-Add Instructions
+ case RISCV::VFWMACC_VF:
+ case RISCV::VFWMACC_VV:
+ case RISCV::VFWNMACC_VF:
+ case RISCV::VFWNMACC_VV:
+ case RISCV::VFWMSAC_VF:
+ case RISCV::VFWMSAC_VV:
+ case RISCV::VFWNMSAC_VF:
+ case RISCV::VFWNMSAC_VV:
+ // Vector Widening Floating-Point Add/Subtract Instructions
+ // Dest EEW=2*SEW. Source EEW=SEW.
+ case RISCV::VFWADD_VV:
+ case RISCV::VFWADD_VF:
+ case RISCV::VFWSUB_VV:
+ case RISCV::VFWSUB_VF:
+ // Vector Widening Floating-Point Multiply
+ case RISCV::VFWMUL_VF:
+ case RISCV::VFWMUL_VV:
+ // Widening Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFWCVT_XU_F_V:
+ case RISCV::VFWCVT_X_F_V:
+ case RISCV::VFWCVT_RTZ_XU_F_V:
+ case RISCV::VFWCVT_RTZ_X_F_V:
+ case RISCV::VFWCVT_F_XU_V:
+ case RISCV::VFWCVT_F_X_V:
+ case RISCV::VFWCVT_F_F_V: {
unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW;
- RISCVII::VLMUL EMUL =
- IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
- return OperandInfo(EMUL, Log2EEW);
+ return Log2EEW;
}
- // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL
+ // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
@@ -498,29 +559,31 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
- case RISCV::VWSUB_WX: {
+ case RISCV::VWSUB_WX:
+ // Vector Widening Floating-Point Add/Subtract Instructions
+ case RISCV::VFWADD_WF:
+ case RISCV::VFWADD_WV:
+ case RISCV::VFWSUB_WF:
+ case RISCV::VFWSUB_WV: {
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsMODef || IsOp1;
unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
- RISCVII::VLMUL EMUL =
- TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
- return OperandInfo(EMUL, Log2EEW);
+ return Log2EEW;
}
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
- return getIntegerExtensionOperandInfo(2, MI, MO);
+ return getIntegerExtensionOperandEEW(2, MI, MO);
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
- return getIntegerExtensionOperandInfo(4, MI, MO);
+ return getIntegerExtensionOperandEEW(4, MI, MO);
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
- return getIntegerExtensionOperandInfo(8, MI, MO);
+ return getIntegerExtensionOperandEEW(8, MI, MO);
// Vector Narrowing Integer Right Shift Instructions
- // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has
- // EEW=SEW EMUL=LMUL.
+ // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
@@ -528,19 +591,26 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Narrowing Fixed-Point Clip Instructions
- // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL
+ // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIP_WI:
case RISCV::VNCLIP_WV:
- case RISCV::VNCLIP_WX: {
+ case RISCV::VNCLIP_WX:
+ // Narrowing Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFNCVT_XU_F_W:
+ case RISCV::VFNCVT_X_F_W:
+ case RISCV::VFNCVT_RTZ_XU_F_W:
+ case RISCV::VFNCVT_RTZ_X_F_W:
+ case RISCV::VFNCVT_F_XU_W:
+ case RISCV::VFNCVT_F_X_W:
+ case RISCV::VFNCVT_F_F_W:
+ case RISCV::VFNCVT_ROD_F_F_W: {
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsOp1;
unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
- RISCVII::VLMUL EMUL =
- TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
- return OperandInfo(EMUL, Log2EEW);
+ return Log2EEW;
}
// Vector Mask Instructions
@@ -548,7 +618,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
// vmsbf.m set-before-first mask bit
// vmsif.m set-including-first mask bit
// vmsof.m set-only-first mask bit
- // EEW=1 and EMUL=(EEW/SEW)*LMUL
+ // EEW=1
// We handle the cases when operand is a v0 mask operand above the switch,
// but these instructions may use non-v0 mask operands and need to be handled
// specifically.
@@ -563,20 +633,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VMSBF_M:
case RISCV::VMSIF_M:
case RISCV::VMSOF_M: {
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+ return 0;
}
// Vector Iota Instruction
- // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
- // (EEW/SEW)*LMUL. Mask operand is not handled before this switch.
+ // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
+ // before this switch.
case RISCV::VIOTA_M: {
if (IsMODef || MO.getOperandNo() == 1)
- return OperandInfo(MIVLMul, MILog2SEW);
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+ return MILog2SEW;
+ return 0;
}
// Vector Integer Compare Instructions
- // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL.
+ // Dest EEW=1. Source EEW=SEW.
case RISCV::VMSEQ_VI:
case RISCV::VMSEQ_VV:
case RISCV::VMSEQ_VX:
@@ -598,29 +668,87 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
case RISCV::VMSGT_VI:
case RISCV::VMSGT_VX:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
- // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask
- // source operand handled above this switch.
+ // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
case RISCV::VMADC_VIM:
case RISCV::VMADC_VVM:
case RISCV::VMADC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
- // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL.
+ // Dest EEW=1. Source EEW=SEW.
case RISCV::VMADC_VV:
case RISCV::VMADC_VI:
case RISCV::VMADC_VX:
case RISCV::VMSBC_VV:
- case RISCV::VMSBC_VX: {
+ case RISCV::VMSBC_VX:
+ // 13.13. Vector Floating-Point Compare Instructions
+ // Dest EEW=1. Source EEW=SEW
+ case RISCV::VMFEQ_VF:
+ case RISCV::VMFEQ_VV:
+ case RISCV::VMFNE_VF:
+ case RISCV::VMFNE_VV:
+ case RISCV::VMFLT_VF:
+ case RISCV::VMFLT_VV:
+ case RISCV::VMFLE_VF:
+ case RISCV::VMFLE_VV:
+ case RISCV::VMFGT_VF:
+ case RISCV::VMFGE_VF: {
if (IsMODef)
- return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
- return OperandInfo(MIVLMul, MILog2SEW);
+ return 0;
+ return MILog2SEW;
+ }
+
+ // Vector Reduction Operations
+ // Vector Single-Width Integer Reduction Instructions
+ case RISCV::VREDAND_VS:
+ case RISCV::VREDMAX_VS:
+ case RISCV::VREDMAXU_VS:
+ case RISCV::VREDMIN_VS:
+ case RISCV::VREDMINU_VS:
+ case RISCV::VREDOR_VS:
+ case RISCV::VREDSUM_VS:
+ case RISCV::VREDXOR_VS: {
+ return MILog2SEW;
}
default:
- return {};
+ return std::nullopt;
}
}
+static std::optional<OperandInfo>
+getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
+ const MachineInstr &MI = *MO.getParent();
+ const RISCVVPseudosTable::PseudoInfo *RVV =
+ RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
+ assert(RVV && "Could not find MI in PseudoTable");
+
+ std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
+ if (!Log2EEW)
+ return std::nullopt;
+
+ switch (RVV->BaseInstr) {
+ // Vector Reduction Operations
+ // Vector Single-Width Integer Reduction Instructions
+ // The Dest and VS1 only read element 0 of the vector register. Return just
+ // the EEW for these.
+ case RISCV::VREDAND_VS:
+ case RISCV::VREDMAX_VS:
+ case RISCV::VREDMAXU_VS:
+ case RISCV::VREDMIN_VS:
+ case RISCV::VREDMINU_VS:
+ case RISCV::VREDOR_VS:
+ case RISCV::VREDSUM_VS:
+ case RISCV::VREDXOR_VS:
+ if (MO.getOperandNo() != 2)
+ return OperandInfo(*Log2EEW);
+ break;
+ };
+
+ // All others have EMUL=EEW/SEW*LMUL
+ return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI),
+ *Log2EEW);
+}
+
/// Return true if this optimization should consider MI for VL reduction. This
/// white-list approach simplifies this optimization for instructions that may
/// have more complex semantics with relation to how it uses VL.
@@ -632,6 +760,32 @@ static bool isSupportedInstr(const MachineInstr &MI) {
return false;
switch (RVV->BaseInstr) {
+ // Vector Unit-Stride Instructions
+ // Vector Strided Instructions
+ case RISCV::VLM_V:
+ case RISCV::VLE8_V:
+ case RISCV::VLSE8_V:
+ case RISCV::VLE16_V:
+ case RISCV::VLSE16_V:
+ case RISCV::VLE32_V:
+ case RISCV::VLSE32_V:
+ case RISCV::VLE64_V:
+ case RISCV::VLSE64_V:
+ // Vector Indexed Instructions
+ case RISCV::VLUXEI8_V:
+ case RISCV::VLOXEI8_V:
+ case RISCV::VLUXEI16_V:
+ case RISCV::VLOXEI16_V:
+ case RISCV::VLUXEI32_V:
+ case RISCV::VLOXEI32_V:
+ case RISCV::VLUXEI64_V:
+ case RISCV::VLOXEI64_V: {
+ for (const MachineMemOperand *MMO : MI.memoperands())
+ if (MMO->isVolatile())
+ return false;
+ return true;
+ }
+
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
@@ -801,6 +955,30 @@ static bool isSupportedInstr(const MachineInstr &MI) {
case RISCV::VMSOF_M:
case RISCV::VIOTA_M:
case RISCV::VID_V:
+ // Single-Width Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFCVT_XU_F_V:
+ case RISCV::VFCVT_X_F_V:
+ case RISCV::VFCVT_RTZ_XU_F_V:
+ case RISCV::VFCVT_RTZ_X_F_V:
+ case RISCV::VFCVT_F_XU_V:
+ case RISCV::VFCVT_F_X_V:
+ // Widening Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFWCVT_XU_F_V:
+ case RISCV::VFWCVT_X_F_V:
+ case RISCV::VFWCVT_RTZ_XU_F_V:
+ case RISCV::VFWCVT_RTZ_X_F_V:
+ case RISCV::VFWCVT_F_XU_V:
+ case RISCV::VFWCVT_F_X_V:
+ case RISCV::VFWCVT_F_F_V:
+ // Narrowing Floating-Point/Integer Type-Convert Instructions
+ case RISCV::VFNCVT_XU_F_W:
+ case RISCV::VFNCVT_X_F_W:
+ case RISCV::VFNCVT_RTZ_XU_F_W:
+ case RISCV::VFNCVT_RTZ_X_F_W:
+ case RISCV::VFNCVT_F_XU_W:
+ case RISCV::VFNCVT_F_X_W:
+ case RISCV::VFNCVT_F_F_W:
+ case RISCV::VFNCVT_ROD_F_F_W:
return true;
}
@@ -835,6 +1013,9 @@ static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) {
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
return MO.getOperandNo() == 3;
+ case RISCV::VMV_X_S:
+ case RISCV::VFMV_F_S:
+ return MO.getOperandNo() == 1;
default:
return false;
}
@@ -904,6 +1085,11 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return false;
}
+ if (MI.mayRaiseFPException()) {
+ LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
+ return false;
+ }
+
// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
// instructions, such as reductions, may write lanes past VL to a scalar
@@ -925,79 +1111,103 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return true;
}
-bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
- MachineInstr &MI) {
+std::optional<MachineOperand>
+RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
+ const MachineInstr &UserMI = *UserOp.getParent();
+ const MCInstrDesc &Desc = UserMI.getDesc();
+
+ if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
+ LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
+ " use VLMAX\n");
+ return std::nullopt;
+ }
+
+ // Instructions like reductions may use a vector register as a scalar
+ // register. In this case, we should treat it as only reading the first lane.
+ if (isVectorOpUsedAsScalarOp(UserOp)) {
+ [[maybe_unused]] Register R = UserOp.getReg();
+ [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R);
+ assert(RISCV::VRRegClass.hasSubClassEq(RC) &&
+ "Expect LMUL 1 register class for vector as scalar operands!");
+ LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
+
+ return MachineOperand::CreateImm(1);
+ }
+
+ unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
+ const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
+ // Looking for an immediate or a register VL that isn't X0.
+ assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
+ "Did not expect X0 VL");
+ return VLOp;
+}
+
+std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
// along lines of an instcombine style worklist which integrates the outer
// pass.
- bool CanReduceVL = true;
+ std::optional<MachineOperand> CommonVL;
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
const MachineInstr &UserMI = *UserOp.getParent();
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
-
- // Instructions like reductions may use a vector register as a scalar
- // register. In this case, we should treat it like a scalar register which
- // does not impact the decision on whether to optimize VL.
- // TODO: Treat it like a scalar register instead of bailing out.
- if (isVectorOpUsedAsScalarOp(UserOp)) {
- CanReduceVL = false;
- break;
- }
-
if (mayReadPastVL(UserMI)) {
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
- CanReduceVL = false;
- break;
+ return std::nullopt;
}
// Tied operands might pass through.
if (UserOp.isTied()) {
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
- CanReduceVL = false;
- break;
- }
-
- const MCInstrDesc &Desc = UserMI.getDesc();
- if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
- LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that"
- " use VLMAX\n");
- CanReduceVL = false;
- break;
+ return std::nullopt;
}
- unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
- const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
-
- // Looking for an immediate or a register VL that isn't X0.
- assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
- "Did not expect X0 VL");
+ auto VLOp = getMinimumVLForUser(UserOp);
+ if (!VLOp)
+ return std::nullopt;
// Use the largest VL among all the users. If we cannot determine this
// statically, then we cannot optimize the VL.
- if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) {
- CommonVL = &VLOp;
+ if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
+ CommonVL = *VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
- } else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) {
+ } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
- CanReduceVL = false;
- break;
+ return std::nullopt;
+ }
+
+ if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
+ LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
+ return std::nullopt;
+ }
+
+ std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
+ std::optional<OperandInfo> ProducerInfo =
+ getOperandInfo(MI.getOperand(0), MRI);
+ if (!ConsumerInfo || !ProducerInfo) {
+ LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
+ LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
+ LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
+ return std::nullopt;
}
- // The SEW and LMUL of destination and source registers need to match.
- OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI);
- OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI);
- if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() ||
- !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) {
- LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown "
- "information for EMUL or EEW.\n");
+ // If the operand is used as a scalar operand, then the EEW must be
+ // compatible. Otherwise, the EMUL *and* EEW must be compatible.
+ bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp);
+ if ((IsVectorOpUsedAsScalarOp &&
+ !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) ||
+ (!IsVectorOpUsedAsScalarOp &&
+ !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) {
+ LLVM_DEBUG(
+ dbgs()
+ << " Abort due to incompatible information for EMUL or EEW.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
- CanReduceVL = false;
- break;
+ return std::nullopt;
}
}
- return CanReduceVL;
+
+ return CommonVL;
}
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
@@ -1009,12 +1219,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
- const MachineOperand *CommonVL = nullptr;
- bool CanReduceVL = true;
- if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
- CanReduceVL = checkUsers(CommonVL, MI);
+ if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI))
+ continue;
- if (!CanReduceVL || !CommonVL)
+ auto CommonVL = checkUsers(MI);
+ if (!CommonVL)
continue;
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&