diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVFrameLowering.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVFrameLowering.cpp | 53 |
1 files changed, 43 insertions, 10 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp index 23b4554..b1ab76a 100644 --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -1544,10 +1544,53 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI, return Offset; } +static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI, + const Register &Reg) { + MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0); + // If it's not a grouped vector register, it doesn't have subregister, so + // the base register is just itself. + if (BaseReg == RISCV::NoRegister) + BaseReg = Reg; + return BaseReg; +} + void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs, RegScavenger *RS) const { TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS); + + // In TargetFrameLowering::determineCalleeSaves, any vector register is marked + // as saved if any of its subregister is clobbered, this is not correct in + // vector registers. We only want the vector register to be marked as saved + // if all of its subregisters are clobbered. + // For example: + // Original behavior: If v24 is marked, v24m2, v24m4, v24m8 are also marked. + // Correct behavior: v24m2 is marked only if v24 and v25 are marked. + const MachineRegisterInfo &MRI = MF.getRegInfo(); + const MCPhysReg *CSRegs = MRI.getCalleeSavedRegs(); + const RISCVRegisterInfo &TRI = *STI.getRegisterInfo(); + for (unsigned i = 0; CSRegs[i]; ++i) { + unsigned CSReg = CSRegs[i]; + // Only vector registers need special care. + if (!RISCV::VRRegClass.contains(getRVVBaseRegister(TRI, CSReg))) + continue; + + SavedRegs.reset(CSReg); + + auto SubRegs = TRI.subregs(CSReg); + // Set the register and all its subregisters. + if (!MRI.def_empty(CSReg) || MRI.getUsedPhysRegsMask().test(CSReg)) { + SavedRegs.set(CSReg); + llvm::for_each(SubRegs, [&](unsigned Reg) { return SavedRegs.set(Reg); }); + } + + // Combine to super register if all of its subregisters are marked. + if (!SubRegs.empty() && llvm::all_of(SubRegs, [&](unsigned Reg) { + return SavedRegs.test(Reg); + })) + SavedRegs.set(CSReg); + } + // Unconditionally spill RA and FP only if the function uses a frame // pointer. if (hasFP(MF)) { @@ -2137,16 +2180,6 @@ static unsigned getCalleeSavedRVVNumRegs(const Register &BaseReg) { : 8; } -static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI, - const Register &Reg) { - MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0); - // If it's not a grouped vector register, it doesn't have subregister, so - // the base register is just itself. - if (BaseReg == RISCV::NoRegister) - BaseReg = Reg; - return BaseReg; -} - void RISCVFrameLowering::emitCalleeSavedRVVPrologCFI( MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, bool HasFP) const { MachineFunction *MF = MBB.getParent(); |