aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/AsmParser/LLLexer.cpp2
-rw-r--r--llvm/lib/AsmParser/LLParser.cpp8
-rw-r--r--llvm/lib/IR/AsmWriter.cpp6
-rw-r--r--llvm/lib/Target/AArch64/AArch64CallingConvention.td16
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp39
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.h6
-rw-r--r--llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp38
7 files changed, 108 insertions, 7 deletions
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index c9a98269..c020fe7 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -597,6 +597,8 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(arm_aapcs_vfpcc);
KEYWORD(aarch64_vector_pcs);
KEYWORD(aarch64_sve_vector_pcs);
+ KEYWORD(aarch64_sme_preservemost_from_x0);
+ KEYWORD(aarch64_sme_preservemost_from_x2);
KEYWORD(msp430_intrcc);
KEYWORD(avr_intrcc);
KEYWORD(avr_signalcc);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 10a775f..7475868 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -1875,6 +1875,8 @@ void LLParser::parseOptionalDLLStorageClass(unsigned &Res) {
/// ::= 'arm_aapcs_vfpcc'
/// ::= 'aarch64_vector_pcs'
/// ::= 'aarch64_sve_vector_pcs'
+/// ::= 'aarch64_sme_preservemost_from_x0'
+/// ::= 'aarch64_sme_preservemost_from_x2'
/// ::= 'msp430_intrcc'
/// ::= 'avr_intrcc'
/// ::= 'avr_signalcc'
@@ -1925,6 +1927,12 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
case lltok::kw_aarch64_sve_vector_pcs:
CC = CallingConv::AArch64_SVE_VectorCall;
break;
+ case lltok::kw_aarch64_sme_preservemost_from_x0:
+ CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0;
+ break;
+ case lltok::kw_aarch64_sme_preservemost_from_x2:
+ CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2;
+ break;
case lltok::kw_msp430_intrcc: CC = CallingConv::MSP430_INTR; break;
case lltok::kw_avr_intrcc: CC = CallingConv::AVR_INTR; break;
case lltok::kw_avr_signalcc: CC = CallingConv::AVR_SIGNAL; break;
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 0ee559a..d9443f4 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -312,6 +312,12 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
case CallingConv::AArch64_SVE_VectorCall:
Out << "aarch64_sve_vector_pcs";
break;
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+ Out << "aarch64_sme_preservemost_from_x0";
+ break;
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
+ Out << "aarch64_sme_preservemost_from_x2";
+ break;
case CallingConv::MSP430_INTR: Out << "msp430_intrcc"; break;
case CallingConv::AVR_INTR: Out << "avr_intrcc "; break;
case CallingConv::AVR_SIGNAL: Out << "avr_signalcc "; break;
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 6cf7bf6..0000d26f 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -435,6 +435,22 @@ def CSR_AArch64_SVE_AAPCS : CalleeSavedRegs<(add (sequence "Z%u", 8, 23),
X19, X20, X21, X22, X23, X24,
X25, X26, X27, X28, LR, FP)>;
+// SME ABI support routines such as __arm_tpidr2_save/restore preserve most registers.
+def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0
+ : CalleeSavedRegs<(add (sequence "Z%u", 0, 31),
+ (sequence "P%u", 0, 15),
+ (sequence "X%u", 0, 13),
+ (sequence "X%u",19, 28),
+ LR, FP)>;
+
+// SME ABI support routines __arm_sme_state preserves most registers.
+def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2
+ : CalleeSavedRegs<(add (sequence "Z%u", 0, 31),
+ (sequence "P%u", 0, 15),
+ (sequence "X%u", 2, 15),
+ (sequence "X%u",19, 28),
+ LR, FP)>;
+
def CSR_AArch64_AAPCS_SwiftTail
: CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 198d332..944c8ba 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4490,6 +4490,32 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
}
+SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
+ SMEAttrs Attrs, SDLoc DL,
+ EVT VT) const {
+ if (Attrs.hasStreamingInterfaceOrBody())
+ return DAG.getConstant(1, DL, VT);
+
+ if (Attrs.hasNonStreamingInterfaceAndBody())
+ return DAG.getConstant(0, DL, VT);
+
+ assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
+
+ SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
+ getPointerTy(DAG.getDataLayout()));
+ Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
+ Type *RetTy = StructType::get(Int64Ty, Int64Ty);
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ ArgListTy Args;
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
+ RetTy, Callee, std::move(Args));
+ std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
+ SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
+ return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
+ Mask);
+}
+
SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -4521,13 +4547,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL);
}
case Intrinsic::aarch64_sme_get_pstatesm: {
- SDValue Chain = Op.getOperand(0);
- SDValue MRS = DAG.getNode(
- AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
- Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64));
- SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64);
- SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask);
- return DAG.getMergeValues({And, Chain}, DL);
+ SDValue Chain = Op->getOperand(0);
+ SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType());
+ return DAG.getMergeValues({PStateSM, Chain}, DL);
}
}
}
@@ -5834,6 +5857,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
return CC_AArch64_Win64_CFGuard_Check;
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 82e0579..a5552ca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -15,6 +15,7 @@
#define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H
#include "AArch64.h"
+#include "Utils/AArch64SMEAttributes.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/SelectionDAG.h"
@@ -1158,6 +1159,11 @@ private:
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
+ // Returns the runtime value for PSTATE.SM. When the function is streaming-
+ // compatible, this generates a call to __arm_sme_state.
+ SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
+ SDLoc DL, EVT VT) const;
+
bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
LLT Ty2) const override;
};
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index f92fcca..91b6d18 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -91,6 +91,18 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const {
return CSR_AArch64_AAVPCS_SaveList;
if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall)
return CSR_AArch64_SVE_AAPCS_SaveList;
+ if (MF->getFunction().getCallingConv() ==
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+ "only supported to improve calls to SME ACLE save/restore/disable-za "
+ "functions, and is not intended to be used beyond that scope.");
+ if (MF->getFunction().getCallingConv() ==
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+ "only supported to improve calls to SME ACLE __arm_sme_state "
+ "and is not intended to be used beyond that scope.");
if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering()
->supportSwiftError() &&
MF->getFunction().getAttributes().hasAttrSomewhere(
@@ -123,6 +135,18 @@ AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const {
if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall)
report_fatal_error(
"Calling convention SVE_VectorCall is unsupported on Darwin.");
+ if (MF->getFunction().getCallingConv() ==
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+ "only supported to improve calls to SME ACLE save/restore/disable-za "
+ "functions, and is not intended to be used beyond that scope.");
+ if (MF->getFunction().getCallingConv() ==
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+ "only supported to improve calls to SME ACLE __arm_sme_state "
+ "and is not intended to be used beyond that scope.");
if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS)
return MF->getInfo<AArch64FunctionInfo>()->isSplitCSR()
? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList
@@ -193,6 +217,14 @@ AArch64RegisterInfo::getDarwinCallPreservedMask(const MachineFunction &MF,
if (CC == CallingConv::AArch64_SVE_VectorCall)
report_fatal_error(
"Calling convention SVE_VectorCall is unsupported on Darwin.");
+ if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+ "unsupported on Darwin.");
+ if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+ report_fatal_error(
+ "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+ "unsupported on Darwin.");
if (CC == CallingConv::CFGuard_Check)
report_fatal_error(
"Calling convention CFGuard_Check is unsupported on Darwin.");
@@ -230,6 +262,10 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF,
if (CC == CallingConv::AArch64_SVE_VectorCall)
return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask
: CSR_AArch64_SVE_AAPCS_RegMask;
+ if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+ return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask;
+ if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+ return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask;
if (CC == CallingConv::CFGuard_Check)
return CSR_Win_AArch64_CFGuard_Check_RegMask;
if (MF.getSubtarget<AArch64Subtarget>().getTargetLowering()
@@ -539,6 +575,8 @@ bool AArch64RegisterInfo::isArgumentRegister(const MachineFunction &MF,
return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg);
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg);
}
}