diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
59 files changed, 3250 insertions, 1638 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h index 5496ebd..8d0ff41 100644 --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -60,6 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass(); FunctionPass *createAArch64CollectLOHPass(); FunctionPass *createSMEABIPass(); FunctionPass *createSMEPeepholeOptPass(); +FunctionPass *createMachineSMEABIPass(); ModulePass *createSVEIntrinsicOptsPass(); InstructionSelector * createAArch64InstructionSelector(const AArch64TargetMachine &, @@ -111,6 +112,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&); void initializeLDTLSCleanupPass(PassRegistry&); void initializeSMEABIPass(PassRegistry &); void initializeSMEPeepholeOptPass(PassRegistry &); +void initializeMachineSMEABIPass(PassRegistry &); void initializeSVEIntrinsicOptsPass(PassRegistry &); void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &); } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp index e8d3161..1169f26 100644 --- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -316,6 +316,12 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( ThunkArgTranslation::PointerIndirection}; }; + if (T->isHalfTy()) { + // Prefix with `llvm` since MSVC doesn't specify `_Float16` + Out << "__llvm_h__"; + return direct(T); + } + if (T->isFloatTy()) { Out << "f"; return direct(T); @@ -327,8 +333,8 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( } if (T->isFloatingPointTy()) { - report_fatal_error( - "Only 32 and 64 bit floating points are supported for ARM64EC thunks"); + report_fatal_error("Only 16, 32, and 64 bit floating points are supported " + "for ARM64EC thunks"); } auto &DL = M->getDataLayout(); @@ -342,8 +348,16 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( uint64_t ElementCnt = T->getArrayNumElements(); uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; - if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { - Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; + if (ElementTy->isHalfTy() || ElementTy->isFloatTy() || + ElementTy->isDoubleTy()) { + if (ElementTy->isHalfTy()) + // Prefix with `llvm` since MSVC doesn't specify `_Float16` + Out << "__llvm_H__"; + else if (ElementTy->isFloatTy()) + Out << "F"; + else if (ElementTy->isDoubleTy()) + Out << "D"; + Out << TotalSizeBytes; if (Alignment.value() >= 16 && !Ret) Out << "a" << Alignment.value(); if (TotalSizeBytes <= 8) { @@ -355,8 +369,9 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( return pointerIndirection(T); } } else if (T->isFloatingPointTy()) { - report_fatal_error("Only 32 and 64 bit floating points are supported for " - "ARM64EC thunks"); + report_fatal_error( + "Only 16, 32, and 64 bit floating points are supported " + "for ARM64EC thunks"); } } @@ -597,6 +612,14 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { return Thunk; } +std::optional<std::string> getArm64ECMangledFunctionName(GlobalValue &GV) { + if (!GV.hasName()) { + GV.setName("__unnamed"); + } + + return llvm::getArm64ECMangledFunctionName(GV.getName()); +} + // Builds the "guest exit thunk", a helper to call a function which may or may // not be an exit thunk. (We optimistically assume non-dllimport function // declarations refer to functions defined in AArch64 code; if the linker @@ -608,7 +631,7 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { getThunkType(F->getFunctionType(), F->getAttributes(), Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty, ArgTranslations); - auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); + auto MangledName = getArm64ECMangledFunctionName(*F); assert(MangledName && "Can't guest exit to function that's already native"); std::string ThunkName = *MangledName; if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { @@ -727,9 +750,6 @@ AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias, // Lower an indirect call with inline code. void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { - assert(CB->getModule()->getTargetTriple().isOSWindows() && - "Only applicable for Windows targets"); - IRBuilder<> B(CB); Value *CalledOperand = CB->getCalledOperand(); @@ -790,7 +810,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { if (!F) continue; if (std::optional<std::string> MangledName = - getArm64ECMangledFunctionName(A.getName().str())) { + getArm64ECMangledFunctionName(A)) { F->addMetadata("arm64ec_unmangled_name", *MDNode::get(M->getContext(), MDString::get(M->getContext(), A.getName()))); @@ -807,7 +827,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { cast<GlobalValue>(F.getPersonalityFn()->stripPointerCasts()); if (PersFn->getValueType() && PersFn->getValueType()->isFunctionTy()) { if (std::optional<std::string> MangledName = - getArm64ECMangledFunctionName(PersFn->getName().str())) { + getArm64ECMangledFunctionName(*PersFn)) { PersFn->setName(MangledName.value()); } } @@ -821,7 +841,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { // Rename hybrid patchable functions and change callers to use a global // alias instead. if (std::optional<std::string> MangledName = - getArm64ECMangledFunctionName(F.getName().str())) { + getArm64ECMangledFunctionName(F)) { std::string OrigName(F.getName()); F.setName(MangledName.value() + HybridPatchableTargetSuffix); @@ -927,7 +947,7 @@ bool AArch64Arm64ECCallLowering::processFunction( // FIXME: Handle functions with weak linkage? if (!F.hasLocalLinkage() || F.hasAddressTaken()) { if (std::optional<std::string> MangledName = - getArm64ECMangledFunctionName(F.getName().str())) { + getArm64ECMangledFunctionName(F)) { F.addMetadata("arm64ec_unmangled_name", *MDNode::get(M->getContext(), MDString::get(M->getContext(), F.getName()))); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index c52487a..da34430 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -1829,8 +1829,8 @@ void AArch64AsmPrinter::emitMOVK(Register Dest, uint64_t Imm, unsigned Shift) { void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) { Register DestReg = MI.getOperand(0).getReg(); - if (STI->hasZeroCycleZeroingFP() && !STI->hasZeroCycleZeroingFPWorkaround() && - STI->isNeonAvailable()) { + if (STI->hasZeroCycleZeroingFPR64() && + !STI->hasZeroCycleZeroingFPWorkaround() && STI->isNeonAvailable()) { // Convert H/S register to corresponding D register if (AArch64::H0 <= DestReg && DestReg <= AArch64::H31) DestReg = AArch64::D0 + (DestReg - AArch64::H0); @@ -2229,13 +2229,24 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) { if (BrTarget == AddrDisc) report_fatal_error("Branch target is signed with its own value"); - // If we are printing BLRA pseudo instruction, then x16 and x17 are - // implicit-def'ed by the MI and AddrDisc is not used as any other input, so - // try to save one MOV by setting MayUseAddrAsScratch. + // If we are printing BLRA pseudo, try to save one MOV by making use of the + // fact that x16 and x17 are described as clobbered by the MI instruction and + // AddrDisc is not used as any other input. + // + // Back in the day, emitPtrauthDiscriminator was restricted to only returning + // either x16 or x17, meaning the returned register is always among the + // implicit-def'ed registers of BLRA pseudo. Now this property can be violated + // if isX16X17Safer predicate is false, thus manually check if AddrDisc is + // among x16 and x17 to prevent clobbering unexpected registers. + // // Unlike BLRA, BRA pseudo is used to perform computed goto, and thus not // declared as clobbering x16/x17. + // + // FIXME: Make use of `killed` flags and register masks instead. + bool AddrDiscIsImplicitDef = + IsCall && (AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17); Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17, - /*MayUseAddrAsScratch=*/IsCall); + AddrDiscIsImplicitDef); bool IsZeroDisc = DiscReg == AArch64::XZR; unsigned Opc; @@ -2862,7 +2873,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { MCInst TmpInst; TmpInst.setOpcode(AArch64::MOVIv16b_ns); TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg())); - TmpInst.addOperand(MCOperand::createImm(MI->getOperand(1).getImm())); + TmpInst.addOperand(MCOperand::createImm(0)); EmitToStreamer(*OutStreamer, TmpInst); return; } @@ -2968,8 +2979,15 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { // See the comments in emitPtrauthBranch. if (Callee == AddrDisc) report_fatal_error("Call target is signed with its own value"); + + // After isX16X17Safer predicate was introduced, emitPtrauthDiscriminator is + // no longer restricted to only reusing AddrDisc when it is X16 or X17 + // (which are implicit-def'ed by AUTH_TCRETURN pseudos), thus impose this + // restriction manually not to clobber an unexpected register. + bool AddrDiscIsImplicitDef = + AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17; Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, ScratchReg, - /*MayUseAddrAsScratch=*/true); + AddrDiscIsImplicitDef); const bool IsZero = DiscReg == AArch64::XZR; const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ}, diff --git a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp index 3436dc9..137ff89 100644 --- a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp +++ b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp @@ -30,6 +30,14 @@ using namespace llvm; #define AARCH64_BRANCH_TARGETS_NAME "AArch64 Branch Targets" namespace { +// BTI HINT encoding: base (32) plus 'c' (2) and/or 'j' (4). +enum : unsigned { + BTIBase = 32, // Base immediate for BTI HINT + BTIC = 1u << 1, // 2 + BTIJ = 1u << 2, // 4 + BTIMask = BTIC | BTIJ, +}; + class AArch64BranchTargets : public MachineFunctionPass { public: static char ID; @@ -42,6 +50,7 @@ private: void addBTI(MachineBasicBlock &MBB, bool CouldCall, bool CouldJump, bool NeedsWinCFI); }; + } // end anonymous namespace char AArch64BranchTargets::ID = 0; @@ -62,9 +71,8 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) { if (!MF.getInfo<AArch64FunctionInfo>()->branchTargetEnforcement()) return false; - LLVM_DEBUG( - dbgs() << "********** AArch64 Branch Targets **********\n" - << "********** Function: " << MF.getName() << '\n'); + LLVM_DEBUG(dbgs() << "********** AArch64 Branch Targets **********\n" + << "********** Function: " << MF.getName() << '\n'); const Function &F = MF.getFunction(); // LLVM does not consider basic blocks which are the targets of jump tables @@ -103,6 +111,12 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) { JumpTableTargets.count(&MBB)) CouldJump = true; + if (MBB.isEHPad()) { + if (HasWinCFI && (MBB.isEHFuncletEntry() || MBB.isCleanupFuncletEntry())) + CouldCall = true; + else + CouldJump = true; + } if (CouldCall || CouldJump) { addBTI(MBB, CouldCall, CouldJump, HasWinCFI); MadeChange = true; @@ -130,7 +144,12 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall, auto MBBI = MBB.begin(); - // Skip the meta instructions, those will be removed anyway. + // If the block starts with EH_LABEL(s), skip them first. + while (MBBI != MBB.end() && MBBI->isEHLabel()) { + ++MBBI; + } + + // Skip meta/CFI/etc. (and EMITBKEY) to reach the first executable insn. for (; MBBI != MBB.end() && (MBBI->isMetaInstruction() || MBBI->getOpcode() == AArch64::EMITBKEY); ++MBBI) @@ -138,16 +157,21 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall, // SCTLR_EL1.BT[01] is set to 0 by default which means // PACI[AB]SP are implicitly BTI C so no BTI C instruction is needed there. - if (MBBI != MBB.end() && HintNum == 34 && + if (MBBI != MBB.end() && ((HintNum & BTIMask) == BTIC) && (MBBI->getOpcode() == AArch64::PACIASP || MBBI->getOpcode() == AArch64::PACIBSP)) return; - if (HasWinCFI && MBBI->getFlag(MachineInstr::FrameSetup)) { - BuildMI(MBB, MBB.begin(), MBB.findDebugLoc(MBB.begin()), - TII->get(AArch64::SEH_Nop)); + // Insert BTI exactly at the first executable instruction. + const DebugLoc DL = MBB.findDebugLoc(MBBI); + MachineInstr *BTI = BuildMI(MBB, MBBI, DL, TII->get(AArch64::HINT)) + .addImm(HintNum) + .getInstr(); + + // WinEH: put .seh_nop after BTI when the first real insn is FrameSetup. + if (HasWinCFI && MBBI != MBB.end() && + MBBI->getFlag(MachineInstr::FrameSetup)) { + auto AfterBTI = std::next(MachineBasicBlock::iterator(BTI)); + BuildMI(MBB, AfterBTI, DL, TII->get(AArch64::SEH_Nop)); } - BuildMI(MBB, MBB.begin(), MBB.findDebugLoc(MBB.begin()), - TII->get(AArch64::HINT)) - .addImm(HintNum); } diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp index 787a1a8..cc46159 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp @@ -75,8 +75,10 @@ static bool finishStackBlock(SmallVectorImpl<CCValAssign> &PendingMembers, auto &It = PendingMembers[0]; CCAssignFn *AssignFn = TLI->CCAssignFnForCall(State.getCallingConv(), /*IsVarArg=*/false); + // FIXME: Get the correct original type. + Type *OrigTy = EVT(It.getValVT()).getTypeForEVT(State.getContext()); if (AssignFn(It.getValNo(), It.getValVT(), It.getValVT(), CCValAssign::Full, - ArgFlags, State)) + ArgFlags, OrigTy, State)) llvm_unreachable("Call operand has unhandled type"); // Return the flags to how they were before. diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.h b/llvm/lib/Target/AArch64/AArch64CallingConvention.h index 63185a9..7105fa6 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.h +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.h @@ -18,52 +18,63 @@ namespace llvm { bool CC_AArch64_AAPCS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, - CCState &State); + Type *OrigTy, CCState &State); bool CC_AArch64_Arm64EC_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_Arm64EC_Thunk_Native(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_DarwinPCS_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_DarwinPCS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_DarwinPCS_ILP32_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, - CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + CCValAssign::LocInfo LocInfo, + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_Win64PCS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, - CCState &State); + Type *OrigTy, CCState &State); bool CC_AArch64_Win64_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_Win64_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool CC_AArch64_GHC(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, - CCState &State); + Type *OrigTy, CCState &State); bool CC_AArch64_Preserve_None(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool RetCC_AArch64_AAPCS(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, - CCState &State); + Type *OrigTy, CCState &State); bool RetCC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State); + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, + CCState &State); bool RetCC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, + ISD::ArgFlagsTy ArgFlags, Type *OrigTy, CCState &State); } // namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index 99f0af5..5f499e5 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -351,9 +351,10 @@ def AArch64PostLegalizerLowering // Post-legalization combines which are primarily optimizations. def AArch64PostLegalizerCombiner : GICombiner<"AArch64PostLegalizerCombinerImpl", - [copy_prop, cast_of_cast_combines, buildvector_of_truncate, - integer_of_truncate, mutate_anyext_to_zext, - combines_for_extload, combine_indexed_load_store, sext_trunc_sextload, + [copy_prop, cast_of_cast_combines, + buildvector_of_truncate, integer_of_truncate, + mutate_anyext_to_zext, combines_for_extload, + combine_indexed_load_store, sext_trunc_sextload, hoist_logic_op_with_same_opcode_hands, redundant_and, xor_of_and_with_same_reg, extractvecelt_pairwise_add, redundant_or, @@ -367,5 +368,6 @@ def AArch64PostLegalizerCombiner select_to_minmax, or_to_bsp, combine_concat_vector, commute_constant_to_rhs, extract_vec_elt_combines, push_freeze_to_prevent_poison_from_propagating, - combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> { + combine_mul_cmlt, combine_use_vector_truncate, + extmultomull, truncsat_combines]> { } diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp index 201bfe0..57dcd68 100644 --- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -92,8 +92,9 @@ private: bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); - MachineBasicBlock *expandRestoreZA(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI); + MachineBasicBlock * + expandCommitOrRestoreZASave(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); }; @@ -528,6 +529,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( UseRev = true; } break; + case AArch64::Destructive2xRegImmUnpred: + // EXT_ZZI_CONSTRUCTIVE Zd, Zs, Imm + // ==> MOVPRFX Zd Zs; EXT_ZZI Zd, Zd, Zs, Imm + std::tie(DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 1, 2); + break; default: llvm_unreachable("Unsupported Destructive Operand type"); } @@ -548,6 +554,7 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( break; case AArch64::DestructiveUnaryPassthru: case AArch64::DestructiveBinaryImm: + case AArch64::Destructive2xRegImmUnpred: DOPRegIsUnique = true; break; case AArch64::DestructiveTernaryCommWithRev: @@ -674,6 +681,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( .add(MI.getOperand(SrcIdx)) .add(MI.getOperand(Src2Idx)); break; + case AArch64::Destructive2xRegImmUnpred: + DOP.addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState) + .add(MI.getOperand(SrcIdx)) + .add(MI.getOperand(Src2Idx)); + break; } if (PRFX) { @@ -979,10 +991,15 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext( return true; } -MachineBasicBlock * -AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI) { +static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111; + +MachineBasicBlock *AArch64ExpandPseudo::expandCommitOrRestoreZASave( + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { MachineInstr &MI = *MBBI; + bool IsRestoreZA = MI.getOpcode() == AArch64::RestoreZAPseudo; + assert((MI.getOpcode() == AArch64::RestoreZAPseudo || + MI.getOpcode() == AArch64::CommitZASavePseudo) && + "Expected ZA commit or restore"); assert((std::next(MBBI) != MBB.end() || MI.getParent()->successors().begin() != MI.getParent()->successors().end()) && @@ -990,21 +1007,23 @@ AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB, // Compare TPIDR2_EL0 value against 0. DebugLoc DL = MI.getDebugLoc(); - MachineInstrBuilder Cbz = BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX)) - .add(MI.getOperand(0)); + MachineInstrBuilder Branch = + BuildMI(MBB, MBBI, DL, + TII->get(IsRestoreZA ? AArch64::CBZX : AArch64::CBNZX)) + .add(MI.getOperand(0)); // Split MBB and create two new blocks: // - MBB now contains all instructions before RestoreZAPseudo. - // - SMBB contains the RestoreZAPseudo instruction only. - // - EndBB contains all instructions after RestoreZAPseudo. + // - SMBB contains the [Commit|RestoreZA]Pseudo instruction only. + // - EndBB contains all instructions after [Commit|RestoreZA]Pseudo. MachineInstr &PrevMI = *std::prev(MBBI); MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true); MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end() ? *SMBB->successors().begin() : SMBB->splitAt(MI, /*UpdateLiveIns*/ true); - // Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB. - Cbz.addMBB(SMBB); + // Add the SMBB label to the CB[N]Z instruction & create a branch to EndBB. + Branch.addMBB(SMBB); BuildMI(&MBB, DL, TII->get(AArch64::B)) .addMBB(EndBB); MBB.addSuccessor(EndBB); @@ -1012,11 +1031,30 @@ AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB, // Replace the pseudo with a call (BL). MachineInstrBuilder MIB = BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL)); - MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit); + // Copy operands (mainly the regmask) from the pseudo. for (unsigned I = 2; I < MI.getNumOperands(); ++I) MIB.add(MI.getOperand(I)); - BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB); + if (IsRestoreZA) { + // Mark the TPIDR2 block pointer (X0) as an implicit use. + MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit); + } else /*CommitZA*/ { + [[maybe_unused]] auto *TRI = + MBB.getParent()->getSubtarget().getRegisterInfo(); + // Clear TPIDR2_EL0. + BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::MSR)) + .addImm(AArch64SysReg::TPIDR2_EL0) + .addReg(AArch64::XZR); + bool ZeroZA = MI.getOperand(1).getImm() != 0; + if (ZeroZA) { + assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!"); + BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::ZERO_M)) + .addImm(ZERO_ALL_ZA_MASK) + .addDef(AArch64::ZAB0, RegState::ImplicitDefine); + } + } + + BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB); MI.eraseFromParent(); return EndBB; } @@ -1236,14 +1274,20 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, .add(MI.getOperand(3)); transferImpOps(MI, I, I); } else { + unsigned RegState = + getRenamableRegState(MI.getOperand(1).isRenamable()) | + getKillRegState( + MI.getOperand(1).isKill() && + MI.getOperand(1).getReg() != MI.getOperand(2).getReg() && + MI.getOperand(1).getReg() != MI.getOperand(3).getReg()); BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::ORRv8i8 : AArch64::ORRv16i8)) .addReg(DstReg, RegState::Define | getRenamableRegState(MI.getOperand(0).isRenamable())) - .add(MI.getOperand(1)) - .add(MI.getOperand(1)); + .addReg(MI.getOperand(1).getReg(), RegState) + .addReg(MI.getOperand(1).getReg(), RegState); auto I2 = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8 @@ -1629,8 +1673,9 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, return expandCALL_BTI(MBB, MBBI); case AArch64::StoreSwiftAsyncContext: return expandStoreSwiftAsyncContext(MBB, MBBI); + case AArch64::CommitZASavePseudo: case AArch64::RestoreZAPseudo: { - auto *NewMBB = expandRestoreZA(MBB, MBBI); + auto *NewMBB = expandCommitOrRestoreZASave(MBB, MBBI); if (NewMBB != &MBB) NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated. return true; @@ -1641,6 +1686,8 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated. return true; } + case AArch64::InOutZAUsePseudo: + case AArch64::RequiresZASavePseudo: case AArch64::COALESCER_BARRIER_FPR16: case AArch64::COALESCER_BARRIER_FPR32: case AArch64::COALESCER_BARRIER_FPR64: diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp index 9d74bb5..cf34498 100644 --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -267,7 +267,7 @@ private: private: CCAssignFn *CCAssignFnForCall(CallingConv::ID CC) const; bool processCallArgs(CallLoweringInfo &CLI, SmallVectorImpl<MVT> &ArgVTs, - unsigned &NumBytes); + SmallVectorImpl<Type *> &OrigTys, unsigned &NumBytes); bool finishCall(CallLoweringInfo &CLI, unsigned NumBytes); public: @@ -3011,11 +3011,13 @@ bool AArch64FastISel::fastLowerArguments() { bool AArch64FastISel::processCallArgs(CallLoweringInfo &CLI, SmallVectorImpl<MVT> &OutVTs, + SmallVectorImpl<Type *> &OrigTys, unsigned &NumBytes) { CallingConv::ID CC = CLI.CallConv; SmallVector<CCValAssign, 16> ArgLocs; CCState CCInfo(CC, false, *FuncInfo.MF, ArgLocs, *Context); - CCInfo.AnalyzeCallOperands(OutVTs, CLI.OutFlags, CCAssignFnForCall(CC)); + CCInfo.AnalyzeCallOperands(OutVTs, CLI.OutFlags, OrigTys, + CCAssignFnForCall(CC)); // Get a count of how many bytes are to be pushed on the stack. NumBytes = CCInfo.getStackSize(); @@ -3194,6 +3196,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) { // Set up the argument vectors. SmallVector<MVT, 16> OutVTs; + SmallVector<Type *, 16> OrigTys; OutVTs.reserve(CLI.OutVals.size()); for (auto *Val : CLI.OutVals) { @@ -3207,6 +3210,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) { return false; OutVTs.push_back(VT); + OrigTys.push_back(Val->getType()); } Address Addr; @@ -3222,7 +3226,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) { // Handle the arguments now that we've gotten them. unsigned NumBytes; - if (!processCallArgs(CLI, OutVTs, NumBytes)) + if (!processCallArgs(CLI, OutVTs, OrigTys, NumBytes)) return false; const AArch64RegisterInfo *RegInfo = Subtarget->getRegisterInfo(); @@ -3574,12 +3578,8 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) { Args.reserve(II->arg_size()); // Populate the argument list. - for (auto &Arg : II->args()) { - ArgListEntry Entry; - Entry.Val = Arg; - Entry.Ty = Arg->getType(); - Args.push_back(Entry); - } + for (auto &Arg : II->args()) + Args.emplace_back(Arg); CallLoweringInfo CLI; MCContext &Ctx = MF->getContext(); @@ -4870,12 +4870,8 @@ bool AArch64FastISel::selectFRem(const Instruction *I) { Args.reserve(I->getNumOperands()); // Populate the argument list. - for (auto &Arg : I->operands()) { - ArgListEntry Entry; - Entry.Val = Arg; - Entry.Ty = Arg->getType(); - Args.push_back(Entry); - } + for (auto &Arg : I->operands()) + Args.emplace_back(Arg); CallLoweringInfo CLI; MCContext &Ctx = MF->getContext(); diff --git a/llvm/lib/Target/AArch64/AArch64Features.td b/llvm/lib/Target/AArch64/AArch64Features.td index c1c1f0a..6904e09 100644 --- a/llvm/lib/Target/AArch64/AArch64Features.td +++ b/llvm/lib/Target/AArch64/AArch64Features.td @@ -621,25 +621,27 @@ def FeatureZCRegMoveGPR64 : SubtargetFeature<"zcm-gpr64", "HasZeroCycleRegMoveGP def FeatureZCRegMoveGPR32 : SubtargetFeature<"zcm-gpr32", "HasZeroCycleRegMoveGPR32", "true", "Has zero-cycle register moves for GPR32 registers">; +def FeatureZCRegMoveFPR128 : SubtargetFeature<"zcm-fpr128", "HasZeroCycleRegMoveFPR128", "true", + "Has zero-cycle register moves for FPR128 registers">; + def FeatureZCRegMoveFPR64 : SubtargetFeature<"zcm-fpr64", "HasZeroCycleRegMoveFPR64", "true", "Has zero-cycle register moves for FPR64 registers">; def FeatureZCRegMoveFPR32 : SubtargetFeature<"zcm-fpr32", "HasZeroCycleRegMoveFPR32", "true", "Has zero-cycle register moves for FPR32 registers">; -def FeatureZCZeroingGP : SubtargetFeature<"zcz-gp", "HasZeroCycleZeroingGP", "true", - "Has zero-cycle zeroing instructions for generic registers">; +def FeatureZCZeroingGPR64 : SubtargetFeature<"zcz-gpr64", "HasZeroCycleZeroingGPR64", "true", + "Has zero-cycle zeroing instructions for GPR64 registers">; + +def FeatureZCZeroingGPR32 : SubtargetFeature<"zcz-gpr32", "HasZeroCycleZeroingGPR32", "true", + "Has zero-cycle zeroing instructions for GPR32 registers">; // It is generally beneficial to rewrite "fmov s0, wzr" to "movi d0, #0". // as movi is more efficient across all cores. Newer cores can eliminate // fmovs early and there is no difference with movi, but this not true for // all implementations. -def FeatureNoZCZeroingFP : SubtargetFeature<"no-zcz-fp", "HasZeroCycleZeroingFP", "false", - "Has no zero-cycle zeroing instructions for FP registers">; - -def FeatureZCZeroing : SubtargetFeature<"zcz", "HasZeroCycleZeroing", "true", - "Has zero-cycle zeroing instructions", - [FeatureZCZeroingGP]>; +def FeatureNoZCZeroingFPR64 : SubtargetFeature<"no-zcz-fpr64", "HasZeroCycleZeroingFPR64", "false", + "Has no zero-cycle zeroing instructions for FPR64 registers">; /// ... but the floating-point version doesn't quite work in rare cases on older /// CPUs. @@ -730,9 +732,13 @@ def FeatureFuseArithmeticLogic : SubtargetFeature< "fuse-arith-logic", "HasFuseArithmeticLogic", "true", "CPU fuses arithmetic and logic operations">; -def FeatureFuseCCSelect : SubtargetFeature< - "fuse-csel", "HasFuseCCSelect", "true", - "CPU fuses conditional select operations">; +def FeatureFuseCmpCSel : SubtargetFeature< + "fuse-csel", "HasFuseCmpCSel", "true", + "CPU can fuse CMP and CSEL operations">; + +def FeatureFuseCmpCSet : SubtargetFeature< + "fuse-cset", "HasFuseCmpCSet", "true", + "CPU can fuse CMP and CSET operations">; def FeatureFuseCryptoEOR : SubtargetFeature< "fuse-crypto-eor", "HasFuseCryptoEOR", "true", diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index 885f2a9..7725fa4 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -338,7 +338,7 @@ static bool requiresSaveVG(const MachineFunction &MF); // Conservatively, returns true if the function is likely to have an SVE vectors // on the stack. This function is safe to be called before callee-saves or // object offsets have been determined. -static bool isLikelyToHaveSVEStack(MachineFunction &MF) { +static bool isLikelyToHaveSVEStack(const MachineFunction &MF) { auto *AFI = MF.getInfo<AArch64FunctionInfo>(); if (AFI->isSVECC()) return true; @@ -532,6 +532,7 @@ bool AArch64FrameLowering::canUseRedZone(const MachineFunction &MF) const { bool AArch64FrameLowering::hasFPImpl(const MachineFunction &MF) const { const MachineFrameInfo &MFI = MF.getFrameInfo(); const TargetRegisterInfo *RegInfo = MF.getSubtarget().getRegisterInfo(); + const AArch64FunctionInfo &AFI = *MF.getInfo<AArch64FunctionInfo>(); // Win64 EH requires a frame pointer if funclets are present, as the locals // are accessed off the frame pointer in both the parent function and the @@ -545,6 +546,29 @@ bool AArch64FrameLowering::hasFPImpl(const MachineFunction &MF) const { MFI.hasStackMap() || MFI.hasPatchPoint() || RegInfo->hasStackRealignment(MF)) return true; + + // If we: + // + // 1. Have streaming mode changes + // OR: + // 2. Have a streaming body with SVE stack objects + // + // Then the value of VG restored when unwinding to this function may not match + // the value of VG used to set up the stack. + // + // This is a problem as the CFA can be described with an expression of the + // form: CFA = SP + NumBytes + VG * NumScalableBytes. + // + // If the value of VG used in that expression does not match the value used to + // set up the stack, an incorrect address for the CFA will be computed, and + // unwinding will fail. + // + // We work around this issue by ensuring the frame-pointer can describe the + // CFA in either of these cases. + if (AFI.needsDwarfUnwindInfo(MF) && + ((requiresSaveVG(MF) || AFI.getSMEFnAttrs().hasStreamingBody()) && + (!AFI.hasCalculatedStackSizeSVE() || AFI.getStackSizeSVE() > 0))) + return true; // With large callframes around we may need to use FP to access the scavenging // emergency spillslot. // @@ -663,10 +687,6 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const { MachineFunction &MF = *MBB.getParent(); MachineFrameInfo &MFI = MF.getFrameInfo(); - AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); - SMEAttrs Attrs = AFI->getSMEFnAttrs(); - bool LocallyStreaming = - Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface(); const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo(); if (CSI.empty()) @@ -680,14 +700,6 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations( assert(!Info.isSpilledToReg() && "Spilling to registers not implemented"); int64_t Offset = MFI.getObjectOffset(FrameIdx) - getOffsetOfLocalArea(); - - // The location of VG will be emitted before each streaming-mode change in - // the function. Only locally-streaming functions require emitting the - // non-streaming VG location here. - if ((LocallyStreaming && FrameIdx == AFI->getStreamingVGIdx()) || - (!LocallyStreaming && Info.getReg() == AArch64::VG)) - continue; - CFIBuilder.buildOffset(Info.getReg(), Offset); } } @@ -707,8 +719,16 @@ void AArch64FrameLowering::emitCalleeSavedSVELocations( AArch64FunctionInfo &AFI = *MF.getInfo<AArch64FunctionInfo>(); CFIInstBuilder CFIBuilder(MBB, MBBI, MachineInstr::FrameSetup); + std::optional<int64_t> IncomingVGOffsetFromDefCFA; + if (requiresSaveVG(MF)) { + auto IncomingVG = *find_if( + reverse(CSI), [](auto &Info) { return Info.getReg() == AArch64::VG; }); + IncomingVGOffsetFromDefCFA = + MFI.getObjectOffset(IncomingVG.getFrameIdx()) - getOffsetOfLocalArea(); + } + for (const auto &Info : CSI) { - if (!(MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector)) + if (MFI.getStackID(Info.getFrameIdx()) != TargetStackID::ScalableVector) continue; // Not all unwinders may know about SVE registers, so assume the lowest @@ -722,7 +742,8 @@ void AArch64FrameLowering::emitCalleeSavedSVELocations( StackOffset::getScalable(MFI.getObjectOffset(Info.getFrameIdx())) - StackOffset::getFixed(AFI.getCalleeSavedStackSize(MFI)); - CFIBuilder.insertCFIInst(createCFAOffset(TRI, Reg, Offset)); + CFIBuilder.insertCFIInst( + createCFAOffset(TRI, Reg, Offset, IncomingVGOffsetFromDefCFA)); } } @@ -783,9 +804,6 @@ static void emitCalleeSavedRestores(MachineBasicBlock &MBB, !static_cast<const AArch64RegisterInfo &>(TRI).regNeedsCFI(Reg, Reg)) continue; - if (!Info.isRestored()) - continue; - CFIBuilder.buildRestore(Info.getReg()); } } @@ -1465,34 +1483,35 @@ bool requiresGetVGCall(MachineFunction &MF) { static bool requiresSaveVG(const MachineFunction &MF) { const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); + if (!AFI->needsDwarfUnwindInfo(MF) || !AFI->hasStreamingModeChanges()) + return false; // For Darwin platforms we don't save VG for non-SVE functions, even if SME // is enabled with streaming mode changes. - if (!AFI->hasStreamingModeChanges()) - return false; auto &ST = MF.getSubtarget<AArch64Subtarget>(); if (ST.isTargetDarwin()) return ST.hasSVE(); return true; } -bool isVGInstruction(MachineBasicBlock::iterator MBBI) { +static bool matchLibcall(const TargetLowering &TLI, const MachineOperand &MO, + RTLIB::Libcall LC) { + return MO.isSymbol() && + StringRef(TLI.getLibcallName(LC)) == MO.getSymbolName(); +} + +bool isVGInstruction(MachineBasicBlock::iterator MBBI, + const TargetLowering &TLI) { unsigned Opc = MBBI->getOpcode(); - if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI || - Opc == AArch64::UBFMXri) + if (Opc == AArch64::CNTD_XPiI) return true; - if (requiresGetVGCall(*MBBI->getMF())) { - if (Opc == AArch64::ORRXrr) - return true; + if (!requiresGetVGCall(*MBBI->getMF())) + return false; - if (Opc == AArch64::BL) { - auto Op1 = MBBI->getOperand(0); - return Op1.isSymbol() && - (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg"); - } - } + if (Opc == AArch64::BL) + return matchLibcall(TLI, MBBI->getOperand(0), RTLIB::SMEABI_GET_CURRENT_VG); - return false; + return Opc == TargetOpcode::COPY; } // Convert callee-save register save/restore instruction to do stack pointer @@ -1507,13 +1526,14 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec( unsigned NewOpc; // If the function contains streaming mode changes, we expect instructions - // to calculate the value of VG before spilling. For locally-streaming - // functions, we need to do this for both the streaming and non-streaming - // vector length. Move past these instructions if necessary. + // to calculate the value of VG before spilling. Move past these instructions + // if necessary. MachineFunction &MF = *MBB.getParent(); - if (requiresSaveVG(MF)) - while (isVGInstruction(MBBI)) + if (requiresSaveVG(MF)) { + auto &TLI = *MF.getSubtarget().getTargetLowering(); + while (isVGInstruction(MBBI, TLI)) ++MBBI; + } switch (MBBI->getOpcode()) { default: @@ -2097,11 +2117,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF, // Move past the saves of the callee-saved registers, fixing up the offsets // and pre-inc if we decided to combine the callee-save and local stack // pointer bump above. + auto &TLI = *MF.getSubtarget().getTargetLowering(); while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) && !IsSVECalleeSave(MBBI)) { if (CombineSPBump && // Only fix-up frame-setup load/store instructions. - (!requiresSaveVG(MF) || !isVGInstruction(MBBI))) + (!requiresSaveVG(MF) || !isVGInstruction(MBBI, TLI))) fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(), NeedsWinCFI, &HasWinCFI); ++MBBI; @@ -3468,8 +3489,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const { MachineFunction &MF = *MBB.getParent(); + auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering(); const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); - AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); bool NeedsWinCFI = needsWinCFI(MF); DebugLoc DL; SmallVector<RegPairInfo, 8> RegPairs; @@ -3538,59 +3559,44 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( } unsigned X0Scratch = AArch64::NoRegister; + auto RestoreX0 = make_scope_exit([&] { + if (X0Scratch != AArch64::NoRegister) + BuildMI(MBB, MI, DL, TII.get(TargetOpcode::COPY), AArch64::X0) + .addReg(X0Scratch) + .setMIFlag(MachineInstr::FrameSetup); + }); + if (Reg1 == AArch64::VG) { // Find an available register to store value of VG to. Reg1 = findScratchNonCalleeSaveRegister(&MBB, true); assert(Reg1 != AArch64::NoRegister); - SMEAttrs Attrs = AFI->getSMEFnAttrs(); - - if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() && - AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) { - // For locally-streaming functions, we need to store both the streaming - // & non-streaming VG. Spill the streaming value first. - BuildMI(MBB, MI, DL, TII.get(AArch64::RDSVLI_XI), Reg1) - .addImm(1) - .setMIFlag(MachineInstr::FrameSetup); - BuildMI(MBB, MI, DL, TII.get(AArch64::UBFMXri), Reg1) - .addReg(Reg1) - .addImm(3) - .addImm(63) - .setMIFlag(MachineInstr::FrameSetup); - - AFI->setStreamingVGIdx(RPI.FrameIdx); - } else if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) { + if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) { BuildMI(MBB, MI, DL, TII.get(AArch64::CNTD_XPiI), Reg1) .addImm(31) .addImm(1) .setMIFlag(MachineInstr::FrameSetup); - AFI->setVGIdx(RPI.FrameIdx); } else { const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>(); - if (llvm::any_of( - MBB.liveins(), - [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) { - return STI.getRegisterInfo()->isSuperOrSubRegisterEq( - AArch64::X0, LiveIn.PhysReg); - })) + if (any_of(MBB.liveins(), + [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) { + return STI.getRegisterInfo()->isSuperOrSubRegisterEq( + AArch64::X0, LiveIn.PhysReg); + })) { X0Scratch = Reg1; - - if (X0Scratch != AArch64::NoRegister) - BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), Reg1) - .addReg(AArch64::XZR) - .addReg(AArch64::X0, RegState::Undef) - .addReg(AArch64::X0, RegState::Implicit) + BuildMI(MBB, MI, DL, TII.get(TargetOpcode::COPY), X0Scratch) + .addReg(AArch64::X0) .setMIFlag(MachineInstr::FrameSetup); + } - const uint32_t *RegMask = TRI->getCallPreservedMask( - MF, - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1); + RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG; + const uint32_t *RegMask = + TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC)); BuildMI(MBB, MI, DL, TII.get(AArch64::BL)) - .addExternalSymbol("__arm_get_current_vg") + .addExternalSymbol(TLI.getLibcallName(LC)) .addRegMask(RegMask) .addReg(AArch64::X0, RegState::ImplicitDefine) .setMIFlag(MachineInstr::FrameSetup); Reg1 = AArch64::X0; - AFI->setVGIdx(RPI.FrameIdx); } } @@ -3685,13 +3691,6 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( if (RPI.isPaired()) MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector); } - - if (X0Scratch != AArch64::NoRegister) - BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), AArch64::X0) - .addReg(AArch64::XZR) - .addReg(X0Scratch, RegState::Undef) - .addReg(X0Scratch, RegState::Implicit) - .setMIFlag(MachineInstr::FrameSetup); } return true; } @@ -4070,15 +4069,8 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, // Increase the callee-saved stack size if the function has streaming mode // changes, as we will need to spill the value of the VG register. - // For locally streaming functions, we spill both the streaming and - // non-streaming VG value. - SMEAttrs Attrs = AFI->getSMEFnAttrs(); - if (requiresSaveVG(MF)) { - if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface()) - CSStackSize += 16; - else - CSStackSize += 8; - } + if (requiresSaveVG(MF)) + CSStackSize += 8; // Determine if a Hazard slot should be used, and increase the CSStackSize by // StackHazardSize if so. @@ -4229,29 +4221,13 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots( // Insert VG into the list of CSRs, immediately before LR if saved. if (requiresSaveVG(MF)) { - std::vector<CalleeSavedInfo> VGSaves; - SMEAttrs Attrs = AFI->getSMEFnAttrs(); - - auto VGInfo = CalleeSavedInfo(AArch64::VG); - VGInfo.setRestored(false); - VGSaves.push_back(VGInfo); - - // Add VG again if the function is locally-streaming, as we will spill two - // values. - if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface()) - VGSaves.push_back(VGInfo); - - bool InsertBeforeLR = false; - - for (unsigned I = 0; I < CSI.size(); I++) - if (CSI[I].getReg() == AArch64::LR) { - InsertBeforeLR = true; - CSI.insert(CSI.begin() + I, VGSaves.begin(), VGSaves.end()); - break; - } - - if (!InsertBeforeLR) - llvm::append_range(CSI, VGSaves); + CalleeSavedInfo VGInfo(AArch64::VG); + auto It = + find_if(CSI, [](auto &Info) { return Info.getReg() == AArch64::LR; }); + if (It != CSI.end()) + CSI.insert(It, VGInfo); + else + CSI.push_back(VGInfo); } Register LastReg = 0; @@ -5254,46 +5230,11 @@ MachineBasicBlock::iterator tryMergeAdjacentSTG(MachineBasicBlock::iterator II, } } // namespace -static void emitVGSaveRestore(MachineBasicBlock::iterator II, - const AArch64FrameLowering *TFI) { - MachineInstr &MI = *II; - MachineBasicBlock *MBB = MI.getParent(); - MachineFunction *MF = MBB->getParent(); - - if (MI.getOpcode() != AArch64::VGSavePseudo && - MI.getOpcode() != AArch64::VGRestorePseudo) - return; - - auto *AFI = MF->getInfo<AArch64FunctionInfo>(); - SMEAttrs FuncAttrs = AFI->getSMEFnAttrs(); - bool LocallyStreaming = - FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface(); - - int64_t VGFrameIdx = - LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx(); - assert(VGFrameIdx != std::numeric_limits<int>::max() && - "Expected FrameIdx for VG"); - - CFIInstBuilder CFIBuilder(*MBB, II, MachineInstr::NoFlags); - if (MI.getOpcode() == AArch64::VGSavePseudo) { - const MachineFrameInfo &MFI = MF->getFrameInfo(); - int64_t Offset = - MFI.getObjectOffset(VGFrameIdx) - TFI->getOffsetOfLocalArea(); - CFIBuilder.buildOffset(AArch64::VG, Offset); - } else { - CFIBuilder.buildRestore(AArch64::VG); - } - - MI.eraseFromParent(); -} - void AArch64FrameLowering::processFunctionBeforeFrameIndicesReplaced( MachineFunction &MF, RegScavenger *RS = nullptr) const { for (auto &BB : MF) for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();) { - if (requiresSaveVG(MF)) - emitVGSaveRestore(II++, this); - else if (StackTaggingMergeSetTag) + if (StackTaggingMergeSetTag) II = tryMergeAdjacentSTG(II, this, RS); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index ad42f4b..6fdc981 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -246,9 +246,9 @@ public: return false; } - template<MVT::SimpleValueType VT> + template <MVT::SimpleValueType VT, bool Negate> bool SelectSVEAddSubImm(SDValue N, SDValue &Imm, SDValue &Shift) { - return SelectSVEAddSubImm(N, VT, Imm, Shift); + return SelectSVEAddSubImm(N, VT, Imm, Shift, Negate); } template <MVT::SimpleValueType VT, bool Negate> @@ -489,7 +489,8 @@ private: bool SelectCMP_SWAP(SDNode *N); - bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift); + bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift, + bool Negate); bool SelectSVEAddSubSSatImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift, bool Negate); bool SelectSVECpyDupImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift); @@ -4227,35 +4228,36 @@ bool AArch64DAGToDAGISel::SelectCMP_SWAP(SDNode *N) { } bool AArch64DAGToDAGISel::SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, - SDValue &Shift) { + SDValue &Shift, bool Negate) { if (!isa<ConstantSDNode>(N)) return false; SDLoc DL(N); - uint64_t Val = cast<ConstantSDNode>(N) - ->getAPIntValue() - .trunc(VT.getFixedSizeInBits()) - .getZExtValue(); + APInt Val = + cast<ConstantSDNode>(N)->getAPIntValue().trunc(VT.getFixedSizeInBits()); + + if (Negate) + Val = -Val; switch (VT.SimpleTy) { case MVT::i8: // All immediates are supported. Shift = CurDAG->getTargetConstant(0, DL, MVT::i32); - Imm = CurDAG->getTargetConstant(Val, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(Val.getZExtValue(), DL, MVT::i32); return true; case MVT::i16: case MVT::i32: case MVT::i64: // Support 8bit unsigned immediates. - if (Val <= 255) { + if ((Val & ~0xff) == 0) { Shift = CurDAG->getTargetConstant(0, DL, MVT::i32); - Imm = CurDAG->getTargetConstant(Val, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(Val.getZExtValue(), DL, MVT::i32); return true; } // Support 16bit unsigned immediates that are a multiple of 256. - if (Val <= 65280 && Val % 256 == 0) { + if ((Val & ~0xff00) == 0) { Shift = CurDAG->getTargetConstant(8, DL, MVT::i32); - Imm = CurDAG->getTargetConstant(Val >> 8, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(Val.lshr(8).getZExtValue(), DL, MVT::i32); return true; } break; @@ -7617,16 +7619,29 @@ bool AArch64DAGToDAGISel::SelectAnyPredicate(SDValue N) { bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize, SDValue &Base, SDValue &Offset, unsigned Scale) { - // Try to untangle an ADD node into a 'reg + offset' - if (CurDAG->isBaseWithConstantOffset(N)) - if (auto C = dyn_cast<ConstantSDNode>(N.getOperand(1))) { + auto MatchConstantOffset = [&](SDValue CN) -> SDValue { + if (auto *C = dyn_cast<ConstantSDNode>(CN)) { int64_t ImmOff = C->getSExtValue(); - if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) { - Base = N.getOperand(0); - Offset = CurDAG->getTargetConstant(ImmOff / Scale, SDLoc(N), MVT::i64); - return true; - } + if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) + return CurDAG->getTargetConstant(ImmOff / Scale, SDLoc(N), MVT::i64); + } + return SDValue(); + }; + + if (SDValue C = MatchConstantOffset(N)) { + Base = CurDAG->getConstant(0, SDLoc(N), MVT::i32); + Offset = C; + return true; + } + + // Try to untangle an ADD node into a 'reg + offset' + if (CurDAG->isBaseWithConstantOffset(N)) { + if (SDValue C = MatchConstantOffset(N.getOperand(1))) { + Base = N.getOperand(0); + Offset = C; + return true; } + } // By default, just match reg + 0. Base = N; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2b6ea86..b7011e0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17,6 +17,7 @@ #include "AArch64PerfectShuffle.h" #include "AArch64RegisterInfo.h" #include "AArch64Subtarget.h" +#include "AArch64TargetMachine.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "Utils/AArch64BaseInfo.h" #include "Utils/AArch64SMEAttributes.h" @@ -1120,7 +1121,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal); - // We combine OR nodes for bitfield operations. + // We combine OR nodes for ccmp operations. setTargetDAGCombine(ISD::OR); // Try to create BICs for vector ANDs. setTargetDAGCombine(ISD::AND); @@ -1769,7 +1770,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); - if (Subtarget->hasSVEB16B16()) { + if (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available()) { setOperationAction(ISD::FADD, VT, Legal); setOperationAction(ISD::FMA, VT, Custom); setOperationAction(ISD::FMAXIMUM, VT, Custom); @@ -1791,7 +1793,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); } - if (!Subtarget->hasSVEB16B16()) { + if (!Subtarget->hasSVEB16B16() || + !Subtarget->isNonStreamingSVEorSME2Available()) { for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM, ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) { setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32); @@ -1998,6 +2001,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::f16, Promote); } +const AArch64TargetMachine &AArch64TargetLowering::getTM() const { + return static_cast<const AArch64TargetMachine &>(getTargetMachine()); +} + void AArch64TargetLowering::addTypeForNEON(MVT VT) { assert(VT.isVector() && "VT should be a vector type"); @@ -2578,6 +2585,30 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode( Known = Known.intersectWith(Known2); break; } + case AArch64ISD::CSNEG: + case AArch64ISD::CSINC: + case AArch64ISD::CSINV: { + KnownBits KnownOp0 = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + KnownBits KnownOp1 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); + + // The result is either: + // CSINC: KnownOp0 or KnownOp1 + 1 + // CSINV: KnownOp0 or ~KnownOp1 + // CSNEG: KnownOp0 or KnownOp1 * -1 + if (Op.getOpcode() == AArch64ISD::CSINC) + KnownOp1 = KnownBits::add( + KnownOp1, + KnownBits::makeConstant(APInt(Op.getScalarValueSizeInBits(), 1))); + else if (Op.getOpcode() == AArch64ISD::CSINV) + std::swap(KnownOp1.Zero, KnownOp1.One); + else if (Op.getOpcode() == AArch64ISD::CSNEG) + KnownOp1 = + KnownBits::mul(KnownOp1, KnownBits::makeConstant(APInt::getAllOnes( + Op.getScalarValueSizeInBits()))); + + Known = KnownOp0.intersectWith(KnownOp1); + break; + } case AArch64ISD::BICi: { // Compute the bit cleared value. APInt Mask = @@ -2977,21 +3008,20 @@ AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI, AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>(); TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); if (TPIDR2.Uses > 0) { + // Note: This case just needs to do `SVL << 48`. It is not implemented as we + // generally don't support big-endian SVE/SME. + if (!Subtarget->isLittleEndian()) + reportFatalInternalError( + "TPIDR2 block initialization is not supported on big-endian targets"); + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - // Store the buffer pointer to the TPIDR2 stack object. - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui)) + // Store buffer pointer and num_za_save_slices. + // Bytes 10-15 are implicitly zeroed. + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STPXi)) .addReg(MI.getOperand(0).getReg()) + .addReg(MI.getOperand(1).getReg()) .addFrameIndex(TPIDR2.FrameIndex) .addImm(0); - // Set the reserved bytes (10-15) to zero - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui)) - .addReg(AArch64::WZR) - .addFrameIndex(TPIDR2.FrameIndex) - .addImm(5); - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui)) - .addReg(AArch64::WZR) - .addFrameIndex(TPIDR2.FrameIndex) - .addImm(3); } else MFI.RemoveStackObject(TPIDR2.FrameIndex); @@ -3083,13 +3113,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>(); const TargetInstrInfo *TII = Subtarget->getInstrInfo(); if (FuncInfo->isSMESaveBufferUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state_size") + .addExternalSymbol(getLibcallName(LC)) .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg()) .addReg(AArch64::X0); @@ -3101,6 +3130,30 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, return BB; } +MachineBasicBlock * +AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI, + MachineBasicBlock *BB) const { + MachineFunction *MF = BB->getParent(); + AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>(); + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + Register ResultReg = MI.getOperand(0).getReg(); + if (FuncInfo->isPStateSMRegUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) + .addExternalSymbol(getLibcallName(LC)) + .addReg(AArch64::X0, RegState::ImplicitDefine) + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg) + .addReg(AArch64::X0); + } else { + assert(MI.getMF()->getRegInfo().use_empty(ResultReg) && + "Expected no users of the entry pstate.sm!"); + } + MI.eraseFromParent(); + return BB; +} + // Helper function to find the instruction that defined a virtual register. // If unable to find such instruction, returns nullptr. static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI, @@ -3216,6 +3269,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitAllocateSMESaveBuffer(MI, BB); case AArch64::GetSMESaveSize: return EmitGetSMESaveSize(MI, BB); + case AArch64::EntryPStateSM: + return EmitEntryPStateSM(MI, BB); case AArch64::F128CSEL: return EmitF128CSEL(MI, BB); case TargetOpcode::STATEPOINT: @@ -3320,7 +3375,8 @@ static bool isZerosVector(const SDNode *N) { /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64 /// CC -static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { +static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC, + SDValue RHS = {}) { switch (CC) { default: llvm_unreachable("Unknown condition code!"); @@ -3331,9 +3387,9 @@ static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { case ISD::SETGT: return AArch64CC::GT; case ISD::SETGE: - return AArch64CC::GE; + return (RHS && isNullConstant(RHS)) ? AArch64CC::PL : AArch64CC::GE; case ISD::SETLT: - return AArch64CC::LT; + return (RHS && isNullConstant(RHS)) ? AArch64CC::MI : AArch64CC::LT; case ISD::SETLE: return AArch64CC::LE; case ISD::SETUGT: @@ -3492,6 +3548,13 @@ bool isLegalCmpImmed(APInt C) { return isLegalArithImmed(C.abs().getZExtValue()); } +unsigned numberOfInstrToLoadImm(APInt C) { + uint64_t Imm = C.getZExtValue(); + SmallVector<AArch64_IMM::ImmInsnModel> Insn; + AArch64_IMM::expandMOVImm(Imm, 32, Insn); + return Insn.size(); +} + static bool isSafeSignedCMN(SDValue Op, SelectionDAG &DAG) { // 0 - INT_MIN sign wraps, so no signed wrap means cmn is safe. if (Op->getFlags().hasNoSignedWrap()) @@ -3782,7 +3845,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val, SDLoc DL(Val); // Determine OutCC and handle FP special case. if (isInteger) { - OutCC = changeIntCCToAArch64CC(CC); + OutCC = changeIntCCToAArch64CC(CC, RHS); } else { assert(LHS.getValueType().isFloatingPoint()); AArch64CC::CondCode ExtraCC; @@ -3961,6 +4024,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, // CC has already been adjusted. RHS = DAG.getConstant(0, DL, VT); } else if (!isLegalCmpImmed(C)) { + unsigned NumImmForC = numberOfInstrToLoadImm(C); // Constant does not fit, try adjusting it by one? switch (CC) { default: @@ -3969,43 +4033,49 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, case ISD::SETGE: if (!C.isMinSignedValue()) { APInt CMinusOne = C - 1; - if (isLegalCmpImmed(CMinusOne)) { + if (isLegalCmpImmed(CMinusOne) || + (NumImmForC > numberOfInstrToLoadImm(CMinusOne))) { CC = (CC == ISD::SETLT) ? ISD::SETLE : ISD::SETGT; RHS = DAG.getConstant(CMinusOne, DL, VT); } } break; case ISD::SETULT: - case ISD::SETUGE: - if (!C.isZero()) { - APInt CMinusOne = C - 1; - if (isLegalCmpImmed(CMinusOne)) { - CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT; - RHS = DAG.getConstant(CMinusOne, DL, VT); - } + case ISD::SETUGE: { + // C is not 0 because it is a legal immediate. + assert(!C.isZero() && "C should not be zero here"); + APInt CMinusOne = C - 1; + if (isLegalCmpImmed(CMinusOne) || + (NumImmForC > numberOfInstrToLoadImm(CMinusOne))) { + CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT; + RHS = DAG.getConstant(CMinusOne, DL, VT); } break; + } case ISD::SETLE: case ISD::SETGT: if (!C.isMaxSignedValue()) { APInt CPlusOne = C + 1; - if (isLegalCmpImmed(CPlusOne)) { + if (isLegalCmpImmed(CPlusOne) || + (NumImmForC > numberOfInstrToLoadImm(CPlusOne))) { CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE; RHS = DAG.getConstant(CPlusOne, DL, VT); } } break; case ISD::SETULE: - case ISD::SETUGT: + case ISD::SETUGT: { if (!C.isAllOnes()) { APInt CPlusOne = C + 1; - if (isLegalCmpImmed(CPlusOne)) { + if (isLegalCmpImmed(CPlusOne) || + (NumImmForC > numberOfInstrToLoadImm(CPlusOne))) { CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE; RHS = DAG.getConstant(CPlusOne, DL, VT); } } break; } + } } } @@ -4079,7 +4149,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, if (!Cmp) { Cmp = emitComparison(LHS, RHS, CC, DL, DAG); - AArch64CC = changeIntCCToAArch64CC(CC); + AArch64CC = changeIntCCToAArch64CC(CC, RHS); } AArch64cc = getCondCode(DAG, AArch64CC); return Cmp; @@ -4865,6 +4935,18 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op, if (DstWidth < SatWidth) return SDValue(); + if (SrcVT == MVT::f16 && SatVT == MVT::i16 && DstVT == MVT::i32) { + if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) { + SDValue CVTf32 = + DAG.getNode(AArch64ISD::FCVTZS_HALF, DL, MVT::f32, SrcVal); + SDValue Bitcast = DAG.getBitcast(DstVT, CVTf32); + return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, DstVT, Bitcast, + DAG.getValueType(SatVT)); + } + SDValue CVTf32 = DAG.getNode(AArch64ISD::FCVTZU_HALF, DL, MVT::f32, SrcVal); + return DAG.getBitcast(DstVT, CVTf32); + } + SDValue NativeCvt = DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal, DAG.getValueType(DstVT)); SDValue Sat; @@ -5174,13 +5256,7 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op, Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext()); ArgListTy Args; - ArgListEntry Entry; - - Entry.Node = Arg; - Entry.Ty = ArgTy; - Entry.IsSExt = false; - Entry.IsZExt = false; - Args.push_back(Entry); + Args.emplace_back(Arg, ArgTy); RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32; @@ -5711,15 +5787,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL, EVT VT) const { - SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; + SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC), getPointerTy(DAG.getDataLayout())); Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); Type *RetTy = StructType::get(Int64Ty, Int64Ty); TargetLowering::CallLoweringInfo CLI(DAG); ArgListTy Args; CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, - RetTy, Callee, std::move(Args)); + getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI); SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), @@ -7886,8 +7962,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments( else if (ActualMVT == MVT::i16) ValVT = MVT::i16; } - bool Res = - AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo); + bool Res = AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, + Ins[i].OrigTy, CCInfo); assert(!Res && "Call operand has unhandled type"); (void)Res; } @@ -8132,19 +8208,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } assert((ArgLocs.size() + ExtraArgLocs) == Ins.size()); + if (Attrs.hasStreamingCompatibleInterface()) { + SDValue EntryPStateSM = + DAG.getNode(AArch64ISD::ENTRY_PSTATE_SM, DL, + DAG.getVTList(MVT::i64, MVT::Other), {Chain}); + + // Copy the value to a virtual register, and save that in FuncInfo. + Register EntryPStateSMReg = + MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + Chain = DAG.getCopyToReg(EntryPStateSM.getValue(1), DL, EntryPStateSMReg, + EntryPStateSM); + FuncInfo->setPStateSMReg(EntryPStateSMReg); + } + // Insert the SMSTART if this is a locally streaming function and // make sure it is Glued to the last CopyFromReg value. if (IsLocallyStreaming) { - SDValue PStateSM; - if (Attrs.hasStreamingCompatibleInterface()) { - PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64); - Register Reg = MF.getRegInfo().createVirtualRegister( - getRegClassFor(PStateSM.getValueType().getSimpleVT())); - FuncInfo->setPStateSMReg(Reg); - Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM); + if (Attrs.hasStreamingCompatibleInterface()) Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, - AArch64SME::IfCallerIsNonStreaming, PStateSM); - } else + AArch64SME::IfCallerIsNonStreaming); + else Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, AArch64SME::Always); @@ -8244,53 +8327,57 @@ SDValue AArch64TargetLowering::LowerFormalArguments( if (Subtarget->hasCustomCallingConv()) Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF); - // Create a 16 Byte TPIDR2 object. The dynamic buffer - // will be expanded and stored in the static object later using a pseudonode. - if (Attrs.hasZAState()) { - TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); - TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false); - SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, - DAG.getConstant(1, DL, MVT::i32)); - - SDValue Buffer; - if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { - Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL, - DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL}); - } else { - SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL); - Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL, - DAG.getVTList(MVT::i64, MVT::Other), - {Chain, Size, DAG.getConstant(1, DL, MVT::i64)}); - MFI.CreateVariableSizedObject(Align(16), nullptr); - } - Chain = DAG.getNode( - AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other), - {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)}); - } else if (Attrs.hasAgnosticZAInterface()) { - // Call __arm_sme_state_size(). - SDValue BufferSize = - DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL, - DAG.getVTList(MVT::i64, MVT::Other), Chain); - Chain = BufferSize.getValue(1); - - SDValue Buffer; - if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { - Buffer = - DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL, - DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize}); - } else { - // Allocate space dynamically. - Buffer = DAG.getNode( - ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other), - {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)}); - MFI.CreateVariableSizedObject(Align(16), nullptr); + if (!getTM().useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) { + // Old SME ABI lowering (deprecated): + // Create a 16 Byte TPIDR2 object. The dynamic buffer + // will be expanded and stored in the static object later using a + // pseudonode. + if (Attrs.hasZAState()) { + TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); + TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false); + SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + SDValue Buffer; + if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { + Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL, + DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL}); + } else { + SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL); + Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL, + DAG.getVTList(MVT::i64, MVT::Other), + {Chain, Size, DAG.getConstant(1, DL, MVT::i64)}); + MFI.CreateVariableSizedObject(Align(16), nullptr); + } + SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + Chain = DAG.getNode( + AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other), + {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0), + /*Num save slices*/ NumZaSaveSlices}); + } else if (Attrs.hasAgnosticZAInterface()) { + // Call __arm_sme_state_size(). + SDValue BufferSize = + DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL, + DAG.getVTList(MVT::i64, MVT::Other), Chain); + Chain = BufferSize.getValue(1); + SDValue Buffer; + if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { + Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL, + DAG.getVTList(MVT::i64, MVT::Other), + {Chain, BufferSize}); + } else { + // Allocate space dynamically. + Buffer = DAG.getNode( + ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other), + {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)}); + MFI.CreateVariableSizedObject(Align(16), nullptr); + } + // Copy the value to a virtual register, and save that in FuncInfo. + Register BufferPtr = + MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + FuncInfo->setSMESaveBufferAddr(BufferPtr); + Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer); } - - // Copy the value to a virtual register, and save that in FuncInfo. - Register BufferPtr = - MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); - FuncInfo->setSMESaveBufferAddr(BufferPtr); - Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer); } if (CallConv == CallingConv::PreserveNone) { @@ -8307,6 +8394,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } } + if (getTM().useNewSMEABILowering()) { + // Clear new ZT0 state. TODO: Move this to the SME ABI pass. + if (Attrs.isNewZT0()) + Chain = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, + DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32), + DAG.getTargetConstant(0, DL, MVT::i32)); + } + return Chain; } @@ -8537,7 +8633,7 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, if (IsCalleeWin64) { UseVarArgCC = true; } else { - UseVarArgCC = !Outs[i].IsFixed; + UseVarArgCC = ArgFlags.isVarArg(); } } @@ -8557,19 +8653,20 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, // FIXME: CCAssignFnForCall should be called once, for the call and not per // argument. This logic should exactly mirror LowerFormalArguments. CCAssignFn *AssignFn = TLI.CCAssignFnForCall(CalleeCC, UseVarArgCC); - bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo); + bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, + Outs[i].OrigTy, CCInfo); assert(!Res && "Call operand has unhandled type"); (void)Res; } } static SMECallAttrs -getSMECallAttrs(const Function &Caller, +getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) - return SMECallAttrs(*CLI.CB); + return SMECallAttrs(*CLI.CB, &TLI); if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) - return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol())); + return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI)); return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal)); } @@ -8591,7 +8688,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // SME Streaming functions are not eligible for TCO as they may require // the streaming mode or ZA to be restored after returning from the call. - SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI); + SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || CallAttrs.caller().hasStreamingBody()) @@ -8834,8 +8931,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue, - unsigned Condition, - SDValue PStateSM) const { + unsigned Condition) const { MachineFunction &MF = DAG.getMachineFunction(); AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); FuncInfo->setHasStreamingModeChanges(true); @@ -8847,9 +8943,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, SmallVector<SDValue> Ops = {Chain, MSROp}; unsigned Opcode; if (Condition != AArch64SME::Always) { + FuncInfo->setPStateSMRegUsed(true); + Register PStateReg = FuncInfo->getPStateSMReg(); + assert(PStateReg.isValid() && "PStateSM Register is invalid"); + SDValue PStateSM = + DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue); + // Use chain and glue from the CopyFromReg. + Ops[0] = PStateSM.getValue(1); + InGlue = PStateSM.getValue(2); SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64); Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP; - assert(PStateSM && "PStateSM should be defined"); Ops.push_back(ConditionOp); Ops.push_back(PStateSM); } else { @@ -8871,22 +8974,19 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, MachineFunction &MF = DAG.getMachineFunction(); AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); FuncInfo->setSMESaveBufferUsed(); - TargetLowering::ArgListTy Args; - TargetLowering::ArgListEntry Entry; - Entry.Ty = PointerType::getUnqual(*DAG.getContext()); - Entry.Node = - DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64); - Args.push_back(Entry); - - SDValue Callee = - DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore", - TLI.getPointerTy(DAG.getDataLayout())); + Args.emplace_back( + DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64), + PointerType::getUnqual(*DAG.getContext())); + + RTLIB::Libcall LC = + IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE; + SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC), + TLI.getPointerTy(DAG.getDataLayout())); auto *RetTy = Type::getVoidTy(*DAG.getContext()); TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy, - Callee, std::move(Args)); + TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); return TLI.LowerCallTo(CLI).second; } @@ -8982,7 +9082,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, unsigned NumArgs = Outs.size(); for (unsigned i = 0; i != NumArgs; ++i) { - if (!Outs[i].IsFixed && Outs[i].VT.isScalableVector()) + if (Outs[i].Flags.isVarArg() && Outs[i].VT.isScalableVector()) report_fatal_error("Passing SVE types to variadic functions is " "currently not supported"); } @@ -9014,14 +9114,28 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, CallConv = CallingConv::AArch64_SVE_VectorCall; } + // Determine whether we need any streaming mode changes. + SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI); + bool UseNewSMEABILowering = getTM().useNewSMEABILowering(); + bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface(); + auto ZAMarkerNode = [&]() -> std::optional<unsigned> { + // TODO: Handle agnostic ZA functions. + if (!UseNewSMEABILowering || IsAgnosticZAFunction) + return std::nullopt; + if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State()) + return std::nullopt; + return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE + : AArch64ISD::INOUT_ZA_USE; + }(); + if (IsTailCall) { // Check if it's really possible to do a tail call. IsTailCall = isEligibleForTailCallOptimization(CLI); // A sibling call is one where we're under the usual C ABI and not planning // to change that but can still do a tail call: - if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail && - CallConv != CallingConv::SwiftTail) + if (!ZAMarkerNode && !TailCallOpt && IsTailCall && + CallConv != CallingConv::Tail && CallConv != CallingConv::SwiftTail) IsSibCall = true; if (IsTailCall) @@ -9073,9 +9187,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, assert(FPDiff % 16 == 0 && "unaligned stack on tail call"); } - // Determine whether we need any streaming mode changes. - SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI); - auto DescribeCallsite = [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & { R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '"; @@ -9089,22 +9200,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, return R; }; - bool RequiresLazySave = CallAttrs.requiresLazySave(); + bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave(); bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState(); if (RequiresLazySave) { - const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); - MachinePointerInfo MPI = - MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex); + TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); SDValue TPIDR2ObjAddr = DAG.getFrameIndex( TPIDR2.FrameIndex, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); - SDValue NumZaSaveSlicesAddr = - DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr, - DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType())); - SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, - DAG.getConstant(1, DL, MVT::i32)); - Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr, - MPI, MVT::i16); Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), @@ -9124,15 +9226,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, /*IsSave=*/true); } - SDValue PStateSM; bool RequiresSMChange = CallAttrs.requiresSMChange(); if (RequiresSMChange) { - if (CallAttrs.caller().hasStreamingInterfaceOrBody()) - PStateSM = DAG.getConstant(1, DL, MVT::i64); - else if (CallAttrs.caller().hasNonStreamingInterface()) - PStateSM = DAG.getConstant(0, DL, MVT::i64); - else - PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64); OptimizationRemarkEmitter ORE(&MF.getFunction()); ORE.emit([&]() { auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition", @@ -9171,10 +9266,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain, DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32)); - // Adjust the stack pointer for the new arguments... + // Adjust the stack pointer for the new arguments... and mark ZA uses. // These operations are automatically eliminated by the prolog/epilog pass - if (!IsSibCall) + assert((!IsSibCall || !ZAMarkerNode) && "ZA markers require CALLSEQ_START"); + if (!IsSibCall) { Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL); + if (ZAMarkerNode) { + // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply + // using a chain can result in incorrect scheduling. The markers refer to + // the position just before the CALLSEQ_START (though occur after as + // CALLSEQ_START lacks in-glue). + Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other), + {Chain, Chain.getValue(1)}); + } + } SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP, getPointerTy(DAG.getDataLayout())); @@ -9441,17 +9546,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, SDValue InGlue; if (RequiresSMChange) { - if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) { - Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL, - DAG.getVTList(MVT::Other, MVT::Glue), Chain); - InGlue = Chain.getValue(1); - } - - SDValue NewChain = changeStreamingMode( - DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue, - getSMToggleCondition(CallAttrs), PStateSM); - Chain = NewChain.getValue(0); - InGlue = NewChain.getValue(1); + Chain = + changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(), + Chain, InGlue, getSMToggleCondition(CallAttrs)); + InGlue = Chain.getValue(1); } // Build a sequence of copy-to-reg nodes chained together with token chain @@ -9633,20 +9731,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, InGlue = Result.getValue(Result->getNumValues() - 1); if (RequiresSMChange) { - assert(PStateSM && "Expected a PStateSM to be set"); Result = changeStreamingMode( DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue, - getSMToggleCondition(CallAttrs), PStateSM); - - if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) { - InGlue = Result.getValue(1); - Result = - DAG.getNode(AArch64ISD::VG_RESTORE, DL, - DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue}); - } + getSMToggleCondition(CallAttrs)); } - if (CallAttrs.requiresEnablingZAAfterCall()) + if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()) // Unconditionally resume ZA. Result = DAG.getNode( AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result, @@ -9659,15 +9749,15 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (RequiresLazySave) { // Conditionally restore the lazy save using a pseudo node. + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE; TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); SDValue RegMask = DAG.getRegisterMask( - TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC))); SDValue RestoreRoutine = DAG.getTargetExternalSymbol( - "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); + getLibcallName(LC), getPointerTy(DAG.getDataLayout())); SDValue TPIDR2_EL0 = DAG.getNode( ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); - // Copy the address of the TPIDR2 block into X0 before 'calling' the // RESTORE_ZA pseudo. SDValue Glue; @@ -9679,7 +9769,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64), RestoreRoutine, RegMask, Result.getValue(1)}); - // Finally reset the TPIDR2_EL0 register to 0. Result = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Result, @@ -9802,14 +9891,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, // Emit SMSTOP before returning from a locally streaming function SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs(); if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) { - if (FuncAttrs.hasStreamingCompatibleInterface()) { - Register Reg = FuncInfo->getPStateSMReg(); - assert(Reg.isValid() && "PStateSM Register is invalid"); - SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64); + if (FuncAttrs.hasStreamingCompatibleInterface()) Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(), - AArch64SME::IfCallerIsNonStreaming, PStateSM); - } else + AArch64SME::IfCallerIsNonStreaming); + else Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(), AArch64SME::Always); Glue = Chain.getValue(1); @@ -11390,13 +11476,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( // select_cc lhs, rhs, sub(rhs, lhs), sub(lhs, rhs), cc -> // select_cc lhs, rhs, neg(sub(lhs, rhs)), sub(lhs, rhs), cc // The second forms can be matched into subs+cneg. + // NOTE: Drop poison generating flags from the negated operand to avoid + // inadvertently propagating poison after the canonicalisation. if (TVal.getOpcode() == ISD::SUB && FVal.getOpcode() == ISD::SUB) { if (TVal.getOperand(0) == LHS && TVal.getOperand(1) == RHS && - FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS) + FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS) { + TVal->dropFlags(SDNodeFlags::PoisonGeneratingFlags); FVal = DAG.getNegative(TVal, DL, TVal.getValueType()); - else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS && - FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS) + } else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS && + FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS) { + FVal->dropFlags(SDNodeFlags::PoisonGeneratingFlags); TVal = DAG.getNegative(FVal, DL, FVal.getValueType()); + } } unsigned Opcode = AArch64ISD::CSEL; @@ -13477,7 +13568,7 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT, // Look for the first non-undef element. const int *FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; }); - // Benefit form APInt to handle overflow when calculating expected element. + // Benefit from APInt to handle overflow when calculating expected element. unsigned NumElts = VT.getVectorNumElements(); unsigned MaskBits = APInt(32, NumElts * 2).logBase2(); APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1, /*isSigned=*/false, @@ -13485,7 +13576,7 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT, // The following shuffle indices must be the successive elements after the // first real element. bool FoundWrongElt = std::any_of(FirstRealElt + 1, M.end(), [&](int Elt) { - return Elt != ExpectedElt++ && Elt != -1; + return Elt != ExpectedElt++ && Elt >= 0; }); if (FoundWrongElt) return false; @@ -14737,12 +14828,107 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) { return ResultSLI; } +static SDValue tryLowerToBSL(SDValue N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + assert(VT.isVector() && "Expected vector type in tryLowerToBSL\n"); + SDLoc DL(N); + const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); + + if (VT.isScalableVector() && !Subtarget.hasSVE2()) + return SDValue(); + + SDValue N0 = N->getOperand(0); + if (N0.getOpcode() != ISD::AND) + return SDValue(); + + SDValue N1 = N->getOperand(1); + if (N1.getOpcode() != ISD::AND) + return SDValue(); + + // InstCombine does (not (neg a)) => (add a -1). + // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c) + // Loop over all combinations of AND operands. + for (int i = 1; i >= 0; --i) { + for (int j = 1; j >= 0; --j) { + SDValue O0 = N0->getOperand(i); + SDValue O1 = N1->getOperand(j); + SDValue Sub, Add, SubSibling, AddSibling; + + // Find a SUB and an ADD operand, one from each AND. + if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) { + Sub = O0; + Add = O1; + SubSibling = N0->getOperand(1 - i); + AddSibling = N1->getOperand(1 - j); + } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) { + Add = O0; + Sub = O1; + AddSibling = N0->getOperand(1 - i); + SubSibling = N1->getOperand(1 - j); + } else + continue; + + if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode())) + continue; + + // Constant ones is always righthand operand of the Add. + if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode())) + continue; + + if (Sub.getOperand(1) != Add.getOperand(0)) + continue; + + return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling); + } + } + + // (or (and a b) (and (not a) c)) => (bsl a b c) + // We only have to look for constant vectors here since the general, variable + // case can be handled in TableGen. + unsigned Bits = VT.getScalarSizeInBits(); + for (int i = 1; i >= 0; --i) + for (int j = 1; j >= 0; --j) { + APInt Val1, Val2; + + if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) && + ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) && + ~Val1.trunc(Bits) == Val2.trunc(Bits)) { + return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i), + N0->getOperand(1 - i), N1->getOperand(1 - j)); + } + BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i)); + BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j)); + if (!BVN0 || !BVN1) + continue; + + bool FoundMatch = true; + for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) { + ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k)); + ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k)); + if (!CN0 || !CN1 || + CN0->getAPIntValue().trunc(Bits) != + ~CN1->getAsAPIntVal().trunc(Bits)) { + FoundMatch = false; + break; + } + } + if (FoundMatch) + return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i), + N0->getOperand(1 - i), N1->getOperand(1 - j)); + } + + return SDValue(); +} + SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SelectionDAG &DAG) const { if (useSVEForFixedLengthVectorVT(Op.getValueType(), !Subtarget->isNeonAvailable())) return LowerToScalableOp(Op, DAG); + if (SDValue Res = tryLowerToBSL(Op, DAG)) + return Res; + // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2)) if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG)) return Res; @@ -15772,6 +15958,7 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { isREVMask(M, EltSize, NumElts, 32) || isREVMask(M, EltSize, NumElts, 16) || isEXTMask(M, VT, DummyBool, DummyUnsigned) || + isSingletonEXTMask(M, VT, DummyUnsigned) || isTRNMask(M, NumElts, DummyUnsigned) || isUZPMask(M, NumElts, DummyUnsigned) || isZIPMask(M, NumElts, DummyUnsigned) || @@ -16284,9 +16471,8 @@ AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, Chain = SP.getValue(1); SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size); if (Align) - SP = - DAG.getNode(ISD::AND, DL, VT, SP.getValue(0), - DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT)); + SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0), + DAG.getSignedConstant(-Align->value(), DL, VT)); Chain = DAG.getCopyToReg(Chain, DL, AArch64::SP, SP); SDValue Ops[2] = {SP, Chain}; return DAG.getMergeValues(Ops, DL); @@ -16323,7 +16509,7 @@ AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size); if (Align) SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0), - DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT)); + DAG.getSignedConstant(-Align->value(), DL, VT)); Chain = DAG.getCopyToReg(Chain, DL, AArch64::SP, SP); Chain = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), DL); @@ -16351,7 +16537,7 @@ AArch64TargetLowering::LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size); if (Align) SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0), - DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT)); + DAG.getSignedConstant(-Align->value(), DL, VT)); // Set the real SP to the new value with a probing loop. Chain = DAG.getNode(AArch64ISD::PROBED_ALLOCA, DL, MVT::Other, Chain, SP); @@ -17254,7 +17440,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor, /// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1 bool AArch64TargetLowering::lowerInterleavedLoad( Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles, - ArrayRef<unsigned> Indices, unsigned Factor) const { + ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const { assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); assert(!Shuffles.empty() && "Empty shufflevector input"); @@ -17264,7 +17450,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad( auto *LI = dyn_cast<LoadInst>(Load); if (!LI) return false; - assert(!Mask && "Unexpected mask on a load"); + assert(!Mask && GapMask.popcount() == Factor && "Unexpected mask on a load"); const DataLayout &DL = LI->getDataLayout(); @@ -17442,14 +17628,16 @@ bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) { bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, Value *LaneMask, ShuffleVectorInst *SVI, - unsigned Factor) const { + unsigned Factor, + const APInt &GapMask) const { assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); auto *SI = dyn_cast<StoreInst>(Store); if (!SI) return false; - assert(!LaneMask && "Unexpected mask on store"); + assert(!LaneMask && GapMask.popcount() == Factor && + "Unexpected mask on store"); auto *VecTy = cast<FixedVectorType>(SVI->getType()); assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); @@ -17987,7 +18175,8 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( case MVT::f64: return true; case MVT::bf16: - return VT.isScalableVector() && Subtarget->hasSVEB16B16(); + return VT.isScalableVector() && Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available(); default: break; } @@ -18151,7 +18340,7 @@ bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, BitSize)) return true; - if ((int64_t)Val < 0) + if (Val < 0) Val = ~Val; if (BitSize == 32) Val &= (1LL << 32) - 1; @@ -19414,106 +19603,6 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG, return FixConv; } -static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, - const AArch64TargetLowering &TLI) { - EVT VT = N->getValueType(0); - SelectionDAG &DAG = DCI.DAG; - SDLoc DL(N); - const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); - - if (!VT.isVector()) - return SDValue(); - - if (VT.isScalableVector() && !Subtarget.hasSVE2()) - return SDValue(); - - if (VT.isFixedLengthVector() && - (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT))) - return SDValue(); - - SDValue N0 = N->getOperand(0); - if (N0.getOpcode() != ISD::AND) - return SDValue(); - - SDValue N1 = N->getOperand(1); - if (N1.getOpcode() != ISD::AND) - return SDValue(); - - // InstCombine does (not (neg a)) => (add a -1). - // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c) - // Loop over all combinations of AND operands. - for (int i = 1; i >= 0; --i) { - for (int j = 1; j >= 0; --j) { - SDValue O0 = N0->getOperand(i); - SDValue O1 = N1->getOperand(j); - SDValue Sub, Add, SubSibling, AddSibling; - - // Find a SUB and an ADD operand, one from each AND. - if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) { - Sub = O0; - Add = O1; - SubSibling = N0->getOperand(1 - i); - AddSibling = N1->getOperand(1 - j); - } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) { - Add = O0; - Sub = O1; - AddSibling = N0->getOperand(1 - i); - SubSibling = N1->getOperand(1 - j); - } else - continue; - - if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode())) - continue; - - // Constant ones is always righthand operand of the Add. - if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode())) - continue; - - if (Sub.getOperand(1) != Add.getOperand(0)) - continue; - - return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling); - } - } - - // (or (and a b) (and (not a) c)) => (bsl a b c) - // We only have to look for constant vectors here since the general, variable - // case can be handled in TableGen. - unsigned Bits = VT.getScalarSizeInBits(); - uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1); - for (int i = 1; i >= 0; --i) - for (int j = 1; j >= 0; --j) { - APInt Val1, Val2; - - if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) && - ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) && - (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) { - return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i), - N0->getOperand(1 - i), N1->getOperand(1 - j)); - } - BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i)); - BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j)); - if (!BVN0 || !BVN1) - continue; - - bool FoundMatch = true; - for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) { - ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k)); - ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k)); - if (!CN0 || !CN1 || - CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) { - FoundMatch = false; - break; - } - } - if (FoundMatch) - return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i), - N0->getOperand(1 - i), N1->getOperand(1 - j)); - } - - return SDValue(); -} - // Given a tree of and/or(csel(0, 1, cc0), csel(0, 1, cc1)), we may be able to // convert to csel(ccmp(.., cc0)), depending on cc1: @@ -19595,17 +19684,10 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget, const AArch64TargetLowering &TLI) { SelectionDAG &DAG = DCI.DAG; - EVT VT = N->getValueType(0); if (SDValue R = performANDORCSELCombine(N, DAG)) return R; - if (!DAG.getTargetLoweringInfo().isTypeLegal(VT)) - return SDValue(); - - if (SDValue Res = tryCombineToBSL(N, DCI, TLI)) - return Res; - return SDValue(); } @@ -22107,6 +22189,17 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc, Zero); } +static SDValue tryCombineNeonFcvtFP16ToI16(SDNode *N, unsigned Opcode, + SelectionDAG &DAG) { + if (N->getValueType(0) != MVT::i16) + return SDValue(); + + SDLoc DL(N); + SDValue CVT = DAG.getNode(Opcode, DL, MVT::f32, N->getOperand(1)); + SDValue Bitcast = DAG.getBitcast(MVT::i32, CVT); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Bitcast); +} + // If a merged operation has no inactive lanes we can relax it to a predicated // or unpredicated operation, which potentially allows better isel (perhaps // using immediate forms) or relaxing register reuse requirements. @@ -22360,6 +22453,26 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_neon_uabd: return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_neon_fcvtzs: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTZS_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtzu: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTZU_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtas: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTAS_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtau: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTAU_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtms: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTMS_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtmu: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTMU_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtns: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTNS_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtnu: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTNU_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtps: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTPS_HALF, DAG); + case Intrinsic::aarch64_neon_fcvtpu: + return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTPU_HALF, DAG); case Intrinsic::aarch64_crc32b: case Intrinsic::aarch64_crc32cb: return tryCombineCRC32(0xff, N, DAG); @@ -25450,6 +25563,29 @@ static SDValue performCSELCombine(SDNode *N, } } + // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't + // use overflow flags, to avoid the comparison with zero. In case of success, + // this also replaces the original SUB(x,y) with the newly created SUBS(x,y). + // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB + // nodes with their SUBS equivalent as is already done for other flag-setting + // operators, in which case doing the replacement here becomes redundant. + if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) && + isNullConstant(Cond.getOperand(1))) { + SDValue Sub = Cond.getOperand(0); + AArch64CC::CondCode CC = + static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2)); + if (Sub.getOpcode() == ISD::SUB && + (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI || + CC == AArch64CC::PL)) { + SDLoc DL(N); + SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), + Sub.getOperand(0), Sub.getOperand(1)); + DCI.CombineTo(Sub.getNode(), Subs); + DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1)); + return SDValue(N, 0); + } + } + // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z if (SDValue CondLast = foldCSELofLASTB(N, DAG)) return CondLast; @@ -28166,6 +28302,7 @@ void AArch64TargetLowering::ReplaceNodeResults( case Intrinsic::aarch64_sme_in_streaming_mode: { SDLoc DL(N); SDValue Chain = DAG.getEntryNode(); + SDValue RuntimePStateSM = getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0)); Results.push_back( @@ -28609,14 +28746,20 @@ Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const { void AArch64TargetLowering::insertSSPDeclarations(Module &M) const { // MSVC CRT provides functionalities for stack protection. - if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) { + RTLIB::LibcallImpl SecurityCheckCookieLibcall = + getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE); + + RTLIB::LibcallImpl SecurityCookieVar = + getLibcallImpl(RTLIB::STACK_CHECK_GUARD); + if (SecurityCheckCookieLibcall != RTLIB::Unsupported && + SecurityCookieVar != RTLIB::Unsupported) { // MSVC CRT has a global variable holding security cookie. - M.getOrInsertGlobal("__security_cookie", + M.getOrInsertGlobal(getLibcallImplName(SecurityCookieVar), PointerType::getUnqual(M.getContext())); // MSVC CRT has a function to validate security cookie. FunctionCallee SecurityCheckCookie = - M.getOrInsertFunction(Subtarget->getSecurityCheckCookieName(), + M.getOrInsertFunction(getLibcallImplName(SecurityCheckCookieLibcall), Type::getVoidTy(M.getContext()), PointerType::getUnqual(M.getContext())); if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) { @@ -28628,17 +28771,12 @@ void AArch64TargetLowering::insertSSPDeclarations(Module &M) const { TargetLowering::insertSSPDeclarations(M); } -Value *AArch64TargetLowering::getSDagStackGuard(const Module &M) const { - // MSVC CRT has a global variable holding security cookie. - if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) - return M.getGlobalVariable("__security_cookie"); - return TargetLowering::getSDagStackGuard(M); -} - Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const { // MSVC CRT has a function to validate security cookie. - if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) - return M.getFunction(Subtarget->getSecurityCheckCookieName()); + RTLIB::LibcallImpl SecurityCheckCookieLibcall = + getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE); + if (SecurityCheckCookieLibcall != RTLIB::Unsupported) + return M.getFunction(getLibcallImplName(SecurityCheckCookieLibcall)); return TargetLowering::getSSPStackGuardCheck(M); } @@ -28972,7 +29110,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { // Checks to allow the use of SME instructions if (auto *Base = dyn_cast<CallBase>(&Inst)) { - auto CallAttrs = SMECallAttrs(*Base); + auto CallAttrs = SMECallAttrs(*Base, this); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index ea63edd8..4673836 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -23,6 +23,8 @@ namespace llvm { +class AArch64TargetMachine; + namespace AArch64 { /// Possible values of current rounding mode, which is specified in bits /// 23:22 of FPCR. @@ -64,6 +66,8 @@ public: explicit AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI); + const AArch64TargetMachine &getTM() const; + /// Control the following reassociation of operands: (op (op x, c1), y) -> (op /// (op x, y), c1) where N0 is (op x, c1) and N1 is y. bool isReassocProfitable(SelectionDAG &DAG, SDValue N0, @@ -173,6 +177,10 @@ public: MachineBasicBlock *EmitZTInstr(MachineInstr &MI, MachineBasicBlock *BB, unsigned Opcode, bool Op0IsDef) const; MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const; + + // Note: The following group of functions are only used as part of the old SME + // ABI lowering. They will be removed once -aarch64-new-sme-abi=true is the + // default. MachineBasicBlock *EmitInitTPIDR2Object(MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI, @@ -181,6 +189,8 @@ public: MachineBasicBlock *BB) const; MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitEntryPStateSM(MachineInstr &MI, + MachineBasicBlock *BB) const; /// Replace (0, vreg) discriminator components with the operands of blend /// or with (immediate, NoRegister) when possible. @@ -220,11 +230,11 @@ public: bool lowerInterleavedLoad(Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles, - ArrayRef<unsigned> Indices, - unsigned Factor) const override; + ArrayRef<unsigned> Indices, unsigned Factor, + const APInt &GapMask) const override; bool lowerInterleavedStore(Instruction *Store, Value *Mask, - ShuffleVectorInst *SVI, - unsigned Factor) const override; + ShuffleVectorInst *SVI, unsigned Factor, + const APInt &GapMask) const override; bool lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask, IntrinsicInst *DI) const override; @@ -344,7 +354,6 @@ public: Value *getIRStackGuard(IRBuilderBase &IRB) const override; void insertSSPDeclarations(Module &M) const override; - Value *getSDagStackGuard(const Module &M) const override; Function *getSSPStackGuardCheck(const Module &M) const override; /// If the target has a standard location for the unsafe stack pointer, @@ -523,8 +532,8 @@ public: /// node. \p Condition should be one of the enum values from /// AArch64SME::ToggleCondition. SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable, - SDValue Chain, SDValue InGlue, unsigned Condition, - SDValue PStateSM = SDValue()) const; + SDValue Chain, SDValue InGlue, + unsigned Condition) const; bool isVScaleKnownToBeAPowerOfTwo() const override { return true; } @@ -887,6 +896,10 @@ private: bool shouldScalarizeBinop(SDValue VecOp) const override { return VecOp.getOpcode() == ISD::SETCC; } + + bool hasMultipleConditionRegisters(EVT VT) const override { + return VT.isScalableVector(); + } }; namespace AArch64 { diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index ba7cbcc..feff590 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -36,7 +36,12 @@ def DestructiveBinary : DestructiveInstTypeEnum<5>; def DestructiveBinaryComm : DestructiveInstTypeEnum<6>; def DestructiveBinaryCommWithRev : DestructiveInstTypeEnum<7>; def DestructiveTernaryCommWithRev : DestructiveInstTypeEnum<8>; -def DestructiveUnaryPassthru : DestructiveInstTypeEnum<9>; + +// 3 inputs unpredicated (reg1, reg2, imm). +// Can be MOVPRFX'd iff reg1 == reg2. +def Destructive2xRegImmUnpred : DestructiveInstTypeEnum<9>; + +def DestructiveUnaryPassthru : DestructiveInstTypeEnum<10>; class FalseLanesEnum<bits<2> val> { bits<2> Value = val; @@ -3027,8 +3032,12 @@ class BaseAddSubEReg64<bit isSub, bit setFlags, RegisterClass dstRegtype, // Aliases for register+register add/subtract. class AddSubRegAlias<string asm, Instruction inst, RegisterClass dstRegtype, - RegisterClass src1Regtype, RegisterClass src2Regtype, - int shiftExt> + RegisterClass src1Regtype, dag src2> + : InstAlias<asm#"\t$dst, $src1, $src2", + (inst dstRegtype:$dst, src1Regtype:$src1, src2)>; +class AddSubRegAlias64<string asm, Instruction inst, RegisterClass dstRegtype, + RegisterClass src1Regtype, RegisterClass src2Regtype, + int shiftExt> : InstAlias<asm#"\t$dst, $src1, $src2", (inst dstRegtype:$dst, src1Regtype:$src1, src2Regtype:$src2, shiftExt)>; @@ -3096,22 +3105,22 @@ multiclass AddSub<bit isSub, string mnemonic, string alias, // Register/register aliases with no shift when SP is not used. def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"), - GPR32, GPR32, GPR32, 0>; + GPR32, GPR32, (arith_shifted_reg32 GPR32:$src2, 0)>; def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"), - GPR64, GPR64, GPR64, 0>; + GPR64, GPR64, (arith_shifted_reg64 GPR64:$src2, 0)>; // Register/register aliases with no shift when either the destination or // first source register is SP. def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"), - GPR32sponly, GPR32sp, GPR32, 16>; // UXTW #0 + GPR32sponly, GPR32sp, + (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0 def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"), - GPR32sp, GPR32sponly, GPR32, 16>; // UXTW #0 - def : AddSubRegAlias<mnemonic, - !cast<Instruction>(NAME#"Xrx64"), - GPR64sponly, GPR64sp, GPR64, 24>; // UXTX #0 - def : AddSubRegAlias<mnemonic, - !cast<Instruction>(NAME#"Xrx64"), - GPR64sp, GPR64sponly, GPR64, 24>; // UXTX #0 + GPR32sp, GPR32sponly, + (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0 + def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"), + GPR64sponly, GPR64sp, GPR64, 24>; // UXTX #0 + def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"), + GPR64sp, GPR64sponly, GPR64, 24>; // UXTX #0 } multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp, @@ -3175,15 +3184,19 @@ multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp, def : InstAlias<cmp#"\t$src, $imm", (!cast<Instruction>(NAME#"Xri") XZR, GPR64sp:$src, addsub_shifted_imm64:$imm), 5>; def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Wrx") - WZR, GPR32sp:$src1, GPR32:$src2, arith_extend:$sh), 4>; + WZR, GPR32sp:$src1, + (arith_extended_reg32_i32 GPR32:$src2, arith_extend:$sh)), 4>; def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrx") - XZR, GPR64sp:$src1, GPR32:$src2, arith_extend:$sh), 4>; + XZR, GPR64sp:$src1, + (arith_extended_reg32_i64 GPR32:$src2, arith_extend:$sh)), 4>; def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrx64") XZR, GPR64sp:$src1, GPR64:$src2, arith_extendlsl64:$sh), 4>; def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Wrs") - WZR, GPR32:$src1, GPR32:$src2, arith_shift32:$sh), 4>; + WZR, GPR32:$src1, + (arith_shifted_reg32 GPR32:$src2, arith_shift32:$sh)), 4>; def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrs") - XZR, GPR64:$src1, GPR64:$src2, arith_shift64:$sh), 4>; + XZR, GPR64:$src1, + (arith_shifted_reg64 GPR64:$src2, arith_shift64:$sh)), 4>; // Support negative immediates, e.g. cmp Rn, -imm -> cmn Rn, imm def : InstSubst<cmpAlias#"\t$src, $imm", (!cast<Instruction>(NAME#"Wri") @@ -3193,27 +3206,28 @@ multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp, // Compare shorthands def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Wrs") - WZR, GPR32:$src1, GPR32:$src2, 0), 5>; + WZR, GPR32:$src1, (arith_shifted_reg32 GPR32:$src2, 0)), 5>; def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Xrs") - XZR, GPR64:$src1, GPR64:$src2, 0), 5>; + XZR, GPR64:$src1, (arith_shifted_reg64 GPR64:$src2, 0)), 5>; def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Wrx") - WZR, GPR32sponly:$src1, GPR32:$src2, 16), 5>; + WZR, GPR32sponly:$src1, + (arith_extended_reg32_i32 GPR32:$src2, 16)), 5>; def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Xrx64") XZR, GPR64sponly:$src1, GPR64:$src2, 24), 5>; // Register/register aliases with no shift when SP is not used. def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"), - GPR32, GPR32, GPR32, 0>; + GPR32, GPR32, (arith_shifted_reg32 GPR32:$src2, 0)>; def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"), - GPR64, GPR64, GPR64, 0>; + GPR64, GPR64, (arith_shifted_reg64 GPR64:$src2, 0)>; // Register/register aliases with no shift when the first source register // is SP. def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"), - GPR32, GPR32sponly, GPR32, 16>; // UXTW #0 - def : AddSubRegAlias<mnemonic, - !cast<Instruction>(NAME#"Xrx64"), - GPR64, GPR64sponly, GPR64, 24>; // UXTX #0 + GPR32, GPR32sponly, + (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0 + def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"), + GPR64, GPR64sponly, GPR64, 24>; // UXTX #0 } class AddSubG<bit isSub, string asm_inst, SDPatternOperator OpNode> @@ -3398,9 +3412,10 @@ class BaseLogicalSReg<bits<2> opc, bit N, RegisterClass regtype, } // Aliases for register+register logical instructions. -class LogicalRegAlias<string asm, Instruction inst, RegisterClass regtype> +class LogicalRegAlias<string asm, Instruction inst, RegisterClass regtype, + dag op2> : InstAlias<asm#"\t$dst, $src1, $src2", - (inst regtype:$dst, regtype:$src1, regtype:$src2, 0)>; + (inst regtype:$dst, regtype:$src1, op2)>; multiclass LogicalImm<bits<2> opc, string mnemonic, SDNode OpNode, string Alias> { @@ -3472,10 +3487,10 @@ multiclass LogicalReg<bits<2> opc, bit N, string mnemonic, let Inst{31} = 1; } - def : LogicalRegAlias<mnemonic, - !cast<Instruction>(NAME#"Wrs"), GPR32>; - def : LogicalRegAlias<mnemonic, - !cast<Instruction>(NAME#"Xrs"), GPR64>; + def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"), + GPR32, (logical_shifted_reg32 GPR32:$src2, 0)>; + def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"), + GPR64, (logical_shifted_reg64 GPR64:$src2, 0)>; } // Split from LogicalReg to allow setting NZCV Defs @@ -3495,10 +3510,10 @@ multiclass LogicalRegS<bits<2> opc, bit N, string mnemonic, } } // Defs = [NZCV] - def : LogicalRegAlias<mnemonic, - !cast<Instruction>(NAME#"Wrs"), GPR32>; - def : LogicalRegAlias<mnemonic, - !cast<Instruction>(NAME#"Xrs"), GPR64>; + def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"), + GPR32, (logical_shifted_reg32 GPR32:$src2, 0)>; + def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"), + GPR64, (logical_shifted_reg64 GPR64:$src2, 0)>; } //--- @@ -3986,9 +4001,10 @@ class LoadStore8RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins, let Inst{4-0} = Rt; } -class ROInstAlias<string asm, DAGOperand regtype, Instruction INST> +class ROInstAlias<string asm, DAGOperand regtype, Instruction INST, + ro_extend ext> : InstAlias<asm # "\t$Rt, [$Rn, $Rm]", - (INST regtype:$Rt, GPR64sp:$Rn, GPR64:$Rm, 0, 0)>; + (INST regtype:$Rt, GPR64sp:$Rn, GPR64:$Rm, (ext 0, 0))>; multiclass Load8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, string asm, ValueType Ty, SDPatternOperator loadop> { @@ -4014,7 +4030,7 @@ multiclass Load8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend8>; } multiclass Store8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, @@ -4039,7 +4055,7 @@ multiclass Store8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend8>; } class LoadStore16RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins, @@ -4086,7 +4102,7 @@ multiclass Load16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend16>; } multiclass Store16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, @@ -4111,7 +4127,7 @@ multiclass Store16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend16>; } class LoadStore32RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins, @@ -4158,7 +4174,7 @@ multiclass Load32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend32>; } multiclass Store32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, @@ -4183,7 +4199,7 @@ multiclass Store32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend32>; } class LoadStore64RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins, @@ -4230,7 +4246,7 @@ multiclass Load64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend64>; } multiclass Store64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, @@ -4255,7 +4271,7 @@ multiclass Store64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend64>; } class LoadStore128RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins, @@ -4302,7 +4318,7 @@ multiclass Load128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend128>; } multiclass Store128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, @@ -4323,7 +4339,7 @@ multiclass Store128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype, let Inst{13} = 0b1; } - def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>; + def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend128>; } let mayLoad = 0, mayStore = 0, hasSideEffects = 1 in @@ -4372,9 +4388,7 @@ multiclass PrefetchRO<bits<2> sz, bit V, bits<2> opc, string asm> { let Inst{13} = 0b1; } - def : InstAlias<"prfm $Rt, [$Rn, $Rm]", - (!cast<Instruction>(NAME # "roX") prfop:$Rt, - GPR64sp:$Rn, GPR64:$Rm, 0, 0)>; + def : ROInstAlias<"prfm", prfop, !cast<Instruction>(NAME # "roX"), ro_Xextend64>; } //--- @@ -6484,7 +6498,9 @@ class BaseSIMDThreeSameVectorDot<bit Q, bit U, bits<2> sz, bits<4> opc, string a (OpNode (AccumType RegType:$Rd), (InputType RegType:$Rn), (InputType RegType:$Rm)))]> { - let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}"); + + let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # + "|" # kind1 # "\t$Rd, $Rn, $Rm}"); } multiclass SIMDThreeSameVectorDot<bit U, bit Mixed, string asm, SDPatternOperator OpNode> { @@ -6507,7 +6523,8 @@ class BaseSIMDThreeSameVectorFML<bit Q, bit U, bit b13, bits<3> size, string asm (OpNode (AccumType RegType:$Rd), (InputType RegType:$Rn), (InputType RegType:$Rm)))]> { - let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}"); + let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # + "|" # kind1 # "\t$Rd, $Rn, $Rm}"); let Inst{13} = b13; } @@ -7359,7 +7376,9 @@ multiclass SIMDDifferentThreeVectorBD<bit U, bits<4> opc, string asm, [(set (v8i16 V128:$Rd), (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm)))]>; def v16i8 : BaseSIMDDifferentThreeVector<U, 0b001, opc, V128, V128, V128, - asm#"2", ".8h", ".16b", ".16b", []>; + asm#"2", ".8h", ".16b", ".16b", + [(set (v8i16 V128:$Rd), (OpNode (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))), + (v8i8 (extract_high_v16i8 (v16i8 V128:$Rm)))))]>; let Predicates = [HasAES] in { def v1i64 : BaseSIMDDifferentThreeVector<U, 0b110, opc, V128, V64, V64, @@ -7371,10 +7390,6 @@ multiclass SIMDDifferentThreeVectorBD<bit U, bits<4> opc, string asm, [(set (v16i8 V128:$Rd), (OpNode (extract_high_v2i64 (v2i64 V128:$Rn)), (extract_high_v2i64 (v2i64 V128:$Rm))))]>; } - - def : Pat<(v8i16 (OpNode (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))), - (v8i8 (extract_high_v16i8 (v16i8 V128:$Rm))))), - (!cast<Instruction>(NAME#"v16i8") V128:$Rn, V128:$Rm)>; } multiclass SIMDLongThreeVectorHS<bit U, bits<4> opc, string asm, @@ -7399,87 +7414,7 @@ multiclass SIMDLongThreeVectorHS<bit U, bits<4> opc, string asm, (extract_high_v4i32 (v4i32 V128:$Rm))))]>; } -multiclass SIMDLongThreeVectorBHSabdl<bit U, bits<4> opc, string asm, - SDPatternOperator OpNode = null_frag> { - def v8i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b000, opc, - V128, V64, V64, - asm, ".8h", ".8b", ".8b", - [(set (v8i16 V128:$Rd), - (zext (v8i8 (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm)))))]>; - def v16i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b001, opc, - V128, V128, V128, - asm#"2", ".8h", ".16b", ".16b", - [(set (v8i16 V128:$Rd), - (zext (v8i8 (OpNode (extract_high_v16i8 (v16i8 V128:$Rn)), - (extract_high_v16i8 (v16i8 V128:$Rm))))))]>; - def v4i16_v4i32 : BaseSIMDDifferentThreeVector<U, 0b010, opc, - V128, V64, V64, - asm, ".4s", ".4h", ".4h", - [(set (v4i32 V128:$Rd), - (zext (v4i16 (OpNode (v4i16 V64:$Rn), (v4i16 V64:$Rm)))))]>; - def v8i16_v4i32 : BaseSIMDDifferentThreeVector<U, 0b011, opc, - V128, V128, V128, - asm#"2", ".4s", ".8h", ".8h", - [(set (v4i32 V128:$Rd), - (zext (v4i16 (OpNode (extract_high_v8i16 (v8i16 V128:$Rn)), - (extract_high_v8i16 (v8i16 V128:$Rm))))))]>; - def v2i32_v2i64 : BaseSIMDDifferentThreeVector<U, 0b100, opc, - V128, V64, V64, - asm, ".2d", ".2s", ".2s", - [(set (v2i64 V128:$Rd), - (zext (v2i32 (OpNode (v2i32 V64:$Rn), (v2i32 V64:$Rm)))))]>; - def v4i32_v2i64 : BaseSIMDDifferentThreeVector<U, 0b101, opc, - V128, V128, V128, - asm#"2", ".2d", ".4s", ".4s", - [(set (v2i64 V128:$Rd), - (zext (v2i32 (OpNode (extract_high_v4i32 (v4i32 V128:$Rn)), - (extract_high_v4i32 (v4i32 V128:$Rm))))))]>; -} - -multiclass SIMDLongThreeVectorTiedBHSabal<bit U, bits<4> opc, - string asm, - SDPatternOperator OpNode> { - def v8i8_v8i16 : BaseSIMDDifferentThreeVectorTied<U, 0b000, opc, - V128, V64, V64, - asm, ".8h", ".8b", ".8b", - [(set (v8i16 V128:$dst), - (add (v8i16 V128:$Rd), - (zext (v8i8 (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm))))))]>; - def v16i8_v8i16 : BaseSIMDDifferentThreeVectorTied<U, 0b001, opc, - V128, V128, V128, - asm#"2", ".8h", ".16b", ".16b", - [(set (v8i16 V128:$dst), - (add (v8i16 V128:$Rd), - (zext (v8i8 (OpNode (extract_high_v16i8 (v16i8 V128:$Rn)), - (extract_high_v16i8 (v16i8 V128:$Rm)))))))]>; - def v4i16_v4i32 : BaseSIMDDifferentThreeVectorTied<U, 0b010, opc, - V128, V64, V64, - asm, ".4s", ".4h", ".4h", - [(set (v4i32 V128:$dst), - (add (v4i32 V128:$Rd), - (zext (v4i16 (OpNode (v4i16 V64:$Rn), (v4i16 V64:$Rm))))))]>; - def v8i16_v4i32 : BaseSIMDDifferentThreeVectorTied<U, 0b011, opc, - V128, V128, V128, - asm#"2", ".4s", ".8h", ".8h", - [(set (v4i32 V128:$dst), - (add (v4i32 V128:$Rd), - (zext (v4i16 (OpNode (extract_high_v8i16 (v8i16 V128:$Rn)), - (extract_high_v8i16 (v8i16 V128:$Rm)))))))]>; - def v2i32_v2i64 : BaseSIMDDifferentThreeVectorTied<U, 0b100, opc, - V128, V64, V64, - asm, ".2d", ".2s", ".2s", - [(set (v2i64 V128:$dst), - (add (v2i64 V128:$Rd), - (zext (v2i32 (OpNode (v2i32 V64:$Rn), (v2i32 V64:$Rm))))))]>; - def v4i32_v2i64 : BaseSIMDDifferentThreeVectorTied<U, 0b101, opc, - V128, V128, V128, - asm#"2", ".2d", ".4s", ".4s", - [(set (v2i64 V128:$dst), - (add (v2i64 V128:$Rd), - (zext (v2i32 (OpNode (extract_high_v4i32 (v4i32 V128:$Rn)), - (extract_high_v4i32 (v4i32 V128:$Rm)))))))]>; -} - +let isCommutable = 1 in multiclass SIMDLongThreeVectorBHS<bit U, bits<4> opc, string asm, SDPatternOperator OpNode = null_frag> { def v8i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b000, opc, @@ -8986,7 +8921,8 @@ class BaseSIMDThreeSameVectorBFDot<bit Q, bit U, string asm, string kind1, (InputType RegType:$Rm)))]> { let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # - ", $Rm" # kind2 # "}"); + ", $Rm" # kind2 # + "|" # kind1 # "\t$Rd, $Rn, $Rm}"); } multiclass SIMDThreeSameVectorBFDot<bit U, string asm> { @@ -9032,7 +8968,7 @@ class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode> [(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd), (v8bf16 V128:$Rn), (v8bf16 V128:$Rm)))]> { - let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h}"); + let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h|.4s\t$Rd, $Rn, $Rm}"); } let mayRaiseFPException = 1, Uses = [FPCR] in @@ -9071,8 +9007,7 @@ class SIMDThreeSameVectorBF16MatrixMul<string asm> (int_aarch64_neon_bfmmla (v4f32 V128:$Rd), (v8bf16 V128:$Rn), (v8bf16 V128:$Rm)))]> { - let AsmString = !strconcat(asm, "{\t$Rd", ".4s", ", $Rn", ".8h", - ", $Rm", ".8h", "}"); + let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h|.4s\t$Rd, $Rn, $Rm}"); } let mayRaiseFPException = 1, Uses = [FPCR] in @@ -9143,7 +9078,7 @@ class SIMDThreeSameVectorMatMul<bit B, bit U, string asm, SDPatternOperator OpNo [(set (v4i32 V128:$dst), (OpNode (v4i32 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128:$Rm)))]> { - let AsmString = asm # "{\t$Rd.4s, $Rn.16b, $Rm.16b}"; + let AsmString = asm # "{\t$Rd.4s, $Rn.16b, $Rm.16b|.4s\t$Rd, $Rn, $Rm}"; } //---------------------------------------------------------------------------- @@ -12561,7 +12496,7 @@ multiclass STOPregister<string asm, string instr> { let Predicates = [HasLSUI] in class BaseSTOPregisterLSUI<string asm, RegisterClass OP, Register Reg, Instruction inst> : - InstAlias<asm # "\t$Rs, [$Rn]", (inst Reg, OP:$Rs, GPR64sp:$Rn), 0>; + InstAlias<asm # "\t$Rs, [$Rn]", (inst Reg, OP:$Rs, GPR64sp:$Rn)>; multiclass STOPregisterLSUI<string asm, string instr> { def : BaseSTOPregisterLSUI<asm # "l", GPR32, WZR, @@ -13344,8 +13279,8 @@ multiclass AtomicFPStore<bit R, bits<3> op0, string asm> { class BaseSIMDThreeSameVectorFP8MatrixMul<string asm, bits<2> size, string kind> : BaseSIMDThreeSameVectorTied<1, 1, {size, 0}, 0b11101, V128, asm, ".16b", []> { - let AsmString = !strconcat(asm, "{\t$Rd", kind, ", $Rn", ".16b", - ", $Rm", ".16b", "}"); + let AsmString = !strconcat(asm, "{\t$Rd", kind, ", $Rn.16b, $Rm.16b", + "|", kind, "\t$Rd, $Rn, $Rm}"); } multiclass SIMDThreeSameVectorFP8MatrixMul<string asm>{ diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 59d4fd2..3ce7829 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -20,7 +20,9 @@ #include "Utils/AArch64BaseInfo.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/CodeGen/CFIInstBuilder.h" #include "llvm/CodeGen/LivePhysRegs.h" #include "llvm/CodeGen/MachineBasicBlock.h" @@ -83,6 +85,11 @@ static cl::opt<unsigned> BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26), cl::desc("Restrict range of B instructions (DEBUG)")); +static cl::opt<unsigned> GatherOptSearchLimit( + "aarch64-search-limit", cl::Hidden, cl::init(2048), + cl::desc("Restrict range of instructions to search for the " + "machine-combiner gather pattern optimization")); + AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI) : AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP, AArch64::CATCHRET), @@ -5068,7 +5075,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); } - } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGP()) { + } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) { BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); @@ -5078,8 +5085,13 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. MCRegister DestRegX = TRI->getMatchingSuperReg( DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); - MCRegister SrcRegX = TRI->getMatchingSuperReg( - SrcReg, AArch64::sub_32, &AArch64::GPR64spRegClass); + assert(DestRegX.isValid() && "Destination super-reg not valid"); + MCRegister SrcRegX = + SrcReg == AArch64::WZR + ? AArch64::XZR + : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); + assert(SrcRegX.isValid() && "Source super-reg not valid"); // This instruction is reading and writing X registers. This may upset // the register scavenger and machine verifier, so we need to indicate // that we are reading an undefined value from SrcRegX, but a proper @@ -5190,7 +5202,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, .addReg(SrcReg, getKillRegState(KillSrc)) .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGP()) { + } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) { BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg) .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); @@ -5306,15 +5318,49 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, if (AArch64::FPR64RegClass.contains(DestReg) && AArch64::FPR64RegClass.contains(SrcReg)) { - BuildMI(MBB, I, DL, get(AArch64::FMOVDr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); + if (Subtarget.hasZeroCycleRegMoveFPR128() && + !Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) { + const TargetRegisterInfo *TRI = &getRegisterInfo(); + MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::dsub, + &AArch64::FPR128RegClass); + MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::dsub, + &AArch64::FPR128RegClass); + // This instruction is reading and writing Q registers. This may upset + // the register scavenger and machine verifier, so we need to indicate + // that we are reading an undefined value from SrcRegQ, but a proper + // value from SrcReg. + BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); + } else { + BuildMI(MBB, I, DL, get(AArch64::FMOVDr), DestReg) + .addReg(SrcReg, getKillRegState(KillSrc)); + } return; } if (AArch64::FPR32RegClass.contains(DestReg) && AArch64::FPR32RegClass.contains(SrcReg)) { - if (Subtarget.hasZeroCycleRegMoveFPR64() && - !Subtarget.hasZeroCycleRegMoveFPR32()) { + if (Subtarget.hasZeroCycleRegMoveFPR128() && + !Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) { + const TargetRegisterInfo *TRI = &getRegisterInfo(); + MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::ssub, + &AArch64::FPR128RegClass); + MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::ssub, + &AArch64::FPR128RegClass); + // This instruction is reading and writing Q registers. This may upset + // the register scavenger and machine verifier, so we need to indicate + // that we are reading an undefined value from SrcRegQ, but a proper + // value from SrcReg. + BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); + } else if (Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32()) { const TargetRegisterInfo *TRI = &getRegisterInfo(); MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::ssub, &AArch64::FPR64RegClass); @@ -5336,8 +5382,24 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, if (AArch64::FPR16RegClass.contains(DestReg) && AArch64::FPR16RegClass.contains(SrcReg)) { - if (Subtarget.hasZeroCycleRegMoveFPR64() && - !Subtarget.hasZeroCycleRegMoveFPR32()) { + if (Subtarget.hasZeroCycleRegMoveFPR128() && + !Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) { + const TargetRegisterInfo *TRI = &getRegisterInfo(); + MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::hsub, + &AArch64::FPR128RegClass); + MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::hsub, + &AArch64::FPR128RegClass); + // This instruction is reading and writing Q registers. This may upset + // the register scavenger and machine verifier, so we need to indicate + // that we are reading an undefined value from SrcRegQ, but a proper + // value from SrcReg. + BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); + } else if (Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32()) { const TargetRegisterInfo *TRI = &getRegisterInfo(); MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::hsub, &AArch64::FPR64RegClass); @@ -5363,8 +5425,24 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, if (AArch64::FPR8RegClass.contains(DestReg) && AArch64::FPR8RegClass.contains(SrcReg)) { - if (Subtarget.hasZeroCycleRegMoveFPR64() && - !Subtarget.hasZeroCycleRegMoveFPR32()) { + if (Subtarget.hasZeroCycleRegMoveFPR128() && + !Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR64() && Subtarget.isNeonAvailable()) { + const TargetRegisterInfo *TRI = &getRegisterInfo(); + MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::bsub, + &AArch64::FPR128RegClass); + MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::bsub, + &AArch64::FPR128RegClass); + // This instruction is reading and writing Q registers. This may upset + // the register scavenger and machine verifier, so we need to indicate + // that we are reading an undefined value from SrcRegQ, but a proper + // value from SrcReg. + BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcRegQ, RegState::Undef) + .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); + } else if (Subtarget.hasZeroCycleRegMoveFPR64() && + !Subtarget.hasZeroCycleRegMoveFPR32()) { const TargetRegisterInfo *TRI = &getRegisterInfo(); MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::bsub, &AArch64::FPR64RegClass); @@ -5861,33 +5939,53 @@ void AArch64InstrInfo::decomposeStackOffsetForFrameOffsets( } } -// Convenience function to create a DWARF expression for -// Expr + NumBytes + NumVGScaledBytes * AArch64::VG -static void appendVGScaledOffsetExpr(SmallVectorImpl<char> &Expr, int NumBytes, - int NumVGScaledBytes, unsigned VG, - llvm::raw_string_ostream &Comment) { - uint8_t buffer[16]; - - if (NumBytes) { +// Convenience function to create a DWARF expression for: Constant `Operation`. +// This helper emits compact sequences for common cases. For example, for`-15 +// DW_OP_plus`, this helper would create DW_OP_lit15 DW_OP_minus. +static void appendConstantExpr(SmallVectorImpl<char> &Expr, int64_t Constant, + dwarf::LocationAtom Operation) { + if (Operation == dwarf::DW_OP_plus && Constant < 0 && -Constant <= 31) { + // -Constant (1 to 31) + Expr.push_back(dwarf::DW_OP_lit0 - Constant); + Operation = dwarf::DW_OP_minus; + } else if (Constant >= 0 && Constant <= 31) { + // Literal value 0 to 31 + Expr.push_back(dwarf::DW_OP_lit0 + Constant); + } else { + // Signed constant Expr.push_back(dwarf::DW_OP_consts); - Expr.append(buffer, buffer + encodeSLEB128(NumBytes, buffer)); - Expr.push_back((uint8_t)dwarf::DW_OP_plus); - Comment << (NumBytes < 0 ? " - " : " + ") << std::abs(NumBytes); + appendLEB128<LEB128Sign::Signed>(Expr, Constant); } + return Expr.push_back(Operation); +} - if (NumVGScaledBytes) { - Expr.push_back((uint8_t)dwarf::DW_OP_consts); - Expr.append(buffer, buffer + encodeSLEB128(NumVGScaledBytes, buffer)); - - Expr.push_back((uint8_t)dwarf::DW_OP_bregx); - Expr.append(buffer, buffer + encodeULEB128(VG, buffer)); - Expr.push_back(0); +// Convenience function to create a DWARF expression for a register. +static void appendReadRegExpr(SmallVectorImpl<char> &Expr, unsigned RegNum) { + Expr.push_back((char)dwarf::DW_OP_bregx); + appendLEB128<LEB128Sign::Unsigned>(Expr, RegNum); + Expr.push_back(0); +} - Expr.push_back((uint8_t)dwarf::DW_OP_mul); - Expr.push_back((uint8_t)dwarf::DW_OP_plus); +// Convenience function to create a DWARF expression for loading a register from +// a CFA offset. +static void appendLoadRegExpr(SmallVectorImpl<char> &Expr, + int64_t OffsetFromDefCFA) { + // This assumes the top of the DWARF stack contains the CFA. + Expr.push_back(dwarf::DW_OP_dup); + // Add the offset to the register. + appendConstantExpr(Expr, OffsetFromDefCFA, dwarf::DW_OP_plus); + // Dereference the address (loads a 64 bit value).. + Expr.push_back(dwarf::DW_OP_deref); +} - Comment << (NumVGScaledBytes < 0 ? " - " : " + ") - << std::abs(NumVGScaledBytes) << " * VG"; +// Convenience function to create a comment for +// (+/-) NumBytes (* RegScale)? +static void appendOffsetComment(int NumBytes, llvm::raw_string_ostream &Comment, + StringRef RegScale = {}) { + if (NumBytes) { + Comment << (NumBytes < 0 ? " - " : " + ") << std::abs(NumBytes); + if (!RegScale.empty()) + Comment << ' ' << RegScale; } } @@ -5909,19 +6007,26 @@ static MCCFIInstruction createDefCFAExpression(const TargetRegisterInfo &TRI, else Comment << printReg(Reg, &TRI); - // Build up the expression (Reg + NumBytes + NumVGScaledBytes * AArch64::VG) + // Build up the expression (Reg + NumBytes + VG * NumVGScaledBytes) SmallString<64> Expr; unsigned DwarfReg = TRI.getDwarfRegNum(Reg, true); - Expr.push_back((uint8_t)(dwarf::DW_OP_breg0 + DwarfReg)); - Expr.push_back(0); - appendVGScaledOffsetExpr(Expr, NumBytes, NumVGScaledBytes, - TRI.getDwarfRegNum(AArch64::VG, true), Comment); + assert(DwarfReg <= 31 && "DwarfReg out of bounds (0..31)"); + // Reg + NumBytes + Expr.push_back(dwarf::DW_OP_breg0 + DwarfReg); + appendLEB128<LEB128Sign::Signed>(Expr, NumBytes); + appendOffsetComment(NumBytes, Comment); + if (NumVGScaledBytes) { + // + VG * NumVGScaledBytes + appendOffsetComment(NumVGScaledBytes, Comment, "* VG"); + appendReadRegExpr(Expr, TRI.getDwarfRegNum(AArch64::VG, true)); + appendConstantExpr(Expr, NumVGScaledBytes, dwarf::DW_OP_mul); + Expr.push_back(dwarf::DW_OP_plus); + } // Wrap this into DW_CFA_def_cfa. SmallString<64> DefCfaExpr; DefCfaExpr.push_back(dwarf::DW_CFA_def_cfa_expression); - uint8_t buffer[16]; - DefCfaExpr.append(buffer, buffer + encodeULEB128(Expr.size(), buffer)); + appendLEB128<LEB128Sign::Unsigned>(DefCfaExpr, Expr.size()); DefCfaExpr.append(Expr.str()); return MCCFIInstruction::createEscape(nullptr, DefCfaExpr.str(), SMLoc(), Comment.str()); @@ -5941,9 +6046,10 @@ MCCFIInstruction llvm::createDefCFA(const TargetRegisterInfo &TRI, return MCCFIInstruction::cfiDefCfa(nullptr, DwarfReg, (int)Offset.getFixed()); } -MCCFIInstruction llvm::createCFAOffset(const TargetRegisterInfo &TRI, - unsigned Reg, - const StackOffset &OffsetFromDefCFA) { +MCCFIInstruction +llvm::createCFAOffset(const TargetRegisterInfo &TRI, unsigned Reg, + const StackOffset &OffsetFromDefCFA, + std::optional<int64_t> IncomingVGOffsetFromDefCFA) { int64_t NumBytes, NumVGScaledBytes; AArch64InstrInfo::decomposeStackOffsetForDwarfOffsets( OffsetFromDefCFA, NumBytes, NumVGScaledBytes); @@ -5958,17 +6064,32 @@ MCCFIInstruction llvm::createCFAOffset(const TargetRegisterInfo &TRI, llvm::raw_string_ostream Comment(CommentBuffer); Comment << printReg(Reg, &TRI) << " @ cfa"; - // Build up expression (NumBytes + NumVGScaledBytes * AArch64::VG) + // Build up expression (CFA + VG * NumVGScaledBytes + NumBytes) + assert(NumVGScaledBytes && "Expected scalable offset"); SmallString<64> OffsetExpr; - appendVGScaledOffsetExpr(OffsetExpr, NumBytes, NumVGScaledBytes, - TRI.getDwarfRegNum(AArch64::VG, true), Comment); + // + VG * NumVGScaledBytes + StringRef VGRegScale; + if (IncomingVGOffsetFromDefCFA) { + appendLoadRegExpr(OffsetExpr, *IncomingVGOffsetFromDefCFA); + VGRegScale = "* IncomingVG"; + } else { + appendReadRegExpr(OffsetExpr, TRI.getDwarfRegNum(AArch64::VG, true)); + VGRegScale = "* VG"; + } + appendConstantExpr(OffsetExpr, NumVGScaledBytes, dwarf::DW_OP_mul); + appendOffsetComment(NumVGScaledBytes, Comment, VGRegScale); + OffsetExpr.push_back(dwarf::DW_OP_plus); + if (NumBytes) { + // + NumBytes + appendOffsetComment(NumBytes, Comment); + appendConstantExpr(OffsetExpr, NumBytes, dwarf::DW_OP_plus); + } // Wrap this into DW_CFA_expression SmallString<64> CfaExpr; CfaExpr.push_back(dwarf::DW_CFA_expression); - uint8_t buffer[16]; - CfaExpr.append(buffer, buffer + encodeULEB128(DwarfReg, buffer)); - CfaExpr.append(buffer, buffer + encodeULEB128(OffsetExpr.size(), buffer)); + appendLEB128<LEB128Sign::Unsigned>(CfaExpr, DwarfReg); + appendLEB128<LEB128Sign::Unsigned>(CfaExpr, OffsetExpr.size()); CfaExpr.append(OffsetExpr.str()); return MCCFIInstruction::createEscape(nullptr, CfaExpr.str(), SMLoc(), @@ -6597,7 +6718,7 @@ static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO, if (MO.isReg() && MO.getReg().isVirtual()) MI = MRI.getUniqueVRegDef(MO.getReg()); // And it needs to be in the trace (otherwise, it won't have a depth). - if (!MI || MI->getParent() != &MBB || (unsigned)MI->getOpcode() != CombineOpc) + if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc) return false; // Must only used by the user we combine with. if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) @@ -7389,11 +7510,319 @@ static bool getMiscPatterns(MachineInstr &Root, return false; } +/// Check if the given instruction forms a gather load pattern that can be +/// optimized for better Memory-Level Parallelism (MLP). This function +/// identifies chains of NEON lane load instructions that load data from +/// different memory addresses into individual lanes of a 128-bit vector +/// register, then attempts to split the pattern into parallel loads to break +/// the serial dependency between instructions. +/// +/// Pattern Matched: +/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) -> +/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root) +/// +/// Transformed Into: +/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64 +/// to combine the results, enabling better memory-level parallelism. +/// +/// Supported Element Types: +/// - 32-bit elements (LD1i32, 4 lanes total) +/// - 16-bit elements (LD1i16, 8 lanes total) +/// - 8-bit elements (LD1i8, 16 lanes total) +static bool getGatherLanePattern(MachineInstr &Root, + SmallVectorImpl<unsigned> &Patterns, + unsigned LoadLaneOpCode, unsigned NumLanes) { + const MachineFunction *MF = Root.getMF(); + + // Early exit if optimizing for size. + if (MF->getFunction().hasMinSize()) + return false; + + const MachineRegisterInfo &MRI = MF->getRegInfo(); + const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo(); + + // The root of the pattern must load into the last lane of the vector. + if (Root.getOperand(2).getImm() != NumLanes - 1) + return false; + + // Check that we have load into all lanes except lane 0. + // For each load we also want to check that: + // 1. It has a single non-debug use (since we will be replacing the virtual + // register) + // 2. That the addressing mode only uses a single pointer operand + auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); + auto Range = llvm::seq<unsigned>(1, NumLanes - 1); + SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end()); + SmallVector<const MachineInstr *, 16> LoadInstrs; + while (!RemainingLanes.empty() && CurrInstr && + CurrInstr->getOpcode() == LoadLaneOpCode && + MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) && + CurrInstr->getNumOperands() == 4) { + RemainingLanes.erase(CurrInstr->getOperand(2).getImm()); + LoadInstrs.push_back(CurrInstr); + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); + } + + // Check that we have found a match for lanes N-1.. 1. + if (!RemainingLanes.empty()) + return false; + + // Match the SUBREG_TO_REG sequence. + if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG) + return false; + + // Verify that the subreg to reg loads an integer into the first lane. + auto Lane0LoadReg = CurrInstr->getOperand(2).getReg(); + unsigned SingleLaneSizeInBits = 128 / NumLanes; + if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits) + return false; + + // Verify that it also has a single non debug use. + if (!MRI.hasOneNonDBGUse(Lane0LoadReg)) + return false; + + LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg)); + + // If there is any chance of aliasing, do not apply the pattern. + // Walk backward through the MBB starting from Root. + // Exit early if we've encountered all load instructions or hit the search + // limit. + auto MBBItr = Root.getIterator(); + unsigned RemainingSteps = GatherOptSearchLimit; + SmallPtrSet<const MachineInstr *, 16> RemainingLoadInstrs; + RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end()); + const MachineBasicBlock *MBB = Root.getParent(); + + for (; MBBItr != MBB->begin() && RemainingSteps > 0 && + !RemainingLoadInstrs.empty(); + --MBBItr, --RemainingSteps) { + const MachineInstr &CurrInstr = *MBBItr; + + // Remove this instruction from remaining loads if it's one we're tracking. + RemainingLoadInstrs.erase(&CurrInstr); + + // Check for potential aliasing with any of the load instructions to + // optimize. + if (CurrInstr.isLoadFoldBarrier()) + return false; + } + + // If we hit the search limit without finding all load instructions, + // don't match the pattern. + if (RemainingSteps == 0 && !RemainingLoadInstrs.empty()) + return false; + + switch (NumLanes) { + case 4: + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32); + break; + case 8: + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16); + break; + case 16: + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8); + break; + default: + llvm_unreachable("Got bad number of lanes for gather pattern."); + } + + return true; +} + +/// Search for patterns of LD instructions we can optimize. +static bool getLoadPatterns(MachineInstr &Root, + SmallVectorImpl<unsigned> &Patterns) { + + // The pattern searches for loads into single lanes. + switch (Root.getOpcode()) { + case AArch64::LD1i32: + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4); + case AArch64::LD1i16: + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8); + case AArch64::LD1i8: + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16); + default: + return false; + } +} + +/// Generate optimized instruction sequence for gather load patterns to improve +/// Memory-Level Parallelism (MLP). This function transforms a chain of +/// sequential NEON lane loads into parallel vector loads that can execute +/// concurrently. +static void +generateGatherLanePattern(MachineInstr &Root, + SmallVectorImpl<MachineInstr *> &InsInstrs, + SmallVectorImpl<MachineInstr *> &DelInstrs, + DenseMap<Register, unsigned> &InstrIdxForVirtReg, + unsigned Pattern, unsigned NumLanes) { + MachineFunction &MF = *Root.getParent()->getParent(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + + // Gather the initial load instructions to build the pattern. + SmallVector<MachineInstr *, 16> LoadToLaneInstrs; + MachineInstr *CurrInstr = &Root; + for (unsigned i = 0; i < NumLanes - 1; ++i) { + LoadToLaneInstrs.push_back(CurrInstr); + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); + } + + // Sort the load instructions according to the lane. + llvm::sort(LoadToLaneInstrs, + [](const MachineInstr *A, const MachineInstr *B) { + return A->getOperand(2).getImm() > B->getOperand(2).getImm(); + }); + + MachineInstr *SubregToReg = CurrInstr; + LoadToLaneInstrs.push_back( + MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg())); + auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs); + + const TargetRegisterClass *FPR128RegClass = + MRI.getRegClass(Root.getOperand(0).getReg()); + + // Helper lambda to create a LD1 instruction. + auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr, + Register SrcRegister, unsigned Lane, + Register OffsetRegister, + bool OffsetRegisterKillState) { + auto NewRegister = MRI.createVirtualRegister(FPR128RegClass); + MachineInstrBuilder LoadIndexIntoRegister = + BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()), + NewRegister) + .addReg(SrcRegister) + .addImm(Lane) + .addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState)); + InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size())); + InsInstrs.push_back(LoadIndexIntoRegister); + return NewRegister; + }; + + // Helper to create load instruction based on the NumLanes in the NEON + // register we are rewriting. + auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg, + Register OffsetReg, + bool KillState) -> MachineInstrBuilder { + unsigned Opcode; + switch (NumLanes) { + case 4: + Opcode = AArch64::LDRSui; + break; + case 8: + Opcode = AArch64::LDRHui; + break; + case 16: + Opcode = AArch64::LDRBui; + break; + default: + llvm_unreachable( + "Got unsupported number of lanes in machine-combiner gather pattern"); + } + // Immediate offset load + return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg) + .addReg(OffsetReg) + .addImm(0); + }; + + // Load the remaining lanes into register 0. + auto LanesToLoadToReg0 = + llvm::make_range(LoadToLaneInstrsAscending.begin() + 1, + LoadToLaneInstrsAscending.begin() + NumLanes / 2); + Register PrevReg = SubregToReg->getOperand(0).getReg(); + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) { + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, + OffsetRegOperand.getReg(), + OffsetRegOperand.isKill()); + DelInstrs.push_back(LoadInstr); + } + Register LastLoadReg0 = PrevReg; + + // First load into register 1. Perform an integer load to zero out the upper + // lanes in a single instruction. + MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin(); + MachineInstr *OriginalSplitLoad = + *std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2); + Register DestRegForMiddleIndex = MRI.createVirtualRegister( + MRI.getRegClass(Lane0Load->getOperand(0).getReg())); + + const MachineOperand &OriginalSplitToLoadOffsetOperand = + OriginalSplitLoad->getOperand(3); + MachineInstrBuilder MiddleIndexLoadInstr = + CreateLDRInstruction(NumLanes, DestRegForMiddleIndex, + OriginalSplitToLoadOffsetOperand.getReg(), + OriginalSplitToLoadOffsetOperand.isKill()); + + InstrIdxForVirtReg.insert( + std::make_pair(DestRegForMiddleIndex, InsInstrs.size())); + InsInstrs.push_back(MiddleIndexLoadInstr); + DelInstrs.push_back(OriginalSplitLoad); + + // Subreg To Reg instruction for register 1. + Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass); + unsigned SubregType; + switch (NumLanes) { + case 4: + SubregType = AArch64::ssub; + break; + case 8: + SubregType = AArch64::hsub; + break; + case 16: + SubregType = AArch64::bsub; + break; + default: + llvm_unreachable( + "Got invalid NumLanes for machine-combiner gather pattern"); + } + + auto SubRegToRegInstr = + BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()), + DestRegForSubregToReg) + .addImm(0) + .addReg(DestRegForMiddleIndex, getKillRegState(true)) + .addImm(SubregType); + InstrIdxForVirtReg.insert( + std::make_pair(DestRegForSubregToReg, InsInstrs.size())); + InsInstrs.push_back(SubRegToRegInstr); + + // Load remaining lanes into register 1. + auto LanesToLoadToReg1 = + llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1, + LoadToLaneInstrsAscending.end()); + PrevReg = SubRegToRegInstr->getOperand(0).getReg(); + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) { + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, + OffsetRegOperand.getReg(), + OffsetRegOperand.isKill()); + + // Do not add the last reg to DelInstrs - it will be removed later. + if (Index == NumLanes / 2 - 2) { + break; + } + DelInstrs.push_back(LoadInstr); + } + Register LastLoadReg1 = PrevReg; + + // Create the final zip instruction to combine the results. + MachineInstrBuilder ZipInstr = + BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64), + Root.getOperand(0).getReg()) + .addReg(LastLoadReg0) + .addReg(LastLoadReg1); + InsInstrs.push_back(ZipInstr); +} + CombinerObjective AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const { switch (Pattern) { case AArch64MachineCombinerPattern::SUBADD_OP1: case AArch64MachineCombinerPattern::SUBADD_OP2: + case AArch64MachineCombinerPattern::GATHER_LANE_i32: + case AArch64MachineCombinerPattern::GATHER_LANE_i16: + case AArch64MachineCombinerPattern::GATHER_LANE_i8: return CombinerObjective::MustReduceDepth; default: return TargetInstrInfo::getCombinerObjective(Pattern); @@ -7423,6 +7852,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns( if (getMiscPatterns(Root, Patterns)) return true; + // Load patterns + if (getLoadPatterns(Root, Patterns)) + return true; + return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, DoRegPressureReduce); } @@ -8678,6 +9111,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence( MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs); break; } + case AArch64MachineCombinerPattern::GATHER_LANE_i32: { + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, + Pattern, 4); + break; + } + case AArch64MachineCombinerPattern::GATHER_LANE_i16: { + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, + Pattern, 8); + break; + } + case AArch64MachineCombinerPattern::GATHER_LANE_i8: { + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, + Pattern, 16); + break; + } } // end switch (Pattern) // Record MUL and ADD/SUB for deletion diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index 7c255da..179574a 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned { FMULv8i16_indexed_OP2, FNMADD, + + GATHER_LANE_i32, + GATHER_LANE_i16, + GATHER_LANE_i8 }; class AArch64InstrInfo final : public AArch64GenInstrInfo { const AArch64RegisterInfo RI; @@ -642,8 +646,10 @@ bool isNZCVTouchedInInstructionRange(const MachineInstr &DefMI, MCCFIInstruction createDefCFA(const TargetRegisterInfo &TRI, unsigned FrameReg, unsigned Reg, const StackOffset &Offset, bool LastAdjustmentWasScalable = true); -MCCFIInstruction createCFAOffset(const TargetRegisterInfo &MRI, unsigned Reg, - const StackOffset &OffsetFromDefCFA); +MCCFIInstruction +createCFAOffset(const TargetRegisterInfo &MRI, unsigned Reg, + const StackOffset &OffsetFromDefCFA, + std::optional<int64_t> IncomingVGOffsetFromDefCFA); /// emitFrameOffset - Emit instructions as needed to set DestReg to SrcReg /// plus Offset. This is intended to be used from within the prolog/epilog @@ -820,7 +826,8 @@ enum DestructiveInstType { DestructiveBinaryComm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x6), DestructiveBinaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x7), DestructiveTernaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x8), - DestructiveUnaryPassthru = TSFLAG_DESTRUCTIVE_INST_TYPE(0x9), + Destructive2xRegImmUnpred = TSFLAG_DESTRUCTIVE_INST_TYPE(0x9), + DestructiveUnaryPassthru = TSFLAG_DESTRUCTIVE_INST_TYPE(0xa), }; enum FalseLaneType { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ac31236..ce40e20 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -143,7 +143,7 @@ def HasFuseAES : Predicate<"Subtarget->hasFuseAES()">, "fuse-aes">; def HasSVE : Predicate<"Subtarget->isSVEAvailable()">, AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">; -def HasSVEB16B16 : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEB16B16()">, +def HasSVEB16B16 : Predicate<"Subtarget->hasSVEB16B16()">, AssemblerPredicateWithAll<(all_of FeatureSVEB16B16), "sve-b16b16">; def HasSVE2 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">, AssemblerPredicateWithAll<(all_of FeatureSVE2), "sve2">; @@ -248,6 +248,10 @@ def HasSVE_or_SME : Predicate<"Subtarget->isSVEorStreamingSVEAvailable()">, AssemblerPredicateWithAll<(any_of FeatureSVE, FeatureSME), "sve or sme">; +def HasNonStreamingSVE_or_SME2 + : Predicate<"Subtarget->isNonStreamingSVEorSME2Available()">, + AssemblerPredicateWithAll<(any_of FeatureSVE, FeatureSME2), + "sve or sme2">; def HasNonStreamingSVE_or_SME2p1 : Predicate<"Subtarget->isSVEAvailable() ||" "(Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSME2p1())">, @@ -985,6 +989,17 @@ def AArch64fcvtxnv: PatFrags<(ops node:$Rn), [(int_aarch64_neon_fcvtxn node:$Rn), (AArch64fcvtxn_n node:$Rn)]>; +def AArch64fcvtzs_half : SDNode<"AArch64ISD::FCVTZS_HALF", SDTFPExtendOp>; +def AArch64fcvtzu_half : SDNode<"AArch64ISD::FCVTZU_HALF", SDTFPExtendOp>; +def AArch64fcvtas_half : SDNode<"AArch64ISD::FCVTAS_HALF", SDTFPExtendOp>; +def AArch64fcvtau_half : SDNode<"AArch64ISD::FCVTAU_HALF", SDTFPExtendOp>; +def AArch64fcvtms_half : SDNode<"AArch64ISD::FCVTMS_HALF", SDTFPExtendOp>; +def AArch64fcvtmu_half : SDNode<"AArch64ISD::FCVTMU_HALF", SDTFPExtendOp>; +def AArch64fcvtns_half : SDNode<"AArch64ISD::FCVTNS_HALF", SDTFPExtendOp>; +def AArch64fcvtnu_half : SDNode<"AArch64ISD::FCVTNU_HALF", SDTFPExtendOp>; +def AArch64fcvtps_half : SDNode<"AArch64ISD::FCVTPS_HALF", SDTFPExtendOp>; +def AArch64fcvtpu_half : SDNode<"AArch64ISD::FCVTPU_HALF", SDTFPExtendOp>; + //def Aarch64softf32tobf16v8: SDNode<"AArch64ISD::", SDTFPRoundOp>; // Vector immediate ops @@ -2151,7 +2166,7 @@ let Predicates = [HasPAuth] in { i64imm:$Disc, GPR64:$AddrDisc), [], "$AuthVal = $Val">, Sched<[WriteI, ReadI]> { let isCodeGenOnly = 1; - let hasSideEffects = 0; + let hasSideEffects = 1; let mayStore = 0; let mayLoad = 0; let Size = 32; @@ -2656,13 +2671,17 @@ defm ADD : AddSub<0, "add", "sub", add>; defm SUB : AddSub<1, "sub", "add">; def : InstAlias<"mov $dst, $src", - (ADDWri GPR32sponly:$dst, GPR32sp:$src, 0, 0)>; + (ADDWri GPR32sponly:$dst, GPR32sp:$src, + (addsub_shifted_imm32 0, 0))>; def : InstAlias<"mov $dst, $src", - (ADDWri GPR32sp:$dst, GPR32sponly:$src, 0, 0)>; + (ADDWri GPR32sp:$dst, GPR32sponly:$src, + (addsub_shifted_imm32 0, 0))>; def : InstAlias<"mov $dst, $src", - (ADDXri GPR64sponly:$dst, GPR64sp:$src, 0, 0)>; + (ADDXri GPR64sponly:$dst, GPR64sp:$src, + (addsub_shifted_imm64 0, 0))>; def : InstAlias<"mov $dst, $src", - (ADDXri GPR64sp:$dst, GPR64sponly:$src, 0, 0)>; + (ADDXri GPR64sp:$dst, GPR64sponly:$src, + (addsub_shifted_imm64 0, 0))>; defm ADDS : AddSubS<0, "adds", AArch64add_flag, "cmn", "subs", "cmp">; defm SUBS : AddSubS<1, "subs", AArch64sub_flag, "cmp", "adds", "cmn">; @@ -2722,19 +2741,31 @@ def : Pat<(AArch64sub_flag GPR64:$Rn, neg_addsub_shifted_imm64:$imm), (ADDSXri GPR64:$Rn, neg_addsub_shifted_imm64:$imm)>; } -def : InstAlias<"neg $dst, $src", (SUBWrs GPR32:$dst, WZR, GPR32:$src, 0), 3>; -def : InstAlias<"neg $dst, $src", (SUBXrs GPR64:$dst, XZR, GPR64:$src, 0), 3>; +def : InstAlias<"neg $dst, $src", + (SUBWrs GPR32:$dst, WZR, + (arith_shifted_reg32 GPR32:$src, 0)), 3>; +def : InstAlias<"neg $dst, $src", + (SUBXrs GPR64:$dst, XZR, + (arith_shifted_reg64 GPR64:$src, 0)), 3>; def : InstAlias<"neg $dst, $src$shift", - (SUBWrs GPR32:$dst, WZR, GPR32:$src, arith_shift32:$shift), 2>; + (SUBWrs GPR32:$dst, WZR, + (arith_shifted_reg32 GPR32:$src, arith_shift32:$shift)), 2>; def : InstAlias<"neg $dst, $src$shift", - (SUBXrs GPR64:$dst, XZR, GPR64:$src, arith_shift64:$shift), 2>; - -def : InstAlias<"negs $dst, $src", (SUBSWrs GPR32:$dst, WZR, GPR32:$src, 0), 3>; -def : InstAlias<"negs $dst, $src", (SUBSXrs GPR64:$dst, XZR, GPR64:$src, 0), 3>; + (SUBXrs GPR64:$dst, XZR, + (arith_shifted_reg64 GPR64:$src, arith_shift64:$shift)), 2>; + +def : InstAlias<"negs $dst, $src", + (SUBSWrs GPR32:$dst, WZR, + (arith_shifted_reg32 GPR32:$src, 0)), 3>; +def : InstAlias<"negs $dst, $src", + (SUBSXrs GPR64:$dst, XZR, + (arith_shifted_reg64 GPR64:$src, 0)), 3>; def : InstAlias<"negs $dst, $src$shift", - (SUBSWrs GPR32:$dst, WZR, GPR32:$src, arith_shift32:$shift), 2>; + (SUBSWrs GPR32:$dst, WZR, + (arith_shifted_reg32 GPR32:$src, arith_shift32:$shift)), 2>; def : InstAlias<"negs $dst, $src$shift", - (SUBSXrs GPR64:$dst, XZR, GPR64:$src, arith_shift64:$shift), 2>; + (SUBSXrs GPR64:$dst, XZR, + (arith_shifted_reg64 GPR64:$src, arith_shift64:$shift)), 2>; // Unsigned/Signed divide @@ -3161,16 +3192,26 @@ defm ORN : LogicalReg<0b01, 1, "orn", BinOpFrag<(or node:$LHS, (not node:$RHS))>>; defm ORR : LogicalReg<0b01, 0, "orr", or>; -def : InstAlias<"mov $dst, $src", (ORRWrs GPR32:$dst, WZR, GPR32:$src, 0), 2>; -def : InstAlias<"mov $dst, $src", (ORRXrs GPR64:$dst, XZR, GPR64:$src, 0), 2>; - -def : InstAlias<"mvn $Wd, $Wm", (ORNWrs GPR32:$Wd, WZR, GPR32:$Wm, 0), 3>; -def : InstAlias<"mvn $Xd, $Xm", (ORNXrs GPR64:$Xd, XZR, GPR64:$Xm, 0), 3>; +def : InstAlias<"mov $dst, $src", + (ORRWrs GPR32:$dst, WZR, + (logical_shifted_reg32 GPR32:$src, 0)), 2>; +def : InstAlias<"mov $dst, $src", + (ORRXrs GPR64:$dst, XZR, + (logical_shifted_reg64 GPR64:$src, 0)), 2>; + +def : InstAlias<"mvn $Wd, $Wm", + (ORNWrs GPR32:$Wd, WZR, + (logical_shifted_reg32 GPR32:$Wm, 0)), 3>; +def : InstAlias<"mvn $Xd, $Xm", + (ORNXrs GPR64:$Xd, XZR, + (logical_shifted_reg64 GPR64:$Xm, 0)), 3>; def : InstAlias<"mvn $Wd, $Wm$sh", - (ORNWrs GPR32:$Wd, WZR, GPR32:$Wm, logical_shift32:$sh), 2>; + (ORNWrs GPR32:$Wd, WZR, + (logical_shifted_reg32 GPR32:$Wm, logical_shift32:$sh)), 2>; def : InstAlias<"mvn $Xd, $Xm$sh", - (ORNXrs GPR64:$Xd, XZR, GPR64:$Xm, logical_shift64:$sh), 2>; + (ORNXrs GPR64:$Xd, XZR, + (logical_shifted_reg64 GPR64:$Xm, logical_shift64:$sh)), 2>; def : InstAlias<"tst $src1, $src2", (ANDSWri WZR, GPR32:$src1, logical_imm32:$src2), 2>; @@ -3178,14 +3219,18 @@ def : InstAlias<"tst $src1, $src2", (ANDSXri XZR, GPR64:$src1, logical_imm64:$src2), 2>; def : InstAlias<"tst $src1, $src2", - (ANDSWrs WZR, GPR32:$src1, GPR32:$src2, 0), 3>; + (ANDSWrs WZR, GPR32:$src1, + (logical_shifted_reg32 GPR32:$src2, 0)), 3>; def : InstAlias<"tst $src1, $src2", - (ANDSXrs XZR, GPR64:$src1, GPR64:$src2, 0), 3>; + (ANDSXrs XZR, GPR64:$src1, + (logical_shifted_reg64 GPR64:$src2, 0)), 3>; def : InstAlias<"tst $src1, $src2$sh", - (ANDSWrs WZR, GPR32:$src1, GPR32:$src2, logical_shift32:$sh), 2>; + (ANDSWrs WZR, GPR32:$src1, + (logical_shifted_reg32 GPR32:$src2, logical_shift32:$sh)), 2>; def : InstAlias<"tst $src1, $src2$sh", - (ANDSXrs XZR, GPR64:$src1, GPR64:$src2, logical_shift64:$sh), 2>; + (ANDSXrs XZR, GPR64:$src1, + (logical_shifted_reg64 GPR64:$src2, logical_shift64:$sh)), 2>; def : Pat<(not GPR32:$Wm), (ORNWrr WZR, GPR32:$Wm)>; @@ -5707,27 +5752,6 @@ let Predicates = [HasFullFP16] in { // Advanced SIMD two vector instructions. //===----------------------------------------------------------------------===// -defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", abdu>; -// Match UABDL in log2-shuffle patterns. -def : Pat<(abs (v8i16 (sub (zext (v8i8 V64:$opA)), - (zext (v8i8 V64:$opB))))), - (UABDLv8i8_v8i16 V64:$opA, V64:$opB)>; -def : Pat<(abs (v8i16 (sub (zext (extract_high_v16i8 (v16i8 V128:$opA))), - (zext (extract_high_v16i8 (v16i8 V128:$opB)))))), - (UABDLv16i8_v8i16 V128:$opA, V128:$opB)>; -def : Pat<(abs (v4i32 (sub (zext (v4i16 V64:$opA)), - (zext (v4i16 V64:$opB))))), - (UABDLv4i16_v4i32 V64:$opA, V64:$opB)>; -def : Pat<(abs (v4i32 (sub (zext (extract_high_v8i16 (v8i16 V128:$opA))), - (zext (extract_high_v8i16 (v8i16 V128:$opB)))))), - (UABDLv8i16_v4i32 V128:$opA, V128:$opB)>; -def : Pat<(abs (v2i64 (sub (zext (v2i32 V64:$opA)), - (zext (v2i32 V64:$opB))))), - (UABDLv2i32_v2i64 V64:$opA, V64:$opB)>; -def : Pat<(abs (v2i64 (sub (zext (extract_high_v4i32 (v4i32 V128:$opA))), - (zext (extract_high_v4i32 (v4i32 V128:$opB)))))), - (UABDLv4i32_v2i64 V128:$opA, V128:$opB)>; - defm ABS : SIMDTwoVectorBHSD<0, 0b01011, "abs", abs>; defm CLS : SIMDTwoVectorBHS<0, 0b00100, "cls", int_aarch64_neon_cls>; defm CLZ : SIMDTwoVectorBHS<1, 0b00100, "clz", ctlz>; @@ -6055,6 +6079,7 @@ defm MLA : SIMDThreeSameVectorBHSTied<0, 0b10010, "mla", null_frag>; defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", null_frag>; defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>; +let isCommutable = 1 in defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>; defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", TriOpFrag<(add node:$LHS, (abds node:$MHS, node:$RHS))> >; @@ -6552,9 +6577,33 @@ defm UQXTN : SIMDTwoScalarMixedBHS<1, 0b10100, "uqxtn", int_aarch64_neon_scalar defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd", int_aarch64_neon_usqadd>; +// f16 -> s16 conversions +let Predicates = [HasFullFP16] in { + def : Pat<(i16(fp_to_sint_sat_gi f16:$Rn)), (FCVTZSv1f16 f16:$Rn)>; + def : Pat<(i16(fp_to_uint_sat_gi f16:$Rn)), (FCVTZUv1f16 f16:$Rn)>; +} + def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))), (CMLTv1i64rz V64:$Rn)>; +// f16 -> i16 conversions leave the bit pattern in a f32 +class F16ToI16ScalarPat<SDNode cvt_isd, BaseSIMDTwoScalar instr> + : Pat<(f32 (cvt_isd (f16 FPR16:$Rn))), + (f32 (SUBREG_TO_REG (i64 0), (instr FPR16:$Rn), hsub))>; + +let Predicates = [HasFullFP16] in { +def : F16ToI16ScalarPat<AArch64fcvtzs_half, FCVTZSv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtzu_half, FCVTZUv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtas_half, FCVTASv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtau_half, FCVTAUv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtms_half, FCVTMSv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtmu_half, FCVTMUv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtns_half, FCVTNSv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtnu_half, FCVTNUv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtps_half, FCVTPSv1f16>; +def : F16ToI16ScalarPat<AArch64fcvtpu_half, FCVTPUv1f16>; +} + // Round FP64 to BF16. let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))), @@ -6802,40 +6851,47 @@ def : Pat <(f64 (uint_to_fp (i32 // Advanced SIMD three different-sized vector instructions. //===----------------------------------------------------------------------===// -defm ADDHN : SIMDNarrowThreeVectorBHS<0,0b0100,"addhn", int_aarch64_neon_addhn>; -defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn>; -defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>; -defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>; -defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>; -defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", abds>; -defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", abds>; +defm ADDHN : SIMDNarrowThreeVectorBHS<0,0b0100,"addhn", int_aarch64_neon_addhn>; +defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn>; +defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>; +defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>; +let isCommutable = 1 in +defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>; +defm SABAL : SIMDLongThreeVectorTiedBHS<0,0b0101,"sabal", + TriOpFrag<(add node:$LHS, (zext (abds node:$MHS, node:$RHS)))>>; +defm SABDL : SIMDLongThreeVectorBHS<0, 0b0111, "sabdl", + BinOpFrag<(zext (abds node:$LHS, node:$RHS))>>; defm SADDL : SIMDLongThreeVectorBHS< 0, 0b0000, "saddl", - BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>; + BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>; defm SADDW : SIMDWideThreeVectorBHS< 0, 0b0001, "saddw", BinOpFrag<(add node:$LHS, (sext node:$RHS))>>; defm SMLAL : SIMDLongThreeVectorTiedBHS<0, 0b1000, "smlal", - TriOpFrag<(add node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>; + TriOpFrag<(add node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>; defm SMLSL : SIMDLongThreeVectorTiedBHS<0, 0b1010, "smlsl", - TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>; + TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>; defm SMULL : SIMDLongThreeVectorBHS<0, 0b1100, "smull", AArch64smull>; defm SQDMLAL : SIMDLongThreeVectorSQDMLXTiedHS<0, 0b1001, "sqdmlal", saddsat>; defm SQDMLSL : SIMDLongThreeVectorSQDMLXTiedHS<0, 0b1011, "sqdmlsl", ssubsat>; -defm SQDMULL : SIMDLongThreeVectorHS<0, 0b1101, "sqdmull", - int_aarch64_neon_sqdmull>; +defm SQDMULL : SIMDLongThreeVectorHS<0, 0b1101, "sqdmull", int_aarch64_neon_sqdmull>; +let isCommutable = 0 in defm SSUBL : SIMDLongThreeVectorBHS<0, 0b0010, "ssubl", BinOpFrag<(sub (sext node:$LHS), (sext node:$RHS))>>; defm SSUBW : SIMDWideThreeVectorBHS<0, 0b0011, "ssubw", BinOpFrag<(sub node:$LHS, (sext node:$RHS))>>; -defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", abdu>; +defm UABAL : SIMDLongThreeVectorTiedBHS<1, 0b0101, "uabal", + TriOpFrag<(add node:$LHS, (zext (abdu node:$MHS, node:$RHS)))>>; +defm UABDL : SIMDLongThreeVectorBHS<1, 0b0111, "uabdl", + BinOpFrag<(zext (abdu node:$LHS, node:$RHS))>>; defm UADDL : SIMDLongThreeVectorBHS<1, 0b0000, "uaddl", BinOpFrag<(add (zanyext node:$LHS), (zanyext node:$RHS))>>; defm UADDW : SIMDWideThreeVectorBHS<1, 0b0001, "uaddw", BinOpFrag<(add node:$LHS, (zanyext node:$RHS))>>; defm UMLAL : SIMDLongThreeVectorTiedBHS<1, 0b1000, "umlal", - TriOpFrag<(add node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>; + TriOpFrag<(add node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>; defm UMLSL : SIMDLongThreeVectorTiedBHS<1, 0b1010, "umlsl", - TriOpFrag<(sub node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>; + TriOpFrag<(sub node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>; defm UMULL : SIMDLongThreeVectorBHS<1, 0b1100, "umull", AArch64umull>; +let isCommutable = 0 in defm USUBL : SIMDLongThreeVectorBHS<1, 0b0010, "usubl", BinOpFrag<(sub (zanyext node:$LHS), (zanyext node:$RHS))>>; defm USUBW : SIMDWideThreeVectorBHS< 1, 0b0011, "usubw", diff --git a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp index 782d62a7..e69fa32 100644 --- a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp +++ b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp @@ -1193,7 +1193,8 @@ AArch64LoadStoreOpt::mergePairedInsns(MachineBasicBlock::iterator I, // USE kill %w1 ; need to clear kill flag when moving STRWui downwards // STRW %w0 Register Reg = getLdStRegOp(*I).getReg(); - for (MachineInstr &MI : make_range(std::next(I), Paired)) + for (MachineInstr &MI : + make_range(std::next(I->getIterator()), Paired->getIterator())) MI.clearRegisterKills(Reg, TRI); } } diff --git a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp index b97d622..fd4ef2a 100644 --- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -8,8 +8,8 @@ // // This pass performs below peephole optimizations on MIR level. // -// 1. MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri -// MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri +// 1. MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri +// MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri // // 2. MOVi32imm + ADDWrr ==> ADDWRi + ADDWRi // MOVi64imm + ADDXrr ==> ADDXri + ADDXri @@ -128,6 +128,7 @@ struct AArch64MIPeepholeOpt : public MachineFunctionPass { // Strategy used to split logical immediate bitmasks. enum class SplitStrategy { Intersect, + Disjoint, }; template <typename T> bool trySplitLogicalImm(unsigned Opc, MachineInstr &MI, @@ -163,6 +164,7 @@ INITIALIZE_PASS(AArch64MIPeepholeOpt, "aarch64-mi-peephole-opt", template <typename T> static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) { T UImm = static_cast<T>(Imm); + assert(UImm && (UImm != ~static_cast<T>(0)) && "Invalid immediate!"); // The bitmask immediate consists of consecutive ones. Let's say there is // constant 0b00000000001000000000010000000000 which does not consist of @@ -191,18 +193,47 @@ static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) { } template <typename T> +static bool splitDisjointBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, + T &Imm2Enc) { + assert(Imm && (Imm != ~static_cast<T>(0)) && "Invalid immediate!"); + + // Try to split a bitmask of the form 0b00000000011000000000011110000000 into + // two disjoint masks such as 0b00000000011000000000000000000000 and + // 0b00000000000000000000011110000000 where the inclusive/exclusive OR of the + // new masks match the original mask. + unsigned LowestBitSet = llvm::countr_zero(Imm); + unsigned LowestGapBitUnset = + LowestBitSet + llvm::countr_one(Imm >> LowestBitSet); + + // Create a mask for the least significant group of consecutive ones. + assert(LowestGapBitUnset < sizeof(T) * CHAR_BIT && "Undefined behaviour!"); + T NewImm1 = (static_cast<T>(1) << LowestGapBitUnset) - + (static_cast<T>(1) << LowestBitSet); + // Create a disjoint mask for the remaining ones. + T NewImm2 = Imm & ~NewImm1; + + // Do not split if NewImm2 is not a valid bitmask immediate. + if (!AArch64_AM::isLogicalImmediate(NewImm2, RegSize)) + return false; + + Imm1Enc = AArch64_AM::encodeLogicalImmediate(NewImm1, RegSize); + Imm2Enc = AArch64_AM::encodeLogicalImmediate(NewImm2, RegSize); + return true; +} + +template <typename T> bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI, SplitStrategy Strategy, unsigned OtherOpc) { - // Try below transformation. + // Try below transformations. // - // MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri - // MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri + // MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri + // MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri // // The mov pseudo instruction could be expanded to multiple mov instructions // later. Let's try to split the constant operand of mov instruction into two - // bitmask immediates. It makes only two AND instructions instead of multiple - // mov + and instructions. + // bitmask immediates based on the given split strategy. It makes only two + // logical instructions instead of multiple mov + logic instructions. return splitTwoPartImm<T>( MI, @@ -224,6 +255,9 @@ bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI, case SplitStrategy::Intersect: SplitSucc = splitBitmaskImm(Imm, RegSize, Imm0, Imm1); break; + case SplitStrategy::Disjoint: + SplitSucc = splitDisjointBitmaskImm(Imm, RegSize, Imm0, Imm1); + break; } if (SplitSucc) return std::make_pair(Opc, !OtherOpc ? Opc : OtherOpc); @@ -889,6 +923,22 @@ bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { Changed |= trySplitLogicalImm<uint64_t>( AArch64::ANDXri, MI, SplitStrategy::Intersect, AArch64::ANDSXri); break; + case AArch64::EORWrr: + Changed |= trySplitLogicalImm<uint32_t>(AArch64::EORWri, MI, + SplitStrategy::Disjoint); + break; + case AArch64::EORXrr: + Changed |= trySplitLogicalImm<uint64_t>(AArch64::EORXri, MI, + SplitStrategy::Disjoint); + break; + case AArch64::ORRWrr: + Changed |= trySplitLogicalImm<uint32_t>(AArch64::ORRWri, MI, + SplitStrategy::Disjoint); + break; + case AArch64::ORRXrr: + Changed |= trySplitLogicalImm<uint64_t>(AArch64::ORRXri, MI, + SplitStrategy::Disjoint); + break; case AArch64::ORRWrs: Changed |= visitORR(MI); break; diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index 800787c..1fde87e 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -213,9 +213,6 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { /// or return type bool IsSVECC = false; - /// The frame-index for the TPIDR2 object used for lazy saves. - TPIDR2Object TPIDR2; - /// Whether this function changes streaming mode within the function. bool HasStreamingModeChanges = false; @@ -231,25 +228,26 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { // on function entry to record the initial pstate of a function. Register PStateSMReg = MCRegister::NoRegister; - // Holds a pointer to a buffer that is large enough to represent - // all SME ZA state and any additional state required by the - // __arm_sme_save/restore support routines. - Register SMESaveBufferAddr = MCRegister::NoRegister; - - // true if SMESaveBufferAddr is used. - bool SMESaveBufferUsed = false; + // true if PStateSMReg is used. + bool PStateSMRegUsed = false; // Has the PNReg used to build PTRUE instruction. // The PTRUE is used for the LD/ST of ZReg pairs in save and restore. unsigned PredicateRegForFillSpill = 0; - // The stack slots where VG values are stored to. - int64_t VGIdx = std::numeric_limits<int>::max(); - int64_t StreamingVGIdx = std::numeric_limits<int>::max(); - // Holds the SME function attributes (streaming mode, ZA/ZT0 state). SMEAttrs SMEFnAttrs; + // Note: The following properties are only used for the old SME ABI lowering: + /// The frame-index for the TPIDR2 object used for lazy saves. + TPIDR2Object TPIDR2; + // Holds a pointer to a buffer that is large enough to represent + // all SME ZA state and any additional state required by the + // __arm_sme_save/restore support routines. + Register SMESaveBufferAddr = MCRegister::NoRegister; + // true if SMESaveBufferAddr is used. + bool SMESaveBufferUsed = false; + public: AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI); @@ -258,6 +256,13 @@ public: const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB) const override; + // Old SME ABI lowering state getters/setters: + Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; + void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; + unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; }; + void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; }; + TPIDR2Object &getTPIDR2Obj() { return TPIDR2; } + void setPredicateRegForFillSpill(unsigned Reg) { PredicateRegForFillSpill = Reg; } @@ -265,26 +270,15 @@ public: return PredicateRegForFillSpill; } - Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; - void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; - - unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; }; - void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; }; - Register getPStateSMReg() const { return PStateSMReg; }; void setPStateSMReg(Register Reg) { PStateSMReg = Reg; }; - int64_t getVGIdx() const { return VGIdx; }; - void setVGIdx(unsigned Idx) { VGIdx = Idx; }; - - int64_t getStreamingVGIdx() const { return StreamingVGIdx; }; - void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; }; + unsigned isPStateSMRegUsed() const { return PStateSMRegUsed; }; + void setPStateSMRegUsed(bool Used = true) { PStateSMRegUsed = Used; }; bool isSVECC() const { return IsSVECC; }; void setIsSVECC(bool s) { IsSVECC = s; }; - TPIDR2Object &getTPIDR2Obj() { return TPIDR2; } - void initializeBaseYamlFields(const yaml::AArch64FunctionInfo &YamlMFI); unsigned getBytesInStackArgArea() const { return BytesInStackArgArea; } diff --git a/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp b/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp index ff7a0d1..f4a7f77 100644 --- a/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp +++ b/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp @@ -237,8 +237,8 @@ static bool isAddressLdStPair(const MachineInstr *FirstMI, } /// Compare and conditional select. -static bool isCCSelectPair(const MachineInstr *FirstMI, - const MachineInstr &SecondMI) { +static bool isCmpCSelPair(const MachineInstr *FirstMI, + const MachineInstr &SecondMI) { // 32 bits if (SecondMI.getOpcode() == AArch64::CSELWr) { // Assume the 1st instr to be a wildcard if it is unspecified. @@ -279,6 +279,40 @@ static bool isCCSelectPair(const MachineInstr *FirstMI, return false; } +/// Compare and cset. +static bool isCmpCSetPair(const MachineInstr *FirstMI, + const MachineInstr &SecondMI) { + if ((SecondMI.getOpcode() == AArch64::CSINCWr && + SecondMI.getOperand(1).getReg() == AArch64::WZR && + SecondMI.getOperand(2).getReg() == AArch64::WZR) || + (SecondMI.getOpcode() == AArch64::CSINCXr && + SecondMI.getOperand(1).getReg() == AArch64::XZR && + SecondMI.getOperand(2).getReg() == AArch64::XZR)) { + // Assume the 1st instr to be a wildcard if it is unspecified. + if (FirstMI == nullptr) + return true; + + if (FirstMI->definesRegister(AArch64::WZR, /*TRI=*/nullptr) || + FirstMI->definesRegister(AArch64::XZR, /*TRI=*/nullptr)) + switch (FirstMI->getOpcode()) { + case AArch64::SUBSWrs: + case AArch64::SUBSXrs: + return !AArch64InstrInfo::hasShiftedReg(*FirstMI); + case AArch64::SUBSWrx: + case AArch64::SUBSXrx: + case AArch64::SUBSXrx64: + return !AArch64InstrInfo::hasExtendedReg(*FirstMI); + case AArch64::SUBSWri: + case AArch64::SUBSWrr: + case AArch64::SUBSXri: + case AArch64::SUBSXrr: + return true; + } + } + + return false; +} + // Arithmetic and logic. static bool isArithmeticLogicPair(const MachineInstr *FirstMI, const MachineInstr &SecondMI) { @@ -465,7 +499,9 @@ static bool shouldScheduleAdjacent(const TargetInstrInfo &TII, return true; if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI)) return true; - if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI)) + if (ST.hasFuseCmpCSel() && isCmpCSelPair(FirstMI, SecondMI)) + return true; + if (ST.hasFuseCmpCSet() && isCmpCSetPair(FirstMI, SecondMI)) return true; if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI)) return true; diff --git a/llvm/lib/Target/AArch64/AArch64Processors.td b/llvm/lib/Target/AArch64/AArch64Processors.td index adc984a..d5f4e91 100644 --- a/llvm/lib/Target/AArch64/AArch64Processors.td +++ b/llvm/lib/Target/AArch64/AArch64Processors.td @@ -22,7 +22,8 @@ def TuneA320 : SubtargetFeature<"a320", "ARMProcFamily", "CortexA320", FeatureFuseAES, FeatureFuseAdrpAdd, FeaturePostRAScheduler, - FeatureUseWzrToVecMove]>; + FeatureUseWzrToVecMove, + FeatureUseFixedOverScalableIfEqualCost]>; def TuneA53 : SubtargetFeature<"a53", "ARMProcFamily", "CortexA53", "Cortex-A53 ARM processors", [ @@ -45,7 +46,8 @@ def TuneA510 : SubtargetFeature<"a510", "ARMProcFamily", "CortexA510", FeatureFuseAES, FeatureFuseAdrpAdd, FeaturePostRAScheduler, - FeatureUseWzrToVecMove + FeatureUseWzrToVecMove, + FeatureUseFixedOverScalableIfEqualCost ]>; def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520", @@ -53,7 +55,8 @@ def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520", FeatureFuseAES, FeatureFuseAdrpAdd, FeaturePostRAScheduler, - FeatureUseWzrToVecMove]>; + FeatureUseWzrToVecMove, + FeatureUseFixedOverScalableIfEqualCost]>; def TuneA520AE : SubtargetFeature<"a520ae", "ARMProcFamily", "CortexA520", "Cortex-A520AE ARM processors", [ @@ -131,6 +134,8 @@ def TuneA78 : SubtargetFeature<"a78", "ARMProcFamily", "CortexA78", FeatureCmpBccFusion, FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureAddrLSLSlow14, FeatureALULSLFast, FeaturePostRAScheduler, @@ -143,6 +148,8 @@ def TuneA78AE : SubtargetFeature<"a78ae", "ARMProcFamily", FeatureCmpBccFusion, FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureAddrLSLSlow14, FeatureALULSLFast, FeaturePostRAScheduler, @@ -155,6 +162,8 @@ def TuneA78C : SubtargetFeature<"a78c", "ARMProcFamily", FeatureCmpBccFusion, FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureAddrLSLSlow14, FeatureALULSLFast, FeaturePostRAScheduler, @@ -166,6 +175,8 @@ def TuneA710 : SubtargetFeature<"a710", "ARMProcFamily", "CortexA710", FeatureCmpBccFusion, FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureALULSLFast, FeaturePostRAScheduler, FeatureEnableSelectOptimize, @@ -178,6 +189,8 @@ def TuneA715 : SubtargetFeature<"a715", "ARMProcFamily", "CortexA715", FeatureCmpBccFusion, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive]>; @@ -188,6 +201,8 @@ def TuneA720 : SubtargetFeature<"a720", "ARMProcFamily", "CortexA720", FeatureCmpBccFusion, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive]>; @@ -198,6 +213,8 @@ def TuneA720AE : SubtargetFeature<"a720ae", "ARMProcFamily", "CortexA720", FeatureCmpBccFusion, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive]>; @@ -209,6 +226,8 @@ def TuneA725 : SubtargetFeature<"cortex-a725", "ARMProcFamily", FeatureCmpBccFusion, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive]>; @@ -259,6 +278,8 @@ def TuneX4 : SubtargetFeature<"cortex-x4", "ARMProcFamily", "CortexX4", "Cortex-X4 ARM processors", [ FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureFuseAES, FeaturePostRAScheduler, FeatureEnableSelectOptimize, @@ -270,6 +291,8 @@ def TuneX925 : SubtargetFeature<"cortex-x925", "ARMProcFamily", "CortexX925", "Cortex-X925 ARM processors",[ FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureFuseAES, FeaturePostRAScheduler, FeatureEnableSelectOptimize, @@ -318,8 +341,9 @@ def TuneAppleA7 : SubtargetFeature<"apple-a7", "ARMProcFamily", "AppleA7", FeatureFuseAES, FeatureFuseCryptoEOR, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing, + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64, FeatureZCZeroingFPWorkaround]>; def TuneAppleA10 : SubtargetFeature<"apple-a10", "ARMProcFamily", "AppleA10", @@ -332,8 +356,9 @@ def TuneAppleA10 : SubtargetFeature<"apple-a10", "ARMProcFamily", "AppleA10", FeatureFuseCryptoEOR, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA11 : SubtargetFeature<"apple-a11", "ARMProcFamily", "AppleA11", "Apple A11", [ @@ -345,8 +370,9 @@ def TuneAppleA11 : SubtargetFeature<"apple-a11", "ARMProcFamily", "AppleA11", FeatureFuseCryptoEOR, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA12 : SubtargetFeature<"apple-a12", "ARMProcFamily", "AppleA12", "Apple A12", [ @@ -358,8 +384,9 @@ def TuneAppleA12 : SubtargetFeature<"apple-a12", "ARMProcFamily", "AppleA12", FeatureFuseCryptoEOR, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA13 : SubtargetFeature<"apple-a13", "ARMProcFamily", "AppleA13", "Apple A13", [ @@ -371,8 +398,9 @@ def TuneAppleA13 : SubtargetFeature<"apple-a13", "ARMProcFamily", "AppleA13", FeatureFuseCryptoEOR, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA14 : SubtargetFeature<"apple-a14", "ARMProcFamily", "AppleA14", "Apple A14", [ @@ -384,13 +412,14 @@ def TuneAppleA14 : SubtargetFeature<"apple-a14", "ARMProcFamily", "AppleA14", FeatureFuseAddress, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseCryptoEOR, FeatureFuseLiterals, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA15 : SubtargetFeature<"apple-a15", "ARMProcFamily", "AppleA15", "Apple A15", [ @@ -402,13 +431,14 @@ def TuneAppleA15 : SubtargetFeature<"apple-a15", "ARMProcFamily", "AppleA15", FeatureFuseAdrpAdd, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseCryptoEOR, FeatureFuseLiterals, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA16 : SubtargetFeature<"apple-a16", "ARMProcFamily", "AppleA16", "Apple A16", [ @@ -420,13 +450,14 @@ def TuneAppleA16 : SubtargetFeature<"apple-a16", "ARMProcFamily", "AppleA16", FeatureFuseAdrpAdd, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseCryptoEOR, FeatureFuseLiterals, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleA17 : SubtargetFeature<"apple-a17", "ARMProcFamily", "AppleA17", "Apple A17", [ @@ -438,13 +469,14 @@ def TuneAppleA17 : SubtargetFeature<"apple-a17", "ARMProcFamily", "AppleA17", FeatureFuseAdrpAdd, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseCryptoEOR, FeatureFuseLiterals, FeatureStorePairSuppress, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneAppleM4 : SubtargetFeature<"apple-m4", "ARMProcFamily", "AppleM4", "Apple M4", [ @@ -456,13 +488,13 @@ def TuneAppleM4 : SubtargetFeature<"apple-m4", "ARMProcFamily", "AppleM4", FeatureFuseAdrpAdd, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseCryptoEOR, FeatureFuseLiterals, FeatureZCRegMoveGPR64, - FeatureZCRegMoveFPR64, - FeatureZCZeroing - ]>; + FeatureZCRegMoveFPR128, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneExynosM3 : SubtargetFeature<"exynosm3", "ARMProcFamily", "ExynosM3", "Samsung Exynos-M3 processors", @@ -470,7 +502,7 @@ def TuneExynosM3 : SubtargetFeature<"exynosm3", "ARMProcFamily", "ExynosM3", FeatureForce32BitJumpTables, FeatureFuseAddress, FeatureFuseAES, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseAdrpAdd, FeatureFuseLiterals, FeatureStorePairSuppress, @@ -488,19 +520,21 @@ def TuneExynosM4 : SubtargetFeature<"exynosm4", "ARMProcFamily", "ExynosM3", FeatureFuseAddress, FeatureFuseAES, FeatureFuseArithmeticLogic, - FeatureFuseCCSelect, + FeatureFuseCmpCSel, FeatureFuseAdrpAdd, FeatureFuseLiterals, FeatureStorePairSuppress, FeatureALULSLFast, FeaturePostRAScheduler, - FeatureZCZeroing]>; + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64]>; def TuneKryo : SubtargetFeature<"kryo", "ARMProcFamily", "Kryo", "Qualcomm Kryo processors", [ FeaturePostRAScheduler, FeaturePredictableSelectIsExpensive, - FeatureZCZeroing, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64, FeatureALULSLFast, FeatureStorePairSuppress]>; @@ -508,7 +542,8 @@ def TuneFalkor : SubtargetFeature<"falkor", "ARMProcFamily", "Falkor", "Qualcomm Falkor processors", [ FeaturePostRAScheduler, FeaturePredictableSelectIsExpensive, - FeatureZCZeroing, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64, FeatureStorePairSuppress, FeatureALULSLFast, FeatureSlowSTRQro]>; @@ -533,6 +568,8 @@ def TuneNeoverseN2 : SubtargetFeature<"neoversen2", "ARMProcFamily", "NeoverseN2 "Neoverse N2 ARM processors", [ FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureALULSLFast, FeaturePostRAScheduler, FeatureEnableSelectOptimize, @@ -544,6 +581,8 @@ def TuneNeoverseN3 : SubtargetFeature<"neoversen3", "ARMProcFamily", "NeoverseN3 FeaturePostRAScheduler, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive]>; @@ -560,6 +599,8 @@ def TuneNeoverseV1 : SubtargetFeature<"neoversev1", "ARMProcFamily", "NeoverseV1 "Neoverse V1 ARM processors", [ FeatureFuseAES, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureAddrLSLSlow14, FeatureALULSLFast, FeaturePostRAScheduler, @@ -572,6 +613,8 @@ def TuneNeoverseV2 : SubtargetFeature<"neoversev2", "ARMProcFamily", "NeoverseV2 FeatureFuseAES, FeatureCmpBccFusion, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeatureALULSLFast, FeaturePostRAScheduler, FeatureEnableSelectOptimize, @@ -585,6 +628,8 @@ def TuneNeoverseV3 : SubtargetFeature<"neoversev3", "ARMProcFamily", "NeoverseV3 FeatureFuseAES, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeaturePostRAScheduler, FeatureEnableSelectOptimize, FeatureAvoidLDAPUR, @@ -595,6 +640,8 @@ def TuneNeoverseV3AE : SubtargetFeature<"neoversev3AE", "ARMProcFamily", "Neover FeatureFuseAES, FeatureALULSLFast, FeatureFuseAdrpAdd, + FeatureFuseCmpCSel, + FeatureFuseCmpCSet, FeaturePostRAScheduler, FeatureEnableSelectOptimize, FeatureAvoidLDAPUR, @@ -604,7 +651,8 @@ def TuneSaphira : SubtargetFeature<"saphira", "ARMProcFamily", "Saphira", "Qualcomm Saphira processors", [ FeaturePostRAScheduler, FeaturePredictableSelectIsExpensive, - FeatureZCZeroing, + FeatureZCZeroingGPR32, + FeatureZCZeroingGPR64, FeatureStorePairSuppress, FeatureALULSLFast]>; @@ -756,7 +804,6 @@ def ProcessorFeatures { FeatureSB, FeaturePAuth, FeatureSSBS, FeatureSVE, FeatureSVE2, FeatureComplxNum, FeatureCRC, FeatureDotProd, FeatureFPARMv8,FeatureFullFP16, FeatureJS, FeatureLSE, - FeatureUseFixedOverScalableIfEqualCost, FeatureRAS, FeatureRCPC, FeatureRDM, FeatureFPAC]; list<SubtargetFeature> A520 = [HasV9_2aOps, FeaturePerfMon, FeatureAM, FeatureMTE, FeatureETE, FeatureSVEBitPerm, @@ -766,7 +813,6 @@ def ProcessorFeatures { FeatureSVE, FeatureSVE2, FeatureBF16, FeatureComplxNum, FeatureCRC, FeatureFPARMv8, FeatureFullFP16, FeatureMatMulInt8, FeatureJS, FeatureNEON, FeatureLSE, FeatureRAS, FeatureRCPC, FeatureRDM, - FeatureUseFixedOverScalableIfEqualCost, FeatureDotProd, FeatureFPAC]; list<SubtargetFeature> A520AE = [HasV9_2aOps, FeaturePerfMon, FeatureAM, FeatureMTE, FeatureETE, FeatureSVEBitPerm, diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index db27ca9..0d8cb3a 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -39,11 +39,18 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2, def AArch64CoalescerBarrier : SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>; -def AArch64VGSave : SDNode<"AArch64ISD::VG_SAVE", SDTypeProfile<0, 0, []>, - [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>; +def AArch64EntryPStateSM + : SDNode<"AArch64ISD::ENTRY_PSTATE_SM", SDTypeProfile<1, 0, + [SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect]>; -def AArch64VGRestore : SDNode<"AArch64ISD::VG_RESTORE", SDTypeProfile<0, 0, []>, - [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>; +let usesCustomInserter = 1 in { + def EntryPStateSM : Pseudo<(outs GPR64:$is_streaming), (ins), []>, Sched<[]> {} +} +def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>; + +//===----------------------------------------------------------------------===// +// Old SME ABI lowering ISD nodes/pseudos (deprecated) +//===----------------------------------------------------------------------===// def AArch64AllocateZABuffer : SDNode<"AArch64ISD::ALLOCATE_ZA_BUFFER", SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>, @@ -54,10 +61,10 @@ let usesCustomInserter = 1, Defs = [SP], Uses = [SP] in { def : Pat<(i64 (AArch64AllocateZABuffer GPR64:$size)), (AllocateZABuffer $size)>; -def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 1, - [SDTCisInt<0>]>, [SDNPHasChain, SDNPMayStore]>; +def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 2, + [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain, SDNPMayStore]>; let usesCustomInserter = 1 in { - def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {} + def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer, GPR64:$save_slices), [(AArch64InitTPIDR2Obj GPR64:$buffer, GPR64:$save_slices)]>, Sched<[WriteI]> {} } // Nodes to allocate a save buffer for SME. @@ -78,6 +85,30 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)), (AllocateSMESaveBuffer $size)>; //===----------------------------------------------------------------------===// +// New SME ABI lowering ISD nodes/pseudos (-aarch64-new-sme-abi) +//===----------------------------------------------------------------------===// + +let hasSideEffects = 1, isMeta = 1 in { + def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; + def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; +} + +def CommitZASavePseudo + : Pseudo<(outs), + (ins GPR64:$tpidr2_el0, i1imm:$zero_za, i64imm:$commit_routine, variable_ops), []>, + Sched<[]>; + +def AArch64_inout_za_use + : SDNode<"AArch64ISD::INOUT_ZA_USE", SDTypeProfile<0, 0,[]>, + [SDNPHasChain, SDNPInGlue]>; +def : Pat<(AArch64_inout_za_use), (InOutZAUsePseudo)>; + +def AArch64_requires_za_save + : SDNode<"AArch64ISD::REQUIRES_ZA_SAVE", SDTypeProfile<0, 0,[]>, + [SDNPHasChain, SDNPInGlue]>; +def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>; + +//===----------------------------------------------------------------------===// // Instruction naming conventions. //===----------------------------------------------------------------------===// @@ -325,16 +356,6 @@ def : Pat<(AArch64_smstart (i32 svcr_op:$pstate)), def : Pat<(AArch64_smstop (i32 svcr_op:$pstate)), (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>; - -// Pseudo to insert cfi_offset/cfi_restore instructions. Used to save or restore -// the streaming value of VG around streaming-mode changes in locally-streaming -// functions. -def VGSavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; -def : Pat<(AArch64VGSave), (VGSavePseudo)>; - -def VGRestorePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; -def : Pat<(AArch64VGRestore), (VGRestorePseudo)>; - //===----------------------------------------------------------------------===// // SME2 Instructions //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 0c4b4f4..bc65af2 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -650,7 +650,7 @@ let Predicates = [HasSVE_or_SME, UseExperimentalZeroingPseudos] in { let Predicates = [HasSVE_or_SME] in { defm ADD_ZI : sve_int_arith_imm0<0b000, "add", add>; - defm SUB_ZI : sve_int_arith_imm0<0b001, "sub", sub>; + defm SUB_ZI : sve_int_arith_imm0<0b001, "sub", sub, add>; defm SUBR_ZI : sve_int_arith_imm0<0b011, "subr", AArch64subr>; defm SQADD_ZI : sve_int_arith_imm0_ssat<0b100, "sqadd", saddsat, ssubsat>; defm UQADD_ZI : sve_int_arith_imm0<0b101, "uqadd", uaddsat>; @@ -1021,7 +1021,9 @@ let Predicates = [HasNonStreamingSVE_or_SME2p2] in { let Predicates = [HasSVE_or_SME] in { defm INSR_ZR : sve_int_perm_insrs<"insr", AArch64insr>; defm INSR_ZV : sve_int_perm_insrv<"insr", AArch64insr>; - defm EXT_ZZI : sve_int_perm_extract_i<"ext", AArch64ext>; + defm EXT_ZZI : sve_int_perm_extract_i<"ext", AArch64ext, "EXT_ZZI_CONSTRUCTIVE">; + + def EXT_ZZI_CONSTRUCTIVE : UnpredRegImmPseudo<ZPR8, imm0_255>; defm RBIT_ZPmZ : sve_int_perm_rev_rbit<"rbit", AArch64rbit_mt>; defm REVB_ZPmZ : sve_int_perm_rev_revb<"revb", AArch64revb_mt>; @@ -2131,21 +2133,37 @@ let Predicates = [HasSVE_or_SME] in { (LASTB_VPZ_D (PTRUE_D 31), ZPR:$Z1), dsub))>; // Splice with lane bigger or equal to 0 - foreach VT = [nxv16i8] in + foreach VT = [nxv16i8] in { def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_255 i32:$index)))), (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + let AddedComplexity = 1 in + def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_255 i32:$index)))), + (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>; + } - foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in + foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in { def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_127 i32:$index)))), (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + let AddedComplexity = 1 in + def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_127 i32:$index)))), + (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>; + } - foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in + foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in { def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_63 i32:$index)))), (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + let AddedComplexity = 1 in + def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_63 i32:$index)))), + (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>; + } - foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in + foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in { def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_31 i32:$index)))), (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + let AddedComplexity = 1 in + def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_31 i32:$index)))), + (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>; + } defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>; defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>; @@ -4390,7 +4408,7 @@ def : InstAlias<"pfalse\t$Pd", (PFALSE PPRorPNR8:$Pd), 0>; // Non-widening BFloat16 to BFloat16 instructions //===----------------------------------------------------------------------===// -let Predicates = [HasSVEB16B16] in { +let Predicates = [HasSVEB16B16, HasNonStreamingSVE_or_SME2] in { defm BFADD_ZZZ : sve_fp_3op_u_zd_bfloat<0b000, "bfadd", AArch64fadd>; defm BFSUB_ZZZ : sve_fp_3op_u_zd_bfloat<0b001, "bfsub", AArch64fsub>; defm BFMUL_ZZZ : sve_fp_3op_u_zd_bfloat<0b010, "bfmul", AArch64fmul>; @@ -4423,9 +4441,9 @@ defm BFMLS_ZZZI : sve_fp_fma_by_indexed_elem_bfloat<"bfmls", 0b11, AArch64fmlsid defm BFMUL_ZZZI : sve_fp_fmul_by_indexed_elem_bfloat<"bfmul", AArch64fmulidx>; defm BFCLAMP_ZZZ : sve_fp_clamp_bfloat<"bfclamp", AArch64fclamp>; -} // End HasSVEB16B16 +} // End HasSVEB16B16, HasNonStreamingSVE_or_SME2 -let Predicates = [HasSVEB16B16, UseExperimentalZeroingPseudos] in { +let Predicates = [HasSVEB16B16, HasNonStreamingSVE_or_SME2, UseExperimentalZeroingPseudos] in { defm BFADD_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fadd>; defm BFSUB_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fsub>; defm BFMUL_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmul>; @@ -4433,7 +4451,7 @@ defm BFMAXNM_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmaxnm>; defm BFMINNM_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fminnm>; defm BFMIN_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmin>; defm BFMAX_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmax>; -} // HasSVEB16B16, UseExperimentalZeroingPseudos +} // HasSVEB16B16, HasNonStreamingSVE_or_SME2, UseExperimentalZeroingPseudos let Predicates = [HasSVEBFSCALE] in { def BFSCALE_ZPZZ : sve_fp_2op_p_zds_bfscale<0b1001, "bfscale", DestructiveBinary>; diff --git a/llvm/lib/Target/AArch64/AArch64SchedA320.td b/llvm/lib/Target/AArch64/AArch64SchedA320.td index 89ed1338..5ec95c7 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedA320.td +++ b/llvm/lib/Target/AArch64/AArch64SchedA320.td @@ -847,7 +847,7 @@ def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instregex "^[SU]XTB_ZPmZ "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_B)>; +def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; // Extract narrow saturating def : InstRW<[CortexA320Write<4, CortexA320UnitVALU>], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]", diff --git a/llvm/lib/Target/AArch64/AArch64SchedA510.td b/llvm/lib/Target/AArch64/AArch64SchedA510.td index 9456878..356e3fa 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedA510.td +++ b/llvm/lib/Target/AArch64/AArch64SchedA510.td @@ -825,7 +825,7 @@ def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instregex "^[SU]XTB_ZPmZ "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_B)>; +def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; // Extract narrow saturating def : InstRW<[CortexA510Write<4, CortexA510UnitVALU>], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]", @@ -1016,7 +1016,7 @@ def : InstRW<[CortexA510MCWrite<16, 13, CortexA510UnitVALU>], (instrs FADDA_VPZ_ def : InstRW<[CortexA510MCWrite<8, 5, CortexA510UnitVALU>], (instrs FADDA_VPZ_D)>; // Floating point compare -def : InstRW<[CortexA510Write<4, CortexA510UnitVALU>], (instregex "^FACG[ET]_PPzZZ_[HSD]", +def : InstRW<[CortexA510MCWrite<4, 2, CortexA510UnitVALU>], (instregex "^FACG[ET]_PPzZZ_[HSD]", "^FCM(EQ|GE|GT|NE)_PPzZ[0Z]_[HSD]", "^FCM(LE|LT)_PPzZ0_[HSD]", "^FCMUO_PPzZZ_[HSD]")>; diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td index 91a7079..e798222 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td @@ -1785,7 +1785,7 @@ def : InstRW<[N2Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]", "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[N2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>; +def : InstRW<[N2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; // Extract narrow saturating def : InstRW<[N2Write_4c_1V1], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]$", diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td index ecfb124..e44d40f 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td @@ -1757,7 +1757,7 @@ def : InstRW<[N3Write_2c_1V], (instregex "^[SU]XTB_ZPmZ_[HSD]", "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[N3Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>; +def : InstRW<[N3Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; // Extract narrow saturating def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]$", diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td index 3686654..44625a2 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td @@ -1575,7 +1575,7 @@ def : InstRW<[V1Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]", "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[V1Write_2c_1V01], (instrs EXT_ZZI)>; +def : InstRW<[V1Write_2c_1V01], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE)>; // Extract/insert operation, SIMD and FP scalar form def : InstRW<[V1Write_3c_1V1], (instregex "^LAST[AB]_VPZ_[BHSD]$", diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td index b2c3da0..6261220 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td @@ -2272,7 +2272,7 @@ def : InstRW<[V2Write_2c_1V13], (instregex "^[SU]XTB_ZPmZ_[HSD]", "^[SU]XTW_ZPmZ_[D]")>; // Extract -def : InstRW<[V2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>; +def : InstRW<[V2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; // Extract narrow saturating def : InstRW<[V2Write_4c_1V13], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]", diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp index 8a5b5ba..d3b1aa6 100644 --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -182,37 +182,25 @@ SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall( const AArch64Subtarget &STI = DAG.getMachineFunction().getSubtarget<AArch64Subtarget>(); const AArch64TargetLowering *TLI = STI.getTargetLowering(); - TargetLowering::ArgListEntry DstEntry; - DstEntry.Ty = PointerType::getUnqual(*DAG.getContext()); - DstEntry.Node = Dst; TargetLowering::ArgListTy Args; - Args.push_back(DstEntry); + Args.emplace_back(Dst, PointerType::getUnqual(*DAG.getContext())); RTLIB::Libcall NewLC; switch (LC) { case RTLIB::MEMCPY: { NewLC = RTLIB::SC_MEMCPY; - TargetLowering::ArgListEntry Entry; - Entry.Ty = PointerType::getUnqual(*DAG.getContext()); - Entry.Node = Src; - Args.push_back(Entry); + Args.emplace_back(Src, PointerType::getUnqual(*DAG.getContext())); break; } case RTLIB::MEMMOVE: { NewLC = RTLIB::SC_MEMMOVE; - TargetLowering::ArgListEntry Entry; - Entry.Ty = PointerType::getUnqual(*DAG.getContext()); - Entry.Node = Src; - Args.push_back(Entry); + Args.emplace_back(Src, PointerType::getUnqual(*DAG.getContext())); break; } case RTLIB::MEMSET: { NewLC = RTLIB::SC_MEMSET; - TargetLowering::ArgListEntry Entry; - Entry.Ty = Type::getInt32Ty(*DAG.getContext()); - Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32); - Entry.Node = Src; - Args.push_back(Entry); + Args.emplace_back(DAG.getZExtOrTrunc(Src, DL, MVT::i32), + Type::getInt32Ty(*DAG.getContext())); break; } default: @@ -221,10 +209,7 @@ SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall( EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout()); SDValue Symbol = DAG.getExternalSymbol(TLI->getLibcallName(NewLC), PointerVT); - TargetLowering::ArgListEntry SizeEntry; - SizeEntry.Node = Size; - SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext()); - Args.push_back(SizeEntry); + Args.emplace_back(Size, DAG.getDataLayout().getIntPtrType(*DAG.getContext())); TargetLowering::CallLoweringInfo CLI(DAG); PointerType *RetTy = PointerType::getUnqual(*DAG.getContext()); diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp index f136a184..a67bd42 100644 --- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp +++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp @@ -585,8 +585,7 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) { ClMaxLifetimes); if (StandardLifetime) { IntrinsicInst *Start = Info.LifetimeStart[0]; - uint64_t Size = - cast<ConstantInt>(Start->getArgOperand(0))->getZExtValue(); + uint64_t Size = *Info.AI->getAllocationSize(*DL); Size = alignTo(Size, kTagGranuleSize); tagAlloca(AI, Start->getNextNode(), TagPCall, Size); diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h index 061ed61..671df35 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -212,6 +212,13 @@ public: return hasSVE() || isStreamingSVEAvailable(); } + /// Returns true if the target has access to either the full range of SVE + /// instructions, or the streaming-compatible subset of SVE instructions + /// available to SME2. + bool isNonStreamingSVEorSME2Available() const { + return isSVEAvailable() || (isSVEorStreamingSVEAvailable() && hasSME2()); + } + unsigned getMinVectorRegisterBitWidth() const { // Don't assume any minimum vector size when PSTATE.SM may not be 0, because // we don't yet support streaming-compatible codegen support that we trust @@ -239,8 +246,8 @@ public: /// Return true if the CPU supports any kind of instruction fusion. bool hasFusion() const { return hasArithmeticBccFusion() || hasArithmeticCbzFusion() || - hasFuseAES() || hasFuseArithmeticLogic() || hasFuseCCSelect() || - hasFuseAdrpAdd() || hasFuseLiterals(); + hasFuseAES() || hasFuseArithmeticLogic() || hasFuseCmpCSel() || + hasFuseCmpCSet() || hasFuseAdrpAdd() || hasFuseLiterals(); } unsigned getEpilogueVectorizationMinVF() const { @@ -451,12 +458,6 @@ public: return "__chkstk"; } - const char* getSecurityCheckCookieName() const { - if (isWindowsArm64EC()) - return "#__security_check_cookie_arm64ec"; - return "__security_check_cookie"; - } - /// Choose a method of checking LR before performing a tail call. AArch64PAuth::AuthCheckMethod getAuthenticatedLRCheckMethod(const MachineFunction &MF) const; diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index 95eab16..e67bd58 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -224,6 +224,11 @@ static cl::opt<bool> cl::desc("Enable Machine Pipeliner for AArch64"), cl::init(false), cl::Hidden); +static cl::opt<bool> + EnableNewSMEABILowering("aarch64-new-sme-abi", + cl::desc("Enable new lowering for the SME ABI"), + cl::init(false), cl::Hidden); + extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() { // Register the target. @@ -263,6 +268,7 @@ LLVMInitializeAArch64Target() { initializeLDTLSCleanupPass(PR); initializeKCFIPass(PR); initializeSMEABIPass(PR); + initializeMachineSMEABIPass(PR); initializeSMEPeepholeOptPass(PR); initializeSVEIntrinsicOptsPass(PR); initializeAArch64SpeculationHardeningPass(PR); @@ -367,7 +373,8 @@ AArch64TargetMachine::AArch64TargetMachine(const Target &T, const Triple &TT, computeDefaultCPU(TT, CPU), FS, Options, getEffectiveRelocModel(TT, RM), getEffectiveAArch64CodeModel(TT, CM, JIT), OL), - TLOF(createTLOF(getTargetTriple())), isLittle(LittleEndian) { + TLOF(createTLOF(getTargetTriple())), isLittle(LittleEndian), + UseNewSMEABILowering(EnableNewSMEABILowering) { initAsmInfo(); if (TT.isOSBinFormatMachO()) { @@ -668,10 +675,12 @@ void AArch64PassConfig::addIRPasses() { addPass(createInterleavedAccessPass()); } - // Expand any functions marked with SME attributes which require special - // changes for the calling convention or that require the lazy-saving - // mechanism specified in the SME ABI. - addPass(createSMEABIPass()); + if (!EnableNewSMEABILowering) { + // Expand any functions marked with SME attributes which require special + // changes for the calling convention or that require the lazy-saving + // mechanism specified in the SME ABI. + addPass(createSMEABIPass()); + } // Add Control Flow Guard checks. if (TM->getTargetTriple().isOSWindows()) { @@ -782,6 +791,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() { } void AArch64PassConfig::addMachineSSAOptimization() { + if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None) + addPass(createMachineSMEABIPass()); + if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt) addPass(createSMEPeepholeOptPass()); @@ -812,6 +824,9 @@ bool AArch64PassConfig::addILPOpts() { } void AArch64PassConfig::addPreRegAlloc() { + if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering) + addPass(createMachineSMEABIPass()); + // Change dead register definitions to refer to the zero register. if (TM->getOptLevel() != CodeGenOptLevel::None && EnableDeadRegisterElimination) diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.h b/llvm/lib/Target/AArch64/AArch64TargetMachine.h index b9e522d..0dd5d95 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.h +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.h @@ -79,8 +79,12 @@ public: size_t clearLinkerOptimizationHints( const SmallPtrSetImpl<MachineInstr *> &MIs) const override; + /// Returns true if the new SME ABI lowering should be used. + bool useNewSMEABILowering() const { return UseNewSMEABILowering; } + private: bool isLittle; + bool UseNewSMEABILowering; }; // AArch64 little endian target machine. diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e1adc0b..922da10 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -220,20 +220,17 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode( static cl::opt<bool> EnableScalableAutovecInStreamingMode( "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); -static bool isSMEABIRoutineCall(const CallInst &CI) { +static bool isSMEABIRoutineCall(const CallInst &CI, + const AArch64TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); - return F && StringSwitch<bool>(F->getName()) - .Case("__arm_sme_state", true) - .Case("__arm_tpidr2_save", true) - .Case("__arm_tpidr2_restore", true) - .Case("__arm_za_disable", true) - .Default(false); + return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); } /// Returns true if the function has explicit operations that can only be /// lowered using incompatible instructions for the selected mode. This also /// returns true if the function F may use or modify ZA state. -static bool hasPossibleIncompatibleOps(const Function *F) { +static bool hasPossibleIncompatibleOps(const Function *F, + const AArch64TargetLowering &TLI) { for (const BasicBlock &BB : *F) { for (const Instruction &I : BB) { // Be conservative for now and assume that any call to inline asm or to @@ -242,7 +239,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) { // all native LLVM instructions can be lowered to compatible instructions. if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() && (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) || - isSMEABIRoutineCall(cast<CallInst>(I)))) + isSMEABIRoutineCall(cast<CallInst>(I), TLI))) return true; } } @@ -290,7 +287,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) { - if (hasPossibleIncompatibleOps(Callee)) + if (hasPossibleIncompatibleOps(Callee, *getTLI())) return false; } @@ -357,7 +354,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, // change only once and avoid inlining of G into F. SMEAttrs FAttrs(*F); - SMECallAttrs CallAttrs(Call); + SMECallAttrs CallAttrs(Call, getTLI()); if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) { if (F == Call.getCaller()) // (1) @@ -554,7 +551,17 @@ static bool isUnpackedVectorVT(EVT VecVT) { VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock; } -static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) { +static InstructionCost getHistogramCost(const AArch64Subtarget *ST, + const IntrinsicCostAttributes &ICA) { + // We need to know at least the number of elements in the vector of buckets + // and the size of each element to update. + if (ICA.getArgTypes().size() < 2) + return InstructionCost::getInvalid(); + + // Only interested in costing for the hardware instruction from SVE2. + if (!ST->hasSVE2()) + return InstructionCost::getInvalid(); + Type *BucketPtrsTy = ICA.getArgTypes()[0]; // Type of vector of pointers Type *EltTy = ICA.getArgTypes()[1]; // Type of bucket elements unsigned TotalHistCnts = 1; @@ -579,9 +586,11 @@ static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) { unsigned NaturalVectorWidth = AArch64::SVEBitsPerBlock / LegalEltSize; TotalHistCnts = EC / NaturalVectorWidth; + + return InstructionCost(BaseHistCntCost * TotalHistCnts); } - return InstructionCost(BaseHistCntCost * TotalHistCnts); + return InstructionCost::getInvalid(); } InstructionCost @@ -597,10 +606,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, return InstructionCost::getInvalid(); switch (ICA.getID()) { - case Intrinsic::experimental_vector_histogram_add: - if (!ST->hasSVE2()) - return InstructionCost::getInvalid(); - return getHistogramCost(ICA); + case Intrinsic::experimental_vector_histogram_add: { + InstructionCost HistCost = getHistogramCost(ST, ICA); + // If the cost isn't valid, we may still be able to scalarize + if (HistCost.isValid()) + return HistCost; + break; + } case Intrinsic::umin: case Intrinsic::umax: case Intrinsic::smin: @@ -631,6 +643,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4; if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; })) return LT.first * Instrs; + + TypeSize TS = getDataLayout().getTypeSizeInBits(RetTy); + uint64_t VectorSize = TS.getKnownMinValue(); + + if (ST->isSVEAvailable() && VectorSize >= 128 && isPowerOf2_64(VectorSize)) + return LT.first * Instrs; + break; } case Intrinsic::abs: { @@ -651,6 +670,16 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, return LT.first; break; } + case Intrinsic::fma: + case Intrinsic::fmuladd: { + // Given a fma or fmuladd, cost it the same as a fmul instruction which are + // usually the same for costs. TODO: Add fp16 and bf16 expansion costs. + Type *EltTy = RetTy->getScalarType(); + if (EltTy->isFloatTy() || EltTy->isDoubleTy() || + (EltTy->isHalfTy() && ST->hasFullFP16())) + return getArithmeticInstrCost(Instruction::FMul, RetTy, CostKind); + break; + } case Intrinsic::stepvector: { InstructionCost Cost = 1; // Cost of the `index' instruction auto LT = getTypeLegalizationCost(RetTy); @@ -2072,6 +2101,20 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { : std::nullopt; } +static std::optional<Instruction *> +instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts, + const AArch64Subtarget *ST) { + if (!ST->isStreaming()) + return std::nullopt; + + // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt + // with SVEPredPattern::all + Value *Cnt = IC.Builder.CreateElementCount( + II.getType(), ElementCount::getScalable(NumElts)); + Cnt->takeName(&II); + return IC.replaceInstUsesWith(II, Cnt); +} + static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, IntrinsicInst &II) { Value *PgVal = II.getArgOperand(0); @@ -2781,6 +2824,14 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, return instCombineSVECntElts(IC, II, 8); case Intrinsic::aarch64_sve_cntb: return instCombineSVECntElts(IC, II, 16); + case Intrinsic::aarch64_sme_cntsd: + return instCombineSMECntsElts(IC, II, 2, ST); + case Intrinsic::aarch64_sme_cntsw: + return instCombineSMECntsElts(IC, II, 4, ST); + case Intrinsic::aarch64_sme_cntsh: + return instCombineSMECntsElts(IC, II, 8, ST); + case Intrinsic::aarch64_sme_cntsb: + return instCombineSMECntsElts(IC, II, 16, ST); case Intrinsic::aarch64_sve_ptest_any: case Intrinsic::aarch64_sve_ptest_first: case Intrinsic::aarch64_sve_ptest_last: @@ -3092,6 +3143,13 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, return AdjustCost( BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); + // For the moment we do not have lowering for SVE1-only fptrunc f64->bf16 as + // we use fcvtx under SVE2. Give them invalid costs. + if (!ST->hasSVE2() && !ST->isStreamingSVEAvailable() && + ISD == ISD::FP_ROUND && SrcTy.isScalableVector() && + DstTy.getScalarType() == MVT::bf16 && SrcTy.getScalarType() == MVT::f64) + return InstructionCost::getInvalid(); + static const TypeConversionCostTblEntry BF16Tbl[] = { {ISD::FP_ROUND, MVT::bf16, MVT::f32, 1}, // bfcvt {ISD::FP_ROUND, MVT::bf16, MVT::f64, 1}, // bfcvt @@ -3100,6 +3158,12 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn + {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f32, 1}, // bfcvt + {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f32, 1}, // bfcvt + {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f32, 3}, // bfcvt+bfcvt+uzp1 + {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f64, 2}, // fcvtx+bfcvt + {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f64, 5}, // 2*fcvtx+2*bfcvt+uzp1 + {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f64, 11}, // 4*fcvt+4*bfcvt+3*uzp }; if (ST->hasBF16()) @@ -3508,11 +3572,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1}, {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3}, + // Truncate from nxvmf32 to nxvmbf16. + {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f32, 8}, + {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f32, 8}, + {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f32, 17}, + // Truncate from nxvmf64 to nxvmf16. {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1}, {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3}, {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7}, + // Truncate from nxvmf64 to nxvmbf16. + {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f64, 9}, + {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f64, 19}, + {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f64, 39}, + // Truncate from nxvmf64 to nxvmf32. {ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1}, {ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3}, @@ -3523,11 +3597,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1}, {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2}, + // Extend from nxvmbf16 to nxvmf32. + {ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2bf16, 1}, // lsl + {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4bf16, 1}, // lsl + {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8bf16, 4}, // unpck+unpck+lsl+lsl + // Extend from nxvmf16 to nxvmf64. {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1}, {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2}, {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4}, + // Extend from nxvmbf16 to nxvmf64. + {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2bf16, 2}, // lsl+fcvt + {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4bf16, 6}, // 2*unpck+2*lsl+2*fcvt + {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8bf16, 14}, // 6*unpck+4*lsl+4*fcvt + // Extend from nxvmf32 to nxvmf64. {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1}, {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2}, @@ -3928,6 +4012,24 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I, return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I); } +InstructionCost +AArch64TTIImpl::getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val, + TTI::TargetCostKind CostKind, + unsigned Index) const { + if (isa<FixedVectorType>(Val)) + return BaseT::getIndexedVectorInstrCostFromEnd(Opcode, Val, CostKind, + Index); + + // This typically requires both while and lastb instructions in order + // to extract the last element. If this is in a loop the while + // instruction can at least be hoisted out, although it will consume a + // predicate register. The cost should be more expensive than the base + // extract cost, which is 2 for most CPUs. + return CostKind == TTI::TCK_CodeSize + ? 2 + : ST->getVectorInsertExtractBaseCost() + 1; +} + InstructionCost AArch64TTIImpl::getScalarizationOverhead( VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind, bool ForPoisonSrc, @@ -3942,6 +4044,27 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead( return DemandedElts.popcount() * (Insert + Extract) * VecInstCost; } +std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost( + Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info, + TTI::OperandValueInfo Op2Info, bool IncludeTrunc, + std::function<InstructionCost(Type *)> InstCost) const { + if (!Ty->getScalarType()->isHalfTy() && !Ty->getScalarType()->isBFloatTy()) + return std::nullopt; + if (Ty->getScalarType()->isHalfTy() && ST->hasFullFP16()) + return std::nullopt; + + Type *PromotedTy = Ty->getWithNewType(Type::getFloatTy(Ty->getContext())); + InstructionCost Cost = getCastInstrCost(Instruction::FPExt, PromotedTy, Ty, + TTI::CastContextHint::None, CostKind); + if (!Op1Info.isConstant() && !Op2Info.isConstant()) + Cost *= 2; + Cost += InstCost(PromotedTy); + if (IncludeTrunc) + Cost += getCastInstrCost(Instruction::FPTrunc, Ty, PromotedTy, + TTI::CastContextHint::None, CostKind); + return Cost; +} + InstructionCost AArch64TTIImpl::getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info, @@ -3964,6 +4087,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); int ISD = TLI->InstructionOpcodeToISD(Opcode); + // Increase the cost for half and bfloat types if not architecturally + // supported. + if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL || + ISD == ISD::FDIV || ISD == ISD::FREM) + if (auto PromotedCost = getFP16BF16PromoteCost( + Ty, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/true, + [&](Type *PromotedTy) { + return getArithmeticInstrCost(Opcode, PromotedTy, CostKind, + Op1Info, Op2Info); + })) + return *PromotedCost; + switch (ISD) { default: return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, @@ -4232,11 +4367,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( [[fallthrough]]; case ISD::FADD: case ISD::FSUB: - // Increase the cost for half and bfloat types if not architecturally - // supported. - if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) || - (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16())) - return 2 * LT.first; if (!Ty->getScalarType()->isFP128Ty()) return LT.first; [[fallthrough]]; @@ -4260,8 +4390,9 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( } InstructionCost -AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE, - const SCEV *Ptr) const { +AArch64TTIImpl::getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, + const SCEV *Ptr, + TTI::TargetCostKind CostKind) const { // Address computations in vectorized code with non-consecutive addresses will // likely result in more instructions compared to scalar code where the // computation can more often be merged into the index mode. The resulting @@ -4269,7 +4400,7 @@ AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE, unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead; int MaxMergeDistance = 64; - if (Ty->isVectorTy() && SE && + if (PtrTy->isVectorTy() && SE && !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1)) return NumVectorInstToHideOverhead; @@ -4282,10 +4413,9 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost( unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info, const Instruction *I) const { - int ISD = TLI->InstructionOpcodeToISD(Opcode); // We don't lower some vector selects well that are wider than the register // width. TODO: Improve this with different cost kinds. - if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) { + if (isa<FixedVectorType>(ValTy) && Opcode == Instruction::Select) { // We would need this many instructions to hide the scalarization happening. const int AmortizationCost = 20; @@ -4315,55 +4445,68 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost( return LT.first; } - static const TypeConversionCostTblEntry - VectorSelectTbl[] = { - { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 }, - { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 }, - { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 }, - { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 }, - { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 }, - { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 }, - { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 }, - { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 }, - { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost }, - { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost }, - { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost } - }; + static const TypeConversionCostTblEntry VectorSelectTbl[] = { + {Instruction::Select, MVT::v2i1, MVT::v2f32, 2}, + {Instruction::Select, MVT::v2i1, MVT::v2f64, 2}, + {Instruction::Select, MVT::v4i1, MVT::v4f32, 2}, + {Instruction::Select, MVT::v4i1, MVT::v4f16, 2}, + {Instruction::Select, MVT::v8i1, MVT::v8f16, 2}, + {Instruction::Select, MVT::v16i1, MVT::v16i16, 16}, + {Instruction::Select, MVT::v8i1, MVT::v8i32, 8}, + {Instruction::Select, MVT::v16i1, MVT::v16i32, 16}, + {Instruction::Select, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost}, + {Instruction::Select, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost}, + {Instruction::Select, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost}}; EVT SelCondTy = TLI->getValueType(DL, CondTy); EVT SelValTy = TLI->getValueType(DL, ValTy); if (SelCondTy.isSimple() && SelValTy.isSimple()) { - if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD, + if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, Opcode, SelCondTy.getSimpleVT(), SelValTy.getSimpleVT())) return Entry->Cost; } } - if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) { - Type *ValScalarTy = ValTy->getScalarType(); - if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) || - ValScalarTy->isBFloatTy()) { - auto *ValVTy = cast<FixedVectorType>(ValTy); - - // Without dedicated instructions we promote [b]f16 compares to f32. - auto *PromotedTy = - VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy); - - InstructionCost Cost = 0; - // Promote operands to float vectors. - Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy, - TTI::CastContextHint::None, CostKind); - // Compare float vectors. - Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind, - Op1Info, Op2Info); - // During codegen we'll truncate the vector result from i32 to i16. - Cost += - getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy), - VectorType::getInteger(PromotedTy), - TTI::CastContextHint::None, CostKind); - return Cost; - } + if (Opcode == Instruction::FCmp) { + if (auto PromotedCost = getFP16BF16PromoteCost( + ValTy, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/false, + [&](Type *PromotedTy) { + InstructionCost Cost = + getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, + CostKind, Op1Info, Op2Info); + if (isa<VectorType>(PromotedTy)) + Cost += getCastInstrCost( + Instruction::Trunc, + VectorType::getInteger(cast<VectorType>(ValTy)), + VectorType::getInteger(cast<VectorType>(PromotedTy)), + TTI::CastContextHint::None, CostKind); + return Cost; + })) + return *PromotedCost; + + auto LT = getTypeLegalizationCost(ValTy); + // Model unknown fp compares as a libcall. + if (LT.second.getScalarType() != MVT::f64 && + LT.second.getScalarType() != MVT::f32 && + LT.second.getScalarType() != MVT::f16) + return LT.first * getCallInstrCost(/*Function*/ nullptr, ValTy, + {ValTy, ValTy}, CostKind); + + // Some comparison operators require expanding to multiple compares + or. + unsigned Factor = 1; + if (!CondTy->isVectorTy() && + (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ)) + Factor = 2; // fcmp with 2 selects + else if (isa<FixedVectorType>(ValTy) && + (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ || + VecPred == FCmpInst::FCMP_ORD || VecPred == FCmpInst::FCMP_UNO)) + Factor = 3; // fcmxx+fcmyy+or + else if (isa<ScalableVectorType>(ValTy) && + (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ)) + Factor = 3; // fcmxx+fcmyy+or + + return Factor * (CostKind == TTI::TCK_Latency ? 2 : LT.first); } // Treat the icmp in icmp(and, 0) or icmp(and, -1/1) when it can be folded to @@ -4371,7 +4514,7 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost( // comparison is not unsigned. FIXME: Enable for non-throughput cost kinds // providing it will not cause performance regressions. if (CostKind == TTI::TCK_RecipThroughput && ValTy->isIntegerTy() && - ISD == ISD::SETCC && I && !CmpInst::isUnsigned(VecPred) && + Opcode == Instruction::ICmp && I && !CmpInst::isUnsigned(VecPred) && TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) && match(I->getOperand(0), m_And(m_Value(), m_Value()))) { if (match(I->getOperand(1), m_Zero())) @@ -4809,32 +4952,18 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE, // Limit to loops with trip counts that are cheap to expand. UP.SCEVExpansionBudget = 1; - // Try to unroll small, single block loops, if they have load/store - // dependencies, to expose more parallel memory access streams. + // Try to unroll small loops, of few-blocks with low budget, if they have + // load/store dependencies, to expose more parallel memory access streams, + // or if they do little work inside a block (i.e. load -> X -> store pattern). BasicBlock *Header = L->getHeader(); - if (Header == L->getLoopLatch()) { + BasicBlock *Latch = L->getLoopLatch(); + if (Header == Latch) { // Estimate the size of the loop. unsigned Size; - if (!isLoopSizeWithinBudget(L, TTI, 8, &Size)) + unsigned Width = 10; + if (!isLoopSizeWithinBudget(L, TTI, Width, &Size)) return; - SmallPtrSet<Value *, 8> LoadedValues; - SmallVector<StoreInst *> Stores; - for (auto *BB : L->blocks()) { - for (auto &I : *BB) { - Value *Ptr = getLoadStorePointerOperand(&I); - if (!Ptr) - continue; - const SCEV *PtrSCEV = SE.getSCEV(Ptr); - if (SE.isLoopInvariant(PtrSCEV, L)) - continue; - if (isa<LoadInst>(&I)) - LoadedValues.insert(&I); - else - Stores.push_back(cast<StoreInst>(&I)); - } - } - // Try to find an unroll count that maximizes the use of the instruction // window, i.e. trying to fetch as many instructions per cycle as possible. unsigned MaxInstsPerLine = 16; @@ -4853,8 +4982,32 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE, UC++; } - if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) { - return LoadedValues.contains(SI->getOperand(0)); + if (BestUC == 1) + return; + + SmallPtrSet<Value *, 8> LoadedValuesPlus; + SmallVector<StoreInst *> Stores; + for (auto *BB : L->blocks()) { + for (auto &I : *BB) { + Value *Ptr = getLoadStorePointerOperand(&I); + if (!Ptr) + continue; + const SCEV *PtrSCEV = SE.getSCEV(Ptr); + if (SE.isLoopInvariant(PtrSCEV, L)) + continue; + if (isa<LoadInst>(&I)) { + LoadedValuesPlus.insert(&I); + // Include in-loop 1st users of loaded values. + for (auto *U : I.users()) + if (L->contains(cast<Instruction>(U))) + LoadedValuesPlus.insert(U); + } else + Stores.push_back(cast<StoreInst>(&I)); + } + } + + if (none_of(Stores, [&LoadedValuesPlus](StoreInst *SI) { + return LoadedValuesPlus.contains(SI->getOperand(0)); })) return; @@ -4866,7 +5019,6 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE, // Try to runtime-unroll loops with early-continues depending on loop-varying // loads; this helps with branch-prediction for the early-continues. auto *Term = dyn_cast<BranchInst>(Header->getTerminator()); - auto *Latch = L->getLoopLatch(); SmallVector<BasicBlock *> Preds(predecessors(Latch)); if (!Term || !Term->isConditional() || Preds.size() == 1 || !llvm::is_contained(Preds, Header) || @@ -5102,6 +5254,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction( return false; switch (RdxDesc.getRecurrenceKind()) { + case RecurKind::Sub: + case RecurKind::AddChainWithSubs: case RecurKind::Add: case RecurKind::FAdd: case RecurKind::And: @@ -5332,13 +5486,14 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost( } InstructionCost -AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy, - VectorType *VecTy, +AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, + Type *ResTy, VectorType *VecTy, TTI::TargetCostKind CostKind) const { EVT VecVT = TLI->getValueType(DL, VecTy); EVT ResVT = TLI->getValueType(DL, ResTy); - if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) { + if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple() && + RedOpcode == Instruction::Add) { std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy); // The legal cases with dotprod are @@ -5349,7 +5504,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy, return LT.first + 2; } - return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind); + return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, VecTy, + CostKind); } InstructionCost @@ -6235,10 +6391,17 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( } } - auto ShouldSinkCondition = [](Value *Cond) -> bool { + auto ShouldSinkCondition = [](Value *Cond, + SmallVectorImpl<Use *> &Ops) -> bool { + if (!isa<IntrinsicInst>(Cond)) + return false; auto *II = dyn_cast<IntrinsicInst>(Cond); - return II && II->getIntrinsicID() == Intrinsic::vector_reduce_or && - isa<ScalableVectorType>(II->getOperand(0)->getType()); + if (II->getIntrinsicID() != Intrinsic::vector_reduce_or || + !isa<ScalableVectorType>(II->getOperand(0)->getType())) + return false; + if (isa<CmpInst>(II->getOperand(0))) + Ops.push_back(&II->getOperandUse(0)); + return true; }; switch (I->getOpcode()) { @@ -6254,7 +6417,7 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( } break; case Instruction::Select: { - if (!ShouldSinkCondition(I->getOperand(0))) + if (!ShouldSinkCondition(I->getOperand(0), Ops)) return false; Ops.push_back(&I->getOperandUse(0)); @@ -6264,7 +6427,7 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( if (cast<BranchInst>(I)->isUnconditional()) return false; - if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition())) + if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition(), Ops)) return false; Ops.push_back(&I->getOperandUse(0)); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 7f45177..b994ca7 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -221,6 +221,11 @@ public: unsigned Index) const override; InstructionCost + getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val, + TTI::TargetCostKind CostKind, + unsigned Index) const override; + + InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF, TTI::TargetCostKind CostKind) const override; @@ -238,8 +243,9 @@ public: ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr) const override; - InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *SE, - const SCEV *Ptr) const override; + InstructionCost + getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, const SCEV *Ptr, + TTI::TargetCostKind CostKind) const override; InstructionCost getCmpSelInstrCost( unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred, @@ -435,6 +441,14 @@ public: bool preferPredicatedReductionSelect() const override { return ST->hasSVE(); } + /// FP16 and BF16 operations are lowered to fptrunc(op(fpext, fpext) if the + /// architecture features are not present. + std::optional<InstructionCost> + getFP16BF16PromoteCost(Type *Ty, TTI::TargetCostKind CostKind, + TTI::OperandValueInfo Op1Info, + TTI::OperandValueInfo Op2Info, bool IncludeTrunc, + std::function<InstructionCost(Type *)> InstCost) const; + InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional<FastMathFlags> FMF, @@ -446,7 +460,7 @@ public: TTI::TargetCostKind CostKind) const override; InstructionCost getMulAccReductionCost( - bool IsUnsigned, Type *ResTy, VectorType *Ty, + bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override; InstructionCost diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp index 1ca61f5..3641e22 100644 --- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -7909,9 +7909,11 @@ bool AArch64AsmParser::parseDirectiveSEHSavePReg(SMLoc L) { } bool AArch64AsmParser::parseDirectiveAeabiSubSectionHeader(SMLoc L) { - // Expecting 3 AsmToken::Identifier after '.aeabi_subsection', a name and 2 - // parameters, e.g.: .aeabi_subsection (1)aeabi_feature_and_bits, (2)optional, - // (3)uleb128 separated by 2 commas. + // Handle parsing of .aeabi_subsection directives + // - On first declaration of a subsection, expect exactly three identifiers + // after `.aeabi_subsection`: the subsection name and two parameters. + // - When switching to an existing subsection, it is valid to provide only + // the subsection name, or the name together with the two parameters. MCAsmParser &Parser = getParser(); // Consume the name (subsection name) @@ -7925,16 +7927,38 @@ bool AArch64AsmParser::parseDirectiveAeabiSubSectionHeader(SMLoc L) { return true; } Parser.Lex(); - // consume a comma + + std::unique_ptr<MCELFStreamer::AttributeSubSection> SubsectionExists = + getTargetStreamer().getAttributesSubsectionByName(SubsectionName); + // Check whether only the subsection name was provided. + // If so, the user is trying to switch to a subsection that should have been + // declared before. + if (Parser.getTok().is(llvm::AsmToken::EndOfStatement)) { + if (SubsectionExists) { + getTargetStreamer().emitAttributesSubsection( + SubsectionName, + static_cast<AArch64BuildAttributes::SubsectionOptional>( + SubsectionExists->IsOptional), + static_cast<AArch64BuildAttributes::SubsectionType>( + SubsectionExists->ParameterType)); + return false; + } + // If subsection does not exists, report error. + else { + Error(Parser.getTok().getLoc(), + "Could not switch to subsection '" + SubsectionName + + "' using subsection name, subsection has not been defined"); + return true; + } + } + + // Otherwise, expecting 2 more parameters: consume a comma // parseComma() return *false* on success, and call Lex(), no need to call // Lex() again. if (Parser.parseComma()) { return true; } - std::unique_ptr<MCELFStreamer::AttributeSubSection> SubsectionExists = - getTargetStreamer().getAttributesSubsectionByName(SubsectionName); - // Consume the first parameter (optionality parameter) AArch64BuildAttributes::SubsectionOptional IsOptional; // options: optional/required diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt index 66136a4..803943f 100644 --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -89,6 +89,7 @@ add_llvm_target(AArch64CodeGen SMEABIPass.cpp SMEPeepholeOpt.cpp SVEIntrinsicOpts.cpp + MachineSMEABIPass.cpp AArch64SIMDInstrOpt.cpp DEPENDS diff --git a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp index ae984be..23e46b8 100644 --- a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp +++ b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp @@ -15,6 +15,7 @@ #include "MCTargetDesc/AArch64MCTargetDesc.h" #include "TargetInfo/AArch64TargetInfo.h" #include "Utils/AArch64BaseInfo.h" +#include "llvm/MC/MCDecoder.h" #include "llvm/MC/MCDecoderOps.h" #include "llvm/MC/MCDisassembler/MCRelocationInfo.h" #include "llvm/MC/MCInst.h" @@ -27,314 +28,21 @@ #include <memory> using namespace llvm; +using namespace llvm::MCD; #define DEBUG_TYPE "aarch64-disassembler" // Pull DecodeStatus and its enum values into the global namespace. using DecodeStatus = MCDisassembler::DecodeStatus; -// Forward declare these because the autogenerated code will reference them. -// Definitions are further down. -template <unsigned RegClassID, unsigned FirstReg, unsigned NumRegsInClass> -static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeGPR64x8ClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address, - const MCDisassembler *Decoder); -template <unsigned Min, unsigned Max> -static DecodeStatus DecodeZPRMul2_MinMax(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeZK(MCInst &Inst, unsigned RegNo, uint64_t Address, - const MCDisassembler *Decoder); -template <unsigned Min, unsigned Max> -static DecodeStatus DecodeZPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const void *Decoder); -static DecodeStatus DecodeZPR4Mul4RegisterClass(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const void *Decoder); -template <unsigned NumBitsForTile> -static DecodeStatus DecodeMatrixTile(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeMatrixTileListRegisterClass(MCInst &Inst, unsigned RegMask, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodePPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo, - uint64_t Address, - const void *Decoder); - -static DecodeStatus DecodeFixedPointScaleImm32(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeFixedPointScaleImm64(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodePCRelLabel16(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodePCRelLabel19(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodePCRelLabel9(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeMemExtend(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeMRSSystemRegister(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeMSRSystemRegister(MCInst &Inst, unsigned Imm, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeThreeAddrSRegInstruction(MCInst &Inst, uint32_t insn, uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeMoveImmInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeUnsignedLdStInstruction(MCInst &Inst, uint32_t insn, uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeSignedLdStInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeExclusiveLdStInstruction(MCInst &Inst, uint32_t insn, uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodePairLdStInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeAuthLoadInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeAddSubERegInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeLogicalImmInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeModImmInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeModImmTiedInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeAdrInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeAddSubImmShift(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeUnconditionalBranch(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeSystemPStateImm0_15Instruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeSystemPStateImm0_1Instruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeTestAndBranch(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); - -static DecodeStatus DecodeFMOVLaneInstruction(MCInst &Inst, unsigned Insn, - uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR64Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR64ImmNarrow(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR32Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR32ImmNarrow(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR16Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR16ImmNarrow(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftR8Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftL64Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftL32Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftL16Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeVecShiftL8Imm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeWSeqPairsClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeXSeqPairsClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeSyspXzrInstruction(MCInst &Inst, uint32_t insn, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus -DecodeSVELogicalImmInstruction(MCInst &Inst, uint32_t insn, uint64_t Address, - const MCDisassembler *Decoder); template <int Bits> static DecodeStatus DecodeSImm(MCInst &Inst, uint64_t Imm, uint64_t Address, const MCDisassembler *Decoder); -template <int ElementWidth> -static DecodeStatus DecodeImm8OptLsl(MCInst &Inst, unsigned Imm, uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeSVEIncDecImm(MCInst &Inst, unsigned Imm, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeSVCROp(MCInst &Inst, unsigned Imm, uint64_t Address, - const MCDisassembler *Decoder); -static DecodeStatus DecodeCPYMemOpInstruction(MCInst &Inst, uint32_t insn, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodeSETMemOpInstruction(MCInst &Inst, uint32_t insn, - uint64_t Addr, - const MCDisassembler *Decoder); -static DecodeStatus DecodePRFMRegInstruction(MCInst &Inst, uint32_t insn, - uint64_t Address, - const MCDisassembler *Decoder); - -#include "AArch64GenDisassemblerTables.inc" -#include "AArch64GenInstrInfo.inc" #define Success MCDisassembler::Success #define Fail MCDisassembler::Fail #define SoftFail MCDisassembler::SoftFail -static MCDisassembler *createAArch64Disassembler(const Target &T, - const MCSubtargetInfo &STI, - MCContext &Ctx) { - - return new AArch64Disassembler(STI, Ctx, T.createMCInstrInfo()); -} - -DecodeStatus AArch64Disassembler::getInstruction(MCInst &MI, uint64_t &Size, - ArrayRef<uint8_t> Bytes, - uint64_t Address, - raw_ostream &CS) const { - CommentStream = &CS; - - Size = 0; - // We want to read exactly 4 bytes of data. - if (Bytes.size() < 4) - return Fail; - Size = 4; - - // Encoded as a small-endian 32-bit word in the stream. - uint32_t Insn = - (Bytes[3] << 24) | (Bytes[2] << 16) | (Bytes[1] << 8) | (Bytes[0] << 0); - - const uint8_t *Tables[] = {DecoderTable32, DecoderTableFallback32}; - - for (const auto *Table : Tables) { - DecodeStatus Result = - decodeInstruction(Table, MI, Insn, Address, this, STI); - - const MCInstrDesc &Desc = MCII->get(MI.getOpcode()); - - // For Scalable Matrix Extension (SME) instructions that have an implicit - // operand for the accumulator (ZA) or implicit immediate zero which isn't - // encoded, manually insert operand. - for (unsigned i = 0; i < Desc.getNumOperands(); i++) { - if (Desc.operands()[i].OperandType == MCOI::OPERAND_REGISTER) { - switch (Desc.operands()[i].RegClass) { - default: - break; - case AArch64::MPRRegClassID: - MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZA)); - break; - case AArch64::MPR8RegClassID: - MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZAB0)); - break; - case AArch64::ZTRRegClassID: - MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZT0)); - break; - } - } else if (Desc.operands()[i].OperandType == - AArch64::OPERAND_IMPLICIT_IMM_0) { - MI.insert(MI.begin() + i, MCOperand::createImm(0)); - } - } - - if (MI.getOpcode() == AArch64::LDR_ZA || - MI.getOpcode() == AArch64::STR_ZA) { - // Spill and fill instructions have a single immediate used for both - // the vector select offset and optional memory offset. Replicate - // the decoded immediate. - const MCOperand &Imm4Op = MI.getOperand(2); - assert(Imm4Op.isImm() && "Unexpected operand type!"); - MI.addOperand(Imm4Op); - } - - if (Result != MCDisassembler::Fail) - return Result; - } - - return MCDisassembler::Fail; -} - -uint64_t AArch64Disassembler::suggestBytesToSkip(ArrayRef<uint8_t> Bytes, - uint64_t Address) const { - // AArch64 instructions are always 4 bytes wide, so there's no point - // in skipping any smaller number of bytes if an instruction can't - // be decoded. - return 4; -} - -static MCSymbolizer * -createAArch64ExternalSymbolizer(const Triple &TT, LLVMOpInfoCallback GetOpInfo, - LLVMSymbolLookupCallback SymbolLookUp, - void *DisInfo, MCContext *Ctx, - std::unique_ptr<MCRelocationInfo> &&RelInfo) { - return new AArch64ExternalSymbolizer(*Ctx, std::move(RelInfo), GetOpInfo, - SymbolLookUp, DisInfo); -} - -extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void -LLVMInitializeAArch64Disassembler() { - TargetRegistry::RegisterMCDisassembler(getTheAArch64leTarget(), - createAArch64Disassembler); - TargetRegistry::RegisterMCDisassembler(getTheAArch64beTarget(), - createAArch64Disassembler); - TargetRegistry::RegisterMCSymbolizer(getTheAArch64leTarget(), - createAArch64ExternalSymbolizer); - TargetRegistry::RegisterMCSymbolizer(getTheAArch64beTarget(), - createAArch64ExternalSymbolizer); - TargetRegistry::RegisterMCDisassembler(getTheAArch64_32Target(), - createAArch64Disassembler); - TargetRegistry::RegisterMCSymbolizer(getTheAArch64_32Target(), - createAArch64ExternalSymbolizer); - - TargetRegistry::RegisterMCDisassembler(getTheARM64Target(), - createAArch64Disassembler); - TargetRegistry::RegisterMCSymbolizer(getTheARM64Target(), - createAArch64ExternalSymbolizer); - TargetRegistry::RegisterMCDisassembler(getTheARM64_32Target(), - createAArch64Disassembler); - TargetRegistry::RegisterMCSymbolizer(getTheARM64_32Target(), - createAArch64ExternalSymbolizer); -} - template <unsigned RegClassID, unsigned FirstReg, unsigned NumRegsInClass> static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address, @@ -1854,3 +1562,120 @@ static DecodeStatus DecodePRFMRegInstruction(MCInst &Inst, uint32_t insn, return Success; } + +#include "AArch64GenDisassemblerTables.inc" +#include "AArch64GenInstrInfo.inc" + +static MCDisassembler *createAArch64Disassembler(const Target &T, + const MCSubtargetInfo &STI, + MCContext &Ctx) { + + return new AArch64Disassembler(STI, Ctx, T.createMCInstrInfo()); +} + +DecodeStatus AArch64Disassembler::getInstruction(MCInst &MI, uint64_t &Size, + ArrayRef<uint8_t> Bytes, + uint64_t Address, + raw_ostream &CS) const { + CommentStream = &CS; + + Size = 0; + // We want to read exactly 4 bytes of data. + if (Bytes.size() < 4) + return Fail; + Size = 4; + + // Encoded as a small-endian 32-bit word in the stream. + uint32_t Insn = + (Bytes[3] << 24) | (Bytes[2] << 16) | (Bytes[1] << 8) | (Bytes[0] << 0); + + const uint8_t *Tables[] = {DecoderTable32, DecoderTableFallback32}; + + for (const auto *Table : Tables) { + DecodeStatus Result = + decodeInstruction(Table, MI, Insn, Address, this, STI); + + const MCInstrDesc &Desc = MCII->get(MI.getOpcode()); + + // For Scalable Matrix Extension (SME) instructions that have an implicit + // operand for the accumulator (ZA) or implicit immediate zero which isn't + // encoded, manually insert operand. + for (unsigned i = 0; i < Desc.getNumOperands(); i++) { + if (Desc.operands()[i].OperandType == MCOI::OPERAND_REGISTER) { + switch (Desc.operands()[i].RegClass) { + default: + break; + case AArch64::MPRRegClassID: + MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZA)); + break; + case AArch64::MPR8RegClassID: + MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZAB0)); + break; + case AArch64::ZTRRegClassID: + MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZT0)); + break; + } + } else if (Desc.operands()[i].OperandType == + AArch64::OPERAND_IMPLICIT_IMM_0) { + MI.insert(MI.begin() + i, MCOperand::createImm(0)); + } + } + + if (MI.getOpcode() == AArch64::LDR_ZA || + MI.getOpcode() == AArch64::STR_ZA) { + // Spill and fill instructions have a single immediate used for both + // the vector select offset and optional memory offset. Replicate + // the decoded immediate. + const MCOperand &Imm4Op = MI.getOperand(2); + assert(Imm4Op.isImm() && "Unexpected operand type!"); + MI.addOperand(Imm4Op); + } + + if (Result != MCDisassembler::Fail) + return Result; + } + + return MCDisassembler::Fail; +} + +uint64_t AArch64Disassembler::suggestBytesToSkip(ArrayRef<uint8_t> Bytes, + uint64_t Address) const { + // AArch64 instructions are always 4 bytes wide, so there's no point + // in skipping any smaller number of bytes if an instruction can't + // be decoded. + return 4; +} + +static MCSymbolizer * +createAArch64ExternalSymbolizer(const Triple &TT, LLVMOpInfoCallback GetOpInfo, + LLVMSymbolLookupCallback SymbolLookUp, + void *DisInfo, MCContext *Ctx, + std::unique_ptr<MCRelocationInfo> &&RelInfo) { + return new AArch64ExternalSymbolizer(*Ctx, std::move(RelInfo), GetOpInfo, + SymbolLookUp, DisInfo); +} + +extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void +LLVMInitializeAArch64Disassembler() { + TargetRegistry::RegisterMCDisassembler(getTheAArch64leTarget(), + createAArch64Disassembler); + TargetRegistry::RegisterMCDisassembler(getTheAArch64beTarget(), + createAArch64Disassembler); + TargetRegistry::RegisterMCSymbolizer(getTheAArch64leTarget(), + createAArch64ExternalSymbolizer); + TargetRegistry::RegisterMCSymbolizer(getTheAArch64beTarget(), + createAArch64ExternalSymbolizer); + TargetRegistry::RegisterMCDisassembler(getTheAArch64_32Target(), + createAArch64Disassembler); + TargetRegistry::RegisterMCSymbolizer(getTheAArch64_32Target(), + createAArch64ExternalSymbolizer); + + TargetRegistry::RegisterMCDisassembler(getTheARM64Target(), + createAArch64Disassembler); + TargetRegistry::RegisterMCSymbolizer(getTheARM64Target(), + createAArch64ExternalSymbolizer); + TargetRegistry::RegisterMCDisassembler(getTheARM64_32Target(), + createAArch64Disassembler); + TargetRegistry::RegisterMCSymbolizer(getTheARM64_32Target(), + createAArch64ExternalSymbolizer); +} diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 010d0aaa..79bef76 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -125,12 +125,12 @@ struct AArch64OutgoingValueAssigner bool UseVarArgsCCForFixed = IsCalleeWin && State.isVarArg(); bool Res; - if (Info.IsFixed && !UseVarArgsCCForFixed) { + if (!Flags.isVarArg() && !UseVarArgsCCForFixed) { if (!IsReturn) applyStackPassedSmallTypeDAGHack(OrigVT, ValVT, LocVT); - Res = AssignFn(ValNo, ValVT, LocVT, LocInfo, Flags, State); + Res = AssignFn(ValNo, ValVT, LocVT, LocInfo, Flags, Info.Ty, State); } else - Res = AssignFnVarArg(ValNo, ValVT, LocVT, LocInfo, Flags, State); + Res = AssignFnVarArg(ValNo, ValVT, LocVT, LocInfo, Flags, Info.Ty, State); StackSize = State.getStackSize(); return Res; @@ -361,7 +361,7 @@ struct OutgoingArgHandler : public CallLowering::OutgoingValueHandler { unsigned MaxSize = MemTy.getSizeInBytes() * 8; // For varargs, we always want to extend them to 8 bytes, in which case // we disable setting a max. - if (!Arg.IsFixed) + if (Arg.Flags[0].isVarArg()) MaxSize = 0; Register ValVReg = Arg.Regs[RegIndex]; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index d905692..0bceb32 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -1102,82 +1102,6 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII, return true; } -static unsigned selectFPConvOpc(unsigned GenericOpc, LLT DstTy, LLT SrcTy) { - if (!DstTy.isScalar() || !SrcTy.isScalar()) - return GenericOpc; - - const unsigned DstSize = DstTy.getSizeInBits(); - const unsigned SrcSize = SrcTy.getSizeInBits(); - - switch (DstSize) { - case 32: - switch (SrcSize) { - case 32: - switch (GenericOpc) { - case TargetOpcode::G_SITOFP: - return AArch64::SCVTFUWSri; - case TargetOpcode::G_UITOFP: - return AArch64::UCVTFUWSri; - case TargetOpcode::G_FPTOSI: - return AArch64::FCVTZSUWSr; - case TargetOpcode::G_FPTOUI: - return AArch64::FCVTZUUWSr; - default: - return GenericOpc; - } - case 64: - switch (GenericOpc) { - case TargetOpcode::G_SITOFP: - return AArch64::SCVTFUXSri; - case TargetOpcode::G_UITOFP: - return AArch64::UCVTFUXSri; - case TargetOpcode::G_FPTOSI: - return AArch64::FCVTZSUWDr; - case TargetOpcode::G_FPTOUI: - return AArch64::FCVTZUUWDr; - default: - return GenericOpc; - } - default: - return GenericOpc; - } - case 64: - switch (SrcSize) { - case 32: - switch (GenericOpc) { - case TargetOpcode::G_SITOFP: - return AArch64::SCVTFUWDri; - case TargetOpcode::G_UITOFP: - return AArch64::UCVTFUWDri; - case TargetOpcode::G_FPTOSI: - return AArch64::FCVTZSUXSr; - case TargetOpcode::G_FPTOUI: - return AArch64::FCVTZUUXSr; - default: - return GenericOpc; - } - case 64: - switch (GenericOpc) { - case TargetOpcode::G_SITOFP: - return AArch64::SCVTFUXDri; - case TargetOpcode::G_UITOFP: - return AArch64::UCVTFUXDri; - case TargetOpcode::G_FPTOSI: - return AArch64::FCVTZSUXDr; - case TargetOpcode::G_FPTOUI: - return AArch64::FCVTZUUXDr; - default: - return GenericOpc; - } - default: - return GenericOpc; - } - default: - return GenericOpc; - }; - return GenericOpc; -} - MachineInstr * AArch64InstructionSelector::emitSelect(Register Dst, Register True, Register False, AArch64CC::CondCode CC, @@ -1349,7 +1273,9 @@ AArch64InstructionSelector::emitSelect(Register Dst, Register True, return &*SelectInst; } -static AArch64CC::CondCode changeICMPPredToAArch64CC(CmpInst::Predicate P) { +static AArch64CC::CondCode +changeICMPPredToAArch64CC(CmpInst::Predicate P, Register RHS = {}, + MachineRegisterInfo *MRI = nullptr) { switch (P) { default: llvm_unreachable("Unknown condition code!"); @@ -1360,8 +1286,18 @@ static AArch64CC::CondCode changeICMPPredToAArch64CC(CmpInst::Predicate P) { case CmpInst::ICMP_SGT: return AArch64CC::GT; case CmpInst::ICMP_SGE: + if (RHS && MRI) { + auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, *MRI); + if (ValAndVReg && ValAndVReg->Value == 0) + return AArch64CC::PL; + } return AArch64CC::GE; case CmpInst::ICMP_SLT: + if (RHS && MRI) { + auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, *MRI); + if (ValAndVReg && ValAndVReg->Value == 0) + return AArch64CC::MI; + } return AArch64CC::LT; case CmpInst::ICMP_SLE: return AArch64CC::LE; @@ -1697,7 +1633,7 @@ bool AArch64InstructionSelector::selectCompareBranchFedByFCmp( emitFPCompare(FCmp.getOperand(2).getReg(), FCmp.getOperand(3).getReg(), MIB, Pred); AArch64CC::CondCode CC1, CC2; - changeFCMPPredToAArch64CC(static_cast<CmpInst::Predicate>(Pred), CC1, CC2); + changeFCMPPredToAArch64CC(Pred, CC1, CC2); MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB); if (CC2 != AArch64CC::AL) @@ -1813,7 +1749,8 @@ bool AArch64InstructionSelector::selectCompareBranchFedByICmp( auto &PredOp = ICmp.getOperand(1); emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB); const AArch64CC::CondCode CC = changeICMPPredToAArch64CC( - static_cast<CmpInst::Predicate>(PredOp.getPredicate())); + static_cast<CmpInst::Predicate>(PredOp.getPredicate()), + ICmp.getOperand(3).getReg(), MIB.getMRI()); MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC).addMBB(DestMBB); I.eraseFromParent(); return true; @@ -2510,8 +2447,8 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) { emitIntegerCompare(/*LHS=*/Cmp->getOperand(2), /*RHS=*/Cmp->getOperand(3), PredOp, MIB); auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); - const AArch64CC::CondCode InvCC = - changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); + const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC( + CmpInst::getInversePredicate(Pred), Cmp->getOperand(3).getReg(), &MRI); emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB); I.eraseFromParent(); return true; @@ -3511,23 +3448,6 @@ bool AArch64InstructionSelector::select(MachineInstr &I) { return true; } - case TargetOpcode::G_SITOFP: - case TargetOpcode::G_UITOFP: - case TargetOpcode::G_FPTOSI: - case TargetOpcode::G_FPTOUI: { - const LLT DstTy = MRI.getType(I.getOperand(0).getReg()), - SrcTy = MRI.getType(I.getOperand(1).getReg()); - const unsigned NewOpc = selectFPConvOpc(Opcode, DstTy, SrcTy); - if (NewOpc == Opcode) - return false; - - I.setDesc(TII.get(NewOpc)); - constrainSelectedInstRegOperands(I, TII, TRI, RBI); - I.setFlags(MachineInstr::NoFPExcept); - - return true; - } - case TargetOpcode::G_FREEZE: return selectCopy(I, TII, MRI, TRI, RBI); @@ -3577,8 +3497,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) { auto &PredOp = I.getOperand(1); emitIntegerCompare(I.getOperand(2), I.getOperand(3), PredOp, MIB); auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); - const AArch64CC::CondCode InvCC = - changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); + const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC( + CmpInst::getInversePredicate(Pred), I.getOperand(3).getReg(), &MRI); emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR, /*Src2=*/AArch64::WZR, InvCC, MIB); I.eraseFromParent(); @@ -4931,7 +4851,7 @@ MachineInstr *AArch64InstructionSelector::emitConjunctionRec( if (Negate) CC = CmpInst::getInversePredicate(CC); if (isa<GICmp>(Cmp)) { - OutCC = changeICMPPredToAArch64CC(CC); + OutCC = changeICMPPredToAArch64CC(CC, RHS, MIB.getMRI()); } else { // Handle special FP cases. AArch64CC::CondCode ExtraCC; @@ -5101,7 +5021,8 @@ bool AArch64InstructionSelector::tryOptSelect(GSelect &I) { emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), PredOp, MIB); auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); - CondCode = changeICMPPredToAArch64CC(Pred); + CondCode = + changeICMPPredToAArch64CC(Pred, CondDef->getOperand(3).getReg(), &MRI); } else { // Get the condition code for the select. auto Pred = diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index e0e1af7..82391f13 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -797,6 +797,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .clampMinNumElements(0, s16, 4) .alwaysLegal(); + getActionDefinitionsBuilder({G_TRUNC_SSAT_S, G_TRUNC_SSAT_U, G_TRUNC_USAT_U}) + .legalFor({{v8s8, v8s16}, {v4s16, v4s32}, {v2s32, v2s64}}); + getActionDefinitionsBuilder(G_SEXT_INREG) .legalFor({s32, s64}) .legalFor(PackedVectorAllTypeList) @@ -876,8 +879,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) {v2s32, v2s32}, {v4s32, v4s32}, {v2s64, v2s64}}) - .legalFor(HasFP16, - {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}}) + .legalFor( + HasFP16, + {{s16, s16}, {s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}}) // Handle types larger than i64 by scalarizing/lowering. .scalarizeIf(scalarOrEltWiderThan(0, 64), 0) .scalarizeIf(scalarOrEltWiderThan(1, 64), 1) @@ -965,6 +969,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) {s128, s64}}); // Control-flow + getActionDefinitionsBuilder(G_BR).alwaysLegal(); getActionDefinitionsBuilder(G_BRCOND) .legalFor({s32}) .clampScalar(0, s32, s32); @@ -1146,7 +1151,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .clampMaxNumElements(1, s32, 4) .clampMaxNumElements(1, s16, 8) .clampMaxNumElements(1, s8, 16) - .clampMaxNumElements(1, p0, 2); + .clampMaxNumElements(1, p0, 2) + .scalarizeIf(scalarOrEltWiderThan(1, 64), 1); getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT) .legalIf( @@ -1161,7 +1167,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .clampNumElements(0, v4s16, v8s16) .clampNumElements(0, v2s32, v4s32) .clampMaxNumElements(0, s64, 2) - .clampMaxNumElements(0, p0, 2); + .clampMaxNumElements(0, p0, 2) + .scalarizeIf(scalarOrEltWiderThan(0, 64), 0); getActionDefinitionsBuilder(G_BUILD_VECTOR) .legalFor({{v8s8, s8}, @@ -1255,6 +1262,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, s64}}); + getActionDefinitionsBuilder({G_TRAP, G_DEBUGTRAP, G_UBSANTRAP}).alwaysLegal(); + getActionDefinitionsBuilder(G_DYN_STACKALLOC).custom(); getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).lower(); @@ -1644,6 +1653,11 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, MachineIRBuilder &MIB = Helper.MIRBuilder; MachineRegisterInfo &MRI = *MIB.getMRI(); + auto LowerUnaryOp = [&MI, &MIB](unsigned Opcode) { + MIB.buildInstr(Opcode, {MI.getOperand(0)}, {MI.getOperand(2)}); + MI.eraseFromParent(); + return true; + }; auto LowerBinOp = [&MI, &MIB](unsigned Opcode) { MIB.buildInstr(Opcode, {MI.getOperand(0)}, {MI.getOperand(2), MI.getOperand(3)}); @@ -1838,6 +1852,12 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerTriOp(AArch64::G_UDOT); case Intrinsic::aarch64_neon_sdot: return LowerTriOp(AArch64::G_SDOT); + case Intrinsic::aarch64_neon_sqxtn: + return LowerUnaryOp(TargetOpcode::G_TRUNC_SSAT_S); + case Intrinsic::aarch64_neon_sqxtun: + return LowerUnaryOp(TargetOpcode::G_TRUNC_SSAT_U); + case Intrinsic::aarch64_neon_uqxtn: + return LowerUnaryOp(TargetOpcode::G_TRUNC_USAT_U); case Intrinsic::vector_reverse: // TODO: Add support for vector_reverse diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp index 3ba08c8..6025f1c 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp @@ -614,8 +614,7 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P, // x uge c => x ugt c - 1 // // When c is not zero. - if (C == 0) - return std::nullopt; + assert(C != 0 && "C should not be zero here!"); P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; C -= 1; break; @@ -656,14 +655,13 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P, if (isLegalArithImmed(C)) return {{C, P}}; - auto IsMaterializableInSingleInstruction = [=](uint64_t Imm) { + auto NumberOfInstrToLoadImm = [=](uint64_t Imm) { SmallVector<AArch64_IMM::ImmInsnModel> Insn; AArch64_IMM::expandMOVImm(Imm, 32, Insn); - return Insn.size() == 1; + return Insn.size(); }; - if (!IsMaterializableInSingleInstruction(OriginalC) && - IsMaterializableInSingleInstruction(C)) + if (NumberOfInstrToLoadImm(OriginalC) > NumberOfInstrToLoadImm(C)) return {{C, P}}; return std::nullopt; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp index 31954e7..cf391c4 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -13,6 +13,7 @@ #include "AArch64RegisterBankInfo.h" #include "AArch64RegisterInfo.h" +#include "AArch64Subtarget.h" #include "MCTargetDesc/AArch64MCTargetDesc.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -492,7 +493,7 @@ static bool isFPIntrinsic(const MachineRegisterInfo &MRI, bool AArch64RegisterBankInfo::isPHIWithFPConstraints( const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, const unsigned Depth) const { + const AArch64RegisterInfo &TRI, const unsigned Depth) const { if (!MI.isPHI() || Depth > MaxFPRSearchDepth) return false; @@ -506,7 +507,7 @@ bool AArch64RegisterBankInfo::isPHIWithFPConstraints( bool AArch64RegisterBankInfo::hasFPConstraints(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, + const AArch64RegisterInfo &TRI, unsigned Depth) const { unsigned Op = MI.getOpcode(); if (Op == TargetOpcode::G_INTRINSIC && isFPIntrinsic(MRI, MI)) @@ -544,7 +545,7 @@ bool AArch64RegisterBankInfo::hasFPConstraints(const MachineInstr &MI, bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, + const AArch64RegisterInfo &TRI, unsigned Depth) const { switch (MI.getOpcode()) { case TargetOpcode::G_FPTOSI: @@ -582,7 +583,7 @@ bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI, bool AArch64RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, + const AArch64RegisterInfo &TRI, unsigned Depth) const { switch (MI.getOpcode()) { case AArch64::G_DUP: @@ -618,6 +619,19 @@ bool AArch64RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI, return hasFPConstraints(MI, MRI, TRI, Depth); } +bool AArch64RegisterBankInfo::prefersFPUse(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + const AArch64RegisterInfo &TRI, + unsigned Depth) const { + switch (MI.getOpcode()) { + case TargetOpcode::G_SITOFP: + case TargetOpcode::G_UITOFP: + return MRI.getType(MI.getOperand(0).getReg()).getSizeInBits() == + MRI.getType(MI.getOperand(1).getReg()).getSizeInBits(); + } + return onlyDefinesFP(MI, MRI, TRI, Depth); +} + bool AArch64RegisterBankInfo::isLoadFromFPType(const MachineInstr &MI) const { // GMemOperation because we also want to match indexed loads. auto *MemOp = cast<GMemOperation>(&MI); @@ -671,8 +685,8 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { const MachineFunction &MF = *MI.getParent()->getParent(); const MachineRegisterInfo &MRI = MF.getRegInfo(); - const TargetSubtargetInfo &STI = MF.getSubtarget(); - const TargetRegisterInfo &TRI = *STI.getRegisterInfo(); + const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>(); + const AArch64RegisterInfo &TRI = *STI.getRegisterInfo(); switch (Opc) { // G_{F|S|U}REM are not listed because they are not legal. @@ -826,16 +840,28 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { // Integer to FP conversions don't necessarily happen between GPR -> FPR // regbanks. They can also be done within an FPR register. Register SrcReg = MI.getOperand(1).getReg(); - if (getRegBank(SrcReg, MRI, TRI) == &AArch64::FPRRegBank) + if (getRegBank(SrcReg, MRI, TRI) == &AArch64::FPRRegBank && + MRI.getType(SrcReg).getSizeInBits() == + MRI.getType(MI.getOperand(0).getReg()).getSizeInBits()) OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR}; else OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR}; break; } + case TargetOpcode::G_FPTOSI_SAT: + case TargetOpcode::G_FPTOUI_SAT: { + LLT DstType = MRI.getType(MI.getOperand(0).getReg()); + if (DstType.isVector()) + break; + if (DstType == LLT::scalar(16)) { + OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR}; + break; + } + OpRegBankIdx = {PMI_FirstGPR, PMI_FirstFPR}; + break; + } case TargetOpcode::G_FPTOSI: case TargetOpcode::G_FPTOUI: - case TargetOpcode::G_FPTOSI_SAT: - case TargetOpcode::G_FPTOUI_SAT: case TargetOpcode::G_INTRINSIC_LRINT: case TargetOpcode::G_INTRINSIC_LLRINT: if (MRI.getType(MI.getOperand(0).getReg()).isVector()) @@ -895,13 +921,13 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { // instruction. // // Int->FP conversion operations are also captured in - // onlyDefinesFP(). + // prefersFPUse(). if (isPHIWithFPConstraints(UseMI, MRI, TRI)) return true; return onlyUsesFP(UseMI, MRI, TRI) || - onlyDefinesFP(UseMI, MRI, TRI); + prefersFPUse(UseMI, MRI, TRI); })) OpRegBankIdx[0] = PMI_FirstFPR; break; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h index 3abbc1b..b2de031 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h @@ -22,6 +22,7 @@ namespace llvm { class TargetRegisterInfo; +class AArch64RegisterInfo; class AArch64GenRegisterBankInfo : public RegisterBankInfo { protected: @@ -123,21 +124,26 @@ class AArch64RegisterBankInfo final : public AArch64GenRegisterBankInfo { /// \returns true if \p MI is a PHI that its def is used by /// any instruction that onlyUsesFP. bool isPHIWithFPConstraints(const MachineInstr &MI, - const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, - unsigned Depth = 0) const; + const MachineRegisterInfo &MRI, + const AArch64RegisterInfo &TRI, + unsigned Depth = 0) const; /// \returns true if \p MI only uses and defines FPRs. bool hasFPConstraints(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, unsigned Depth = 0) const; + const AArch64RegisterInfo &TRI, + unsigned Depth = 0) const; /// \returns true if \p MI only uses FPRs. bool onlyUsesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, unsigned Depth = 0) const; + const AArch64RegisterInfo &TRI, unsigned Depth = 0) const; /// \returns true if \p MI only defines FPRs. bool onlyDefinesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, - const TargetRegisterInfo &TRI, unsigned Depth = 0) const; + const AArch64RegisterInfo &TRI, unsigned Depth = 0) const; + + /// \returns true if \p MI can take both fpr and gpr uses, but prefers fp. + bool prefersFPUse(const MachineInstr &MI, const MachineRegisterInfo &MRI, + const AArch64RegisterInfo &TRI, unsigned Depth = 0) const; /// \returns true if the load \p MI is likely loading from a floating-point /// type. diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp index 7618a57..a388216 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp @@ -40,6 +40,7 @@ protected: bool IsPCRel) const override; bool needsRelocateWithSymbol(const MCValue &, unsigned Type) const override; bool isNonILP32reloc(const MCFixup &Fixup, AArch64::Specifier RefKind) const; + void sortRelocs(std::vector<ELFRelocationEntry> &Relocs) override; bool IsILP32; }; @@ -96,8 +97,8 @@ unsigned AArch64ELFObjectWriter::getRelocType(const MCFixup &Fixup, case AArch64::S_TPREL: case AArch64::S_TLSDESC: case AArch64::S_TLSDESC_AUTH: - if (auto *SA = Target.getAddSym()) - cast<MCSymbolELF>(SA)->setType(ELF::STT_TLS); + if (auto *SA = const_cast<MCSymbol *>(Target.getAddSym())) + static_cast<MCSymbolELF *>(SA)->setType(ELF::STT_TLS); break; default: break; @@ -488,7 +489,8 @@ bool AArch64ELFObjectWriter::needsRelocateWithSymbol(const MCValue &Val, // this global needs to be tagged. In addition, the linker needs to know // whether to emit a special addend when relocating `end` symbols, and this // can only be determined by the attributes of the symbol itself. - if (Val.getAddSym() && cast<MCSymbolELF>(Val.getAddSym())->isMemtag()) + if (Val.getAddSym() && + static_cast<const MCSymbolELF *>(Val.getAddSym())->isMemtag()) return true; if ((Val.getSpecifier() & AArch64::S_GOT) == AArch64::S_GOT) @@ -497,6 +499,17 @@ bool AArch64ELFObjectWriter::needsRelocateWithSymbol(const MCValue &Val, Val.getSpecifier()); } +void AArch64ELFObjectWriter::sortRelocs( + std::vector<ELFRelocationEntry> &Relocs) { + // PATCHINST relocations should be applied last because they may overwrite the + // whole instruction and so should take precedence over other relocations that + // modify operands of the original instruction. + std::stable_partition(Relocs.begin(), Relocs.end(), + [](const ELFRelocationEntry &R) { + return R.Type != ELF::R_AARCH64_PATCHINST; + }); +} + std::unique_ptr<MCObjectTargetWriter> llvm::createAArch64ELFObjectWriter(uint8_t OSABI, bool IsILP32) { return std::make_unique<AArch64ELFObjectWriter>(OSABI, IsILP32); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp index 6257e99..917dbdf 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp @@ -35,7 +35,6 @@ #include "llvm/MC/MCTargetOptions.h" #include "llvm/MC/MCWinCOFFStreamer.h" #include "llvm/Support/AArch64BuildAttributes.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Support/raw_ostream.h" @@ -418,7 +417,8 @@ private: } MCSymbol *emitMappingSymbol(StringRef Name) { - auto *Symbol = cast<MCSymbolELF>(getContext().createLocalSymbol(Name)); + auto *Symbol = + static_cast<MCSymbolELF *>(getContext().createLocalSymbol(Name)); emitLabel(Symbol); return Symbol; } @@ -455,7 +455,7 @@ void AArch64TargetELFStreamer::emitInst(uint32_t Inst) { void AArch64TargetELFStreamer::emitDirectiveVariantPCS(MCSymbol *Symbol) { getStreamer().getAssembler().registerSymbol(*Symbol); - cast<MCSymbolELF>(Symbol)->setOther(ELF::STO_AARCH64_VARIANT_PCS); + static_cast<MCSymbolELF *>(Symbol)->setOther(ELF::STO_AARCH64_VARIANT_PCS); } void AArch64TargetELFStreamer::finish() { @@ -541,7 +541,7 @@ void AArch64TargetELFStreamer::finish() { MCSectionELF *MemtagSec = nullptr; for (const MCSymbol &Symbol : Asm.symbols()) { - const auto &Sym = cast<MCSymbolELF>(Symbol); + auto &Sym = static_cast<const MCSymbolELF &>(Symbol); if (Sym.isMemtag()) { MemtagSec = Ctx.getELFSection(".memtag.globals.static", ELF::SHT_AARCH64_MEMTAG_GLOBALS_STATIC, 0); @@ -556,7 +556,7 @@ void AArch64TargetELFStreamer::finish() { S.switchSection(MemtagSec); const auto *Zero = MCConstantExpr::create(0, Ctx); for (const MCSymbol &Symbol : Asm.symbols()) { - const auto &Sym = cast<MCSymbolELF>(Symbol); + auto &Sym = static_cast<const MCSymbolELF &>(Symbol); if (!Sym.isMemtag()) continue; auto *SRE = MCSymbolRefExpr::create(&Sym, Ctx); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp index 3c8b571..54b58e9 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp @@ -1017,14 +1017,22 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, else return false; + StringRef Reg = getRegisterName(MI->getOperand(4).getReg()); + bool NotXZR = Reg != "xzr"; + + // If a mandatory is not specified in the TableGen + // (i.e. no register operand should be present), and the register value + // is not xzr/x31, then disassemble to a SYS alias instead. + if (NotXZR && !NeedsReg) + return false; + std::string Str = Ins + Name; llvm::transform(Str, Str.begin(), ::tolower); O << '\t' << Str; - if (NeedsReg) { - O << ", "; - printRegName(O, MI->getOperand(4).getReg()); - } + + if (NeedsReg) + O << ", " << Reg; return true; } diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp index 828c5c5..2b5cf34 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp @@ -53,9 +53,9 @@ const MCAsmInfo::AtSpecifier MachOAtSpecifiers[] = { {AArch64::S_MACHO_TLVPPAGEOFF, "TLVPPAGEOFF"}, }; -StringRef AArch64::getSpecifierName(const MCSpecifierExpr &Expr) { +StringRef AArch64::getSpecifierName(AArch64::Specifier S) { // clang-format off - switch (static_cast<uint32_t>(Expr.getSpecifier())) { + switch (static_cast<uint32_t>(S)) { case AArch64::S_CALL: return ""; case AArch64::S_LO12: return ":lo12:"; case AArch64::S_ABS_G3: return ":abs_g3:"; @@ -124,7 +124,7 @@ static bool evaluate(const MCSpecifierExpr &Expr, MCValue &Res, if (!Expr.getSubExpr()->evaluateAsRelocatable(Res, Asm)) return false; Res.setSpecifier(Expr.getSpecifier()); - return true; + return !Res.getSubSym(); } AArch64MCAsmInfoDarwin::AArch64MCAsmInfoDarwin(bool IsILP32) { @@ -183,7 +183,7 @@ void AArch64MCAsmInfoDarwin::printSpecifierExpr( raw_ostream &OS, const MCSpecifierExpr &Expr) const { if (auto *AE = dyn_cast<AArch64AuthMCExpr>(&Expr)) return AE->print(OS, this); - OS << AArch64::getSpecifierName(Expr); + OS << AArch64::getSpecifierName(Expr.getSpecifier()); printExpr(OS, *Expr.getSubExpr()); } @@ -232,7 +232,7 @@ void AArch64MCAsmInfoELF::printSpecifierExpr( raw_ostream &OS, const MCSpecifierExpr &Expr) const { if (auto *AE = dyn_cast<AArch64AuthMCExpr>(&Expr)) return AE->print(OS, this); - OS << AArch64::getSpecifierName(Expr); + OS << AArch64::getSpecifierName(Expr.getSpecifier()); printExpr(OS, *Expr.getSubExpr()); } @@ -262,7 +262,7 @@ AArch64MCAsmInfoMicrosoftCOFF::AArch64MCAsmInfoMicrosoftCOFF() { void AArch64MCAsmInfoMicrosoftCOFF::printSpecifierExpr( raw_ostream &OS, const MCSpecifierExpr &Expr) const { - OS << AArch64::getSpecifierName(Expr); + OS << AArch64::getSpecifierName(Expr.getSpecifier()); printExpr(OS, *Expr.getSubExpr()); } @@ -292,7 +292,7 @@ AArch64MCAsmInfoGNUCOFF::AArch64MCAsmInfoGNUCOFF() { void AArch64MCAsmInfoGNUCOFF::printSpecifierExpr( raw_ostream &OS, const MCSpecifierExpr &Expr) const { - OS << AArch64::getSpecifierName(Expr); + OS << AArch64::getSpecifierName(Expr.getSpecifier()); printExpr(OS, *Expr.getSubExpr()); } diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h index c28e925..0dfa61b 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h @@ -181,7 +181,7 @@ enum { /// Return the string representation of the ELF relocation specifier /// (e.g. ":got:", ":lo12:"). -StringRef getSpecifierName(const MCSpecifierExpr &Expr); +StringRef getSpecifierName(Specifier S); inline Specifier getSymbolLoc(Specifier S) { return static_cast<Specifier>(S & AArch64::S_SymLocBits); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp index a53b676..5fe9993 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp @@ -73,9 +73,10 @@ unsigned AArch64WinCOFFObjectWriter::getRelocType( // Supported break; default: - Ctx.reportError(Fixup.getLoc(), "relocation specifier " + - AArch64::getSpecifierName(*A64E) + - " unsupported on COFF targets"); + Ctx.reportError(Fixup.getLoc(), + "relocation specifier " + + AArch64::getSpecifierName(A64E->getSpecifier()) + + " unsupported on COFF targets"); return COFF::IMAGE_REL_ARM64_ABSOLUTE; // Dummy return value } } @@ -83,9 +84,10 @@ unsigned AArch64WinCOFFObjectWriter::getRelocType( switch (FixupKind) { default: { if (auto *A64E = dyn_cast<MCSpecifierExpr>(Expr)) { - Ctx.reportError(Fixup.getLoc(), "relocation specifier " + - AArch64::getSpecifierName(*A64E) + - " unsupported on COFF targets"); + Ctx.reportError(Fixup.getLoc(), + "relocation specifier " + + AArch64::getSpecifierName(A64E->getSpecifier()) + + " unsupported on COFF targets"); } else { MCFixupKindInfo Info = MAB.getFixupKindInfo(Fixup.getKind()); Ctx.reportError(Fixup.getLoc(), Twine("relocation type ") + Info.Name + diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp new file mode 100644 index 0000000..5dfaa891 --- /dev/null +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -0,0 +1,696 @@ +//===- MachineSMEABIPass.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the SME ABI requirements for ZA state. This includes +// implementing the lazy ZA state save schemes around calls. +// +//===----------------------------------------------------------------------===// +// +// This pass works by collecting instructions that require ZA to be in a +// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state +// transitions to ensure ZA is in the required state before instructions. State +// transitions represent actions such as setting up or restoring a lazy save. +// Certain points within a function may also have predefined states independent +// of any instructions, for example, a "shared_za" function is always entered +// and exited in the "ACTIVE" state. +// +// To handle ZA state across control flow, we make use of edge bundling. This +// assigns each block an "incoming" and "outgoing" edge bundle (representing +// incoming and outgoing edges). Initially, these are unique to each block; +// then, in the process of forming bundles, the outgoing block of a block is +// joined with the incoming bundle of all successors. The result is that each +// bundle can be assigned a single ZA state, which ensures the state required by +// all a blocks' successors is the same, and that each basic block will always +// be entered with the same ZA state. This eliminates the need for splitting +// edges to insert state transitions or "phi" nodes for ZA states. +// +// See below for a simple example of edge bundling. +// +// The following shows a conditionally executed basic block (BB1): +// +// if (cond) +// BB1 +// BB2 +// +// Initial Bundles Joined Bundles +// +// ┌──0──┐ ┌──0──┐ +// │ BB0 │ │ BB0 │ +// └──1──┘ └──1──┘ +// ├───────┐ ├───────┐ +// ▼ │ ▼ │ +// ┌──2──┐ │ ─────► ┌──1──┐ │ +// │ BB1 │ ▼ │ BB1 │ ▼ +// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐ +// └───►4 BB2 │ └───►1 BB2 │ +// └──5──┘ └──2──┘ +// +// On the left are the initial per-block bundles, and on the right are the +// joined bundles (which are the result of the EdgeBundles analysis). + +#include "AArch64InstrInfo.h" +#include "AArch64MachineFunctionInfo.h" +#include "AArch64Subtarget.h" +#include "MCTargetDesc/AArch64AddressingModes.h" +#include "llvm/ADT/BitmaskEnum.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/EdgeBundles.h" +#include "llvm/CodeGen/LivePhysRegs.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-machine-sme-abi" + +namespace { + +enum ZAState { + // Any/unknown state (not valid) + ANY = 0, + + // ZA is in use and active (i.e. within the accumulator) + ACTIVE, + + // A ZA save has been set up or committed (i.e. ZA is dormant or off) + LOCAL_SAVED, + + // ZA is off or a lazy save has been set up by the caller + CALLER_DORMANT, + + // ZA is off + OFF, + + // The number of ZA states (not a valid state) + NUM_ZA_STATE +}; + +/// A bitmask enum to record live physical registers that the "emit*" routines +/// may need to preserve. Note: This only tracks registers we may clobber. +enum LiveRegs : uint8_t { + None = 0, + NZCV = 1 << 0, + W0 = 1 << 1, + W0_HI = 1 << 2, + X0 = W0 | W0_HI, + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI) +}; + +/// Holds the virtual registers live physical registers have been saved to. +struct PhysRegSave { + LiveRegs PhysLiveRegs; + Register StatusFlags = AArch64::NoRegister; + Register X0Save = AArch64::NoRegister; +}; + +static bool isLegalEdgeBundleZAState(ZAState State) { + switch (State) { + case ZAState::ACTIVE: + case ZAState::LOCAL_SAVED: + return true; + default: + return false; + } +} +struct TPIDR2State { + int FrameIndex = -1; +}; + +StringRef getZAStateString(ZAState State) { +#define MAKE_CASE(V) \ + case V: \ + return #V; + switch (State) { + MAKE_CASE(ZAState::ANY) + MAKE_CASE(ZAState::ACTIVE) + MAKE_CASE(ZAState::LOCAL_SAVED) + MAKE_CASE(ZAState::CALLER_DORMANT) + MAKE_CASE(ZAState::OFF) + default: + llvm_unreachable("Unexpected ZAState"); + } +#undef MAKE_CASE +} + +static bool isZAorZTRegOp(const TargetRegisterInfo &TRI, + const MachineOperand &MO) { + if (!MO.isReg() || !MO.getReg().isPhysical()) + return false; + return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) { + return AArch64::MPR128RegClass.contains(SR) || + AArch64::ZTRRegClass.contains(SR); + }); +} + +/// Returns the required ZA state needed before \p MI and an iterator pointing +/// to where any code required to change the ZA state should be inserted. +static std::pair<ZAState, MachineBasicBlock::iterator> +getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI, + bool ZAOffAtReturn) { + MachineBasicBlock::iterator InsertPt(MI); + + if (MI.getOpcode() == AArch64::InOutZAUsePseudo) + return {ZAState::ACTIVE, std::prev(InsertPt)}; + + if (MI.getOpcode() == AArch64::RequiresZASavePseudo) + return {ZAState::LOCAL_SAVED, std::prev(InsertPt)}; + + if (MI.isReturn()) + return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt}; + + for (auto &MO : MI.operands()) { + if (isZAorZTRegOp(TRI, MO)) + return {ZAState::ACTIVE, InsertPt}; + } + + return {ZAState::ANY, InsertPt}; +} + +struct MachineSMEABI : public MachineFunctionPass { + inline static char ID = 0; + + MachineSMEABI() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + + StringRef getPassName() const override { return "Machine SME ABI pass"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<EdgeBundlesWrapperLegacy>(); + AU.addPreservedID(MachineLoopInfoID); + AU.addPreservedID(MachineDominatorsID); + MachineFunctionPass::getAnalysisUsage(AU); + } + + /// Collects the needed ZA state (and live registers) before each instruction + /// within the machine function. + void collectNeededZAStates(SMEAttrs); + + /// Assigns each edge bundle a ZA state based on the needed states of blocks + /// that have incoming or outgoing edges in that bundle. + void assignBundleZAStates(); + + /// Inserts code to handle changes between ZA states within the function. + /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA. + void insertStateChanges(); + + // Emission routines for private and shared ZA functions (using lazy saves). + void emitNewZAPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); + void emitRestoreLazySave(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + LiveRegs PhysLiveRegs); + void emitSetupLazySave(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); + void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); + void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, + bool ClearTPIDR2); + + void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, + ZAState From, ZAState To, LiveRegs PhysLiveRegs); + + /// Save live physical registers to virtual registers. + PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, DebugLoc DL); + /// Restore physical registers from a save of their previous values. + void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, DebugLoc DL); + + /// Get or create a TPIDR2 block in this function. + TPIDR2State getTPIDR2Block(); + +private: + /// Contains the needed ZA state (and live registers) at an instruction. + struct InstInfo { + ZAState NeededState{ZAState::ANY}; + MachineBasicBlock::iterator InsertPt; + LiveRegs PhysLiveRegs = LiveRegs::None; + }; + + /// Contains the needed ZA state for each instruction in a block. + /// Instructions that do not require a ZA state are not recorded. + struct BlockInfo { + ZAState FixedEntryState{ZAState::ANY}; + SmallVector<InstInfo> Insts; + LiveRegs PhysLiveRegsAtExit = LiveRegs::None; + }; + + // All pass state that must be cleared between functions. + struct PassState { + SmallVector<BlockInfo> Blocks; + SmallVector<ZAState> BundleStates; + std::optional<TPIDR2State> TPIDR2Block; + } State; + + MachineFunction *MF = nullptr; + EdgeBundles *Bundles = nullptr; + const AArch64Subtarget *Subtarget = nullptr; + const AArch64RegisterInfo *TRI = nullptr; + const TargetInstrInfo *TII = nullptr; + MachineRegisterInfo *MRI = nullptr; +}; + +void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { + assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) && + "Expected function to have ZA/ZT0 state!"); + + State.Blocks.resize(MF->getNumBlockIDs()); + for (MachineBasicBlock &MBB : *MF) { + BlockInfo &Block = State.Blocks[MBB.getNumber()]; + if (MBB.isEntryBlock()) { + // Entry block: + Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface() + ? ZAState::CALLER_DORMANT + : ZAState::ACTIVE; + } else if (MBB.isEHPad()) { + // EH entry block: + Block.FixedEntryState = ZAState::LOCAL_SAVED; + } + + LiveRegUnits LiveUnits(*TRI); + LiveUnits.addLiveOuts(MBB); + + auto GetPhysLiveRegs = [&] { + LiveRegs PhysLiveRegs = LiveRegs::None; + if (!LiveUnits.available(AArch64::NZCV)) + PhysLiveRegs |= LiveRegs::NZCV; + // We have to track W0 and X0 separately as otherwise things can get + // confused if we attempt to preserve X0 but only W0 was defined. + if (!LiveUnits.available(AArch64::W0)) + PhysLiveRegs |= LiveRegs::W0; + if (!LiveUnits.available(AArch64::W0_HI)) + PhysLiveRegs |= LiveRegs::W0_HI; + return PhysLiveRegs; + }; + + Block.PhysLiveRegsAtExit = GetPhysLiveRegs(); + auto FirstTerminatorInsertPt = MBB.getFirstTerminator(); + for (MachineInstr &MI : reverse(MBB)) { + MachineBasicBlock::iterator MBBI(MI); + LiveUnits.stepBackward(MI); + LiveRegs PhysLiveRegs = GetPhysLiveRegs(); + auto [NeededState, InsertPt] = getZAStateBeforeInst( + *TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface()); + assert((InsertPt == MBBI || + InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) && + "Unexpected state change insertion point!"); + // TODO: Do something to avoid state changes where NZCV is live. + if (MBBI == FirstTerminatorInsertPt) + Block.PhysLiveRegsAtExit = PhysLiveRegs; + if (NeededState != ZAState::ANY) + Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs}); + } + + // Reverse vector (as we had to iterate backwards for liveness). + std::reverse(Block.Insts.begin(), Block.Insts.end()); + } +} + +void MachineSMEABI::assignBundleZAStates() { + State.BundleStates.resize(Bundles->getNumBundles()); + for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) { + LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n'); + + // Attempt to assign a ZA state for this bundle that minimizes state + // transitions. Edges within loops are given a higher weight as we assume + // they will be executed more than once. + // TODO: We should propagate desired incoming/outgoing states through blocks + // that have the "ANY" state first to make better global decisions. + int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0}; + for (unsigned BlockID : Bundles->getBlocks(I)) { + LLVM_DEBUG(dbgs() << "- bb." << BlockID); + + const BlockInfo &Block = State.Blocks[BlockID]; + if (Block.Insts.empty()) { + LLVM_DEBUG(dbgs() << " (no state preference)\n"); + continue; + } + bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I; + bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I; + + ZAState DesiredIncomingState = Block.Insts.front().NeededState; + if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) { + EdgeStateCounts[DesiredIncomingState]++; + LLVM_DEBUG(dbgs() << " DesiredIncomingState: " + << getZAStateString(DesiredIncomingState)); + } + ZAState DesiredOutgoingState = Block.Insts.back().NeededState; + if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) { + EdgeStateCounts[DesiredOutgoingState]++; + LLVM_DEBUG(dbgs() << " DesiredOutgoingState: " + << getZAStateString(DesiredOutgoingState)); + } + LLVM_DEBUG(dbgs() << '\n'); + } + + ZAState BundleState = + ZAState(max_element(EdgeStateCounts) - EdgeStateCounts); + + // Force ZA to be active in bundles that don't have a preferred state. + // TODO: Something better here (to avoid extra mode switches). + if (BundleState == ZAState::ANY) + BundleState = ZAState::ACTIVE; + + LLVM_DEBUG({ + dbgs() << "Chosen ZA state: " << getZAStateString(BundleState) << '\n' + << "Edge counts:"; + for (auto [State, Count] : enumerate(EdgeStateCounts)) + dbgs() << " " << getZAStateString(ZAState(State)) << ": " << Count; + dbgs() << "\n\n"; + }); + + State.BundleStates[I] = BundleState; + } +} + +void MachineSMEABI::insertStateChanges() { + for (MachineBasicBlock &MBB : *MF) { + const BlockInfo &Block = State.Blocks[MBB.getNumber()]; + ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(), + /*Out=*/false)]; + + ZAState CurrentState = Block.FixedEntryState; + if (CurrentState == ZAState::ANY) + CurrentState = InState; + + for (auto &Inst : Block.Insts) { + if (CurrentState != Inst.NeededState) + emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState, + Inst.PhysLiveRegs); + CurrentState = Inst.NeededState; + } + + if (MBB.succ_empty()) + continue; + + ZAState OutState = + State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)]; + if (CurrentState != OutState) + emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState, + Block.PhysLiveRegsAtExit); + } +} + +TPIDR2State MachineSMEABI::getTPIDR2Block() { + if (State.TPIDR2Block) + return *State.TPIDR2Block; + MachineFrameInfo &MFI = MF->getFrameInfo(); + State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)}; + return *State.TPIDR2Block; +} + +static DebugLoc getDebugLoc(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + if (MBBI != MBB.end()) + return MBBI->getDebugLoc(); + return DebugLoc(); +} + +void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + DebugLoc DL = getDebugLoc(MBB, MBBI); + + // Get pointer to TPIDR2 block. + Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass); + Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2) + .addFrameIndex(getTPIDR2Block().FrameIndex) + .addImm(0) + .addImm(0); + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr) + .addReg(TPIDR2); + // Set TPIDR2_EL0 to point to TPIDR2 block. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR)) + .addImm(AArch64SysReg::TPIDR2_EL0) + .addReg(TPIDR2Ptr); +} + +PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + DebugLoc DL) { + PhysRegSave RegSave{PhysLiveRegs}; + if (PhysLiveRegs & LiveRegs::NZCV) { + RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags) + .addImm(AArch64SysReg::NZCV) + .addReg(AArch64::NZCV, RegState::Implicit); + } + // Note: Preserving X0 is "free" as this is before register allocation, so + // the register allocator is still able to optimize these copies. + if (PhysLiveRegs & LiveRegs::W0) { + RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI + ? &AArch64::GPR64RegClass + : &AArch64::GPR32RegClass); + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save) + .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0); + } + return RegSave; +} + +void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + DebugLoc DL) { + if (RegSave.StatusFlags != AArch64::NoRegister) + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR)) + .addImm(AArch64SysReg::NZCV) + .addReg(RegSave.StatusFlags) + .addReg(AArch64::NZCV, RegState::ImplicitDefine); + + if (RegSave.X0Save != AArch64::NoRegister) + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), + RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0) + .addReg(RegSave.X0Save); +} + +void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + LiveRegs PhysLiveRegs) { + auto *TLI = Subtarget->getTargetLowering(); + DebugLoc DL = getDebugLoc(MBB, MBBI); + Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + Register TPIDR2 = AArch64::X0; + + // TODO: Emit these within the restore MBB to prevent unnecessary saves. + PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL); + + // Enable ZA. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) + .addImm(AArch64SVCR::SVCRZA) + .addImm(1); + // Get current TPIDR2_EL0. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0) + .addImm(AArch64SysReg::TPIDR2_EL0); + // Get pointer to TPIDR2 block. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2) + .addFrameIndex(getTPIDR2Block().FrameIndex) + .addImm(0) + .addImm(0); + // (Conditionally) restore ZA state. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo)) + .addReg(TPIDR2EL0) + .addReg(TPIDR2) + .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_RESTORE)) + .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + // Zero TPIDR2_EL0. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR)) + .addImm(AArch64SysReg::TPIDR2_EL0) + .addReg(AArch64::XZR); + + restorePhyRegSave(RegSave, MBB, MBBI, DL); +} + +void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + bool ClearTPIDR2) { + DebugLoc DL = getDebugLoc(MBB, MBBI); + + if (ClearTPIDR2) + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR)) + .addImm(AArch64SysReg::TPIDR2_EL0) + .addReg(AArch64::XZR); + + // Disable ZA. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) + .addImm(AArch64SVCR::SVCRZA) + .addImm(0); +} + +void MachineSMEABI::emitAllocateLazySaveBuffer( + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { + MachineFrameInfo &MFI = MF->getFrameInfo(); + + DebugLoc DL = getDebugLoc(MBB, MBBI); + Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + Register Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + + // Calculate SVL. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1); + + // 1. Allocate the lazy save buffer. + { + // TODO This function grows the stack with a subtraction, which doesn't work + // on Windows. Some refactoring to share the functionality in + // LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI + // supports SME + assert(!Subtarget->isTargetWindows() && + "Lazy ZA save is not yet supported on Windows"); + // Get original stack pointer. + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP) + .addReg(AArch64::SP); + // Allocate a lazy-save buffer object of the size given, normally SVL * SVL + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer) + .addReg(SVL) + .addReg(SVL) + .addReg(SP); + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP) + .addReg(Buffer); + // We have just allocated a variable sized object, tell this to PEI. + MFI.CreateVariableSizedObject(Align(16), nullptr); + } + + // 2. Setup the TPIDR2 block. + { + // Note: This case just needs to do `SVL << 48`. It is not implemented as we + // generally don't support big-endian SVE/SME. + if (!Subtarget->isLittleEndian()) + reportFatalInternalError( + "TPIDR2 block initialization is not supported on big-endian targets"); + + // Store buffer pointer and num_za_save_slices. + // Bytes 10-15 are implicitly zeroed. + BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi)) + .addReg(Buffer) + .addReg(SVL) + .addFrameIndex(getTPIDR2Block().FrameIndex) + .addImm(0); + } +} + +void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + auto *TLI = Subtarget->getTargetLowering(); + DebugLoc DL = getDebugLoc(MBB, MBBI); + + // Get current TPIDR2_EL0. + Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS)) + .addReg(TPIDR2EL0, RegState::Define) + .addImm(AArch64SysReg::TPIDR2_EL0); + // If TPIDR2_EL0 is non-zero, commit the lazy save. + // NOTE: Functions that only use ZT0 don't need to zero ZA. + bool ZeroZA = + MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs().hasZAState(); + auto CommitZASave = + BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo)) + .addReg(TPIDR2EL0) + .addImm(ZeroZA ? 1 : 0) + .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE)) + .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + if (ZeroZA) + CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine); + // Enable ZA (as ZA could have previously been in the OFF state). + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) + .addImm(AArch64SVCR::SVCRZA) + .addImm(1); +} + +void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB, + MachineBasicBlock::iterator InsertPt, + ZAState From, ZAState To, + LiveRegs PhysLiveRegs) { + + // ZA not used. + if (From == ZAState::ANY || To == ZAState::ANY) + return; + + // If we're exiting from the CALLER_DORMANT state that means this new ZA + // function did not touch ZA (so ZA was never turned on). + if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF) + return; + + // TODO: Avoid setting up the save buffer if there's no transition to + // LOCAL_SAVED. + if (From == ZAState::CALLER_DORMANT) { + assert(MBB.getParent() + ->getInfo<AArch64FunctionInfo>() + ->getSMEFnAttrs() + .hasPrivateZAInterface() && + "CALLER_DORMANT state requires private ZA interface"); + assert(&MBB == &MBB.getParent()->front() && + "CALLER_DORMANT state only valid in entry block"); + emitNewZAPrologue(MBB, MBB.getFirstNonPHI()); + if (To == ZAState::ACTIVE) + return; // Nothing more to do (ZA is active after the prologue). + + // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save + // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this + // case by changing the placement of the zero instruction. + From = ZAState::ACTIVE; + } + + if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED) + emitSetupLazySave(MBB, InsertPt); + else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE) + emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs); + else if (To == ZAState::OFF) { + assert(From != ZAState::CALLER_DORMANT && + "CALLER_DORMANT to OFF should have already been handled"); + emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED); + } else { + dbgs() << "Error: Transition from " << getZAStateString(From) << " to " + << getZAStateString(To) << '\n'; + llvm_unreachable("Unimplemented state transition"); + } +} + +} // end anonymous namespace + +INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI", + false, false) + +bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) { + if (!MF.getSubtarget<AArch64Subtarget>().hasSME()) + return false; + + auto *AFI = MF.getInfo<AArch64FunctionInfo>(); + SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs(); + if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State()) + return false; + + assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!"); + + // Reset pass state. + State = PassState{}; + this->MF = &MF; + Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles(); + Subtarget = &MF.getSubtarget<AArch64Subtarget>(); + TII = Subtarget->getInstrInfo(); + TRI = Subtarget->getRegisterInfo(); + MRI = &MF.getRegInfo(); + + collectNeededZAStates(SMEFnAttrs); + assignBundleZAStates(); + insertStateChanges(); + + // Allocate save buffer (if needed). + if (State.TPIDR2Block) { + MachineBasicBlock &EntryBlock = MF.front(); + emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI()); + } + + return true; +} + +FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); } diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index 4af4d49..2008516 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -15,11 +15,16 @@ #include "AArch64.h" #include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/RuntimeLibcalls.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; @@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass { bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetPassConfig>(); + } + private: bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder, - SMEAttrs FnAttrs); + SMEAttrs FnAttrs, const TargetLowering &TLI); }; } // end anonymous namespace @@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); } //===----------------------------------------------------------------------===// // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0. -void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { +void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI, + bool ZT0IsUndef = false) { auto &Ctx = M->getContext(); auto *TPIDR2SaveTy = FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false); auto Attrs = AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible"); + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE; FunctionCallee Callee = - M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs); + M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs); CallInst *Call = Builder.CreateCall(Callee); // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark @@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { if (ZT0IsUndef) Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef")); - Call->setCallingConv( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0); + Call->setCallingConv(TLI.getLibcallCallingConv(LC)); // A save to TPIDR2 should be followed by clearing TPIDR2_EL0. Function *WriteIntr = @@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { /// interface if it does not share ZA or ZT0. /// bool SMEABI::updateNewStateFunctions(Module *M, Function *F, - IRBuilder<> &Builder, SMEAttrs FnAttrs) { + IRBuilder<> &Builder, SMEAttrs FnAttrs, + const TargetLowering &TLI) { LLVMContext &Context = F->getContext(); BasicBlock *OrigBB = &F->getEntryBlock(); Builder.SetInsertPoint(&OrigBB->front()); @@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F, // Create a call __arm_tpidr2_save, which commits the lazy save. Builder.SetInsertPoint(&SaveBB->back()); - emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); + emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); // Enable pstate.za at the start of the function. Builder.SetInsertPoint(&OrigBB->front()); @@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) { if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za")) return false; + const TargetMachine &TM = + getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); + const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering(); + bool Changed = false; SMEAttrs FnAttrs(F); if (FnAttrs.isNewZA() || FnAttrs.isNewZT0()) - Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs); + Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI); return Changed; } diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp index bd28716..85cca1d 100644 --- a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp @@ -80,16 +80,10 @@ static bool isMatchingStartStopPair(const MachineInstr *MI1, if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask()) return false; - // This optimisation is unlikely to happen in practice for conditional - // smstart/smstop pairs as the virtual registers for pstate.sm will always - // be different. - // TODO: For this optimisation to apply to conditional smstart/smstop, - // this pass will need to do more work to remove redundant calls to - // __arm_sme_state. - // Only consider conditional start/stop pairs which read the same register - // holding the original value of pstate.sm, as some conditional start/stops - // require the state on entry to the function. + // holding the original value of pstate.sm. This is somewhat over conservative + // as all conditional streaming mode changes only look at the state on entry + // to the function. if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) { Register Reg1 = MI1->getOperand(3).getReg(); Register Reg2 = MI2->getOperand(3).getReg(); @@ -134,13 +128,6 @@ bool SMEPeepholeOpt::optimizeStartStopPairs( bool Changed = false; MachineInstr *Prev = nullptr; - SmallVector<MachineInstr *, 4> ToBeRemoved; - - // Convenience function to reset the matching of a sequence. - auto Reset = [&]() { - Prev = nullptr; - ToBeRemoved.clear(); - }; // Walk through instructions in the block trying to find pairs of smstart // and smstop nodes that cancel each other out. We only permit a limited @@ -162,14 +149,10 @@ bool SMEPeepholeOpt::optimizeStartStopPairs( // that we marked for deletion in between. Prev->eraseFromParent(); MI.eraseFromParent(); - for (MachineInstr *TBR : ToBeRemoved) - TBR->eraseFromParent(); - ToBeRemoved.clear(); Prev = nullptr; Changed = true; NumSMChangesRemoved += 2; } else { - Reset(); Prev = &MI; } continue; @@ -185,7 +168,7 @@ bool SMEPeepholeOpt::optimizeStartStopPairs( // of streaming mode. If not, the algorithm should reset. switch (MI.getOpcode()) { default: - Reset(); + Prev = nullptr; break; case AArch64::COALESCER_BARRIER_FPR16: case AArch64::COALESCER_BARRIER_FPR32: @@ -199,7 +182,7 @@ bool SMEPeepholeOpt::optimizeStartStopPairs( // concrete example/test-case. if (isSVERegOp(TRI, MRI, MI.getOperand(0)) || isSVERegOp(TRI, MRI, MI.getOperand(1))) - Reset(); + Prev = nullptr; break; case AArch64::ADJCALLSTACKDOWN: case AArch64::ADJCALLSTACKUP: @@ -207,12 +190,6 @@ bool SMEPeepholeOpt::optimizeStartStopPairs( case AArch64::ADDXri: // We permit these as they don't generate SVE/NEON instructions. break; - case AArch64::VGRestorePseudo: - case AArch64::VGSavePseudo: - // When the smstart/smstop are removed, we should also remove - // the pseudos that save/restore the VG value for CFI info. - ToBeRemoved.push_back(&MI); - break; case AArch64::MSRpstatesvcrImm1: case AArch64::MSRpstatePseudo: llvm_unreachable("Should have been handled"); diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index a0320f9..74e4a7f 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -315,10 +315,16 @@ def addsub_imm8_opt_lsl_i16 : imm8_opt_lsl<16, "uint16_t", SVEAddSubImmOperand16 def addsub_imm8_opt_lsl_i32 : imm8_opt_lsl<32, "uint32_t", SVEAddSubImmOperand32>; def addsub_imm8_opt_lsl_i64 : imm8_opt_lsl<64, "uint64_t", SVEAddSubImmOperand64>; -def SVEAddSubImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8>", []>; -def SVEAddSubImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16>", []>; -def SVEAddSubImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32>", []>; -def SVEAddSubImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64>", []>; +let Complexity = 1 in { +def SVEAddSubImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8, false>", []>; +def SVEAddSubImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16, false>", []>; +def SVEAddSubImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32, false>", []>; +def SVEAddSubImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64, false>", []>; + +def SVEAddSubNegImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8, true>", []>; +def SVEAddSubNegImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16, true>", []>; +def SVEAddSubNegImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32, true>", []>; +def SVEAddSubNegImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64, true>", []>; def SVEAddSubSSatNegImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i8, true>", []>; def SVEAddSubSSatNegImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i16, true>", []>; @@ -329,6 +335,7 @@ def SVEAddSubSSatPosImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MV def SVEAddSubSSatPosImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i16, false>", []>; def SVEAddSubSSatPosImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i32, false>", []>; def SVEAddSubSSatPosImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubSSatImm<MVT::i64, false>", []>; +} // Complexity = 1 def SVECpyDupImm8Pat : ComplexPattern<i32, 2, "SelectSVECpyDupImm<MVT::i8>", []>; def SVECpyDupImm16Pat : ComplexPattern<i32, 2, "SelectSVECpyDupImm<MVT::i16>", []>; @@ -809,6 +816,11 @@ let hasNoSchedulingInfo = 1 in { Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2, zprty:$Zs3), []> { let FalseLanes = flags; } + + class UnpredRegImmPseudo<ZPRRegOp zprty, Operand immty> + : SVEPseudo2Instr<NAME, 0>, + Pseudo<(outs zprty:$Zd), (ins zprty:$Zs, immty:$imm), []> { + } } // @@ -1885,13 +1897,14 @@ class sve_int_perm_extract_i<string asm> let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = DestructiveOther; + let DestructiveInstType = Destructive2xRegImmUnpred; let ElementSize = ElementSizeNone; let hasSideEffects = 0; } -multiclass sve_int_perm_extract_i<string asm, SDPatternOperator op> { - def NAME : sve_int_perm_extract_i<asm>; +multiclass sve_int_perm_extract_i<string asm, SDPatternOperator op, string Ps> { + def NAME : sve_int_perm_extract_i<asm>, + SVEPseudo2Instr<Ps, 1>; def : SVE_3_Op_Imm_Pat<nxv16i8, op, nxv16i8, nxv16i8, i32, imm0_255, !cast<Instruction>(NAME)>; @@ -5148,11 +5161,14 @@ multiclass sve_int_dup_imm<string asm> { (!cast<Instruction>(NAME # _D) ZPR64:$Zd, cpy_imm8_opt_lsl_i64:$imm), 1>; def : InstAlias<"fmov $Zd, #0.0", - (!cast<Instruction>(NAME # _H) ZPR16:$Zd, 0, 0), 1>; + (!cast<Instruction>(NAME # _H) ZPR16:$Zd, + (cpy_imm8_opt_lsl_i16 0, 0)), 1>; def : InstAlias<"fmov $Zd, #0.0", - (!cast<Instruction>(NAME # _S) ZPR32:$Zd, 0, 0), 1>; + (!cast<Instruction>(NAME # _S) ZPR32:$Zd, + (cpy_imm8_opt_lsl_i32 0, 0)), 1>; def : InstAlias<"fmov $Zd, #0.0", - (!cast<Instruction>(NAME # _D) ZPR64:$Zd, 0, 0), 1>; + (!cast<Instruction>(NAME # _D) ZPR64:$Zd, + (cpy_imm8_opt_lsl_i64 0, 0)), 1>; } class sve_int_dup_fpimm<bits<2> sz8_64, Operand fpimmtype, @@ -5212,7 +5228,8 @@ class sve_int_arith_imm0<bits<2> sz8_64, bits<3> opc, string asm, let hasSideEffects = 0; } -multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op> { +multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op, + SDPatternOperator inv_op = null_frag> { def _B : sve_int_arith_imm0<0b00, opc, asm, ZPR8, addsub_imm8_opt_lsl_i8>; def _H : sve_int_arith_imm0<0b01, opc, asm, ZPR16, addsub_imm8_opt_lsl_i16>; def _S : sve_int_arith_imm0<0b10, opc, asm, ZPR32, addsub_imm8_opt_lsl_i32>; @@ -5222,6 +5239,12 @@ multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op> { def : SVE_1_Op_Imm_OptLsl_Pat<nxv8i16, op, ZPR16, i32, SVEAddSubImm16Pat, !cast<Instruction>(NAME # _H)>; def : SVE_1_Op_Imm_OptLsl_Pat<nxv4i32, op, ZPR32, i32, SVEAddSubImm32Pat, !cast<Instruction>(NAME # _S)>; def : SVE_1_Op_Imm_OptLsl_Pat<nxv2i64, op, ZPR64, i64, SVEAddSubImm64Pat, !cast<Instruction>(NAME # _D)>; + + // Extra patterns for add(x, splat(-ve)) -> sub(x, +ve). There is no i8 + // pattern as all i8 constants can be handled by an add. + def : SVE_1_Op_Imm_OptLsl_Pat<nxv8i16, inv_op, ZPR16, i32, SVEAddSubNegImm16Pat, !cast<Instruction>(NAME # _H)>; + def : SVE_1_Op_Imm_OptLsl_Pat<nxv4i32, inv_op, ZPR32, i32, SVEAddSubNegImm32Pat, !cast<Instruction>(NAME # _S)>; + def : SVE_1_Op_Imm_OptLsl_Pat<nxv2i64, inv_op, ZPR64, i64, SVEAddSubNegImm64Pat, !cast<Instruction>(NAME # _D)>; } multiclass sve_int_arith_imm0_ssat<bits<3> opc, string asm, SDPatternOperator op, @@ -5543,11 +5566,14 @@ multiclass sve_int_dup_imm_pred_merge<string asm, SDPatternOperator op> { nxv2i64, nxv2i1, i64, SVECpyDupImm64Pat>; def : InstAlias<"fmov $Zd, $Pg/m, #0.0", - (!cast<Instruction>(NAME # _H) ZPR16:$Zd, PPRAny:$Pg, 0, 0), 0>; + (!cast<Instruction>(NAME # _H) ZPR16:$Zd, PPRAny:$Pg, + (cpy_imm8_opt_lsl_i16 0, 0)), 0>; def : InstAlias<"fmov $Zd, $Pg/m, #0.0", - (!cast<Instruction>(NAME # _S) ZPR32:$Zd, PPRAny:$Pg, 0, 0), 0>; + (!cast<Instruction>(NAME # _S) ZPR32:$Zd, PPRAny:$Pg, + (cpy_imm8_opt_lsl_i32 0, 0)), 0>; def : InstAlias<"fmov $Zd, $Pg/m, #0.0", - (!cast<Instruction>(NAME # _D) ZPR64:$Zd, PPRAny:$Pg, 0, 0), 0>; + (!cast<Instruction>(NAME # _D) ZPR64:$Zd, PPRAny:$Pg, + (cpy_imm8_opt_lsl_i64 0, 0)), 0>; def : Pat<(vselect PPRAny:$Pg, (SVEDup0), (nxv8f16 ZPR:$Zd)), (!cast<Instruction>(NAME # _H) $Zd, $Pg, 0, 0)>; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 271094f..dd6fa16 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -7,17 +7,14 @@ //===----------------------------------------------------------------------===// #include "AArch64SMEAttributes.h" +#include "AArch64ISelLowering.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/RuntimeLibcalls.h" #include <cassert> using namespace llvm; -void SMEAttrs::set(unsigned M, bool Enable) { - if (Enable) - Bitmask |= M; - else - Bitmask &= ~M; - +void SMEAttrs::validate() const { // Streaming Mode Attrs assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) && "SM_Enabled and SM_Compatible are mutually exclusive"); @@ -77,19 +74,36 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= encodeZT0State(StateValue::New); } -void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) { +void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, + const AArch64TargetLowering &TLI) { + RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + if (Impl == RTLIB::Unsupported) + return; unsigned KnownAttrs = SMEAttrs::Normal; - if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state") - KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine); - if (FuncName == "__arm_tpidr2_restore") + RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl); + switch (LC) { + case RTLIB::SMEABI_SME_STATE: + case RTLIB::SMEABI_TPIDR2_SAVE: + case RTLIB::SMEABI_GET_CURRENT_VG: + case RTLIB::SMEABI_SME_STATE_SIZE: + case RTLIB::SMEABI_SME_SAVE: + case RTLIB::SMEABI_SME_RESTORE: + KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + break; + case RTLIB::SMEABI_ZA_DISABLE: + case RTLIB::SMEABI_TPIDR2_RESTORE: KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) | SMEAttrs::SME_ABI_Routine; - if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" || - FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr") + break; + case RTLIB::SC_MEMCPY: + case RTLIB::SC_MEMMOVE: + case RTLIB::SC_MEMSET: + case RTLIB::SC_MEMCHR: KnownAttrs |= SMEAttrs::SM_Compatible; - if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" || - FuncName == "__arm_sme_state_size") - KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + break; + default: + break; + } set(KnownAttrs); } @@ -110,11 +124,11 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB) +SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) - CalledFn = SMEAttrs(*CalledFunction, SMEAttrs::InferAttrsFromName::Yes); + CalledFn = SMEAttrs(*CalledFunction, TLI); // FIXME: We probably should not allow SME attributes on direct calls but // clang duplicates streaming mode attributes at each callsite. diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index f1be0ecb..d26e3cd 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -13,6 +13,8 @@ namespace llvm { +class AArch64TargetLowering; + class Function; class CallBase; class AttributeList; @@ -48,19 +50,27 @@ public: CallSiteFlags_Mask = ZT0_Undef }; - enum class InferAttrsFromName { No, Yes }; - SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No) + SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr) : SMEAttrs(F.getAttributes()) { - if (Infer == InferAttrsFromName::Yes) - addKnownFunctionAttrs(F.getName()); + if (TLI) + addKnownFunctionAttrs(F.getName(), *TLI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); }; + SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) { + addKnownFunctionAttrs(FuncName, TLI); + }; - void set(unsigned M, bool Enable = true); + void set(unsigned M, bool Enable = true) { + if (Enable) + Bitmask |= M; + else + Bitmask &= ~M; +#ifndef NDEBUG + validate(); +#endif + } // Interfaces to query PSTATE.SM bool hasStreamingBody() const { return Bitmask & SM_Body; } @@ -146,7 +156,9 @@ public: } private: - void addKnownFunctionAttrs(StringRef FuncName); + void addKnownFunctionAttrs(StringRef FuncName, + const AArch64TargetLowering &TLI); + void validate() const; }; /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has @@ -163,7 +175,7 @@ public: SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB); + SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } @@ -194,7 +206,7 @@ public: } bool requiresEnablingZAAfterCall() const { - return requiresLazySave() || requiresDisablingZABeforeCall(); + return requiresDisablingZABeforeCall(); } bool requiresPreservingAllZAState() const { |