diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 178 |
1 files changed, 121 insertions, 57 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a4c1e26..899baa9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8086,13 +8086,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); } +// Emit a call to __arm_sme_save or __arm_sme_restore. +static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, + SelectionDAG &DAG, + AArch64FunctionInfo *Info, SDLoc DL, + SDValue Chain, bool IsSave) { + MachineFunction &MF = DAG.getMachineFunction(); + AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); + FuncInfo->setSMESaveBufferUsed(); + TargetLowering::ArgListTy Args; + Args.emplace_back( + DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64), + PointerType::getUnqual(*DAG.getContext())); + + RTLIB::Libcall LC = + IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE; + SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC), + TLI.getPointerTy(DAG.getDataLayout())); + auto *RetTy = Type::getVoidTy(*DAG.getContext()); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); + return TLI.LowerCallTo(CLI).second; +} + +static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL, + const AArch64TargetLowering &TLI, + const AArch64RegisterInfo &TRI, + AArch64FunctionInfo &FuncInfo, + SelectionDAG &DAG) { + // Conditionally restore the lazy save using a pseudo node. + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE; + TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj(); + SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask( + DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC))); + SDValue RestoreRoutine = DAG.getTargetExternalSymbol( + TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout())); + SDValue TPIDR2_EL0 = DAG.getNode( + ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain, + DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); + // Copy the address of the TPIDR2 block into X0 before 'calling' the + // RESTORE_ZA pseudo. + SDValue Glue; + SDValue TPIDR2Block = DAG.getFrameIndex( + TPIDR2.FrameIndex, + DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue); + Chain = + DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, + {Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64), + RestoreRoutine, RegMask, Chain.getValue(1)}); + // Finally reset the TPIDR2_EL0 register to 0. + Chain = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, + DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64)); + TPIDR2.Uses++; + return Chain; +} + SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL, SelectionDAG &DAG) const { assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value"); SDValue Glue = Chain.getValue(1); MachineFunction &MF = DAG.getMachineFunction(); - SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs(); + auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>(); + auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); + const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo(); + + SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs(); // The following conditions are true on entry to an exception handler: // - PSTATE.SM is 0. @@ -8107,14 +8170,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL, // These mode changes are usually optimized away in catch blocks as they // occur before the __cxa_begin_catch (which is a non-streaming function), // but are necessary in some cases (such as for cleanups). + // + // Additionally, if the function has ZA or ZT0 state, we must restore it. + // [COND_]SMSTART SM if (SMEFnAttrs.hasStreamingInterfaceOrBody()) - return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, - /*Glue*/ Glue, AArch64SME::Always); + Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, + /*Glue*/ Glue, AArch64SME::Always); + else if (SMEFnAttrs.hasStreamingCompatibleInterface()) + Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue, + AArch64SME::IfCallerIsStreaming); - if (SMEFnAttrs.hasStreamingCompatibleInterface()) - return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue, - AArch64SME::IfCallerIsStreaming); + if (getTM().useNewSMEABILowering()) + return Chain; + + if (SMEFnAttrs.hasAgnosticZAInterface()) { + // Restore full ZA + Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain, + /*IsSave=*/false); + } else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) { + // SMSTART ZA + Chain = DAG.getNode( + AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain, + DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32)); + + // Restore ZT0 + if (SMEFnAttrs.hasZT0State()) { + SDValue ZT0FrameIndex = + getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG); + Chain = + DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other), + {Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex}); + } + + // Restore ZA + if (SMEFnAttrs.hasZAState()) + Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG); + } return Chain; } @@ -9232,30 +9324,6 @@ SDValue AArch64TargetLowering::changeStreamingMode( return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1)); } -// Emit a call to __arm_sme_save or __arm_sme_restore. -static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, - SelectionDAG &DAG, - AArch64FunctionInfo *Info, SDLoc DL, - SDValue Chain, bool IsSave) { - MachineFunction &MF = DAG.getMachineFunction(); - AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); - FuncInfo->setSMESaveBufferUsed(); - TargetLowering::ArgListTy Args; - Args.emplace_back( - DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64), - PointerType::getUnqual(*DAG.getContext())); - - RTLIB::Libcall LC = - IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE; - SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC), - TLI.getPointerTy(DAG.getDataLayout())); - auto *RetTy = Type::getVoidTy(*DAG.getContext()); - TargetLowering::CallLoweringInfo CLI(DAG); - CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); - return TLI.LowerCallTo(CLI).second; -} - static AArch64SME::ToggleCondition getSMToggleCondition(const SMECallAttrs &CallAttrs) { if (!CallAttrs.caller().hasStreamingCompatibleInterface() || @@ -10015,33 +10083,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx}); if (RequiresLazySave) { - // Conditionally restore the lazy save using a pseudo node. - RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE; - TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); - SDValue RegMask = DAG.getRegisterMask( - TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC))); - SDValue RestoreRoutine = DAG.getTargetExternalSymbol( - getLibcallName(LC), getPointerTy(DAG.getDataLayout())); - SDValue TPIDR2_EL0 = DAG.getNode( - ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, - DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); - // Copy the address of the TPIDR2 block into X0 before 'calling' the - // RESTORE_ZA pseudo. - SDValue Glue; - SDValue TPIDR2Block = DAG.getFrameIndex( - TPIDR2.FrameIndex, - DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); - Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue); - Result = - DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, - {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64), - RestoreRoutine, RegMask, Result.getValue(1)}); - // Finally reset the TPIDR2_EL0 register to 0. - Result = DAG.getNode( - ISD::INTRINSIC_VOID, DL, MVT::Other, Result, - DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i64)); - TPIDR2.Uses++; + Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG); } else if (RequiresSaveAllZA) { Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result, /*IsSave=*/false); @@ -11736,6 +11778,28 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( return DAG.getNode(ISD::AND, DL, VT, LHS, Shift); } + // Check for sign bit test patterns that can use TST optimization. + // (SELECT_CC setlt, sign_extend_inreg, 0, tval, fval) + // -> TST %operand, sign_bit; CSEL + // (SELECT_CC setlt, sign_extend, 0, tval, fval) + // -> TST %operand, sign_bit; CSEL + if (CC == ISD::SETLT && RHSC && RHSC->isZero() && LHS.hasOneUse() && + (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG || + LHS.getOpcode() == ISD::SIGN_EXTEND)) { + + uint64_t SignBitPos; + std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS); + EVT TestVT = LHS.getValueType(); + SDValue SignBitConst = DAG.getConstant(1ULL << SignBitPos, DL, TestVT); + SDValue TST = + DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(TestVT, MVT::i32), + LHS, SignBitConst); + + SDValue Flags = TST.getValue(1); + return DAG.getNode(AArch64ISD::CSEL, DL, TVal.getValueType(), TVal, FVal, + DAG.getConstant(AArch64CC::NE, DL, MVT::i32), Flags); + } + // Canonicalise absolute difference patterns: // select_cc lhs, rhs, sub(lhs, rhs), sub(rhs, lhs), cc -> // select_cc lhs, rhs, sub(lhs, rhs), neg(sub(lhs, rhs)), cc |