diff options
author | Sander de Smalen <sander.desmalen@arm.com> | 2022-09-14 15:53:13 +0000 |
---|---|---|
committer | Sander de Smalen <sander.desmalen@arm.com> | 2022-09-15 15:14:13 +0000 |
commit | 45d28779c5dc6c8afa6feb24d68606f01b9800f4 (patch) | |
tree | ea5fcee24539871a03e415c17a2904cede77a162 /llvm/lib | |
parent | b0eea8f440af48dfe52ada04a58198460694fb56 (diff) | |
download | llvm-45d28779c5dc6c8afa6feb24d68606f01b9800f4.zip llvm-45d28779c5dc6c8afa6feb24d68606f01b9800f4.tar.gz llvm-45d28779c5dc6c8afa6feb24d68606f01b9800f4.tar.bz2 |
[AArch64][SME] Fix lowering of llvm.aarch64.get.pstatesm()
A thread may not have access to SME or TPIDR2_EL0, so in order to
safely query PSTATE.SM in a streaming-compatible function, the
code should call `__arm_sme_state()`, as described in the ABI:
https://github.com/ARM-software/abi-aa/pull/123/commits/c2bb09c4d4ee60a5787baf1ccc7e92e67e4240b7
This means that the value of pstate.sm is:
* 0 if the function is non-streaming.
* 1 if the function has `arm_streaming` or `arm_locally_streaming`.
* evaluated at runtime by a call to __arm_sme_state() otherwise.
This patch also adds a calling convention for calls to SME support routines.
At some point we can remove the need for the llvm.aarch64.get.pstatesm() intrinsic
and use function calls (with the corresponding cc) directly instead.
Reviewed By: aemerson
Differential Revision: https://reviews.llvm.org/D131571
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/AsmParser/LLLexer.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/AsmParser/LLParser.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/IR/AsmWriter.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64CallingConvention.td | 16 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 39 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.h | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp | 38 |
7 files changed, 108 insertions, 7 deletions
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp index c9a98269..c020fe7 100644 --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -597,6 +597,8 @@ lltok::Kind LLLexer::LexIdentifier() { KEYWORD(arm_aapcs_vfpcc); KEYWORD(aarch64_vector_pcs); KEYWORD(aarch64_sve_vector_pcs); + KEYWORD(aarch64_sme_preservemost_from_x0); + KEYWORD(aarch64_sme_preservemost_from_x2); KEYWORD(msp430_intrcc); KEYWORD(avr_intrcc); KEYWORD(avr_signalcc); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 10a775f..7475868 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1875,6 +1875,8 @@ void LLParser::parseOptionalDLLStorageClass(unsigned &Res) { /// ::= 'arm_aapcs_vfpcc' /// ::= 'aarch64_vector_pcs' /// ::= 'aarch64_sve_vector_pcs' +/// ::= 'aarch64_sme_preservemost_from_x0' +/// ::= 'aarch64_sme_preservemost_from_x2' /// ::= 'msp430_intrcc' /// ::= 'avr_intrcc' /// ::= 'avr_signalcc' @@ -1925,6 +1927,12 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) { case lltok::kw_aarch64_sve_vector_pcs: CC = CallingConv::AArch64_SVE_VectorCall; break; + case lltok::kw_aarch64_sme_preservemost_from_x0: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0; + break; + case lltok::kw_aarch64_sme_preservemost_from_x2: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2; + break; case lltok::kw_msp430_intrcc: CC = CallingConv::MSP430_INTR; break; case lltok::kw_avr_intrcc: CC = CallingConv::AVR_INTR; break; case lltok::kw_avr_signalcc: CC = CallingConv::AVR_SIGNAL; break; diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 0ee559a..d9443f4 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -312,6 +312,12 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) { case CallingConv::AArch64_SVE_VectorCall: Out << "aarch64_sve_vector_pcs"; break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + Out << "aarch64_sme_preservemost_from_x0"; + break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: + Out << "aarch64_sme_preservemost_from_x2"; + break; case CallingConv::MSP430_INTR: Out << "msp430_intrcc"; break; case CallingConv::AVR_INTR: Out << "avr_intrcc "; break; case CallingConv::AVR_SIGNAL: Out << "avr_signalcc "; break; diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td index 6cf7bf6..0000d26f 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -435,6 +435,22 @@ def CSR_AArch64_SVE_AAPCS : CalleeSavedRegs<(add (sequence "Z%u", 8, 23), X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, LR, FP)>; +// SME ABI support routines such as __arm_tpidr2_save/restore preserve most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 0, 13), + (sequence "X%u",19, 28), + LR, FP)>; + +// SME ABI support routines __arm_sme_state preserves most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 2, 15), + (sequence "X%u",19, 28), + LR, FP)>; + def CSR_AArch64_AAPCS_SwiftTail : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 198d332..944c8ba 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4490,6 +4490,32 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask); } +SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain, + SMEAttrs Attrs, SDLoc DL, + EVT VT) const { + if (Attrs.hasStreamingInterfaceOrBody()) + return DAG.getConstant(1, DL, VT); + + if (Attrs.hasNonStreamingInterfaceAndBody()) + return DAG.getConstant(0, DL, VT); + + assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface"); + + SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + getPointerTy(DAG.getDataLayout())); + Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); + Type *RetTy = StructType::get(Int64Ty, Int64Ty); + TargetLowering::CallLoweringInfo CLI(DAG); + ArgListTy Args; + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, + RetTy, Callee, std::move(Args)); + std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI); + SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); + return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), + Mask); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -4521,13 +4547,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL); } case Intrinsic::aarch64_sme_get_pstatesm: { - SDValue Chain = Op.getOperand(0); - SDValue MRS = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other), - Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64)); - SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64); - SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask); - return DAG.getMergeValues({And, Chain}, DL); + SDValue Chain = Op->getOperand(0); + SMEAttrs Attrs(DAG.getMachineFunction().getFunction()); + SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType()); + return DAG.getMergeValues({PStateSM, Chain}, DL); } } } @@ -5834,6 +5857,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, return CC_AArch64_Win64_CFGuard_Check; case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 82e0579..a5552ca 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H #include "AArch64.h" +#include "Utils/AArch64SMEAttributes.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/SelectionDAG.h" @@ -1158,6 +1159,11 @@ private: // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + // Returns the runtime value for PSTATE.SM. When the function is streaming- + // compatible, this generates a call to __arm_sme_state. + SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs, + SDLoc DL, EVT VT) const; + bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp index f92fcca..91b6d18 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -91,6 +91,18 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { return CSR_AArch64_AAVPCS_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) return CSR_AArch64_SVE_AAPCS_SaveList; + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering() ->supportSwiftError() && MF->getFunction().getAttributes().hasAttrSomewhere( @@ -123,6 +135,18 @@ AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const { if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) return MF->getInfo<AArch64FunctionInfo>()->isSplitCSR() ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList @@ -193,6 +217,14 @@ AArch64RegisterInfo::getDarwinCallPreservedMask(const MachineFunction &MF, if (CC == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "unsupported on Darwin."); if (CC == CallingConv::CFGuard_Check) report_fatal_error( "Calling convention CFGuard_Check is unsupported on Darwin."); @@ -230,6 +262,10 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF, if (CC == CallingConv::AArch64_SVE_VectorCall) return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask : CSR_AArch64_SVE_AAPCS_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask; if (CC == CallingConv::CFGuard_Check) return CSR_Win_AArch64_CFGuard_Check_RegMask; if (MF.getSubtarget<AArch64Subtarget>().getTargetLowering() @@ -539,6 +575,8 @@ bool AArch64RegisterInfo::isArgumentRegister(const MachineFunction &MF, return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg); case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg); } } |