diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/AsmParser/LLLexer.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/AsmParser/LLParser.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/IR/AsmWriter.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64CallingConvention.td | 16 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 39 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.h | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp | 38 |
7 files changed, 108 insertions, 7 deletions
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp index c9a98269..c020fe7 100644 --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -597,6 +597,8 @@ lltok::Kind LLLexer::LexIdentifier() { KEYWORD(arm_aapcs_vfpcc); KEYWORD(aarch64_vector_pcs); KEYWORD(aarch64_sve_vector_pcs); + KEYWORD(aarch64_sme_preservemost_from_x0); + KEYWORD(aarch64_sme_preservemost_from_x2); KEYWORD(msp430_intrcc); KEYWORD(avr_intrcc); KEYWORD(avr_signalcc); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 10a775f..7475868 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1875,6 +1875,8 @@ void LLParser::parseOptionalDLLStorageClass(unsigned &Res) { /// ::= 'arm_aapcs_vfpcc' /// ::= 'aarch64_vector_pcs' /// ::= 'aarch64_sve_vector_pcs' +/// ::= 'aarch64_sme_preservemost_from_x0' +/// ::= 'aarch64_sme_preservemost_from_x2' /// ::= 'msp430_intrcc' /// ::= 'avr_intrcc' /// ::= 'avr_signalcc' @@ -1925,6 +1927,12 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) { case lltok::kw_aarch64_sve_vector_pcs: CC = CallingConv::AArch64_SVE_VectorCall; break; + case lltok::kw_aarch64_sme_preservemost_from_x0: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0; + break; + case lltok::kw_aarch64_sme_preservemost_from_x2: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2; + break; case lltok::kw_msp430_intrcc: CC = CallingConv::MSP430_INTR; break; case lltok::kw_avr_intrcc: CC = CallingConv::AVR_INTR; break; case lltok::kw_avr_signalcc: CC = CallingConv::AVR_SIGNAL; break; diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 0ee559a..d9443f4 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -312,6 +312,12 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) { case CallingConv::AArch64_SVE_VectorCall: Out << "aarch64_sve_vector_pcs"; break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + Out << "aarch64_sme_preservemost_from_x0"; + break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: + Out << "aarch64_sme_preservemost_from_x2"; + break; case CallingConv::MSP430_INTR: Out << "msp430_intrcc"; break; case CallingConv::AVR_INTR: Out << "avr_intrcc "; break; case CallingConv::AVR_SIGNAL: Out << "avr_signalcc "; break; diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td index 6cf7bf6..0000d26f 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -435,6 +435,22 @@ def CSR_AArch64_SVE_AAPCS : CalleeSavedRegs<(add (sequence "Z%u", 8, 23), X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, LR, FP)>; +// SME ABI support routines such as __arm_tpidr2_save/restore preserve most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 0, 13), + (sequence "X%u",19, 28), + LR, FP)>; + +// SME ABI support routines __arm_sme_state preserves most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 2, 15), + (sequence "X%u",19, 28), + LR, FP)>; + def CSR_AArch64_AAPCS_SwiftTail : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 198d332..944c8ba 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4490,6 +4490,32 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask); } +SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain, + SMEAttrs Attrs, SDLoc DL, + EVT VT) const { + if (Attrs.hasStreamingInterfaceOrBody()) + return DAG.getConstant(1, DL, VT); + + if (Attrs.hasNonStreamingInterfaceAndBody()) + return DAG.getConstant(0, DL, VT); + + assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface"); + + SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + getPointerTy(DAG.getDataLayout())); + Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); + Type *RetTy = StructType::get(Int64Ty, Int64Ty); + TargetLowering::CallLoweringInfo CLI(DAG); + ArgListTy Args; + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, + RetTy, Callee, std::move(Args)); + std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI); + SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); + return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), + Mask); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -4521,13 +4547,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL); } case Intrinsic::aarch64_sme_get_pstatesm: { - SDValue Chain = Op.getOperand(0); - SDValue MRS = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other), - Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64)); - SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64); - SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask); - return DAG.getMergeValues({And, Chain}, DL); + SDValue Chain = Op->getOperand(0); + SMEAttrs Attrs(DAG.getMachineFunction().getFunction()); + SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType()); + return DAG.getMergeValues({PStateSM, Chain}, DL); } } } @@ -5834,6 +5857,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, return CC_AArch64_Win64_CFGuard_Check; case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 82e0579..a5552ca 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H #include "AArch64.h" +#include "Utils/AArch64SMEAttributes.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/SelectionDAG.h" @@ -1158,6 +1159,11 @@ private: // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + // Returns the runtime value for PSTATE.SM. When the function is streaming- + // compatible, this generates a call to __arm_sme_state. + SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs, + SDLoc DL, EVT VT) const; + bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp index f92fcca..91b6d18 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -91,6 +91,18 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { return CSR_AArch64_AAVPCS_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) return CSR_AArch64_SVE_AAPCS_SaveList; + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering() ->supportSwiftError() && MF->getFunction().getAttributes().hasAttrSomewhere( @@ -123,6 +135,18 @@ AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const { if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) return MF->getInfo<AArch64FunctionInfo>()->isSplitCSR() ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList @@ -193,6 +217,14 @@ AArch64RegisterInfo::getDarwinCallPreservedMask(const MachineFunction &MF, if (CC == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "unsupported on Darwin."); if (CC == CallingConv::CFGuard_Check) report_fatal_error( "Calling convention CFGuard_Check is unsupported on Darwin."); @@ -230,6 +262,10 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF, if (CC == CallingConv::AArch64_SVE_VectorCall) return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask : CSR_AArch64_SVE_AAPCS_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask; if (CC == CallingConv::CFGuard_Check) return CSR_Win_AArch64_CFGuard_Check_RegMask; if (MF.getSubtarget<AArch64Subtarget>().getTargetLowering() @@ -539,6 +575,8 @@ bool AArch64RegisterInfo::isArgumentRegister(const MachineFunction &MF, return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg); case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg); } } |