aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/AArch64/AArch64FrameLowering.cpp246
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp12
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.h3
-rw-r--r--llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp8
-rw-r--r--llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h11
-rw-r--r--llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td15
6 files changed, 282 insertions, 13 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index cf617c7..a991813 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -321,7 +321,7 @@ bool AArch64FrameLowering::homogeneousPrologEpilog(
return false;
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
- if (AFI->hasSwiftAsyncContext())
+ if (AFI->hasSwiftAsyncContext() || AFI->hasStreamingModeChanges())
return false;
// If there are an odd number of GPRs before LR and FP in the CSRs list,
@@ -558,6 +558,10 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
+ AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs(MF.getFunction());
+ bool LocallyStreaming =
+ Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();
const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo();
if (CSI.empty())
@@ -569,14 +573,22 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
DebugLoc DL = MBB.findDebugLoc(MBBI);
for (const auto &Info : CSI) {
- if (MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector)
+ unsigned FrameIdx = Info.getFrameIdx();
+ if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector)
continue;
assert(!Info.isSpilledToReg() && "Spilling to registers not implemented");
- unsigned DwarfReg = TRI.getDwarfRegNum(Info.getReg(), true);
+ int64_t DwarfReg = TRI.getDwarfRegNum(Info.getReg(), true);
+ int64_t Offset = MFI.getObjectOffset(FrameIdx) - getOffsetOfLocalArea();
+
+ // The location of VG will be emitted before each streaming-mode change in
+ // the function. Only locally-streaming functions require emitting the
+ // non-streaming VG location here.
+ if ((LocallyStreaming && FrameIdx == AFI->getStreamingVGIdx()) ||
+ (!LocallyStreaming &&
+ DwarfReg == TRI.getDwarfRegNum(AArch64::VG, true)))
+ continue;
- int64_t Offset =
- MFI.getObjectOffset(Info.getFrameIdx()) - getOffsetOfLocalArea();
unsigned CFIIndex = MF.addFrameInst(
MCCFIInstruction::createOffset(nullptr, DwarfReg, Offset));
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
@@ -699,6 +711,9 @@ static void emitCalleeSavedRestores(MachineBasicBlock &MBB,
!static_cast<const AArch64RegisterInfo &>(TRI).regNeedsCFI(Reg, Reg))
continue;
+ if (!Info.isRestored())
+ continue;
+
unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRestore(
nullptr, TRI.getDwarfRegNum(Info.getReg(), true)));
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
@@ -1342,6 +1357,32 @@ static void fixupSEHOpcode(MachineBasicBlock::iterator MBBI,
ImmOpnd->setImm(ImmOpnd->getImm() + LocalStackSize);
}
+bool requiresGetVGCall(MachineFunction &MF) {
+ AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+ return AFI->hasStreamingModeChanges() &&
+ !MF.getSubtarget<AArch64Subtarget>().hasSVE();
+}
+
+bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
+ unsigned Opc = MBBI->getOpcode();
+ if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI ||
+ Opc == AArch64::UBFMXri)
+ return true;
+
+ if (requiresGetVGCall(*MBBI->getMF())) {
+ if (Opc == AArch64::ORRXrr)
+ return true;
+
+ if (Opc == AArch64::BL) {
+ auto Op1 = MBBI->getOperand(0);
+ return Op1.isSymbol() &&
+ (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
+ }
+ }
+
+ return false;
+}
+
// Convert callee-save register save/restore instruction to do stack pointer
// decrement/increment to allocate/deallocate the callee-save stack area by
// converting store/load to use pre/post increment version.
@@ -1352,6 +1393,17 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
MachineInstr::MIFlag FrameFlag = MachineInstr::FrameSetup,
int CFAOffset = 0) {
unsigned NewOpc;
+
+ // If the function contains streaming mode changes, we expect instructions
+ // to calculate the value of VG before spilling. For locally-streaming
+ // functions, we need to do this for both the streaming and non-streaming
+ // vector length. Move past these instructions if necessary.
+ MachineFunction &MF = *MBB.getParent();
+ AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+ if (AFI->hasStreamingModeChanges())
+ while (isVGInstruction(MBBI))
+ ++MBBI;
+
switch (MBBI->getOpcode()) {
default:
llvm_unreachable("Unexpected callee-save save/restore opcode!");
@@ -1408,7 +1460,6 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
// If the first store isn't right where we want SP then we can't fold the
// update in so create a normal arithmetic instruction instead.
- MachineFunction &MF = *MBB.getParent();
if (MBBI->getOperand(MBBI->getNumOperands() - 1).getImm() != 0 ||
CSStackSizeInc < MinOffset || CSStackSizeInc > MaxOffset) {
emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP,
@@ -1660,6 +1711,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
LiveRegs.removeReg(AArch64::X19);
LiveRegs.removeReg(AArch64::FP);
LiveRegs.removeReg(AArch64::LR);
+
+ // X0 will be clobbered by a call to __arm_get_current_vg in the prologue.
+ // This is necessary to spill VG if required where SVE is unavailable, but
+ // X0 is preserved around this call.
+ if (requiresGetVGCall(MF))
+ LiveRegs.removeReg(AArch64::X0);
}
auto VerifyClobberOnExit = make_scope_exit([&]() {
@@ -1846,6 +1903,11 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
// pointer bump above.
while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) &&
!IsSVECalleeSave(MBBI)) {
+ // Move past instructions generated to calculate VG
+ if (AFI->hasStreamingModeChanges())
+ while (isVGInstruction(MBBI))
+ ++MBBI;
+
if (CombineSPBump)
fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(),
NeedsWinCFI, &HasWinCFI);
@@ -2768,7 +2830,7 @@ struct RegPairInfo {
unsigned Reg2 = AArch64::NoRegister;
int FrameIdx;
int Offset;
- enum RegType { GPR, FPR64, FPR128, PPR, ZPR } Type;
+ enum RegType { GPR, FPR64, FPR128, PPR, ZPR, VG } Type;
RegPairInfo() = default;
@@ -2780,6 +2842,7 @@ struct RegPairInfo {
return 2;
case GPR:
case FPR64:
+ case VG:
return 8;
case ZPR:
case FPR128:
@@ -2855,6 +2918,8 @@ static void computeCalleeSaveRegisterPairs(
RPI.Type = RegPairInfo::ZPR;
else if (AArch64::PPRRegClass.contains(RPI.Reg1))
RPI.Type = RegPairInfo::PPR;
+ else if (RPI.Reg1 == AArch64::VG)
+ RPI.Type = RegPairInfo::VG;
else
llvm_unreachable("Unsupported register class.");
@@ -2887,6 +2952,8 @@ static void computeCalleeSaveRegisterPairs(
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
RPI.Reg2 = NextReg;
break;
+ case RegPairInfo::VG:
+ break;
}
}
@@ -3003,6 +3070,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
MachineFunction &MF = *MBB.getParent();
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
+ AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
bool NeedsWinCFI = needsWinCFI(MF);
DebugLoc DL;
SmallVector<RegPairInfo, 8> RegPairs;
@@ -3070,7 +3138,70 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
Size = 2;
Alignment = Align(2);
break;
+ case RegPairInfo::VG:
+ StrOpc = AArch64::STRXui;
+ Size = 8;
+ Alignment = Align(8);
+ break;
}
+
+ unsigned X0Scratch = AArch64::NoRegister;
+ if (Reg1 == AArch64::VG) {
+ // Find an available register to store value of VG to.
+ Reg1 = findScratchNonCalleeSaveRegister(&MBB);
+ assert(Reg1 != AArch64::NoRegister);
+ SMEAttrs Attrs(MF.getFunction());
+
+ if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
+ AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
+ // For locally-streaming functions, we need to store both the streaming
+ // & non-streaming VG. Spill the streaming value first.
+ BuildMI(MBB, MI, DL, TII.get(AArch64::RDSVLI_XI), Reg1)
+ .addImm(1)
+ .setMIFlag(MachineInstr::FrameSetup);
+ BuildMI(MBB, MI, DL, TII.get(AArch64::UBFMXri), Reg1)
+ .addReg(Reg1)
+ .addImm(3)
+ .addImm(63)
+ .setMIFlag(MachineInstr::FrameSetup);
+
+ AFI->setStreamingVGIdx(RPI.FrameIdx);
+ } else if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) {
+ BuildMI(MBB, MI, DL, TII.get(AArch64::CNTD_XPiI), Reg1)
+ .addImm(31)
+ .addImm(1)
+ .setMIFlag(MachineInstr::FrameSetup);
+ AFI->setVGIdx(RPI.FrameIdx);
+ } else {
+ const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+ if (llvm::any_of(
+ MBB.liveins(),
+ [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) {
+ return STI.getRegisterInfo()->isSuperOrSubRegisterEq(
+ AArch64::X0, LiveIn.PhysReg);
+ }))
+ X0Scratch = Reg1;
+
+ if (X0Scratch != AArch64::NoRegister)
+ BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), Reg1)
+ .addReg(AArch64::XZR)
+ .addReg(AArch64::X0, RegState::Undef)
+ .addReg(AArch64::X0, RegState::Implicit)
+ .setMIFlag(MachineInstr::FrameSetup);
+
+ const uint32_t *RegMask = TRI->getCallPreservedMask(
+ MF,
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
+ BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
+ .addExternalSymbol("__arm_get_current_vg")
+ .addRegMask(RegMask)
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
+ .setMIFlag(MachineInstr::FrameSetup);
+ Reg1 = AArch64::X0;
+ AFI->setVGIdx(RPI.FrameIdx);
+ }
+ }
+
LLVM_DEBUG(dbgs() << "CSR spill: (" << printReg(Reg1, TRI);
if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI);
dbgs() << ") -> fi#(" << RPI.FrameIdx;
@@ -3162,6 +3293,13 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}
+
+ if (X0Scratch != AArch64::NoRegister)
+ BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), AArch64::X0)
+ .addReg(AArch64::XZR)
+ .addReg(X0Scratch, RegState::Undef)
+ .addReg(X0Scratch, RegState::Implicit)
+ .setMIFlag(MachineInstr::FrameSetup);
}
return true;
}
@@ -3241,6 +3379,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
Size = 2;
Alignment = Align(2);
break;
+ case RegPairInfo::VG:
+ continue;
}
LLVM_DEBUG(dbgs() << "CSR restore: (" << printReg(Reg1, TRI);
if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI);
@@ -3440,6 +3580,19 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
CSStackSize += RegSize;
}
+ // Increase the callee-saved stack size if the function has streaming mode
+ // changes, as we will need to spill the value of the VG register.
+ // For locally streaming functions, we spill both the streaming and
+ // non-streaming VG value.
+ const Function &F = MF.getFunction();
+ SMEAttrs Attrs(F);
+ if (AFI->hasStreamingModeChanges()) {
+ if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
+ CSStackSize += 16;
+ else
+ CSStackSize += 8;
+ }
+
// Save number of saved regs, so we can easily update CSStackSize later.
unsigned NumSavedRegs = SavedRegs.count();
@@ -3576,6 +3729,33 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
if ((unsigned)FrameIdx > MaxCSFrameIndex) MaxCSFrameIndex = FrameIdx;
}
+ // Insert VG into the list of CSRs, immediately before LR if saved.
+ if (AFI->hasStreamingModeChanges()) {
+ std::vector<CalleeSavedInfo> VGSaves;
+ SMEAttrs Attrs(MF.getFunction());
+
+ auto VGInfo = CalleeSavedInfo(AArch64::VG);
+ VGInfo.setRestored(false);
+ VGSaves.push_back(VGInfo);
+
+ // Add VG again if the function is locally-streaming, as we will spill two
+ // values.
+ if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
+ VGSaves.push_back(VGInfo);
+
+ bool InsertBeforeLR = false;
+
+ for (unsigned I = 0; I < CSI.size(); I++)
+ if (CSI[I].getReg() == AArch64::LR) {
+ InsertBeforeLR = true;
+ CSI.insert(CSI.begin() + I, VGSaves.begin(), VGSaves.end());
+ break;
+ }
+
+ if (!InsertBeforeLR)
+ CSI.insert(CSI.end(), VGSaves.begin(), VGSaves.end());
+ }
+
for (auto &CS : CSI) {
Register Reg = CS.getReg();
const TargetRegisterClass *RC = RegInfo->getMinimalPhysRegClass(Reg);
@@ -4191,12 +4371,58 @@ MachineBasicBlock::iterator tryMergeAdjacentSTG(MachineBasicBlock::iterator II,
}
} // namespace
+MachineBasicBlock::iterator emitVGSaveRestore(MachineBasicBlock::iterator II,
+ const AArch64FrameLowering *TFI) {
+ MachineInstr &MI = *II;
+ MachineBasicBlock *MBB = MI.getParent();
+ MachineFunction *MF = MBB->getParent();
+
+ if (MI.getOpcode() != AArch64::VGSavePseudo &&
+ MI.getOpcode() != AArch64::VGRestorePseudo)
+ return II;
+
+ SMEAttrs FuncAttrs(MF->getFunction());
+ bool LocallyStreaming =
+ FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
+ const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();
+ const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
+ const AArch64InstrInfo *TII =
+ MF->getSubtarget<AArch64Subtarget>().getInstrInfo();
+
+ int64_t VGFrameIdx =
+ LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
+ assert(VGFrameIdx != std::numeric_limits<int>::max() &&
+ "Expected FrameIdx for VG");
+
+ unsigned CFIIndex;
+ if (MI.getOpcode() == AArch64::VGSavePseudo) {
+ const MachineFrameInfo &MFI = MF->getFrameInfo();
+ int64_t Offset =
+ MFI.getObjectOffset(VGFrameIdx) - TFI->getOffsetOfLocalArea();
+ CFIIndex = MF->addFrameInst(MCCFIInstruction::createOffset(
+ nullptr, TRI->getDwarfRegNum(AArch64::VG, true), Offset));
+ } else
+ CFIIndex = MF->addFrameInst(MCCFIInstruction::createRestore(
+ nullptr, TRI->getDwarfRegNum(AArch64::VG, true)));
+
+ MachineInstr *UnwindInst = BuildMI(*MBB, II, II->getDebugLoc(),
+ TII->get(TargetOpcode::CFI_INSTRUCTION))
+ .addCFIIndex(CFIIndex);
+
+ MI.eraseFromParent();
+ return UnwindInst->getIterator();
+}
+
void AArch64FrameLowering::processFunctionBeforeFrameIndicesReplaced(
MachineFunction &MF, RegScavenger *RS = nullptr) const {
- if (StackTaggingMergeSetTag)
- for (auto &BB : MF)
- for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();)
+ AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+ for (auto &BB : MF)
+ for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();) {
+ if (AFI->hasStreamingModeChanges())
+ II = emitVGSaveRestore(II, this);
+ if (StackTaggingMergeSetTag)
II = tryMergeAdjacentSTG(II, this, RS);
+ }
}
/// For Win64 AArch64 EH, the offset to the Unwind object is from the SP
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4f819f..af8b9d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2493,6 +2493,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
case AArch64ISD::FIRST_NUMBER:
break;
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
+ MAKE_CASE(AArch64ISD::VG_SAVE)
+ MAKE_CASE(AArch64ISD::VG_RESTORE)
MAKE_CASE(AArch64ISD::SMSTART)
MAKE_CASE(AArch64ISD::SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
@@ -8514,6 +8516,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue InGlue;
if (RequiresSMChange) {
+
+ Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), Chain);
+ InGlue = Chain.getValue(1);
+
SDValue NewChain = changeStreamingMode(
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
@@ -8691,6 +8698,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Result = changeStreamingMode(
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
+ InGlue = Result.getValue(1);
+
+ Result =
+ DAG.getNode(AArch64ISD::VG_RESTORE, DL,
+ DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
}
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 48a4ea9..b57ba09 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -70,6 +70,9 @@ enum NodeType : unsigned {
COALESCER_BARRIER,
+ VG_SAVE,
+ VG_RESTORE,
+
SMSTART,
SMSTOP,
RESTORE_ZA,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
index c3d64f5..957d7bc 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
@@ -196,12 +196,14 @@ bool AArch64FunctionInfo::needsAsyncDwarfUnwindInfo(
const MachineFunction &MF) const {
if (!NeedsAsyncDwarfUnwindInfo) {
const Function &F = MF.getFunction();
+ const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
// The check got "minsize" is because epilogue unwind info is not emitted
// (yet) for homogeneous epilogues, outlined functions, and functions
// outlined from.
- NeedsAsyncDwarfUnwindInfo = needsDwarfUnwindInfo(MF) &&
- F.getUWTableKind() == UWTableKind::Async &&
- !F.hasMinSize();
+ NeedsAsyncDwarfUnwindInfo =
+ needsDwarfUnwindInfo(MF) &&
+ ((F.getUWTableKind() == UWTableKind::Async && !F.hasMinSize()) ||
+ AFI->hasStreamingModeChanges());
}
return *NeedsAsyncDwarfUnwindInfo;
}
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index df09fc5..839a3a3 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -13,6 +13,7 @@
#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H
#define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H
+#include "AArch64Subtarget.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -216,6 +217,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
+ // The stack slots where VG values are stored to.
+ int64_t VGIdx = std::numeric_limits<int>::max();
+ int64_t StreamingVGIdx = std::numeric_limits<int>::max();
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -234,6 +239,12 @@ public:
Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
+ int64_t getVGIdx() const { return VGIdx; };
+ void setVGIdx(unsigned Idx) { VGIdx = Idx; };
+
+ int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
+ void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; };
+
bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 2b70c47..fea70b7 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -31,6 +31,12 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
def AArch64CoalescerBarrier
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64VGSave : SDNode<"AArch64ISD::VG_SAVE", SDTypeProfile<0, 0, []>,
+ [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
+
+def AArch64VGRestore : SDNode<"AArch64ISD::VG_RESTORE", SDTypeProfile<0, 0, []>,
+ [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
+
//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
@@ -221,6 +227,15 @@ def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
(MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+// Pseudo to insert cfi_offset/cfi_restore instructions. Used to save or restore
+// the streaming value of VG around streaming-mode changes in locally-streaming
+// functions.
+def VGSavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
+def : Pat<(AArch64VGSave), (VGSavePseudo)>;
+
+def VGRestorePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
+def : Pat<(AArch64VGRestore), (VGRestorePseudo)>;
+
//===----------------------------------------------------------------------===//
// SME2 Instructions
//===----------------------------------------------------------------------===//