aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp178
1 files changed, 121 insertions, 57 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a4c1e26..899baa9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8086,13 +8086,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
}
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+ SelectionDAG &DAG,
+ AArch64FunctionInfo *Info, SDLoc DL,
+ SDValue Chain, bool IsSave) {
+ MachineFunction &MF = DAG.getMachineFunction();
+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ FuncInfo->setSMESaveBufferUsed();
+ TargetLowering::ArgListTy Args;
+ Args.emplace_back(
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
+ PointerType::getUnqual(*DAG.getContext()));
+
+ RTLIB::Libcall LC =
+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
+ SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
+ TLI.getPointerTy(DAG.getDataLayout()));
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+ TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
+ return TLI.LowerCallTo(CLI).second;
+}
+
+static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
+ const AArch64TargetLowering &TLI,
+ const AArch64RegisterInfo &TRI,
+ AArch64FunctionInfo &FuncInfo,
+ SelectionDAG &DAG) {
+ // Conditionally restore the lazy save using a pseudo node.
+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
+ TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
+ SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
+ DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
+ TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
+ SDValue TPIDR2_EL0 = DAG.getNode(
+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
+ // RESTORE_ZA pseudo.
+ SDValue Glue;
+ SDValue TPIDR2Block = DAG.getFrameIndex(
+ TPIDR2.FrameIndex,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+ Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
+ Chain =
+ DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
+ {Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
+ RestoreRoutine, RegMask, Chain.getValue(1)});
+ // Finally reset the TPIDR2_EL0 register to 0.
+ Chain = DAG.getNode(
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+ DAG.getConstant(0, DL, MVT::i64));
+ TPIDR2.Uses++;
+ return Chain;
+}
+
SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
SelectionDAG &DAG) const {
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
SDValue Glue = Chain.getValue(1);
MachineFunction &MF = DAG.getMachineFunction();
- SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
+ auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
+
+ SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();
// The following conditions are true on entry to an exception handler:
// - PSTATE.SM is 0.
@@ -8107,14 +8170,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
// These mode changes are usually optimized away in catch blocks as they
// occur before the __cxa_begin_catch (which is a non-streaming function),
// but are necessary in some cases (such as for cleanups).
+ //
+ // Additionally, if the function has ZA or ZT0 state, we must restore it.
+ // [COND_]SMSTART SM
if (SMEFnAttrs.hasStreamingInterfaceOrBody())
- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
- /*Glue*/ Glue, AArch64SME::Always);
+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
+ /*Glue*/ Glue, AArch64SME::Always);
+ else if (SMEFnAttrs.hasStreamingCompatibleInterface())
+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
+ AArch64SME::IfCallerIsStreaming);
- if (SMEFnAttrs.hasStreamingCompatibleInterface())
- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
- AArch64SME::IfCallerIsStreaming);
+ if (getTM().useNewSMEABILowering())
+ return Chain;
+
+ if (SMEFnAttrs.hasAgnosticZAInterface()) {
+ // Restore full ZA
+ Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
+ /*IsSave=*/false);
+ } else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
+ // SMSTART ZA
+ Chain = DAG.getNode(
+ AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+ DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));
+
+ // Restore ZT0
+ if (SMEFnAttrs.hasZT0State()) {
+ SDValue ZT0FrameIndex =
+ getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
+ Chain =
+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
+ }
+
+ // Restore ZA
+ if (SMEFnAttrs.hasZAState())
+ Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
+ }
return Chain;
}
@@ -9232,30 +9324,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
}
-// Emit a call to __arm_sme_save or __arm_sme_restore.
-static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
- SelectionDAG &DAG,
- AArch64FunctionInfo *Info, SDLoc DL,
- SDValue Chain, bool IsSave) {
- MachineFunction &MF = DAG.getMachineFunction();
- AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
- FuncInfo->setSMESaveBufferUsed();
- TargetLowering::ArgListTy Args;
- Args.emplace_back(
- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
- PointerType::getUnqual(*DAG.getContext()));
-
- RTLIB::Libcall LC =
- IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
- SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
- TLI.getPointerTy(DAG.getDataLayout()));
- auto *RetTy = Type::getVoidTy(*DAG.getContext());
- TargetLowering::CallLoweringInfo CLI(DAG);
- CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
- TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
- return TLI.LowerCallTo(CLI).second;
-}
-
static AArch64SME::ToggleCondition
getSMToggleCondition(const SMECallAttrs &CallAttrs) {
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
@@ -10015,33 +10083,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
if (RequiresLazySave) {
- // Conditionally restore the lazy save using a pseudo node.
- RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- SDValue RegMask = DAG.getRegisterMask(
- TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
- getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
- SDValue TPIDR2_EL0 = DAG.getNode(
- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
- // Copy the address of the TPIDR2 block into X0 before 'calling' the
- // RESTORE_ZA pseudo.
- SDValue Glue;
- SDValue TPIDR2Block = DAG.getFrameIndex(
- TPIDR2.FrameIndex,
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
- Result =
- DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
- {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
- RestoreRoutine, RegMask, Result.getValue(1)});
- // Finally reset the TPIDR2_EL0 register to 0.
- Result = DAG.getNode(
- ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64));
- TPIDR2.Uses++;
+ Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
} else if (RequiresSaveAllZA) {
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
/*IsSave=*/false);
@@ -11736,6 +11778,28 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(
return DAG.getNode(ISD::AND, DL, VT, LHS, Shift);
}
+ // Check for sign bit test patterns that can use TST optimization.
+ // (SELECT_CC setlt, sign_extend_inreg, 0, tval, fval)
+ // -> TST %operand, sign_bit; CSEL
+ // (SELECT_CC setlt, sign_extend, 0, tval, fval)
+ // -> TST %operand, sign_bit; CSEL
+ if (CC == ISD::SETLT && RHSC && RHSC->isZero() && LHS.hasOneUse() &&
+ (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+ LHS.getOpcode() == ISD::SIGN_EXTEND)) {
+
+ uint64_t SignBitPos;
+ std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
+ EVT TestVT = LHS.getValueType();
+ SDValue SignBitConst = DAG.getConstant(1ULL << SignBitPos, DL, TestVT);
+ SDValue TST =
+ DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(TestVT, MVT::i32),
+ LHS, SignBitConst);
+
+ SDValue Flags = TST.getValue(1);
+ return DAG.getNode(AArch64ISD::CSEL, DL, TVal.getValueType(), TVal, FVal,
+ DAG.getConstant(AArch64CC::NE, DL, MVT::i32), Flags);
+ }
+
// Canonicalise absolute difference patterns:
// select_cc lhs, rhs, sub(lhs, rhs), sub(rhs, lhs), cc ->
// select_cc lhs, rhs, sub(lhs, rhs), neg(sub(lhs, rhs)), cc