diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
79 files changed, 10221 insertions, 1673 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h index 8d0ff41..a8e15c3 100644 --- a/llvm/lib/Target/AArch64/AArch64.h +++ b/llvm/lib/Target/AArch64/AArch64.h @@ -26,11 +26,14 @@ namespace llvm { class AArch64RegisterBankInfo; class AArch64Subtarget; class AArch64TargetMachine; +enum class CodeGenOptLevel; class FunctionPass; class InstructionSelector; +class ModulePass; FunctionPass *createAArch64DeadRegisterDefinitions(); FunctionPass *createAArch64RedundantCopyEliminationPass(); +FunctionPass *createAArch64RedundantCondBranchPass(); FunctionPass *createAArch64CondBrTuning(); FunctionPass *createAArch64CompressJumpTablesPass(); FunctionPass *createAArch64ConditionalCompares(); @@ -60,7 +63,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass(); FunctionPass *createAArch64CollectLOHPass(); FunctionPass *createSMEABIPass(); FunctionPass *createSMEPeepholeOptPass(); -FunctionPass *createMachineSMEABIPass(); +FunctionPass *createMachineSMEABIPass(CodeGenOptLevel); ModulePass *createSVEIntrinsicOptsPass(); InstructionSelector * createAArch64InstructionSelector(const AArch64TargetMachine &, @@ -101,6 +104,7 @@ void initializeAArch64PostSelectOptimizePass(PassRegistry &); void initializeAArch64PreLegalizerCombinerPass(PassRegistry &); void initializeAArch64PromoteConstantPass(PassRegistry&); void initializeAArch64RedundantCopyEliminationPass(PassRegistry&); +void initializeAArch64RedundantCondBranchPass(PassRegistry &); void initializeAArch64SIMDInstrOptPass(PassRegistry &); void initializeAArch64SLSHardeningPass(PassRegistry &); void initializeAArch64SpeculationHardeningPass(PassRegistry &); diff --git a/llvm/lib/Target/AArch64/AArch64.td b/llvm/lib/Target/AArch64/AArch64.td index a4529a5..1a4367b8 100644 --- a/llvm/lib/Target/AArch64/AArch64.td +++ b/llvm/lib/Target/AArch64/AArch64.td @@ -40,6 +40,8 @@ include "AArch64SchedPredExynos.td" include "AArch64SchedPredNeoverse.td" include "AArch64Combine.td" +defm : RemapAllTargetPseudoPointerOperands<GPR64sp>; + def AArch64InstrInfo : InstrInfo; //===----------------------------------------------------------------------===// @@ -133,6 +135,8 @@ include "AArch64SchedNeoverseN2.td" include "AArch64SchedNeoverseN3.td" include "AArch64SchedNeoverseV1.td" include "AArch64SchedNeoverseV2.td" +include "AArch64SchedNeoverseV3.td" +include "AArch64SchedNeoverseV3AE.td" include "AArch64SchedOryon.td" include "AArch64Processors.td" diff --git a/llvm/lib/Target/AArch64/AArch64A53Fix835769.cpp b/llvm/lib/Target/AArch64/AArch64A53Fix835769.cpp index a51f630..407714a 100644 --- a/llvm/lib/Target/AArch64/AArch64A53Fix835769.cpp +++ b/llvm/lib/Target/AArch64/AArch64A53Fix835769.cpp @@ -178,11 +178,10 @@ static void insertNopBeforeInstruction(MachineBasicBlock &MBB, MachineInstr* MI, MachineInstr *I = getLastNonPseudo(MBB, TII); assert(I && "Expected instruction"); DebugLoc DL = I->getDebugLoc(); - BuildMI(I->getParent(), DL, TII->get(AArch64::HINT)).addImm(0); - } - else { + BuildMI(I->getParent(), DL, TII->get(AArch64::NOP)); + } else { DebugLoc DL = MI->getDebugLoc(); - BuildMI(MBB, MI, DL, TII->get(AArch64::HINT)).addImm(0); + BuildMI(MBB, MI, DL, TII->get(AArch64::NOP)); } ++NumNopsAdded; diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp index 1169f26..d0c4b1b 100644 --- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp @@ -655,25 +655,22 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); IRBuilder<> B(BB); - // Load the global symbol as a pointer to the check function. - Value *GuardFn; - if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) - GuardFn = GuardFnCFGlobal; - else - GuardFn = GuardFnGlobal; - LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn); - - // Create new call instruction. The CFGuard check should always be a call, - // even if the original CallBase is an Invoke or CallBr instruction. + // Create new call instruction. The call check should always be a call, + // even if the original CallBase is an Invoke or CallBr instructio. + // This is treated as a direct call, so do not use GuardFnCFGlobal. + LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFnGlobal); Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); CallInst *GuardCheck = B.CreateCall( GuardFnType, GuardCheckLoad, {F, Thunk}); + Value *GuardCheckDest = B.CreateExtractValue(GuardCheck, 0); + Value *GuardFinalDest = B.CreateExtractValue(GuardCheck, 1); // Ensure that the first argument is passed in the correct register. GuardCheck->setCallingConv(CallingConv::CFGuard_Check); SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args())); - CallInst *Call = B.CreateCall(Arm64Ty, GuardCheck, Args); + OperandBundleDef OB("cfguardtarget", GuardFinalDest); + CallInst *Call = B.CreateCall(Arm64Ty, GuardCheckDest, Args, OB); Call->setTailCallKind(llvm::CallInst::TCK_MustTail); if (Call->getType()->isVoidTy()) @@ -773,11 +770,21 @@ void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { CallInst *GuardCheck = B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand, Thunk}, Bundles); + Value *GuardCheckDest = B.CreateExtractValue(GuardCheck, 0); + Value *GuardFinalDest = B.CreateExtractValue(GuardCheck, 1); // Ensure that the first argument is passed in the correct register. GuardCheck->setCallingConv(CallingConv::CFGuard_Check); - CB->setCalledOperand(GuardCheck); + // Update the call: set the callee, and add a bundle with the final + // destination, + CB->setCalledOperand(GuardCheckDest); + OperandBundleDef OB("cfguardtarget", GuardFinalDest); + auto *NewCall = CallBase::addOperandBundle(CB, LLVMContext::OB_cfguardtarget, + OB, CB->getIterator()); + NewCall->copyMetadata(*CB); + CB->replaceAllUsesWith(NewCall); + CB->eraseFromParent(); } bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { @@ -795,7 +802,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { I64Ty = Type::getInt64Ty(M->getContext()); VoidTy = Type::getVoidTy(M->getContext()); - GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false); + GuardFnType = + FunctionType::get(StructType::get(PtrTy, PtrTy), {PtrTy, PtrTy}, false); DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false); GuardFnCFGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", PtrTy); GuardFnGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall", PtrTy); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index c31a090..57431616 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -49,12 +49,14 @@ #include "llvm/IR/Module.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" +#include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" #include "llvm/MC/MCInstBuilder.h" #include "llvm/MC/MCSectionELF.h" #include "llvm/MC/MCSectionMachO.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" +#include "llvm/MC/MCValue.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -95,6 +97,7 @@ class AArch64AsmPrinter : public AsmPrinter { bool EnableImportCallOptimization = false; DenseMap<MCSection *, std::vector<std::pair<MCSymbol *, MCSymbol *>>> SectionToImportedFunctionCalls; + unsigned PAuthIFuncNextUniqueID = 1; public: static char ID; @@ -162,8 +165,7 @@ public: Register ScratchReg, AArch64PACKey::ID Key, AArch64PAuth::AuthCheckMethod Method, - bool ShouldTrap, - const MCSymbol *OnFailure); + const MCSymbol *OnFailure = nullptr); // Check authenticated LR before tail calling. void emitPtrauthTailCallHardening(const MachineInstr *TC); @@ -174,7 +176,12 @@ public: const MachineOperand *AUTAddrDisc, Register Scratch, std::optional<AArch64PACKey::ID> PACKey, - uint64_t PACDisc, Register PACAddrDisc); + uint64_t PACDisc, Register PACAddrDisc, Value *DS); + + // Emit R_AARCH64_PATCHINST, the deactivation symbol relocation. Returns true + // if no instruction should be emitted because the deactivation symbol is + // defined in the current module so this function emitted a NOP instead. + bool emitDeactivationSymbolRelocation(Value *DS); // Emit the sequence for PAC. void emitPtrauthSign(const MachineInstr *MI); @@ -212,6 +219,10 @@ public: // authenticating) void LowerLOADgotAUTH(const MachineInstr &MI); + const MCExpr *emitPAuthRelocationAsIRelative( + const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr); + /// tblgen'erated driver function for lowering simple MI->MC /// pseudo instructions. bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst); @@ -461,7 +472,7 @@ void AArch64AsmPrinter::emitSled(const MachineInstr &MI, SledKind Kind) { EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::B).addImm(8)); for (int8_t I = 0; I < NoopsInSledCount; I++) - EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::HINT).addImm(0)); + EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::NOP)); OutStreamer->emitLabel(Target); recordSled(CurSled, MI, Kind, 2); @@ -1266,9 +1277,7 @@ void AArch64AsmPrinter::PrintDebugValueComment(const MachineInstr *MI, // Frame address. Currently handles register +- offset only. assert(MI->isIndirectDebugValue()); OS << '['; - for (unsigned I = 0, E = std::distance(MI->debug_operands().begin(), - MI->debug_operands().end()); - I < E; ++I) { + for (unsigned I = 0, E = llvm::size(MI->debug_operands()); I < E; ++I) { if (I != 0) OS << ", "; printOperand(MI, I, OS); @@ -1688,7 +1697,7 @@ void AArch64AsmPrinter::LowerSTACKMAP(MCStreamer &OutStreamer, StackMaps &SM, // Emit nops. for (unsigned i = 0; i < NumNOPBytes; i += 4) - EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::HINT).addImm(0)); + EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::NOP)); } // Lower a patchpoint of the form: @@ -1722,7 +1731,7 @@ void AArch64AsmPrinter::LowerPATCHPOINT(MCStreamer &OutStreamer, StackMaps &SM, assert((NumBytes - EncodedBytes) % 4 == 0 && "Invalid number of NOP bytes requested!"); for (unsigned i = EncodedBytes; i < NumBytes; i += 4) - EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::HINT).addImm(0)); + EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::NOP)); } void AArch64AsmPrinter::LowerSTATEPOINT(MCStreamer &OutStreamer, StackMaps &SM, @@ -1731,7 +1740,7 @@ void AArch64AsmPrinter::LowerSTATEPOINT(MCStreamer &OutStreamer, StackMaps &SM, if (unsigned PatchBytes = SOpers.getNumPatchBytes()) { assert(PatchBytes % 4 == 0 && "Invalid number of NOP bytes requested!"); for (unsigned i = 0; i < PatchBytes; i += 4) - EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::HINT).addImm(0)); + EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::NOP)); } else { // Lower call target and choose correct opcode const MachineOperand &CallTarget = SOpers.getCallTarget(); @@ -1939,14 +1948,19 @@ Register AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc, return ScratchReg; } -/// Emits a code sequence to check an authenticated pointer value. +/// Emit a code sequence to check an authenticated pointer value. +/// +/// This function emits a sequence of instructions that checks if TestedReg was +/// authenticated successfully. On success, execution continues at the next +/// instruction after the sequence. /// -/// If OnFailure argument is passed, jump there on check failure instead -/// of proceeding to the next instruction (only if ShouldTrap is false). +/// The action performed on failure depends on the OnFailure argument: +/// * if OnFailure is not nullptr, control is transferred to that label after +/// clearing the PAC field +/// * otherwise, BRK instruction is emitted to generate an error void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue( Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key, - AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap, - const MCSymbol *OnFailure) { + AArch64PAuth::AuthCheckMethod Method, const MCSymbol *OnFailure) { // Insert a sequence to check if authentication of TestedReg succeeded, // such as: // @@ -1983,7 +1997,7 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue( .addReg(getWRegFromXReg(ScratchReg)) .addReg(TestedReg) .addImm(0)); - assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error"); + assert(!OnFailure && "DummyLoad always traps on error"); return; } @@ -2037,15 +2051,14 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue( llvm_unreachable("Unsupported check method"); } - if (ShouldTrap) { - assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap"); + if (!OnFailure) { // Trapping sequences do a 'brk'. // brk #<0xc470 + aut key> EmitToStreamer(MCInstBuilder(AArch64::BRK).addImm(0xc470 | Key)); } else { // Non-trapping checked sequences return the stripped result in TestedReg, - // skipping over success-only code (such as re-signing the pointer) if - // there is one. + // skipping over success-only code (such as re-signing the pointer) by + // jumping to OnFailure label. // Note that this can introduce an authentication oracle (such as based on // the high bits of the re-signed value). @@ -2070,12 +2083,9 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue( MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg)); } - if (OnFailure) { - // b Lend - EmitToStreamer( - MCInstBuilder(AArch64::B) - .addExpr(MCSymbolRefExpr::create(OnFailure, OutContext))); - } + // b Lend + const auto *OnFailureExpr = MCSymbolRefExpr::create(OnFailure, OutContext); + EmitToStreamer(MCInstBuilder(AArch64::B).addExpr(OnFailureExpr)); } // If the auth check succeeds, we can continue. @@ -2102,16 +2112,35 @@ void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) { "Neither x16 nor x17 is available as a scratch register"); AArch64PACKey::ID Key = AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA; - emitPtrauthCheckAuthenticatedValue( - AArch64::LR, ScratchReg, Key, LRCheckMethod, - /*ShouldTrap=*/true, /*OnFailure=*/nullptr); + emitPtrauthCheckAuthenticatedValue(AArch64::LR, ScratchReg, Key, + LRCheckMethod); +} + +bool AArch64AsmPrinter::emitDeactivationSymbolRelocation(Value *DS) { + if (!DS) + return false; + + if (isa<GlobalAlias>(DS)) { + // Just emit the nop directly. + EmitToStreamer(MCInstBuilder(AArch64::NOP)); + return true; + } + MCSymbol *Dot = OutContext.createTempSymbol(); + OutStreamer->emitLabel(Dot); + const MCExpr *DeactDotExpr = MCSymbolRefExpr::create(Dot, OutContext); + + const MCExpr *DSExpr = MCSymbolRefExpr::create( + OutContext.getOrCreateSymbol(DS->getName()), OutContext); + OutStreamer->emitRelocDirective(*DeactDotExpr, "R_AARCH64_PATCHINST", DSExpr, + SMLoc()); + return false; } void AArch64AsmPrinter::emitPtrauthAuthResign( Register AUTVal, AArch64PACKey::ID AUTKey, uint64_t AUTDisc, const MachineOperand *AUTAddrDisc, Register Scratch, std::optional<AArch64PACKey::ID> PACKey, uint64_t PACDisc, - Register PACAddrDisc) { + Register PACAddrDisc, Value *DS) { const bool IsAUTPAC = PACKey.has_value(); // We expand AUT/AUTPAC into a sequence of the form @@ -2158,15 +2187,17 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( bool AUTZero = AUTDiscReg == AArch64::XZR; unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero); - // autiza x16 ; if AUTZero - // autia x16, x17 ; if !AUTZero - MCInst AUTInst; - AUTInst.setOpcode(AUTOpc); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - if (!AUTZero) - AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); - EmitToStreamer(*OutStreamer, AUTInst); + if (!emitDeactivationSymbolRelocation(DS)) { + // autiza x16 ; if AUTZero + // autia x16, x17 ; if !AUTZero + MCInst AUTInst; + AUTInst.setOpcode(AUTOpc); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + if (!AUTZero) + AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); + EmitToStreamer(*OutStreamer, AUTInst); + } // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done. if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) @@ -2178,9 +2209,8 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( if (IsAUTPAC && !ShouldTrap) EndSym = createTempSymbol("resign_end_"); - emitPtrauthCheckAuthenticatedValue(AUTVal, Scratch, AUTKey, - AArch64PAuth::AuthCheckMethod::XPAC, - ShouldTrap, EndSym); + emitPtrauthCheckAuthenticatedValue( + AUTVal, Scratch, AUTKey, AArch64PAuth::AuthCheckMethod::XPAC, EndSym); } // We already emitted unchecked and checked-but-non-trapping AUTs. @@ -2231,6 +2261,9 @@ void AArch64AsmPrinter::emitPtrauthSign(const MachineInstr *MI) { bool IsZeroDisc = DiscReg == AArch64::XZR; unsigned Opc = getPACOpcodeForKey(Key, IsZeroDisc); + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // paciza x16 ; if IsZeroDisc // pacia x16, x17 ; if !IsZeroDisc MCInst PACInst; @@ -2303,6 +2336,214 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) { EmitToStreamer(*OutStreamer, BRInst); } +static void emitAddress(MCStreamer &Streamer, MCRegister Reg, + const MCExpr *Expr, bool DSOLocal, + const MCSubtargetInfo &STI) { + MCValue Val; + if (!Expr->evaluateAsRelocatable(Val, nullptr)) + report_fatal_error("emitAddress could not evaluate"); + if (DSOLocal) { + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADRP) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(Expr, AArch64::S_ABS_PAGE, + Streamer.getContext())), + STI); + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADDXri) + .addReg(Reg) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(Expr, AArch64::S_LO12, + Streamer.getContext())) + .addImm(0), + STI); + } else { + auto *SymRef = + MCSymbolRefExpr::create(Val.getAddSym(), Streamer.getContext()); + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADRP) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(SymRef, AArch64::S_GOT_PAGE, + Streamer.getContext())), + STI); + Streamer.emitInstruction( + MCInstBuilder(AArch64::LDRXui) + .addReg(Reg) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(SymRef, AArch64::S_GOT_LO12, + Streamer.getContext())), + STI); + if (Val.getConstant()) + Streamer.emitInstruction(MCInstBuilder(AArch64::ADDXri) + .addReg(Reg) + .addReg(Reg) + .addImm(Val.getConstant()) + .addImm(0), + STI); + } +} + +static bool targetSupportsPAuthRelocation(const Triple &TT, + const MCExpr *Target, + const MCExpr *DSExpr) { + // No released version of glibc supports PAuth relocations. + if (TT.isOSGlibc()) + return false; + + // We emit PAuth constants as IRELATIVE relocations in cases where the + // constant cannot be represented as a PAuth relocation: + // 1) There is a deactivation symbol. + // 2) The signed value is not a symbol. + return !DSExpr && !isa<MCConstantExpr>(Target); +} + +static bool targetSupportsIRelativeRelocation(const Triple &TT) { + // IFUNCs are ELF-only. + if (!TT.isOSBinFormatELF()) + return false; + + // musl doesn't support IFUNCs. + if (TT.isMusl()) + return false; + + return true; +} + +// Emit an ifunc resolver that returns a signed pointer to the specified target, +// and return a FUNCINIT reference to the resolver. In the linked binary, this +// function becomes the target of an IRELATIVE relocation. This resolver is used +// to relocate signed pointers in global variable initializers in special cases +// where the standard R_AARCH64_AUTH_ABS64 relocation would not work. +// +// Example (signed null pointer, not address discriminated): +// +// .8byte .Lpauth_ifunc0 +// .pushsection .text.startup,"ax",@progbits +// .Lpauth_ifunc0: +// mov x0, #0 +// mov x1, #12345 +// b __emupac_pacda +// +// Example (signed null pointer, address discriminated): +// +// .Ltmp: +// .8byte .Lpauth_ifunc0 +// .pushsection .text.startup,"ax",@progbits +// .Lpauth_ifunc0: +// mov x0, #0 +// adrp x1, .Ltmp +// add x1, x1, :lo12:.Ltmp +// b __emupac_pacda +// .popsection +// +// Example (signed pointer to symbol, not address discriminated): +// +// .Ltmp: +// .8byte .Lpauth_ifunc0 +// .pushsection .text.startup,"ax",@progbits +// .Lpauth_ifunc0: +// adrp x0, symbol +// add x0, x0, :lo12:symbol +// mov x1, #12345 +// b __emupac_pacda +// .popsection +// +// Example (signed null pointer, not address discriminated, with deactivation +// symbol ds): +// +// .8byte .Lpauth_ifunc0 +// .pushsection .text.startup,"ax",@progbits +// .Lpauth_ifunc0: +// mov x0, #0 +// mov x1, #12345 +// .reloc ., R_AARCH64_PATCHINST, ds +// b __emupac_pacda +// ret +// .popsection +const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( + const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) { + const Triple &TT = TM.getTargetTriple(); + + // We only emit an IRELATIVE relocation if the target supports IRELATIVE and + // does not support the kind of PAuth relocation that we are trying to emit. + if (targetSupportsPAuthRelocation(TT, Target, DSExpr) || + !targetSupportsIRelativeRelocation(TT)) + return nullptr; + + // For now, only the DA key is supported. + if (KeyID != AArch64PACKey::DA) + return nullptr; + + std::unique_ptr<MCSubtargetInfo> STI( + TM.getTarget().createMCSubtargetInfo(TT, "", "")); + assert(STI && "Unable to create subtarget info"); + this->STI = static_cast<const AArch64Subtarget *>(&*STI); + + MCSymbol *Place = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(Place); + OutStreamer->pushSection(); + + OutStreamer->switchSection(OutStreamer->getContext().getELFSection( + ".text.startup", ELF::SHT_PROGBITS, ELF::SHF_ALLOC | ELF::SHF_EXECINSTR, + 0, "", true, PAuthIFuncNextUniqueID++, nullptr)); + + MCSymbol *IRelativeSym = + OutStreamer->getContext().createLinkerPrivateSymbol("pauth_ifunc"); + OutStreamer->emitLabel(IRelativeSym); + if (isa<MCConstantExpr>(Target)) { + OutStreamer->emitInstruction(MCInstBuilder(AArch64::MOVZXi) + .addReg(AArch64::X0) + .addExpr(Target) + .addImm(0), + *STI); + } else { + emitAddress(*OutStreamer, AArch64::X0, Target, IsDSOLocal, *STI); + } + if (HasAddressDiversity) { + auto *PlacePlusDisc = MCBinaryExpr::createAdd( + MCSymbolRefExpr::create(Place, OutStreamer->getContext()), + MCConstantExpr::create(static_cast<int16_t>(Disc), + OutStreamer->getContext()), + OutStreamer->getContext()); + emitAddress(*OutStreamer, AArch64::X1, PlacePlusDisc, /*IsDSOLocal=*/true, + *STI); + } else { + emitMOVZ(AArch64::X1, Disc, 0); + } + + if (DSExpr) { + MCSymbol *PrePACInst = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(PrePACInst); + + auto *PrePACInstExpr = + MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext()); + OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_PATCHINST", + DSExpr, SMLoc()); + } + + // We don't know the subtarget because this is being emitted for a global + // initializer. Because the performance of IFUNC resolvers is unimportant, we + // always call the EmuPAC runtime, which will end up using the PAC instruction + // if the target supports PAC. + MCSymbol *EmuPAC = + OutStreamer->getContext().getOrCreateSymbol("__emupac_pacda"); + const MCSymbolRefExpr *EmuPACRef = + MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext()); + OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef), + *STI); + + // We need a RET despite the above tail call because the deactivation symbol + // may replace the tail call with a NOP. + if (DSExpr) + OutStreamer->emitInstruction( + MCInstBuilder(AArch64::RET).addReg(AArch64::LR), *STI); + OutStreamer->popSection(); + + return MCSymbolRefExpr::create(IRelativeSym, AArch64::S_FUNCINIT, + OutStreamer->getContext()); +} + const MCExpr * AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { MCContext &Ctx = OutContext; @@ -2314,22 +2555,26 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { auto *BaseGVB = dyn_cast<GlobalValue>(BaseGV); - // If we can't understand the referenced ConstantExpr, there's nothing - // else we can do: emit an error. - if (!BaseGVB) { - BaseGV->getContext().emitError( - "cannot resolve target base/addend of ptrauth constant"); - return nullptr; + const MCExpr *Sym; + if (BaseGVB) { + // If there is an addend, turn that into the appropriate MCExpr. + Sym = MCSymbolRefExpr::create(getSymbol(BaseGVB), Ctx); + if (Offset.sgt(0)) + Sym = MCBinaryExpr::createAdd( + Sym, MCConstantExpr::create(Offset.getSExtValue(), Ctx), Ctx); + else if (Offset.slt(0)) + Sym = MCBinaryExpr::createSub( + Sym, MCConstantExpr::create((-Offset).getSExtValue(), Ctx), Ctx); + } else { + Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx); } - // If there is an addend, turn that into the appropriate MCExpr. - const MCExpr *Sym = MCSymbolRefExpr::create(getSymbol(BaseGVB), Ctx); - if (Offset.sgt(0)) - Sym = MCBinaryExpr::createAdd( - Sym, MCConstantExpr::create(Offset.getSExtValue(), Ctx), Ctx); - else if (Offset.slt(0)) - Sym = MCBinaryExpr::createSub( - Sym, MCConstantExpr::create((-Offset).getSExtValue(), Ctx), Ctx); + const MCExpr *DSExpr = nullptr; + if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) { + if (isa<GlobalAlias>(DS)) + return Sym; + DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx); + } uint64_t KeyID = CPA.getKey()->getZExtValue(); // We later rely on valid KeyID value in AArch64PACKeyIDToString call from @@ -2348,6 +2593,16 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { Disc = 0; } + // Check if we need to represent this with an IRELATIVE and emit it if so. + if (auto *IFuncSym = emitPAuthRelocationAsIRelative( + Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), + BaseGVB && BaseGVB->isDSOLocal(), DSExpr)) + return IFuncSym; + + if (DSExpr) + report_fatal_error("deactivation symbols unsupported in constant " + "expressions on this target"); + // Finally build the complete @AUTH expr. return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), Ctx); @@ -2519,9 +2774,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) { : AArch64PACKey::DA); emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AuthKey, - AArch64PAuth::AuthCheckMethod::XPAC, - /*ShouldTrap=*/true, - /*OnFailure=*/nullptr); + AArch64PAuth::AuthCheckMethod::XPAC); } } else { EmitToStreamer(MCInstBuilder(AArch64::LDRXui) @@ -2654,9 +2907,7 @@ void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) { (AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA : AArch64PACKey::DA); emitPtrauthCheckAuthenticatedValue(AuthResultReg, AArch64::X17, AuthKey, - AArch64PAuth::AuthCheckMethod::XPAC, - /*ShouldTrap=*/true, - /*OnFailure=*/nullptr); + AArch64PAuth::AuthCheckMethod::XPAC); emitMovXReg(DstReg, AuthResultReg); } @@ -2677,23 +2928,32 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) { void AArch64AsmPrinter::emitCBPseudoExpansion(const MachineInstr *MI) { bool IsImm = false; - bool Is32Bit = false; + unsigned Width = 0; switch (MI->getOpcode()) { default: llvm_unreachable("This is not a CB pseudo instruction"); + case AArch64::CBBAssertExt: + IsImm = false; + Width = 8; + break; + case AArch64::CBHAssertExt: + IsImm = false; + Width = 16; + break; case AArch64::CBWPrr: - Is32Bit = true; + Width = 32; break; case AArch64::CBXPrr: - Is32Bit = false; + Width = 64; break; case AArch64::CBWPri: IsImm = true; - Is32Bit = true; + Width = 32; break; case AArch64::CBXPri: IsImm = true; + Width = 64; break; } @@ -2703,61 +2963,61 @@ void AArch64AsmPrinter::emitCBPseudoExpansion(const MachineInstr *MI) { bool NeedsImmDec = false; bool NeedsImmInc = false; +#define GET_CB_OPC(IsImm, Width, ImmCond, RegCond) \ + (IsImm \ + ? (Width == 32 ? AArch64::CB##ImmCond##Wri : AArch64::CB##ImmCond##Xri) \ + : (Width == 8 \ + ? AArch64::CBB##RegCond##Wrr \ + : (Width == 16 ? AArch64::CBH##RegCond##Wrr \ + : (Width == 32 ? AArch64::CB##RegCond##Wrr \ + : AArch64::CB##RegCond##Xrr)))) + unsigned MCOpC; + // Decide if we need to either swap register operands or increment/decrement // immediate operands - unsigned MCOpC; switch (CC) { default: llvm_unreachable("Invalid CB condition code"); case AArch64CC::EQ: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBEQWri : AArch64::CBEQXri) - : (Is32Bit ? AArch64::CBEQWrr : AArch64::CBEQXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ EQ, /* Reg-Reg */ EQ); break; case AArch64CC::NE: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBNEWri : AArch64::CBNEXri) - : (Is32Bit ? AArch64::CBNEWrr : AArch64::CBNEXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ NE, /* Reg-Reg */ NE); break; case AArch64CC::HS: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri) - : (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ HI, /* Reg-Reg */ HS); NeedsImmDec = IsImm; break; case AArch64CC::LO: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri) - : (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ LO, /* Reg-Reg */ HI); NeedsRegSwap = !IsImm; break; case AArch64CC::HI: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri) - : (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ HI, /* Reg-Reg */ HI); break; case AArch64CC::LS: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri) - : (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ LO, /* Reg-Reg */ HS); NeedsRegSwap = !IsImm; NeedsImmInc = IsImm; break; case AArch64CC::GE: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri) - : (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ GT, /* Reg-Reg */ GE); NeedsImmDec = IsImm; break; case AArch64CC::LT: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri) - : (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ LT, /* Reg-Reg */ GT); NeedsRegSwap = !IsImm; break; case AArch64CC::GT: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri) - : (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ GT, /* Reg-Reg */ GT); break; case AArch64CC::LE: - MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri) - : (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr); + MCOpC = GET_CB_OPC(IsImm, Width, /* Reg-Imm */ LT, /* Reg-Reg */ GE); NeedsRegSwap = !IsImm; NeedsImmInc = IsImm; break; } +#undef GET_CB_OPC MCInst Inst; Inst.setOpcode(MCOpC); @@ -2947,17 +3207,18 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { } case AArch64::AUTx16x17: - emitPtrauthAuthResign(AArch64::X16, - (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), - AArch64::X17, std::nullopt, 0, 0); + emitPtrauthAuthResign( + AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, + std::nullopt, 0, 0, MI->getDeactivationSymbol()); return; case AArch64::AUTxMxN: emitPtrauthAuthResign(MI->getOperand(0).getReg(), (AArch64PACKey::ID)MI->getOperand(3).getImm(), MI->getOperand(4).getImm(), &MI->getOperand(5), - MI->getOperand(1).getReg(), std::nullopt, 0, 0); + MI->getOperand(1).getReg(), std::nullopt, 0, 0, + MI->getDeactivationSymbol()); return; case AArch64::AUTPAC: @@ -2965,7 +3226,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), MI->getOperand(5).getReg()); + MI->getOperand(4).getImm(), MI->getOperand(5).getReg(), + MI->getDeactivationSymbol()); return; case AArch64::PAC: @@ -3364,6 +3626,22 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { TS->emitARM64WinCFIPACSignLR(); return; + case AArch64::SEH_SaveAnyRegI: + assert(MI->getOperand(1).getImm() <= 1008 && + "SaveAnyRegQP SEH opcode offset must fit into 6 bits"); + TS->emitARM64WinCFISaveAnyRegI(MI->getOperand(0).getImm(), + MI->getOperand(1).getImm()); + return; + + case AArch64::SEH_SaveAnyRegIP: + assert(MI->getOperand(1).getImm() - MI->getOperand(0).getImm() == 1 && + "Non-consecutive registers not allowed for save_any_reg"); + assert(MI->getOperand(2).getImm() <= 1008 && + "SaveAnyRegQP SEH opcode offset must fit into 6 bits"); + TS->emitARM64WinCFISaveAnyRegIP(MI->getOperand(0).getImm(), + MI->getOperand(2).getImm()); + return; + case AArch64::SEH_SaveAnyRegQP: assert(MI->getOperand(1).getImm() - MI->getOperand(0).getImm() == 1 && "Non-consecutive registers not allowed for save_any_reg"); @@ -3422,12 +3700,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { } case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: emitCBPseudoExpansion(MI); return; } + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // Finally, do the automated lowerings for everything else. MCInst TmpInst; MCInstLowering.Lower(MI, TmpInst); diff --git a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp index 137ff89..57934ae 100644 --- a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp +++ b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp @@ -18,6 +18,7 @@ #include "AArch64MachineFunctionInfo.h" #include "AArch64Subtarget.h" +#include "Utils/AArch64BaseInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" @@ -47,6 +48,8 @@ public: StringRef getPassName() const override { return AARCH64_BRANCH_TARGETS_NAME; } private: + const AArch64Subtarget *Subtarget; + void addBTI(MachineBasicBlock &MBB, bool CouldCall, bool CouldJump, bool NeedsWinCFI); }; @@ -75,6 +78,8 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) { << "********** Function: " << MF.getName() << '\n'); const Function &F = MF.getFunction(); + Subtarget = &MF.getSubtarget<AArch64Subtarget>(); + // LLVM does not consider basic blocks which are the targets of jump tables // to be address-taken (the address can't escape anywhere else), but they are // used for indirect branches, so need BTI instructions. @@ -100,9 +105,8 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) { // a BTI, and pointing the indirect branch at that. For non-ELF targets we // can't rely on that, so we assume that `CouldCall` is _always_ true due // to the risk of long-branch thunks at link time. - if (&MBB == &*MF.begin() && - (!MF.getSubtarget<AArch64Subtarget>().isTargetELF() || - (F.hasAddressTaken() || !F.hasLocalLinkage()))) + if (&MBB == &*MF.begin() && (!Subtarget->isTargetELF() || + (F.hasAddressTaken() || !F.hasLocalLinkage()))) CouldCall = true; // If the block itself is address-taken, it could be indirectly branched @@ -132,16 +136,7 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall, << (CouldCall ? "c" : "") << " to " << MBB.getName() << "\n"); - const AArch64InstrInfo *TII = static_cast<const AArch64InstrInfo *>( - MBB.getParent()->getSubtarget().getInstrInfo()); - - unsigned HintNum = 32; - if (CouldCall) - HintNum |= 2; - if (CouldJump) - HintNum |= 4; - assert(HintNum != 32 && "No target kinds!"); - + unsigned HintNum = getBTIHintNum(CouldCall, CouldJump); auto MBBI = MBB.begin(); // If the block starts with EH_LABEL(s), skip them first. @@ -155,13 +150,16 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall, ++MBBI) ; - // SCTLR_EL1.BT[01] is set to 0 by default which means - // PACI[AB]SP are implicitly BTI C so no BTI C instruction is needed there. + // PACI[AB]SP are implicitly BTI c so insertion of a BTI can be skipped in + // this case. Depending on the runtime value of SCTLR_EL1.BT[01], they are not + // equivalent to a BTI jc, which still requires an additional BTI. if (MBBI != MBB.end() && ((HintNum & BTIMask) == BTIC) && (MBBI->getOpcode() == AArch64::PACIASP || MBBI->getOpcode() == AArch64::PACIBSP)) return; + const AArch64InstrInfo *TII = Subtarget->getInstrInfo(); + // Insert BTI exactly at the first executable instruction. const DebugLoc DL = MBB.findDebugLoc(MBBI); MachineInstr *BTI = BuildMI(MBB, MBBI, DL, TII->get(AArch64::HINT)) diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td index 1b5a713..e2a79a4 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -162,7 +162,13 @@ def RetCC_AArch64_AAPCS : CallingConv<[ ]>; let Entry = 1 in -def CC_AArch64_Win64PCS : CallingConv<AArch64_Common>; +def CC_AArch64_Win64PCS : CallingConv<!listconcat( + [ + // 'CFGuardTarget' is used for Arm64EC; it passes its parameter in X9. + CCIfCFGuardTarget<CCAssignToReg<[X9]>> + ], + AArch64_Common) +>; // Vararg functions on windows pass floats in integer registers let Entry = 1 in @@ -177,6 +183,9 @@ def CC_AArch64_Win64_VarArg : CallingConv<[ // a stack layout compatible with the x64 calling convention. let Entry = 1 in def CC_AArch64_Arm64EC_VarArg : CallingConv<[ + // 'CFGuardTarget' is used for Arm64EC; it passes its parameter in X9. + CCIfCFGuardTarget<CCAssignToReg<[X9]>>, + CCIfNest<CCAssignToReg<[X15]>>, // Convert small floating-point values to integer. @@ -345,7 +354,7 @@ def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[ let Entry = 1 in def RetCC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[ - CCIfType<[i64], CCAssignToReg<[X11]>> + CCIfType<[i64], CCAssignToReg<[X11, X9]>> ]>; @@ -601,6 +610,12 @@ def CSR_Win_AArch64_AAPCS_SwiftError def CSR_Win_AArch64_AAPCS_SwiftTail : CalleeSavedRegs<(sub CSR_Win_AArch64_AAPCS, X20, X22)>; +def CSR_Win_AArch64_RT_MostRegs + : CalleeSavedRegs<(add CSR_Win_AArch64_AAPCS, (sequence "X%u", 9, 15))>; + +def CSR_Win_AArch64_RT_AllRegs + : CalleeSavedRegs<(add CSR_Win_AArch64_RT_MostRegs, (sequence "Q%u", 8, 31))>; + // The Control Flow Guard check call uses a custom calling convention that also // preserves X0-X8 and Q0-Q7. def CSR_Win_AArch64_CFGuard_Check : CalleeSavedRegs<(add CSR_Win_AArch64_AAPCS, diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index b3ec65c..2783147 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -366,6 +366,7 @@ def AArch64PostLegalizerCombiner select_to_minmax, or_to_bsp, combine_concat_vector, commute_constant_to_rhs, extract_vec_elt_combines, push_freeze_to_prevent_poison_from_propagating, - combine_mul_cmlt, combine_use_vector_truncate, - extmultomull, truncsat_combines, lshr_of_trunc_of_lshr]> { + combine_mul_cmlt, combine_use_vector_truncate, + extmultomull, truncsat_combines, lshr_of_trunc_of_lshr, + funnel_shift_from_or_shift_constants_are_legal]> { } diff --git a/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp b/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp index cb831963..7712d2a 100644 --- a/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp +++ b/llvm/lib/Target/AArch64/AArch64ConditionalCompares.cpp @@ -629,8 +629,7 @@ void SSACCmpConv::convert(SmallVectorImpl<MachineBasicBlock *> &RemovedBlocks) { } const MCInstrDesc &MCID = TII->get(Opc); // Create a dummy virtual register for the SUBS def. - Register DestReg = - MRI->createVirtualRegister(TII->getRegClass(MCID, 0, TRI)); + Register DestReg = MRI->createVirtualRegister(TII->getRegClass(MCID, 0)); // Insert a SUBS Rn, #0 instruction instead of the cbz / cbnz. BuildMI(*Head, Head->end(), TermDL, MCID) .addReg(DestReg, RegState::Define | RegState::Dead) @@ -638,8 +637,7 @@ void SSACCmpConv::convert(SmallVectorImpl<MachineBasicBlock *> &RemovedBlocks) { .addImm(0) .addImm(0); // SUBS uses the GPR*sp register classes. - MRI->constrainRegClass(HeadCond[2].getReg(), - TII->getRegClass(MCID, 1, TRI)); + MRI->constrainRegClass(HeadCond[2].getReg(), TII->getRegClass(MCID, 1)); } Head->splice(Head->end(), CmpBB, CmpBB->begin(), CmpBB->end()); @@ -686,10 +684,10 @@ void SSACCmpConv::convert(SmallVectorImpl<MachineBasicBlock *> &RemovedBlocks) { unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(CmpBBTailCC); const MCInstrDesc &MCID = TII->get(Opc); MRI->constrainRegClass(CmpMI->getOperand(FirstOp).getReg(), - TII->getRegClass(MCID, 0, TRI)); + TII->getRegClass(MCID, 0)); if (CmpMI->getOperand(FirstOp + 1).isReg()) MRI->constrainRegClass(CmpMI->getOperand(FirstOp + 1).getReg(), - TII->getRegClass(MCID, 1, TRI)); + TII->getRegClass(MCID, 1)); MachineInstrBuilder MIB = BuildMI(*Head, CmpMI, CmpMI->getDebugLoc(), MCID) .add(CmpMI->getOperand(FirstOp)); // Register Rn if (isZBranch) diff --git a/llvm/lib/Target/AArch64/AArch64DeadRegisterDefinitionsPass.cpp b/llvm/lib/Target/AArch64/AArch64DeadRegisterDefinitionsPass.cpp index 75361f5..4ff49a6 100644 --- a/llvm/lib/Target/AArch64/AArch64DeadRegisterDefinitionsPass.cpp +++ b/llvm/lib/Target/AArch64/AArch64DeadRegisterDefinitionsPass.cpp @@ -156,7 +156,7 @@ void AArch64DeadRegisterDefinitions::processMachineBasicBlock( LLVM_DEBUG(dbgs() << " Ignoring, def is tied operand.\n"); continue; } - const TargetRegisterClass *RC = TII->getRegClass(Desc, I, TRI); + const TargetRegisterClass *RC = TII->getRegClass(Desc, I); unsigned NewReg; if (RC == nullptr) { LLVM_DEBUG(dbgs() << " Ignoring, register is not a GPR.\n"); diff --git a/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp b/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp index e9660ac1..ae58184 100644 --- a/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp @@ -549,6 +549,8 @@ void AArch64_IMM::expandMOVImm(uint64_t Imm, unsigned BitSize, // Prefer MOVZ/MOVN over ORR because of the rules for the "mov" alias. if ((BitSize / 16) - OneChunks <= 1 || (BitSize / 16) - ZeroChunks <= 1) { expandMOVImmSimple(Imm, BitSize, OneChunks, ZeroChunks, Insn); + assert(Insn.size() == 1 && + "Move of immediate should have expanded to a single MOVZ/MOVN"); return; } diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp index 1e607f4..60e6a82 100644 --- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -1063,6 +1063,7 @@ AArch64ExpandPseudo::expandCommitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { MachineInstr &MI = *MBBI; DebugLoc DL = MI.getDebugLoc(); + [[maybe_unused]] auto *RI = MBB.getParent()->getSubtarget().getRegisterInfo(); // Compare TPIDR2_EL0 against 0. Commit ZA if TPIDR2_EL0 is non-zero. MachineInstrBuilder Branch = @@ -1073,21 +1074,25 @@ AArch64ExpandPseudo::expandCommitZASave(MachineBasicBlock &MBB, MachineInstrBuilder MIB = BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::BL)); // Copy operands (mainly the regmask) from the pseudo. - for (unsigned I = 2; I < MI.getNumOperands(); ++I) + for (unsigned I = 3; I < MI.getNumOperands(); ++I) MIB.add(MI.getOperand(I)); // Clear TPIDR2_EL0. BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::MSR)) .addImm(AArch64SysReg::TPIDR2_EL0) .addReg(AArch64::XZR); bool ZeroZA = MI.getOperand(1).getImm() != 0; + bool ZeroZT0 = MI.getOperand(2).getImm() != 0; if (ZeroZA) { - [[maybe_unused]] auto *TRI = - MBB.getParent()->getSubtarget().getRegisterInfo(); - assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!"); + assert(MI.definesRegister(AArch64::ZAB0, RI) && "should define ZA!"); BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::ZERO_M)) .addImm(ZERO_ALL_ZA_MASK) .addDef(AArch64::ZAB0, RegState::ImplicitDefine); } + if (ZeroZT0) { + assert(MI.definesRegister(AArch64::ZT0, RI) && "should define ZT0!"); + BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::ZERO_T)) + .addDef(AArch64::ZT0); + } MI.eraseFromParent(); return &EndBB; @@ -1712,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, } case AArch64::InOutZAUsePseudo: case AArch64::RequiresZASavePseudo: + case AArch64::RequiresZT0SavePseudo: case AArch64::SMEStateAllocPseudo: case AArch64::COALESCER_BARRIER_FPR16: case AArch64::COALESCER_BARRIER_FPR32: @@ -1871,7 +1877,7 @@ bool AArch64ExpandPseudo::expandMBB(MachineBasicBlock &MBB) { } bool AArch64ExpandPseudo::runOnMachineFunction(MachineFunction &MF) { - TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); + TII = MF.getSubtarget<AArch64Subtarget>().getInstrInfo(); bool Modified = false; for (auto &MBB : MF) diff --git a/llvm/lib/Target/AArch64/AArch64FMV.td b/llvm/lib/Target/AArch64/AArch64FMV.td index b0f76ec..1293999 100644 --- a/llvm/lib/Target/AArch64/AArch64FMV.td +++ b/llvm/lib/Target/AArch64/AArch64FMV.td @@ -83,3 +83,14 @@ def : FMVExtension<"sve2-sha3", "SVE_SHA3">; def : FMVExtension<"sve2-sm4", "SVE_SM4">; def : FMVExtension<"wfxt", "WFXT">; def : FMVExtension<"cssc", "CSSC">; + +// Extensions which allow the user to override version priority. +// 8-bits allow 256-1 priority levels (excluding all zeros). +def : FMVExtension<"P0", "P0">; +def : FMVExtension<"P1", "P1">; +def : FMVExtension<"P2", "P2">; +def : FMVExtension<"P3", "P3">; +def : FMVExtension<"P4", "P4">; +def : FMVExtension<"P5", "P5">; +def : FMVExtension<"P6", "P6">; +def : FMVExtension<"P7", "P7">; diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp index cf34498..0246c74 100644 --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -16,10 +16,10 @@ #include "AArch64CallingConvention.h" #include "AArch64MachineFunctionInfo.h" #include "AArch64RegisterInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "Utils/AArch64BaseInfo.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" @@ -81,10 +81,7 @@ namespace { class AArch64FastISel final : public FastISel { class Address { public: - using BaseKind = enum { - RegBase, - FrameIndexBase - }; + enum BaseKind { RegBase, FrameIndexBase }; private: BaseKind Kind = RegBase; diff --git a/llvm/lib/Target/AArch64/AArch64Features.td b/llvm/lib/Target/AArch64/AArch64Features.td index 0e94b78..066724b 100644 --- a/llvm/lib/Target/AArch64/AArch64Features.td +++ b/llvm/lib/Target/AArch64/AArch64Features.td @@ -394,9 +394,6 @@ def FeatureTRBE : Extension<"trbe", "TRBE", "FEAT_TRBE", def FeatureETE : Extension<"ete", "ETE", "FEAT_ETE", "Enable Embedded Trace Extension", [FeatureTRBE]>; -def FeatureTME : ExtensionWithMArch<"tme", "TME", "FEAT_TME", - "Enable Transactional Memory Extension">; - //===----------------------------------------------------------------------===// // Armv9.1 Architecture Extensions //===----------------------------------------------------------------------===// @@ -626,6 +623,22 @@ def FeatureF16F32MM : ExtensionWithMArch<"f16f32mm", "F16F32MM", "FEAT_F16F32MM" "Enable Armv9.7-A Advanced SIMD half-precision matrix multiply-accumulate to single-precision", [FeatureNEON, FeatureFullFP16]>; //===----------------------------------------------------------------------===// +// Future Architecture Technologies +//===----------------------------------------------------------------------===// + +def FeatureMOPS_GO: ExtensionWithMArch<"mops-go", "MOPS_GO", "FEAT_MOPS_GO", + "Enable memset acceleration granule only">; + +def FeatureBTIE: ExtensionWithMArch<"btie", "BTIE", "FEAT_BTIE", + "Enable Enhanced Branch Target Identification extension">; + +def FeatureS1POE2: ExtensionWithMArch<"poe2", "POE2", "FEAT_S1POE2", + "Enable Stage 1 Permission Overlays Extension 2 instructions">; + +def FeatureTEV: ExtensionWithMArch<"tev", "TEV", "FEAT_TEV", + "Enable TIndex Exception-like Vector instructions">; + +//===----------------------------------------------------------------------===// // Other Features //===----------------------------------------------------------------------===// @@ -881,6 +894,11 @@ def FeatureUseFixedOverScalableIfEqualCost : SubtargetFeature<"use-fixed-over-sc "UseFixedOverScalableIfEqualCost", "true", "Prefer fixed width loop vectorization over scalable if the cost-model assigns equal costs">; +def FeatureDisableMaximizeScalableBandwidth : SubtargetFeature< "disable-maximize-scalable-bandwidth", + "DisableMaximizeScalableBandwidth", "true", + "Determine the maximum scalable vector length for a loop by the " + "largest scalar type rather than the smallest">; + // For performance reasons we prefer to use ldapr to ldapur on certain cores. def FeatureAvoidLDAPUR : SubtargetFeature<"avoid-ldapur", "AvoidLDAPUR", "true", "Prefer add+ldapr to offset ldapur">; diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index c76689f..c2f5c03 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -218,10 +218,10 @@ #include "AArch64MachineFunctionInfo.h" #include "AArch64PrologueEpilogue.h" #include "AArch64RegisterInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "MCTargetDesc/AArch64MCTargetDesc.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ValueTracking.h" @@ -644,10 +644,10 @@ bool AArch64FrameLowering::hasReservedCallFrame( MachineBasicBlock::iterator AArch64FrameLowering::eliminateCallFramePseudoInstr( MachineFunction &MF, MachineBasicBlock &MBB, MachineBasicBlock::iterator I) const { - const AArch64InstrInfo *TII = - static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); - const AArch64TargetLowering *TLI = - MF.getSubtarget<AArch64Subtarget>().getTargetLowering(); + + const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>(); + const AArch64InstrInfo *TII = Subtarget.getInstrInfo(); + const AArch64TargetLowering *TLI = Subtarget.getTargetLowering(); [[maybe_unused]] MachineFrameInfo &MFI = MF.getFrameInfo(); DebugLoc DL = I->getDebugLoc(); unsigned Opc = I->getOpcode(); @@ -973,8 +973,7 @@ bool AArch64FrameLowering::shouldSignReturnAddressEverywhere( if (MF.getTarget().getMCAsmInfo()->usesWindowsCFI()) return false; const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); - bool SignReturnAddressAll = AFI->shouldSignReturnAddress(/*SpillsLR=*/false); - return SignReturnAddressAll; + return AFI->getSignReturnAddressCondition() == SignReturnAddress::All; } // Given a load or a store instruction, generate an appropriate unwinding SEH @@ -1082,14 +1081,24 @@ AArch64FrameLowering::insertSEH(MachineBasicBlock::iterator MBBI, case AArch64::LDPXi: { Register Reg0 = MBBI->getOperand(0).getReg(); Register Reg1 = MBBI->getOperand(1).getReg(); + + int SEHReg0 = RegInfo->getSEHRegNum(Reg0); + int SEHReg1 = RegInfo->getSEHRegNum(Reg1); + if (Reg0 == AArch64::FP && Reg1 == AArch64::LR) MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveFPLR)) .addImm(Imm * 8) .setMIFlag(Flag); - else + else if (SEHReg0 >= 19 && SEHReg1 >= 19) MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveRegP)) - .addImm(RegInfo->getSEHRegNum(Reg0)) - .addImm(RegInfo->getSEHRegNum(Reg1)) + .addImm(SEHReg0) + .addImm(SEHReg1) + .addImm(Imm * 8) + .setMIFlag(Flag); + else + MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveAnyRegIP)) + .addImm(SEHReg0) + .addImm(SEHReg1) .addImm(Imm * 8) .setMIFlag(Flag); break; @@ -1097,10 +1106,16 @@ AArch64FrameLowering::insertSEH(MachineBasicBlock::iterator MBBI, case AArch64::STRXui: case AArch64::LDRXui: { int Reg = RegInfo->getSEHRegNum(MBBI->getOperand(0).getReg()); - MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveReg)) - .addImm(Reg) - .addImm(Imm * 8) - .setMIFlag(Flag); + if (Reg >= 19) + MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveReg)) + .addImm(Reg) + .addImm(Imm * 8) + .setMIFlag(Flag); + else + MIB = BuildMI(MF, DL, TII.get(AArch64::SEH_SaveAnyRegI)) + .addImm(Reg) + .addImm(Imm * 8) + .setMIFlag(Flag); break; } case AArch64::STRDui: @@ -1319,8 +1334,8 @@ StackOffset AArch64FrameLowering::getStackOffset(const MachineFunction &MF, // TODO: This function currently does not work for scalable vectors. int AArch64FrameLowering::getSEHFrameIndexOffset(const MachineFunction &MF, int FI) const { - const auto *RegInfo = static_cast<const AArch64RegisterInfo *>( - MF.getSubtarget().getRegisterInfo()); + const AArch64RegisterInfo *RegInfo = + MF.getSubtarget<AArch64Subtarget>().getRegisterInfo(); int ObjectOffset = MF.getFrameInfo().getObjectOffset(FI); return RegInfo->getLocalAddressRegister(MF) == AArch64::FP ? getFPOffset(MF, ObjectOffset).getFixed() @@ -1343,10 +1358,9 @@ StackOffset AArch64FrameLowering::resolveFrameOffsetReference( TargetStackID::Value StackID, Register &FrameReg, bool PreferFP, bool ForSimm) const { const auto &MFI = MF.getFrameInfo(); - const auto *RegInfo = static_cast<const AArch64RegisterInfo *>( - MF.getSubtarget().getRegisterInfo()); - const auto *AFI = MF.getInfo<AArch64FunctionInfo>(); const auto &Subtarget = MF.getSubtarget<AArch64Subtarget>(); + const AArch64RegisterInfo *RegInfo = Subtarget.getRegisterInfo(); + const auto *AFI = MF.getInfo<AArch64FunctionInfo>(); int64_t FPOffset = getFPOffset(MF, ObjectOffset).getFixed(); int64_t Offset = getStackOffset(MF, ObjectOffset).getFixed(); @@ -1466,7 +1480,7 @@ StackOffset AArch64FrameLowering::resolveFrameOffsetReference( return FPOffset; } FrameReg = RegInfo->hasBasePointer(MF) ? RegInfo->getBaseRegister() - : (unsigned)AArch64::SP; + : MCRegister(AArch64::SP); return SPOffset; } @@ -1539,8 +1553,10 @@ static bool produceCompactUnwindFrame(const AArch64FrameLowering &AFL, !AFL.requiresSaveVG(MF) && !AFI->isSVECC(); } -static bool invalidateWindowsRegisterPairing(unsigned Reg1, unsigned Reg2, - bool NeedsWinCFI, bool IsFirst, +static bool invalidateWindowsRegisterPairing(bool SpillExtendedVolatile, + unsigned SpillCount, unsigned Reg1, + unsigned Reg2, bool NeedsWinCFI, + bool IsFirst, const TargetRegisterInfo *TRI) { // If we are generating register pairs for a Windows function that requires // EH support, then pair consecutive registers only. There are no unwind @@ -1553,8 +1569,18 @@ static bool invalidateWindowsRegisterPairing(unsigned Reg1, unsigned Reg2, return true; if (!NeedsWinCFI) return false; + + // ARM64EC introduced `save_any_regp`, which expects 16-byte alignment. + // This is handled by only allowing paired spills for registers spilled at + // even positions (which should be 16-byte aligned, as other GPRs/FPRs are + // 8-bytes). We carve out an exception for {FP,LR}, which does not require + // 16-byte alignment in the uop representation. if (TRI->getEncodingValue(Reg2) == TRI->getEncodingValue(Reg1) + 1) - return false; + return SpillExtendedVolatile + ? !((Reg1 == AArch64::FP && Reg2 == AArch64::LR) || + (SpillCount % 2) == 0) + : false; + // If pairing a GPR with LR, the pair can be described by the save_lrpair // opcode. If this is the first register pair, it would end up with a // predecrement, but there's no save_lrpair_x opcode, so we can only do this @@ -1570,12 +1596,15 @@ static bool invalidateWindowsRegisterPairing(unsigned Reg1, unsigned Reg2, /// WindowsCFI requires that only consecutive registers can be paired. /// LR and FP need to be allocated together when the frame needs to save /// the frame-record. This means any other register pairing with LR is invalid. -static bool invalidateRegisterPairing(unsigned Reg1, unsigned Reg2, - bool UsesWinAAPCS, bool NeedsWinCFI, - bool NeedsFrameRecord, bool IsFirst, +static bool invalidateRegisterPairing(bool SpillExtendedVolatile, + unsigned SpillCount, unsigned Reg1, + unsigned Reg2, bool UsesWinAAPCS, + bool NeedsWinCFI, bool NeedsFrameRecord, + bool IsFirst, const TargetRegisterInfo *TRI) { if (UsesWinAAPCS) - return invalidateWindowsRegisterPairing(Reg1, Reg2, NeedsWinCFI, IsFirst, + return invalidateWindowsRegisterPairing(SpillExtendedVolatile, SpillCount, + Reg1, Reg2, NeedsWinCFI, IsFirst, TRI); // If we need to store the frame record, don't pair any register @@ -1589,8 +1618,8 @@ static bool invalidateRegisterPairing(unsigned Reg1, unsigned Reg2, namespace { struct RegPairInfo { - unsigned Reg1 = AArch64::NoRegister; - unsigned Reg2 = AArch64::NoRegister; + Register Reg1; + Register Reg2; int FrameIdx; int Offset; enum RegType { GPR, FPR64, FPR128, PPR, ZPR, VG } Type; @@ -1598,21 +1627,21 @@ struct RegPairInfo { RegPairInfo() = default; - bool isPaired() const { return Reg2 != AArch64::NoRegister; } + bool isPaired() const { return Reg2.isValid(); } bool isScalable() const { return Type == PPR || Type == ZPR; } }; } // end anonymous namespace -unsigned findFreePredicateReg(BitVector &SavedRegs) { +MCRegister findFreePredicateReg(BitVector &SavedRegs) { for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) { if (SavedRegs.test(PReg)) { unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0; - return PNReg; + return MCRegister(PNReg); } } - return AArch64::NoRegister; + return MCRegister(); } // The multivector LD/ST are available only for SME or SVE2p1 targets @@ -1673,6 +1702,19 @@ void computeCalleeSaveRegisterPairs(const AArch64FrameLowering &AFL, } bool FPAfterSVECalleeSaves = IsWindows && AFI->getSVECalleeSavedStackSize(); + // Windows AAPCS has x9-x15 as volatile registers, x16-x17 as intra-procedural + // scratch, x18 as platform reserved. However, clang has extended calling + // convensions such as preserve_most and preserve_all which treat these as + // CSR. As such, the ARM64 unwind uOPs bias registers by 19. We use ARM64EC + // uOPs which have separate restrictions. We need to check for that. + // + // NOTE: we currently do not account for the D registers as LLVM does not + // support non-ABI compliant D register spills. + bool SpillExtendedVolatile = + IsWindows && llvm::any_of(CSI, [](const CalleeSavedInfo &CSI) { + const auto &Reg = CSI.getReg(); + return Reg >= AArch64::X0 && Reg <= AArch64::X18; + }); int ZPRByteOffset = 0; int PPRByteOffset = 0; @@ -1734,17 +1776,19 @@ void computeCalleeSaveRegisterPairs(const AArch64FrameLowering &AFL, if (unsigned(i + RegInc) < Count && !HasCSHazardPadding) { MCRegister NextReg = CSI[i + RegInc].getReg(); bool IsFirst = i == FirstReg; + unsigned SpillCount = NeedsWinCFI ? FirstReg - i : i; switch (RPI.Type) { case RegPairInfo::GPR: if (AArch64::GPR64RegClass.contains(NextReg) && - !invalidateRegisterPairing(RPI.Reg1, NextReg, IsWindows, - NeedsWinCFI, NeedsFrameRecord, IsFirst, - TRI)) + !invalidateRegisterPairing( + SpillExtendedVolatile, SpillCount, RPI.Reg1, NextReg, IsWindows, + NeedsWinCFI, NeedsFrameRecord, IsFirst, TRI)) RPI.Reg2 = NextReg; break; case RegPairInfo::FPR64: if (AArch64::FPR64RegClass.contains(NextReg) && - !invalidateWindowsRegisterPairing(RPI.Reg1, NextReg, NeedsWinCFI, + !invalidateWindowsRegisterPairing(SpillExtendedVolatile, SpillCount, + RPI.Reg1, NextReg, NeedsWinCFI, IsFirst, TRI)) RPI.Reg2 = NextReg; break; @@ -1930,8 +1974,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( } bool PTrueCreated = false; for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) { - unsigned Reg1 = RPI.Reg1; - unsigned Reg2 = RPI.Reg2; + Register Reg1 = RPI.Reg1; + Register Reg2 = RPI.Reg2; unsigned StrOpc; // Issue sequence of spills for cs regs. The first spill may be converted @@ -1967,7 +2011,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( break; } - unsigned X0Scratch = AArch64::NoRegister; + Register X0Scratch; auto RestoreX0 = make_scope_exit([&] { if (X0Scratch != AArch64::NoRegister) BuildMI(MBB, MI, DL, TII.get(TargetOpcode::COPY), AArch64::X0) @@ -2009,11 +2053,15 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( } } - LLVM_DEBUG(dbgs() << "CSR spill: (" << printReg(Reg1, TRI); - if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI); - dbgs() << ") -> fi#(" << RPI.FrameIdx; - if (RPI.isPaired()) dbgs() << ", " << RPI.FrameIdx + 1; - dbgs() << ")\n"); + LLVM_DEBUG({ + dbgs() << "CSR spill: (" << printReg(Reg1, TRI); + if (RPI.isPaired()) + dbgs() << ", " << printReg(Reg2, TRI); + dbgs() << ") -> fi#(" << RPI.FrameIdx; + if (RPI.isPaired()) + dbgs() << ", " << RPI.FrameIdx + 1; + dbgs() << ")\n"; + }); assert((!NeedsWinCFI || !(Reg1 == AArch64::LR && Reg2 == AArch64::FP)) && "Windows unwdinding requires a consecutive (FP,LR) pair"); @@ -2143,8 +2191,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( bool PTrueCreated = false; for (const RegPairInfo &RPI : RegPairs) { - unsigned Reg1 = RPI.Reg1; - unsigned Reg2 = RPI.Reg2; + Register Reg1 = RPI.Reg1; + Register Reg2 = RPI.Reg2; // Issue sequence of restores for cs regs. The last restore may be converted // to a post-increment load later by emitEpilogue if the callee-save stack @@ -2176,11 +2224,15 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( case RegPairInfo::VG: continue; } - LLVM_DEBUG(dbgs() << "CSR restore: (" << printReg(Reg1, TRI); - if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI); - dbgs() << ") -> fi#(" << RPI.FrameIdx; - if (RPI.isPaired()) dbgs() << ", " << RPI.FrameIdx + 1; - dbgs() << ")\n"); + LLVM_DEBUG({ + dbgs() << "CSR restore: (" << printReg(Reg1, TRI); + if (RPI.isPaired()) + dbgs() << ", " << printReg(Reg2, TRI); + dbgs() << ") -> fi#(" << RPI.FrameIdx; + if (RPI.isPaired()) + dbgs() << ", " << RPI.FrameIdx + 1; + dbgs() << ")\n"; + }); // Windows unwind codes require consecutive registers if registers are // paired. Make the switch here, so that the code below will save (x,x+1) @@ -2357,36 +2409,41 @@ void AArch64FrameLowering::determineStackHazardSlot( AFI->setStackHazardSlotIndex(ID); } - // Determine if we should use SplitSVEObjects. This should only be used if - // there's a possibility of a stack hazard between PPRs and ZPRs or FPRs. + if (!AFI->hasStackHazardSlotIndex()) + return; + if (SplitSVEObjects) { - if (!HasPPRCSRs && !HasPPRStackObjects) { - LLVM_DEBUG( - dbgs() << "Not using SplitSVEObjects as no PPRs are on the stack\n"); + CallingConv::ID CC = MF.getFunction().getCallingConv(); + if (AFI->isSVECC() || CC == CallingConv::AArch64_SVE_VectorCall) { + AFI->setSplitSVEObjects(true); + LLVM_DEBUG(dbgs() << "Using SplitSVEObjects for SVE CC function\n"); return; } - if (!HasFPRCSRs && !HasFPRStackObjects) { + // We only use SplitSVEObjects in non-SVE CC functions if there's a + // possibility of a stack hazard between PPRs and ZPRs/FPRs. + LLVM_DEBUG(dbgs() << "Determining if SplitSVEObjects should be used in " + "non-SVE CC function...\n"); + + // If another calling convention is explicitly set FPRs can't be promoted to + // ZPR callee-saves. + if (!is_contained({CallingConv::C, CallingConv::Fast}, CC)) { LLVM_DEBUG( dbgs() - << "Not using SplitSVEObjects as no FPRs or ZPRs are on the stack\n"); + << "Calling convention is not supported with SplitSVEObjects\n"); return; } - const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); - if (MFI.hasVarSizedObjects() || TRI->hasStackRealignment(MF)) { - LLVM_DEBUG(dbgs() << "SplitSVEObjects is not supported with variable " - "sized objects or realignment\n"); + if (!HasPPRCSRs && !HasPPRStackObjects) { + LLVM_DEBUG( + dbgs() << "Not using SplitSVEObjects as no PPRs are on the stack\n"); return; } - // If another calling convention is explicitly set FPRs can't be promoted to - // ZPR callee-saves. - if (!is_contained({CallingConv::C, CallingConv::Fast, - CallingConv::AArch64_SVE_VectorCall}, - MF.getFunction().getCallingConv())) { + if (!HasFPRCSRs && !HasFPRStackObjects) { LLVM_DEBUG( - dbgs() << "Calling convention is not supported with SplitSVEObjects"); + dbgs() + << "Not using SplitSVEObjects as no FPRs or ZPRs are on the stack\n"); return; } @@ -2395,6 +2452,7 @@ void AArch64FrameLowering::determineStackHazardSlot( assert(Subtarget.isSVEorStreamingSVEAvailable() && "Expected SVE to be available for PPRs"); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); // With SplitSVEObjects the CS hazard padding is placed between the // PPRs and ZPRs. If there are any FPR CS there would be a hazard between // them and the CS GRPs. Avoid this by promoting all FPR CS to ZPRs. @@ -2435,8 +2493,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>(); TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS); - const AArch64RegisterInfo *RegInfo = static_cast<const AArch64RegisterInfo *>( - MF.getSubtarget().getRegisterInfo()); + const AArch64RegisterInfo *RegInfo = Subtarget.getRegisterInfo(); AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); unsigned UnspilledCSGPR = AArch64::NoRegister; unsigned UnspilledCSGPRPaired = AArch64::NoRegister; @@ -2444,9 +2501,8 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, MachineFrameInfo &MFI = MF.getFrameInfo(); const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs(); - unsigned BasePointerReg = RegInfo->hasBasePointer(MF) - ? RegInfo->getBaseRegister() - : (unsigned)AArch64::NoRegister; + MCRegister BasePointerReg = + RegInfo->hasBasePointer(MF) ? RegInfo->getBaseRegister() : MCRegister(); unsigned ExtraCSSpill = 0; bool HasUnpairedGPR64 = false; @@ -2456,7 +2512,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, // Figure out which callee-saved registers to save/restore. for (unsigned i = 0; CSRegs[i]; ++i) { - const unsigned Reg = CSRegs[i]; + const MCRegister Reg = CSRegs[i]; // Add the base pointer register to SavedRegs if it is callee-save. if (Reg == BasePointerReg) @@ -2470,7 +2526,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, } bool RegUsed = SavedRegs.test(Reg); - unsigned PairedReg = AArch64::NoRegister; + MCRegister PairedReg; const bool RegIsGPR64 = AArch64::GPR64RegClass.contains(Reg); if (RegIsGPR64 || AArch64::FPR64RegClass.contains(Reg) || AArch64::FPR128RegClass.contains(Reg)) { @@ -2522,8 +2578,8 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); // Find a suitable predicate register for the multi-vector spill/fill // instructions. - unsigned PnReg = findFreePredicateReg(SavedRegs); - if (PnReg != AArch64::NoRegister) + MCRegister PnReg = findFreePredicateReg(SavedRegs); + if (PnReg.isValid()) AFI->setPredicateRegForFillSpill(PnReg); // If no free callee-save has been found assign one. if (!AFI->getPredicateRegForFillSpill() && @@ -2558,7 +2614,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, unsigned PPRCSStackSize = 0; const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); for (unsigned Reg : SavedRegs.set_bits()) { - auto *RC = TRI->getMinimalPhysRegClass(Reg); + auto *RC = TRI->getMinimalPhysRegClass(MCRegister(Reg)); assert(RC && "expected register class!"); auto SpillSize = TRI->getSpillSize(*RC); bool IsZPR = AArch64::ZPRRegClass.contains(Reg); @@ -2600,7 +2656,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, LLVM_DEBUG({ dbgs() << "*** determineCalleeSaves\nSaved CSRs:"; for (unsigned Reg : SavedRegs.set_bits()) - dbgs() << ' ' << printReg(Reg, RegInfo); + dbgs() << ' ' << printReg(MCRegister(Reg), RegInfo); dbgs() << "\n"; }); @@ -2699,8 +2755,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, bool AArch64FrameLowering::assignCalleeSavedSpillSlots( MachineFunction &MF, const TargetRegisterInfo *RegInfo, - std::vector<CalleeSavedInfo> &CSI, unsigned &MinCSFrameIndex, - unsigned &MaxCSFrameIndex) const { + std::vector<CalleeSavedInfo> &CSI) const { bool NeedsWinCFI = needsWinCFI(MF); unsigned StackHazardSize = getStackHazardSize(MF); // To match the canonical windows frame layout, reverse the list of @@ -2723,10 +2778,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots( if (UsesWinAAPCS && hasFP(MF) && AFI->hasSwiftAsyncContext()) { int FrameIdx = MFI.CreateStackObject(8, Align(16), true); AFI->setSwiftAsyncContextFrameIdx(FrameIdx); - if ((unsigned)FrameIdx < MinCSFrameIndex) - MinCSFrameIndex = FrameIdx; - if ((unsigned)FrameIdx > MaxCSFrameIndex) - MaxCSFrameIndex = FrameIdx; + MFI.setIsCalleeSavedObjectIndex(FrameIdx, true); } // Insert VG into the list of CSRs, immediately before LR if saved. @@ -2756,31 +2808,21 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots( LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << HazardSlotIndex << "\n"); AFI->setStackHazardCSRSlotIndex(HazardSlotIndex); - if ((unsigned)HazardSlotIndex < MinCSFrameIndex) - MinCSFrameIndex = HazardSlotIndex; - if ((unsigned)HazardSlotIndex > MaxCSFrameIndex) - MaxCSFrameIndex = HazardSlotIndex; + MFI.setIsCalleeSavedObjectIndex(HazardSlotIndex, true); } unsigned Size = RegInfo->getSpillSize(*RC); Align Alignment(RegInfo->getSpillAlign(*RC)); int FrameIdx = MFI.CreateStackObject(Size, Alignment, true); CS.setFrameIdx(FrameIdx); - - if ((unsigned)FrameIdx < MinCSFrameIndex) - MinCSFrameIndex = FrameIdx; - if ((unsigned)FrameIdx > MaxCSFrameIndex) - MaxCSFrameIndex = FrameIdx; + MFI.setIsCalleeSavedObjectIndex(FrameIdx, true); // Grab 8 bytes below FP for the extended asynchronous frame info. if (hasFP(MF) && AFI->hasSwiftAsyncContext() && !UsesWinAAPCS && Reg == AArch64::FP) { FrameIdx = MFI.CreateStackObject(8, Alignment, true); AFI->setSwiftAsyncContextFrameIdx(FrameIdx); - if ((unsigned)FrameIdx < MinCSFrameIndex) - MinCSFrameIndex = FrameIdx; - if ((unsigned)FrameIdx > MaxCSFrameIndex) - MaxCSFrameIndex = FrameIdx; + MFI.setIsCalleeSavedObjectIndex(FrameIdx, true); } LastReg = Reg; } @@ -2792,10 +2834,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots( LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << HazardSlotIndex << "\n"); AFI->setStackHazardCSRSlotIndex(HazardSlotIndex); - if ((unsigned)HazardSlotIndex < MinCSFrameIndex) - MinCSFrameIndex = HazardSlotIndex; - if ((unsigned)HazardSlotIndex > MaxCSFrameIndex) - MaxCSFrameIndex = HazardSlotIndex; + MFI.setIsCalleeSavedObjectIndex(HazardSlotIndex, true); } return true; @@ -2912,9 +2951,8 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF, } for (int FI = 0, E = MFI.getObjectIndexEnd(); FI != E; ++FI) { - if (FI == StackProtectorFI || MFI.isDeadObjectIndex(FI)) - continue; - if (MaxCSFrameIndex >= FI && FI >= MinCSFrameIndex) + if (FI == StackProtectorFI || MFI.isDeadObjectIndex(FI) || + MFI.isCalleeSavedObjectIndex(FI)) continue; if (MFI.getStackID(FI) != TargetStackID::ScalableVector && diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.h b/llvm/lib/Target/AArch64/AArch64FrameLowering.h index 32a9bd8..97db18d 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.h +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.h @@ -88,11 +88,10 @@ public: bool hasReservedCallFrame(const MachineFunction &MF) const override; - bool assignCalleeSavedSpillSlots(MachineFunction &MF, - const TargetRegisterInfo *TRI, - std::vector<CalleeSavedInfo> &CSI, - unsigned &MinCSFrameIndex, - unsigned &MaxCSFrameIndex) const override; + bool + assignCalleeSavedSpillSlots(MachineFunction &MF, + const TargetRegisterInfo *TRI, + std::vector<CalleeSavedInfo> &CSI) const override; void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs, RegScavenger *RS) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index e7b2d20..54ad7be 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -513,6 +513,9 @@ private: bool SelectAnyPredicate(SDValue N); bool SelectCmpBranchUImm6Operand(SDNode *P, SDValue N, SDValue &Imm); + + template <bool MatchCBB> + bool SelectCmpBranchExtOperand(SDValue N, SDValue &Reg, SDValue &ExtType); }; class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy { @@ -1554,7 +1557,10 @@ void AArch64DAGToDAGISel::SelectPtrauthAuth(SDNode *N) { extractPtrauthBlendDiscriminators(AUTDisc, CurDAG); if (!Subtarget->isX16X17Safer()) { - SDValue Ops[] = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + std::vector<SDValue> Ops = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + // Copy deactivation symbol if present. + if (N->getNumOperands() > 4) + Ops.push_back(N->getOperand(4)); SDNode *AUT = CurDAG->getMachineNode(AArch64::AUTxMxN, DL, MVT::i64, MVT::i64, Ops); @@ -4400,43 +4406,46 @@ bool AArch64DAGToDAGISel::SelectSVEArithImm(SDValue N, MVT VT, SDValue &Imm) { bool AArch64DAGToDAGISel::SelectSVELogicalImm(SDValue N, MVT VT, SDValue &Imm, bool Invert) { - if (auto CNode = dyn_cast<ConstantSDNode>(N)) { - uint64_t ImmVal = CNode->getZExtValue(); - SDLoc DL(N); + uint64_t ImmVal; + if (auto CI = dyn_cast<ConstantSDNode>(N)) + ImmVal = CI->getZExtValue(); + else if (auto CFP = dyn_cast<ConstantFPSDNode>(N)) + ImmVal = CFP->getValueAPF().bitcastToAPInt().getZExtValue(); + else + return false; - if (Invert) - ImmVal = ~ImmVal; + if (Invert) + ImmVal = ~ImmVal; - // Shift mask depending on type size. - switch (VT.SimpleTy) { - case MVT::i8: - ImmVal &= 0xFF; - ImmVal |= ImmVal << 8; - ImmVal |= ImmVal << 16; - ImmVal |= ImmVal << 32; - break; - case MVT::i16: - ImmVal &= 0xFFFF; - ImmVal |= ImmVal << 16; - ImmVal |= ImmVal << 32; - break; - case MVT::i32: - ImmVal &= 0xFFFFFFFF; - ImmVal |= ImmVal << 32; - break; - case MVT::i64: - break; - default: - llvm_unreachable("Unexpected type"); - } - - uint64_t encoding; - if (AArch64_AM::processLogicalImmediate(ImmVal, 64, encoding)) { - Imm = CurDAG->getTargetConstant(encoding, DL, MVT::i64); - return true; - } + // Shift mask depending on type size. + switch (VT.SimpleTy) { + case MVT::i8: + ImmVal &= 0xFF; + ImmVal |= ImmVal << 8; + ImmVal |= ImmVal << 16; + ImmVal |= ImmVal << 32; + break; + case MVT::i16: + ImmVal &= 0xFFFF; + ImmVal |= ImmVal << 16; + ImmVal |= ImmVal << 32; + break; + case MVT::i32: + ImmVal &= 0xFFFFFFFF; + ImmVal |= ImmVal << 32; + break; + case MVT::i64: + break; + default: + llvm_unreachable("Unexpected type"); } - return false; + + uint64_t encoding; + if (!AArch64_AM::processLogicalImmediate(ImmVal, 64, encoding)) + return false; + + Imm = CurDAG->getTargetConstant(encoding, SDLoc(N), MVT::i64); + return true; } // SVE shift intrinsics allow shift amounts larger than the element's bitwidth. @@ -6220,6 +6229,26 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { AArch64::FMINNM_VG4_4ZZ_S, AArch64::FMINNM_VG4_4ZZ_D})) SelectDestructiveMultiIntrinsic(Node, 4, false, Op); return; + case Intrinsic::aarch64_sve_fscale_single_x4: + SelectDestructiveMultiIntrinsic(Node, 4, false, AArch64::BFSCALE_4ZZ); + return; + case Intrinsic::aarch64_sve_fscale_single_x2: + SelectDestructiveMultiIntrinsic(Node, 2, false, AArch64::BFSCALE_2ZZ); + return; + case Intrinsic::aarch64_sve_fmul_single_x4: + if (auto Op = SelectOpcodeFromVT<SelectTypeKind::FP>( + Node->getValueType(0), + {AArch64::BFMUL_4ZZ, AArch64::FMUL_4ZZ_H, AArch64::FMUL_4ZZ_S, + AArch64::FMUL_4ZZ_D})) + SelectDestructiveMultiIntrinsic(Node, 4, false, Op); + return; + case Intrinsic::aarch64_sve_fmul_single_x2: + if (auto Op = SelectOpcodeFromVT<SelectTypeKind::FP>( + Node->getValueType(0), + {AArch64::BFMUL_2ZZ, AArch64::FMUL_2ZZ_H, AArch64::FMUL_2ZZ_S, + AArch64::FMUL_2ZZ_D})) + SelectDestructiveMultiIntrinsic(Node, 2, false, Op); + return; case Intrinsic::aarch64_sve_fmaxnm_x2: if (auto Op = SelectOpcodeFromVT<SelectTypeKind::FP>( Node->getValueType(0), @@ -6248,6 +6277,26 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { AArch64::FMINNM_VG4_4Z4Z_S, AArch64::FMINNM_VG4_4Z4Z_D})) SelectDestructiveMultiIntrinsic(Node, 4, true, Op); return; + case Intrinsic::aarch64_sve_fscale_x4: + SelectDestructiveMultiIntrinsic(Node, 4, true, AArch64::BFSCALE_4Z4Z); + return; + case Intrinsic::aarch64_sve_fscale_x2: + SelectDestructiveMultiIntrinsic(Node, 2, true, AArch64::BFSCALE_2Z2Z); + return; + case Intrinsic::aarch64_sve_fmul_x4: + if (auto Op = SelectOpcodeFromVT<SelectTypeKind::FP>( + Node->getValueType(0), + {AArch64::BFMUL_4Z4Z, AArch64::FMUL_4Z4Z_H, AArch64::FMUL_4Z4Z_S, + AArch64::FMUL_4Z4Z_D})) + SelectDestructiveMultiIntrinsic(Node, 4, true, Op); + return; + case Intrinsic::aarch64_sve_fmul_x2: + if (auto Op = SelectOpcodeFromVT<SelectTypeKind::FP>( + Node->getValueType(0), + {AArch64::BFMUL_2Z2Z, AArch64::FMUL_2Z2Z_H, AArch64::FMUL_2Z2Z_S, + AArch64::FMUL_2Z2Z_D})) + SelectDestructiveMultiIntrinsic(Node, 2, true, Op); + return; case Intrinsic::aarch64_sve_fcvtzs_x2: SelectCVTIntrinsic(Node, 2, AArch64::FCVTZS_2Z2Z_StoS); return; @@ -7697,3 +7746,31 @@ bool AArch64DAGToDAGISel::SelectCmpBranchUImm6Operand(SDNode *P, SDValue N, return false; } + +template <bool MatchCBB> +bool AArch64DAGToDAGISel::SelectCmpBranchExtOperand(SDValue N, SDValue &Reg, + SDValue &ExtType) { + + // Use an invalid shift-extend value to indicate we don't need to extend later + if (N.getOpcode() == ISD::AssertZext || N.getOpcode() == ISD::AssertSext) { + EVT Ty = cast<VTSDNode>(N.getOperand(1))->getVT(); + if (Ty != (MatchCBB ? MVT::i8 : MVT::i16)) + return false; + Reg = N.getOperand(0); + ExtType = CurDAG->getSignedTargetConstant(AArch64_AM::InvalidShiftExtend, + SDLoc(N), MVT::i32); + return true; + } + + AArch64_AM::ShiftExtendType ET = getExtendTypeForNode(N); + + if ((MatchCBB && (ET == AArch64_AM::UXTB || ET == AArch64_AM::SXTB)) || + (!MatchCBB && (ET == AArch64_AM::UXTH || ET == AArch64_AM::SXTH))) { + Reg = N.getOperand(0); + ExtType = + CurDAG->getTargetConstant(getExtendEncoding(ET), SDLoc(N), MVT::i32); + return true; + } + + return false; +} diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d16b116..30eb190 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16,11 +16,11 @@ #include "AArch64MachineFunctionInfo.h" #include "AArch64PerfectShuffle.h" #include "AArch64RegisterInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" #include "AArch64TargetMachine.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "Utils/AArch64BaseInfo.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -50,6 +50,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetCallingConv.h" @@ -104,7 +105,6 @@ #include <vector> using namespace llvm; -using namespace llvm::PatternMatch; #define DEBUG_TYPE "aarch64-lower" @@ -387,7 +387,7 @@ extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) { AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) - : TargetLowering(TM), Subtarget(&STI) { + : TargetLowering(TM, STI), Subtarget(&STI) { // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so // we have to make something up. Arbitrarily, choose ZeroOrOne. setBooleanContents(ZeroOrOneBooleanContent); @@ -445,6 +445,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass); + // Add sve predicate as counter type + addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass); + // Add legal sve data types addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass); @@ -473,15 +476,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } } - if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { - addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass); - setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1); - setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1); - - setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom); - setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand); - } - // Compute derived properties from the register classes computeRegisterProperties(Subtarget->getRegisterInfo()); @@ -536,7 +530,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FREM, MVT::f32, Expand); setOperationAction(ISD::FREM, MVT::f64, Expand); - setOperationAction(ISD::FREM, MVT::f80, Expand); setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand); @@ -1052,15 +1045,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Lower READCYCLECOUNTER using an mrs from CNTVCT_EL0. setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal); - if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr && - getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) { - // Issue __sincos_stret if available. - setOperationAction(ISD::FSINCOS, MVT::f64, Custom); - setOperationAction(ISD::FSINCOS, MVT::f32, Custom); - } else { - setOperationAction(ISD::FSINCOS, MVT::f64, Expand); - setOperationAction(ISD::FSINCOS, MVT::f32, Expand); - } + // Issue __sincos_stret if available. + setOperationAction(ISD::FSINCOS, MVT::f64, Expand); + setOperationAction(ISD::FSINCOS, MVT::f32, Expand); // Make floating-point constants legal for the large code model, so they don't // become loads from the constant pool. @@ -1180,6 +1167,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::SHL); setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE); + setTargetDAGCombine(ISD::CTPOP); // In case of strict alignment, avoid an excessive number of byte wide stores. MaxStoresPerMemsetOptSize = 8; @@ -1438,12 +1426,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::v2i16, Custom); setOperationAction(ISD::BITCAST, MVT::v4i8, Custom); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom); // ADDP custom lowering for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) @@ -1523,6 +1523,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32}) setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom); + + for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) + setOperationAction(ISD::FMA, VT, Custom); } if (Subtarget->isSVEorStreamingSVEAvailable()) { @@ -1590,6 +1593,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::AVGCEILS, VT, Custom); setOperationAction(ISD::AVGCEILU, VT, Custom); + setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom); + if (!Subtarget->isLittleEndian()) setOperationAction(ISD::BITCAST, VT, Custom); @@ -1614,6 +1621,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); + // Promote predicate as counter load/stores to standard predicates. + setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1); + setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1); + + // Predicate as counter legalization actions. + setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom); + setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); @@ -1774,17 +1789,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); + } - if (Subtarget->hasSVEB16B16() && - Subtarget->isNonStreamingSVEorSME2Available()) { - setOperationAction(ISD::FADD, VT, Legal); + if (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available()) { + // Note: Use SVE for bfloat16 operations when +sve-b16b16 is available. + for (auto VT : {MVT::v4bf16, MVT::v8bf16, MVT::nxv2bf16, MVT::nxv4bf16, + MVT::nxv8bf16}) { + setOperationAction(ISD::FADD, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMAXNUM, VT, Custom); setOperationAction(ISD::FMINIMUM, VT, Custom); setOperationAction(ISD::FMINNUM, VT, Custom); - setOperationAction(ISD::FMUL, VT, Legal); - setOperationAction(ISD::FSUB, VT, Legal); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); } } @@ -1800,22 +1819,37 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (!Subtarget->hasSVEB16B16() || !Subtarget->isNonStreamingSVEorSME2Available()) { - for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM, - ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) { - setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32); - setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); - setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); + for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { + MVT PromotedVT = VT.changeVectorElementType(MVT::f32); + setOperationPromotedToType(ISD::FADD, VT, PromotedVT); + setOperationPromotedToType(ISD::FMA, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FSUB, VT, PromotedVT); + + if (VT != MVT::nxv2bf16 && Subtarget->hasBF16()) + setOperationAction(ISD::FMUL, VT, Custom); + else + setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); } + + if (Subtarget->hasBF16() && Subtarget->isNeonAvailable()) + setOperationAction(ISD::FMUL, MVT::v8bf16, Custom); } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); - // NEON doesn't support integer divides, but SVE does + // A number of operations like MULH and integer divides are not supported by + // NEON but are available in SVE. for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) { setOperationAction(ISD::SDIV, VT, Custom); setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); } // NEON doesn't support 64-bit vector integer muls, but SVE does. @@ -1852,10 +1886,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTLZ, MVT::v1i64, Custom); setOperationAction(ISD::CTLZ, MVT::v2i64, Custom); setOperationAction(ISD::CTTZ, MVT::v1i64, Custom); - setOperationAction(ISD::MULHS, MVT::v1i64, Custom); - setOperationAction(ISD::MULHS, MVT::v2i64, Custom); - setOperationAction(ISD::MULHU, MVT::v1i64, Custom); - setOperationAction(ISD::MULHU, MVT::v2i64, Custom); setOperationAction(ISD::SMAX, MVT::v1i64, Custom); setOperationAction(ISD::SMAX, MVT::v2i64, Custom); setOperationAction(ISD::SMIN, MVT::v1i64, Custom); @@ -1877,8 +1907,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_AND, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); - setOperationAction(ISD::MULHS, VT, Custom); - setOperationAction(ISD::MULHU, VT, Custom); } // Use SVE for vectors with more than 2 elements. @@ -1921,6 +1949,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal); setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal); } + + // Handle floating-point partial reduction + if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32, + MVT::nxv8f16, Legal); + // We can use SVE2p1 fdot to emulate the fixed-length variant. + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32, + MVT::v8f16, Custom); + } } // Handle non-aliasing elements mask @@ -1956,10 +1993,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // We can lower types that have <vscale x {2|4}> elements to compact. - for (auto VT : - {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32, - MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32}) + for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, + MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, + MVT::nxv4i32, MVT::nxv4f32}) { setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom); + // Use a custom lowering for masked stores that could be a supported + // compressing store. Note: These types still use the normal (Legal) + // lowering for non-compressing masked stores. + setOperationAction(ISD::MSTORE, VT, Custom); + } // If we have SVE, we can use SVE logic for legal (or smaller than legal) // NEON vectors in the lowest bits of the SVE register. @@ -2288,6 +2330,11 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { MVT::getVectorVT(MVT::i8, NumElts * 8), Custom); } + if (Subtarget->hasSVE2p1() && VT.getVectorElementType() == MVT::f32) { + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT, + MVT::getVectorVT(MVT::f16, NumElts * 2), Custom); + } + // Lower fixed length vector operations to scalable equivalents. setOperationAction(ISD::ABDS, VT, Default); setOperationAction(ISD::ABDU, VT, Default); @@ -2547,7 +2594,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant( return false; // Exit early if we demand all bits. - if (DemandedBits.popcount() == Size) + if (DemandedBits.isAllOnes()) return false; unsigned NewOpc; @@ -3863,22 +3910,30 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, /// \param MustBeFirst Set to true if this subtree needs to be negated and we /// cannot do the negation naturally. We are required to /// emit the subtree first in this case. +/// \param PreferFirst Set to true if processing this subtree first may +/// result in more efficient code. /// \param WillNegate Is true if are called when the result of this /// subexpression must be negated. This happens when the /// outer expression is an OR. We can use this fact to know /// that we have a double negation (or (or ...) ...) that /// can be implemented for free. -static bool canEmitConjunction(const SDValue Val, bool &CanNegate, - bool &MustBeFirst, bool WillNegate, +static bool canEmitConjunction(SelectionDAG &DAG, const SDValue Val, + bool &CanNegate, bool &MustBeFirst, + bool &PreferFirst, bool WillNegate, unsigned Depth = 0) { if (!Val.hasOneUse()) return false; unsigned Opcode = Val->getOpcode(); if (Opcode == ISD::SETCC) { - if (Val->getOperand(0).getValueType() == MVT::f128) + EVT VT = Val->getOperand(0).getValueType(); + if (VT == MVT::f128) return false; CanNegate = true; MustBeFirst = false; + // Designate this operation as a preferred first operation if the result + // of a SUB operation can be reused. + PreferFirst = DAG.doesNodeExist(ISD::SUB, DAG.getVTList(VT), + {Val->getOperand(0), Val->getOperand(1)}); return true; } // Protect against exponential runtime and stack overflow. @@ -3890,11 +3945,15 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate, SDValue O1 = Val->getOperand(1); bool CanNegateL; bool MustBeFirstL; - if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1)) + bool PreferFirstL; + if (!canEmitConjunction(DAG, O0, CanNegateL, MustBeFirstL, PreferFirstL, + IsOR, Depth + 1)) return false; bool CanNegateR; bool MustBeFirstR; - if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1)) + bool PreferFirstR; + if (!canEmitConjunction(DAG, O1, CanNegateR, MustBeFirstR, PreferFirstR, + IsOR, Depth + 1)) return false; if (MustBeFirstL && MustBeFirstR) @@ -3917,6 +3976,7 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate, CanNegate = false; MustBeFirst = MustBeFirstL || MustBeFirstR; } + PreferFirst = PreferFirstL || PreferFirstR; return true; } return false; @@ -3978,19 +4038,25 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val, SDValue LHS = Val->getOperand(0); bool CanNegateL; bool MustBeFirstL; - bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR); + bool PreferFirstL; + bool ValidL = canEmitConjunction(DAG, LHS, CanNegateL, MustBeFirstL, + PreferFirstL, IsOR); assert(ValidL && "Valid conjunction/disjunction tree"); (void)ValidL; SDValue RHS = Val->getOperand(1); bool CanNegateR; bool MustBeFirstR; - bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR); + bool PreferFirstR; + bool ValidR = canEmitConjunction(DAG, RHS, CanNegateR, MustBeFirstR, + PreferFirstR, IsOR); assert(ValidR && "Valid conjunction/disjunction tree"); (void)ValidR; - // Swap sub-tree that must come first to the right side. - if (MustBeFirstL) { + bool ShouldFirstL = PreferFirstL && !PreferFirstR && !MustBeFirstR; + + // Swap sub-tree that must or should come first to the right side. + if (MustBeFirstL || ShouldFirstL) { assert(!MustBeFirstR && "Valid conjunction/disjunction tree"); std::swap(LHS, RHS); std::swap(CanNegateL, CanNegateR); @@ -4046,7 +4112,9 @@ static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val, AArch64CC::CondCode &OutCC) { bool DummyCanNegate; bool DummyMustBeFirst; - if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false)) + bool DummyPreferFirst; + if (!canEmitConjunction(DAG, Val, DummyCanNegate, DummyMustBeFirst, + DummyPreferFirst, false)) return SDValue(); return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL); @@ -4492,6 +4560,26 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG, return DAG.getMergeValues({Sum, OutFlag}, DL); } +static SDValue lowerIntNeonIntrinsic(SDValue Op, unsigned Opcode, + SelectionDAG &DAG) { + SDLoc DL(Op); + auto getFloatVT = [](EVT VT) { + assert((VT == MVT::i32 || VT == MVT::i64) && "Unexpected VT"); + return VT == MVT::i32 ? MVT::f32 : MVT::f64; + }; + auto bitcastToFloat = [&](SDValue Val) { + return DAG.getBitcast(getFloatVT(Val.getValueType()), Val); + }; + SmallVector<SDValue, 2> NewOps; + NewOps.reserve(Op.getNumOperands() - 1); + + for (unsigned I = 1, E = Op.getNumOperands(); I < E; ++I) + NewOps.push_back(bitcastToFloat(Op.getOperand(I))); + EVT OrigVT = Op.getValueType(); + SDValue OpNode = DAG.getNode(Opcode, DL, getFloatVT(OrigVT), NewOps); + return DAG.getBitcast(OrigVT, OpNode); +} + static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { // Let legalize expand this if it isn't a legal type yet. if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType())) @@ -5346,35 +5434,6 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op, return SDValue(); } -SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op, - SelectionDAG &DAG) const { - // For iOS, we want to call an alternative entry point: __sincos_stret, - // which returns the values in two S / D registers. - SDLoc DL(Op); - SDValue Arg = Op.getOperand(0); - EVT ArgVT = Arg.getValueType(); - Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext()); - - ArgListTy Args; - Args.emplace_back(Arg, ArgTy); - - RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64 - : RTLIB::SINCOS_STRET_F32; - const char *LibcallName = getLibcallName(LC); - SDValue Callee = - DAG.getExternalSymbol(LibcallName, getPointerTy(DAG.getDataLayout())); - - StructType *RetTy = StructType::get(ArgTy, ArgTy); - TargetLowering::CallLoweringInfo CLI(DAG); - CallingConv::ID CC = getLibcallCallingConv(LC); - CLI.setDebugLoc(DL) - .setChain(DAG.getEntryNode()) - .setLibCallee(CC, RetTy, Callee, std::move(Args)); - - std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI); - return CallResult.first; -} - static MVT getSVEContainerType(EVT ContentTy); SDValue @@ -5578,9 +5637,10 @@ SDValue AArch64TargetLowering::LowerGET_ROUNDING(SDValue Op, SDLoc DL(Op); SDValue Chain = Op.getOperand(0); - SDValue FPCR_64 = DAG.getNode( - ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, - {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)}); + SDValue FPCR_64 = + DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, + {Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, + MVT::i64)}); Chain = FPCR_64.getValue(1); SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPCR_64); SDValue FltRounds = DAG.getNode(ISD::ADD, DL, MVT::i32, FPCR_32, @@ -5666,7 +5726,8 @@ SDValue AArch64TargetLowering::LowerSET_FPMODE(SDValue Op, // Set new value of FPCR. SDValue Ops2[] = { - Chain, DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), FPCR}; + Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), + FPCR}; return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } @@ -5689,9 +5750,9 @@ SDValue AArch64TargetLowering::LowerRESET_FPMODE(SDValue Op, DAG.getConstant(AArch64::ReservedFPControlBits, DL, MVT::i64)); // Set new value of FPCR. - SDValue Ops2[] = {Chain, - DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), - FPSCRMasked}; + SDValue Ops2[] = { + Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), + FPSCRMasked}; return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } @@ -5769,8 +5830,10 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { if (VT.is64BitVector()) { if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && isNullConstant(N0.getOperand(1)) && + N0.getOperand(0).getValueType().is128BitVector() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - isNullConstant(N1.getOperand(1))) { + isNullConstant(N1.getOperand(1)) && + N1.getOperand(0).getValueType().is128BitVector()) { N0 = N0.getOperand(0); N1 = N1.getOperand(0); VT = N0.getValueType(); @@ -6363,26 +6426,46 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(1).getValueType(), Op.getOperand(1), Op.getOperand(2))); return SDValue(); + case Intrinsic::aarch64_neon_sqrshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQRSHL, DAG); + case Intrinsic::aarch64_neon_sqshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSHL, DAG); + case Intrinsic::aarch64_neon_uqrshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQRSHL, DAG); + case Intrinsic::aarch64_neon_uqshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSHL, DAG); case Intrinsic::aarch64_neon_sqadd: if (Op.getValueType().isVector()) return DAG.getNode(ISD::SADDSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQADD, DAG); + case Intrinsic::aarch64_neon_sqsub: if (Op.getValueType().isVector()) return DAG.getNode(ISD::SSUBSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSUB, DAG); + case Intrinsic::aarch64_neon_uqadd: if (Op.getValueType().isVector()) return DAG.getNode(ISD::UADDSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQADD, DAG); case Intrinsic::aarch64_neon_uqsub: if (Op.getValueType().isVector()) return DAG.getNode(ISD::USUBSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSUB, DAG); + case Intrinsic::aarch64_neon_sqdmulls_scalar: + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQDMULL, DAG); case Intrinsic::aarch64_sve_whilelt: return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true, /*IsEqual=*/false); @@ -6416,9 +6499,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_lastb: return DAG.getNode(AArch64ISD::LASTB, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - case Intrinsic::aarch64_sve_rev: - return DAG.getNode(ISD::VECTOR_REVERSE, DL, Op.getValueType(), - Op.getOperand(1)); case Intrinsic::aarch64_sve_tbl: return DAG.getNode(AArch64ISD::TBL, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); @@ -6744,8 +6824,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend, return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2; } +/// Helper function to check if a small vector load can be optimized. +static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD, + const AArch64Subtarget &Subtarget) { + if (!Subtarget.isNeonAvailable()) + return false; + if (LD->isVolatile()) + return false; + + EVT MemVT = LD->getMemoryVT(); + if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16) + return false; + + Align Alignment = LD->getAlign(); + Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue()); + if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment) + return false; + + return true; +} + bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { EVT ExtVT = ExtVal.getValueType(); + // Small, illegal vectors can be extended inreg. + if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) { + if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 && + isEligibleForSmallVectorLoadOpt(Load, *Subtarget)) + return true; + } if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors()) return false; @@ -7204,12 +7310,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op, return Result; } +/// Helper function to optimize loads of extended small vectors. +/// These patterns would otherwise get scalarized into inefficient sequences. +static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) { + const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); + if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget)) + return SDValue(); + + EVT MemVT = Load->getMemoryVT(); + EVT ResVT = Load->getValueType(0); + unsigned NumElts = ResVT.getVectorNumElements(); + unsigned DstEltBits = ResVT.getScalarSizeInBits(); + unsigned SrcEltBits = MemVT.getScalarSizeInBits(); + + unsigned ExtOpcode; + switch (Load->getExtensionType()) { + case ISD::EXTLOAD: + case ISD::ZEXTLOAD: + ExtOpcode = ISD::ZERO_EXTEND; + break; + case ISD::SEXTLOAD: + ExtOpcode = ISD::SIGN_EXTEND; + break; + case ISD::NON_EXTLOAD: + return SDValue(); + } + + SDLoc DL(Load); + SDValue Chain = Load->getChain(); + SDValue BasePtr = Load->getBasePtr(); + const MachinePointerInfo &PtrInfo = Load->getPointerInfo(); + Align Alignment = Load->getAlign(); + + // Load the data as an FP scalar to avoid issues with integer loads. + unsigned LoadBits = MemVT.getStoreSizeInBits(); + MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits); + SDValue ScalarLoad = + DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment); + + MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits); + SDValue ScalarToVec = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad); + MVT BitcastTy = + MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits); + SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec); + + SDValue Res = Bitcast; + unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits(); + unsigned CurrentNumElts = Res.getValueType().getVectorNumElements(); + while (CurrentEltBits < DstEltBits) { + if (Res.getValueSizeInBits() >= 128) { + CurrentNumElts = CurrentNumElts / 2; + MVT ExtractVT = + MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res, + DAG.getConstant(0, DL, MVT::i64)); + } + CurrentEltBits = CurrentEltBits * 2; + MVT ExtVT = + MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts); + Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res); + } + + if (CurrentNumElts != NumElts) { + MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res, + DAG.getConstant(0, DL, MVT::i64)); + } + + return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL); +} + SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); LoadSDNode *LoadNode = cast<LoadSDNode>(Op); assert(LoadNode && "Expected custom lowering of a load node"); + if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG)) + return Result; + if (LoadNode->getMemoryVT() == MVT::i64x8) { SmallVector<SDValue, 8> Ops; SDValue Base = LoadNode->getBasePtr(); @@ -7228,37 +7408,38 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, return DAG.getMergeValues({Loaded, Chain}, DL); } - // Custom lowering for extending v4i8 vector loads. - EVT VT = Op->getValueType(0); - assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32"); - - if (LoadNode->getMemoryVT() != MVT::v4i8) - return SDValue(); - - // Avoid generating unaligned loads. - if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4)) - return SDValue(); + return SDValue(); +} - unsigned ExtType; - if (LoadNode->getExtensionType() == ISD::SEXTLOAD) - ExtType = ISD::SIGN_EXTEND; - else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD || - LoadNode->getExtensionType() == ISD::EXTLOAD) - ExtType = ISD::ZERO_EXTEND; - else - return SDValue(); +// Convert to ContainerVT with no-op casts where possible. +static SDValue convertToSVEContainerType(SDLoc DL, SDValue Vec, EVT ContainerVT, + SelectionDAG &DAG) { + EVT VecVT = Vec.getValueType(); + if (VecVT.isFloatingPoint()) { + // Use no-op casts for floating-point types. + EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType()); + Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedVT, Vec); + Vec = DAG.getNode(AArch64ISD::NVCAST, DL, ContainerVT, Vec); + } else { + // Extend integers (may not be a no-op). + Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec); + } + return Vec; +} - SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(), - LoadNode->getBasePtr(), MachinePointerInfo()); - SDValue Chain = Load.getValue(1); - SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load); - SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec); - SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC); - Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext, - DAG.getConstant(0, DL, MVT::i64)); - if (VT == MVT::v4i32) - Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext); - return DAG.getMergeValues({Ext, Chain}, DL); +// Convert to VecVT with no-op casts where possible. +static SDValue convertFromSVEContainerType(SDLoc DL, SDValue Vec, EVT VecVT, + SelectionDAG &DAG) { + if (VecVT.isFloatingPoint()) { + // Use no-op casts for floating-point types. + EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType()); + Vec = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVT, Vec); + Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VecVT, Vec); + } else { + // Truncate integers (may not be a no-op). + Vec = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Vec); + } + return Vec; } SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op, @@ -7312,49 +7493,49 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op, // Get legal type for compact instruction EVT ContainerVT = getSVEContainerType(VecVT); - EVT CastVT = VecVT.changeVectorElementTypeToInteger(); - // Convert to i32 or i64 for smaller types, as these are the only supported + // Convert to 32 or 64 bits for smaller types, as these are the only supported // sizes for compact. - if (ContainerVT != VecVT) { - Vec = DAG.getBitcast(CastVT, Vec); - Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec); - } + Vec = convertToSVEContainerType(DL, Vec, ContainerVT, DAG); SDValue Compressed = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(), - DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec); + DAG.getTargetConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, + Vec); // compact fills with 0s, so if our passthru is all 0s, do nothing here. if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) { SDValue Offset = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, - DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask); + DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, + Mask); SDValue IndexMask = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, MaskVT, - DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64), + DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64), DAG.getConstant(0, DL, MVT::i64), Offset); Compressed = DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru); } + // If we changed the element type before, we need to convert it back. + if (ElmtVT.isFloatingPoint()) + Compressed = convertFromSVEContainerType(DL, Compressed, VecVT, DAG); + // Extracting from a legal SVE type before truncating produces better code. if (IsFixedLength) { - Compressed = DAG.getNode( - ISD::EXTRACT_SUBVECTOR, DL, - FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()), - Compressed, DAG.getConstant(0, DL, MVT::i64)); - CastVT = FixedVecVT.changeVectorElementTypeToInteger(); + EVT FixedSubVector = VecVT.isInteger() + ? FixedVecVT.changeVectorElementType( + ContainerVT.getVectorElementType()) + : FixedVecVT; + Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedSubVector, + Compressed, DAG.getConstant(0, DL, MVT::i64)); VecVT = FixedVecVT; } - // If we changed the element type before, we need to convert it back. - if (ContainerVT != VecVT) { - Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed); - Compressed = DAG.getBitcast(VecVT, Compressed); - } + if (VecVT.isInteger()) + Compressed = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed); return Compressed; } @@ -7462,10 +7643,10 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) { DAG.getUNDEF(ExpVT), Exp, Zero); SDValue VPg = getPTrue(DAG, DL, XVT.changeVectorElementType(MVT::i1), AArch64SVEPredPattern::all); - SDValue FScale = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XVT, - DAG.getConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64), - VPg, VX, VExp); + SDValue FScale = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, XVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64), VPg, + VX, VExp); SDValue Final = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero); if (X.getValueType() != XScalarTy) @@ -7552,6 +7733,117 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, EndOfTrmp); } +SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + if (VT.getScalarType() != MVT::bf16 || + (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + + assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); + assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) && + "Unexpected FMUL VT"); + + auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { + return [&, IID](EVT VT, auto... Ops) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Ops...); + }; + }; + + auto Reinterpret = [&](SDValue Value, EVT VT) { + EVT SrcVT = Value.getValueType(); + if (VT == SrcVT) + return Value; + if (SrcVT.isFixedLengthVector()) + return convertToScalableVector(DAG, VT, Value); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(DAG, VT, Value); + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value); + }; + + bool UseSVEBFMLAL = VT.isScalableVector(); + auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); + auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); + + // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant. + // This does not match BFCVTN[2], so we use SVE to convert back to bf16. + auto BFMLALB = + MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb + : Intrinsic::aarch64_neon_bfmlalb); + auto BFMLALT = + MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt + : Intrinsic::aarch64_neon_bfmlalt); + + EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32; + SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags()); + SDValue Pg = getPredicateForVector(DAG, DL, AccVT); + + // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom + // instructions. These result in two f32 vectors, which can be converted back + // to bf16 with FCVT and FCVTNT. + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + // All SVE intrinsics expect to operate on full bf16 vector types. + if (UseSVEBFMLAL) { + LHS = Reinterpret(LHS, MVT::nxv8bf16); + RHS = Reinterpret(RHS, MVT::nxv8bf16); + } + + SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32); + SDValue BottomBF16 = + FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32); + // Note: nxv4bf16 only uses even lanes. + if (VT == MVT::nxv4bf16) + return Reinterpret(BottomBF16, VT); + + SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32); + SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32); + return Reinterpret(TopBF16, VT); +} + +SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const { + SDValue OpA = Op->getOperand(0); + SDValue OpB = Op->getOperand(1); + SDValue OpC = Op->getOperand(2); + EVT VT = Op.getValueType(); + SDLoc DL(Op); + + assert(VT.isVector() && "Scalar fma lowering should be handled by patterns"); + + // Bail early if we're definitely not looking to merge FNEGs into the FMA. + if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); + + if (OpC.getOpcode() != ISD::FNEG) + return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) + ? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED) + : Op; // Fallback to NEON lowering. + + // Convert FMA/FNEG nodes to SVE to enable the following patterns: + // fma(a, b, neg(c)) -> fnmls(a, b, c) + // fma(neg(a), b, neg(c)) -> fnmla(a, b, c) + // fma(a, neg(b), neg(c)) -> fnmla(a, b, c) + SDValue Pg = getPredicateForVector(DAG, DL, VT); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + auto ConvertToScalableFnegMt = [&](SDValue Op) { + if (Op.getOpcode() == ISD::FNEG) + Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU); + return convertToScalableVector(DAG, ContainerVT, Op); + }; + + OpA = ConvertToScalableFnegMt(OpA); + OpB = ConvertToScalableFnegMt(OpB); + OpC = ConvertToScalableFnegMt(OpC); + + SDValue ScalableRes = + DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC); + return convertFromScalableVector(DAG, VT, ScalableRes); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -7626,9 +7918,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::FSUB: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); case ISD::FMUL: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + return LowerFMUL(Op, DAG); case ISD::FMA: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); + return LowerFMA(Op, DAG); case ISD::FDIV: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); case ISD::FNEG: @@ -7673,6 +7965,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG); + case ISD::ANY_EXTEND_VECTOR_INREG: + case ISD::SIGN_EXTEND_VECTOR_INREG: + return LowerEXTEND_VECTOR_INREG(Op, DAG); case ISD::ZERO_EXTEND_VECTOR_INREG: return LowerZERO_EXTEND_VECTOR_INREG(Op, DAG); case ISD::VECTOR_SHUFFLE: @@ -7723,8 +8018,6 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: return LowerFP_TO_INT_SAT(Op, DAG); - case ISD::FSINCOS: - return LowerFSINCOS(Op, DAG); case ISD::GET_ROUNDING: return LowerGET_ROUNDING(Op, DAG); case ISD::SET_ROUNDING: @@ -7756,7 +8049,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STORE: return LowerSTORE(Op, DAG); case ISD::MSTORE: - return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + return LowerMSTORE(Op, DAG); case ISD::MGATHER: return LowerMGATHER(Op, DAG); case ISD::MSCATTER: @@ -7911,6 +8204,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: return LowerPARTIAL_REDUCE_MLA(Op, DAG); } } @@ -8130,7 +8424,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL, TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout())); SDValue TPIDR2_EL0 = DAG.getNode( ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); + DAG.getTargetConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); // Copy the address of the TPIDR2 block into X0 before 'calling' the // RESTORE_ZA pseudo. SDValue Glue; @@ -8145,7 +8439,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL, // Finally reset the TPIDR2_EL0 register to 0. Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), DAG.getConstant(0, DL, MVT::i64)); TPIDR2.Uses++; return Chain; @@ -8462,7 +8756,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments( Subtarget->isWindowsArm64EC()) && "Indirect arguments should be scalable on most subtargets"); - uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue(); + TypeSize PartSize = VA.getValVT().getStoreSize(); unsigned NumParts = 1; if (Ins[i].Flags.isInConsecutiveRegs()) { while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) @@ -8479,16 +8773,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments( InVals.push_back(ArgValue); NumParts--; if (NumParts > 0) { - SDValue BytesIncrement; - if (PartLoad.isScalableVector()) { - BytesIncrement = DAG.getVScale( - DL, Ptr.getValueType(), - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize)); - } else { - BytesIncrement = DAG.getConstant( - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL, - Ptr.getValueType()); - } + SDValue BytesIncrement = + DAG.getTypeSize(DL, Ptr.getValueType(), PartSize); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement, SDNodeFlags::NoUnsignedWrap); ExtraArgLocs++; @@ -8735,15 +9021,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } } - if (getTM().useNewSMEABILowering()) { - // Clear new ZT0 state. TODO: Move this to the SME ABI pass. - if (Attrs.isNewZT0()) - Chain = DAG.getNode( - ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32), - DAG.getTargetConstant(0, DL, MVT::i32)); - } - return Chain; } @@ -9028,11 +9305,12 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( CallingConv::ID CallerCC = CallerF.getCallingConv(); // SME Streaming functions are not eligible for TCO as they may require - // the streaming mode or ZA to be restored after returning from the call. + // the streaming mode or ZA/ZT0 to be restored after returning from the call. SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, getRuntimeLibcallsInfo(), CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || + CallAttrs.requiresPreservingZT0() || CallAttrs.caller().hasStreamingBody()) return false; @@ -9465,6 +9743,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState()) ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE; + else if (CallAttrs.requiresPreservingZT0()) + ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE; else if (CallAttrs.caller().hasZAState() || CallAttrs.caller().hasZT0State()) ZAMarkerNode = AArch64ISD::INOUT_ZA_USE; @@ -9552,7 +9832,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), TPIDR2ObjAddr); OptimizationRemarkEmitter ORE(&MF.getFunction()); ORE.emit([&]() { @@ -9584,7 +9864,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, SDValue ZTFrameIdx; MachineFrameInfo &MFI = MF.getFrameInfo(); - bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0(); + bool ShouldPreserveZT0 = + !UseNewSMEABILowering && CallAttrs.requiresPreservingZT0(); // If the caller has ZT0 state which will not be preserved by the callee, // spill ZT0 before the call. @@ -9597,7 +9878,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // If caller shares ZT0 but the callee is not shared ZA, we need to stop // PSTATE.ZA before the call if there is no lazy-save active. - bool DisableZA = CallAttrs.requiresDisablingZABeforeCall(); + bool DisableZA = + !UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall(); assert((!DisableZA || !RequiresLazySave) && "Lazy-save should have PSTATE.SM=1 on entry to the function"); @@ -9616,8 +9898,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // using a chain can result in incorrect scheduling. The markers refer to // the position just before the CALLSEQ_START (though occur after as // CALLSEQ_START lacks in-glue). - Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other), - {Chain, Chain.getValue(1)}); + Chain = + DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other, MVT::Glue), + {Chain, Chain.getValue(1)}); } } @@ -9698,8 +9981,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, assert((isScalable || Subtarget->isWindowsArm64EC()) && "Indirect arguments should be scalable on most subtargets"); - uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue(); - uint64_t PartSize = StoreSize; + TypeSize StoreSize = VA.getValVT().getStoreSize(); + TypeSize PartSize = StoreSize; unsigned NumParts = 1; if (Outs[i].Flags.isInConsecutiveRegs()) { while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) @@ -9710,7 +9993,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); MachineFrameInfo &MFI = MF.getFrameInfo(); - int FI = MFI.CreateStackObject(StoreSize, Alignment, false); + int FI = + MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false); if (isScalable) { bool IsPred = VA.getValVT() == MVT::aarch64svcount || VA.getValVT().getVectorElementType() == MVT::i1; @@ -9731,16 +10015,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, NumParts--; if (NumParts > 0) { - SDValue BytesIncrement; - if (isScalable) { - BytesIncrement = DAG.getVScale( - DL, Ptr.getValueType(), - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize)); - } else { - BytesIncrement = DAG.getConstant( - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL, - Ptr.getValueType()); - } + SDValue BytesIncrement = + DAG.getTypeSize(DL, Ptr.getValueType(), PartSize); MPI = MachinePointerInfo(MPI.getAddrSpace()); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement, SDNodeFlags::NoUnsignedWrap); @@ -10033,6 +10309,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (InGlue.getNode()) Ops.push_back(InGlue); + if (CLI.DeactivationSymbol) + Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol)); + // If we're doing a tall call, use a TC_RETURN here rather than an // actual call instruction. if (IsTailCall) { @@ -10082,7 +10361,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, getSMToggleCondition(CallAttrs)); } - if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()) + if (!UseNewSMEABILowering && + (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())) // Unconditionally resume ZA. Result = DAG.getNode( AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result, @@ -10622,16 +10902,41 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr, const SDLoc &DL, SelectionDAG &DAG) const { EVT PtrVT = getPointerTy(DAG.getDataLayout()); + auto &MF = DAG.getMachineFunction(); + auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); + SDValue Glue; SDValue Chain = DAG.getEntryNode(); SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); + SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal); + bool RequiresSMChange = TLSCallAttrs.requiresSMChange(); + + auto ChainAndGlue = [](SDValue Chain) -> std::pair<SDValue, SDValue> { + return {Chain, Chain.getValue(1)}; + }; + + if (RequiresSMChange) + std::tie(Chain, Glue) = + ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue, + getSMToggleCondition(TLSCallAttrs))); + unsigned Opcode = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT() ? AArch64ISD::TLSDESC_AUTH_CALLSEQ : AArch64ISD::TLSDESC_CALLSEQ; - Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr}); - SDValue Glue = Chain.getValue(1); + SDValue Ops[] = {Chain, SymAddr, Glue}; + std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode( + Opcode, DL, NodeTys, Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back())); + + if (TLSCallAttrs.requiresLazySave()) + std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode( + AArch64ISD::REQUIRES_ZA_SAVE, DL, NodeTys, {Chain, Chain.getValue(1)})); + + if (RequiresSMChange) + std::tie(Chain, Glue) = + ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue, + getSMToggleCondition(TLSCallAttrs))); return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue); } @@ -11366,9 +11671,10 @@ SDValue AArch64TargetLowering::LowerMinMax(SDValue Op, break; } + // Note: This lowering only overrides NEON for v1i64 and v2i64, where we + // prefer using SVE if available. if (VT.isScalableVector() || - useSVEForFixedLengthVectorVT( - VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors())) { + useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true)) { switch (Opcode) { default: llvm_unreachable("Wrong instruction"); @@ -11539,7 +11845,12 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { } if (LHS.getValueType().isInteger()) { - + if (Subtarget->hasCSSC() && CC == ISD::SETNE && isNullConstant(RHS)) { + SDValue One = DAG.getConstant(1, DL, LHS.getValueType()); + SDValue UMin = DAG.getNode(ISD::UMIN, DL, LHS.getValueType(), LHS, One); + SDValue Res = DAG.getZExtOrTrunc(UMin, DL, VT); + return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res; + } simplifySetCCIntoEq(CC, LHS, RHS, DAG, DL); SDValue CCVal; @@ -13443,8 +13754,8 @@ SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) { return DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, VT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), SourceVec, - MaskSourceVec); + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + SourceVec, MaskSourceVec); } // Gather data to see if the operation can be modelled as a @@ -14300,14 +14611,16 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask, V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + V1Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } else { if (IndexLen == 8) { V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + V1Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } else { // FIXME: We cannot, for the moment, emit a TBL2 instruction because we @@ -14318,8 +14631,8 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask, // IndexLen)); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst, - V2Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), + V1Cst, V2Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } } @@ -14487,6 +14800,40 @@ static SDValue tryToConvertShuffleOfTbl2ToTbl4(SDValue Op, Tbl2->getOperand(1), Tbl2->getOperand(2), TBLMask}); } +SDValue +AArch64TargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + assert(VT.isScalableVector() && "Unexpected result type!"); + + bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG; + unsigned UnpackOpcode = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + + // Repeatedly unpack Val until the result is of the desired type. + SDValue Val = Op.getOperand(0); + switch (Val.getSimpleValueType().SimpleTy) { + default: + return SDValue(); + case MVT::nxv16i8: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv8i16, Val); + if (VT == MVT::nxv8i16) + break; + [[fallthrough]]; + case MVT::nxv8i16: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv4i32, Val); + if (VT == MVT::nxv4i32) + break; + [[fallthrough]]; + case MVT::nxv4i32: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv2i64, Val); + assert(VT == MVT::nxv2i64 && "Unexpected result type!"); + break; + } + + return Val; +} + // Baseline legalization for ZERO_EXTEND_VECTOR_INREG will blend-in zeros, // but we don't have an appropriate instruction, // so custom-lower it as ZIP1-with-zeros. @@ -14495,6 +14842,10 @@ AArch64TargetLowering::LowerZERO_EXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); EVT VT = Op.getValueType(); + + if (VT.isScalableVector()) + return LowerEXTEND_VECTOR_INREG(Op, DAG); + SDValue SrcOp = Op.getOperand(0); EVT SrcVT = SrcOp.getValueType(); assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 && @@ -14604,17 +14955,20 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, } unsigned WhichResult; - if (isZIPMask(ShuffleMask, NumElts, WhichResult)) { + unsigned OperandOrder; + if (isZIPMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2; - return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); + return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2, + OperandOrder == 0 ? V2 : V1); } if (isUZPMask(ShuffleMask, NumElts, WhichResult)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2; return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); } - if (isTRNMask(ShuffleMask, NumElts, WhichResult)) { + if (isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2; - return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); + return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2, + OperandOrder == 0 ? V2 : V1); } if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) { @@ -16326,9 +16680,9 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { isREVMask(M, EltSize, NumElts, 16) || isEXTMask(M, VT, DummyBool, DummyUnsigned) || isSingletonEXTMask(M, VT, DummyUnsigned) || - isTRNMask(M, NumElts, DummyUnsigned) || + isTRNMask(M, NumElts, DummyUnsigned, DummyUnsigned) || isUZPMask(M, NumElts, DummyUnsigned) || - isZIPMask(M, NumElts, DummyUnsigned) || + isZIPMask(M, NumElts, DummyUnsigned, DummyUnsigned) || isTRN_v_undef_Mask(M, VT, DummyUnsigned) || isUZP_v_undef_Mask(M, VT, DummyUnsigned) || isZIP_v_undef_Mask(M, VT, DummyUnsigned) || @@ -16472,10 +16826,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0), DAG.getTargetConstant(Cnt, DL, MVT::i32)); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, - DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL, - MVT::i32), - Op.getOperand(0), Op.getOperand(1)); + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getTargetConstant(Intrinsic::aarch64_neon_ushl, DL, MVT::i32), + Op.getOperand(0), Op.getOperand(1)); case ISD::SRA: case ISD::SRL: if (VT.isScalableVector() && @@ -16977,7 +17331,7 @@ SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op, template <unsigned NumVecs> static bool setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL, - AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) { + AArch64TargetLowering::IntrinsicInfo &Info, const CallBase &CI) { Info.opc = ISD::INTRINSIC_VOID; // Retrieve EC from first vector argument. const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType()); @@ -17002,7 +17356,7 @@ setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL, /// MemIntrinsicNodes. The associated MachineMemOperands record the alignment /// specified in the intrinsic calls. bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, - const CallInst &I, + const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const { auto &DL = I.getDataLayout(); @@ -17590,6 +17944,7 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion( // udot instruction. if (SrcWidth * 4 <= DstWidth) { if (all_of(I->users(), [&](auto *U) { + using namespace llvm::PatternMatch; auto *SingleUser = cast<Instruction>(&*U); if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value())))) return true; @@ -17861,6 +18216,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad( // into shift / and masks. For the moment we do this just for uitofp (not // zext) to avoid issues with widening instructions. if (Shuffles.size() == 4 && all_of(Shuffles, [](ShuffleVectorInst *SI) { + using namespace llvm::PatternMatch; return SI->hasOneUse() && match(SI->user_back(), m_UIToFP(m_Value())) && SI->getType()->getScalarSizeInBits() * 4 == SI->user_back()->getType()->getScalarSizeInBits(); @@ -18569,7 +18925,7 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( case MVT::f64: return true; case MVT::bf16: - return VT.isScalableVector() && Subtarget->hasSVEB16B16() && + return VT.isScalableVector() && Subtarget->hasBF16() && Subtarget->isNonStreamingSVEorSME2Available(); default: break; @@ -18752,6 +19108,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, return (Index == 0 || Index == ResVT.getVectorMinNumElements()); } +bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits( + LLVMContext &Context, EVT VT) const { + if (getTypeAction(Context, VT) != TypeExpandInteger) + return false; + + EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2); + return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal; +} + /// Turn vector tests of the signbit in the form of: /// xor (sra X, elt_size(X)-1), -1 /// into: @@ -19314,20 +19679,37 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, return CSNeg; } -static std::optional<unsigned> IsSVECntIntrinsic(SDValue S) { +static bool IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: break; case Intrinsic::aarch64_sve_cntb: - return 8; case Intrinsic::aarch64_sve_cnth: - return 16; case Intrinsic::aarch64_sve_cntw: - return 32; case Intrinsic::aarch64_sve_cntd: - return 64; + return true; + } + return false; +} + +// Returns the maximum (scalable) value that can be returned by an SVE count +// intrinsic. Returns std::nullopt if \p Op is not aarch64_sve_cnt*. +static std::optional<ElementCount> getMaxValueForSVECntIntrinsic(SDValue Op) { + Intrinsic::ID IID = getIntrinsicID(Op.getNode()); + if (IID == Intrinsic::aarch64_sve_cntp) + return Op.getOperand(1).getValueType().getVectorElementCount(); + switch (IID) { + case Intrinsic::aarch64_sve_cntd: + return ElementCount::getScalable(2); + case Intrinsic::aarch64_sve_cntw: + return ElementCount::getScalable(4); + case Intrinsic::aarch64_sve_cnth: + return ElementCount::getScalable(8); + case Intrinsic::aarch64_sve_cntb: + return ElementCount::getScalable(16); + default: + return std::nullopt; } - return {}; } /// Calculates what the pre-extend type is, based on the extension @@ -19971,7 +20353,9 @@ static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG, return Res; EVT VT = N->getValueType(0); - if (VT != MVT::f32 && VT != MVT::f64) + if (VT != MVT::f16 && VT != MVT::f32 && VT != MVT::f64) + return SDValue(); + if (VT == MVT::f16 && !Subtarget->hasFullFP16()) return SDValue(); // Only optimize when the source and destination types have the same width. @@ -20069,7 +20453,7 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG, : Intrinsic::aarch64_neon_vcvtfp2fxu; SDValue FixConv = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy, - DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), + DAG.getTargetConstant(IntrinsicOpcode, DL, MVT::i32), Op->getOperand(0), DAG.getTargetConstant(C, DL, MVT::i32)); // We can handle smaller integers by generating an extra trunc. if (IntBits < FloatBits) @@ -21623,9 +22007,8 @@ static SDValue performBuildVectorCombine(SDNode *N, SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0); if (LowLanesSrcVec.getValueType() == MVT::v2f64) { SDValue HighLanes; - if (Elt2->getOpcode() == ISD::UNDEF && - Elt3->getOpcode() == ISD::UNDEF) { - HighLanes = DAG.getUNDEF(MVT::v2f32); + if (Elt2->isUndef() && Elt3->isUndef()) { + HighLanes = DAG.getPOISON(MVT::v2f32); } else if (Elt2->getOpcode() == ISD::FP_ROUND && Elt3->getOpcode() == ISD::FP_ROUND && isa<ConstantSDNode>(Elt2->getOperand(1)) && @@ -22328,6 +22711,69 @@ static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift); } +// Attempt to combine the following patterns: +// SUB x, (CSET LO, (CMP a, b)) -> SBC x, 0, (CMP a, b) +// SUB (SUB x, y), (CSET LO, (CMP a, b)) -> SBC x, y, (CMP a, b) +// The CSET may be preceded by a ZEXT. +static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + + EVT VT = N->getValueType(0); + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + SDValue N1 = N->getOperand(1); + if (N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) + N1 = N1.getOperand(0); + if (!N1.hasOneUse() || getCSETCondCode(N1) != AArch64CC::LO) + return SDValue(); + + SDValue Flags = N1.getOperand(3); + if (Flags.getOpcode() != AArch64ISD::SUBS) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + if (N0->getOpcode() == ISD::SUB) + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0.getOperand(0), + N0.getOperand(1), Flags); + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0, DAG.getConstant(0, DL, VT), + Flags); +} + +// add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW +// -> +// X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) +// The original converts into ashr+lshr+xtn+xtn+add. The second becomes +// ashr+xtn+usra. The first form has less total latency due to more parallelism, +// but more micro-ops and seems to be slower in practice. +static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; + EVT VT = N->getValueType(0); + if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8) + return SDValue(); + + SDValue AShr, LShr; + if (!sd_match(N, m_Add(m_Trunc(m_Value(AShr)), m_Trunc(m_Value(LShr))))) + return SDValue(); + if (AShr.getOpcode() != AArch64ISD::VASHR) + std::swap(AShr, LShr); + if (AShr.getOpcode() != AArch64ISD::VASHR || + LShr.getOpcode() != AArch64ISD::VLSHR || + AShr.getOperand(0) != LShr.getOperand(0) || + AShr.getConstantOperandVal(1) < VT.getScalarSizeInBits() || + LShr.getConstantOperandVal(1) != VT.getScalarSizeInBits() * 2 - 1) + return SDValue(); + + SDLoc DL(N); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, AShr); + SDValue Shift = DAG.getNode( + AArch64ISD::VLSHR, DL, VT, Trunc, + DAG.getTargetConstant(VT.getScalarSizeInBits() - 1, DL, MVT::i32)); + return DAG.getNode(ISD::ADD, DL, VT, Trunc, Shift); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Try to change sum of two reductions. @@ -22349,6 +22795,10 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG)) return Val; + if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG)) + return Val; + if (SDValue Val = performAddTruncShiftCombine(N, DCI.DAG)) + return Val; if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG)) return Val; @@ -23000,11 +23450,15 @@ static SDValue performIntrinsicCombine(SDNode *N, return DAG.getNode(ISD::OR, SDLoc(N), N->getValueType(0), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_sabd_u: - return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0), - N->getOperand(2), N->getOperand(3)); + if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDS, DAG, true)) + return V; + return DAG.getNode(AArch64ISD::ABDS_PRED, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_uabd_u: - return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0), - N->getOperand(2), N->getOperand(3)); + if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDU, DAG, true)) + return V; + return DAG.getNode(AArch64ISD::ABDU_PRED, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_sdiv_u: return DAG.getNode(AArch64ISD::SDIV_PRED, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); @@ -23927,7 +24381,7 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); // uzp1(x, undef) -> concat(truncate(x), undef) - if (Op1.getOpcode() == ISD::UNDEF) { + if (Op1.isUndef()) { EVT BCVT = MVT::Other, HalfVT = MVT::Other; switch (ResVT.getSimpleVT().SimpleTy) { default: @@ -26070,7 +26524,7 @@ static SDValue performCSELCombine(SDNode *N, // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1 // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1 if (SDValue Folded = foldCSELofCTTZ(N, DAG)) - return Folded; + return Folded; // CSEL a, b, cc, SUBS(x, y) -> CSEL a, b, swapped(cc), SUBS(y, x) // if SUB(y, x) already exists and we can produce a swapped predicate for cc. @@ -26095,29 +26549,6 @@ static SDValue performCSELCombine(SDNode *N, } } - // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't - // use overflow flags, to avoid the comparison with zero. In case of success, - // this also replaces the original SUB(x,y) with the newly created SUBS(x,y). - // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB - // nodes with their SUBS equivalent as is already done for other flag-setting - // operators, in which case doing the replacement here becomes redundant. - if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) && - isNullConstant(Cond.getOperand(1))) { - SDValue Sub = Cond.getOperand(0); - AArch64CC::CondCode CC = - static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2)); - if (Sub.getOpcode() == ISD::SUB && - (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI || - CC == AArch64CC::PL)) { - SDLoc DL(N); - SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), - Sub.getOperand(0), Sub.getOperand(1)); - DCI.CombineTo(Sub.getNode(), Subs); - DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1)); - return SDValue(N, 0); - } - } - // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z if (SDValue CondLast = foldCSELofLASTB(N, DAG)) return CondLast; @@ -26396,8 +26827,7 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDValue L1 = LHS->getOperand(1); SDValue L2 = LHS->getOperand(2); - if (L0.getOpcode() == ISD::UNDEF && isNullConstant(L2) && - isSignExtInReg(L1)) { + if (L0.isUndef() && isNullConstant(L2) && isSignExtInReg(L1)) { SDLoc DL(N); SDValue Shl = L1.getOperand(0); SDValue NewLHS = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, @@ -26661,22 +27091,25 @@ static SDValue performSelectCombine(SDNode *N, assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) && "Scalar-SETCC feeding SELECT has unexpected result type!"); - // If NumMaskElts == 0, the comparison is larger than select result. The - // largest real NEON comparison is 64-bits per lane, which means the result is - // at most 32-bits and an illegal vector. Just bail out for now. - EVT SrcVT = N0.getOperand(0).getValueType(); - // Don't try to do this optimization when the setcc itself has i1 operands. // There are no legal vectors of i1, so this would be pointless. v1f16 is // ruled out to prevent the creation of setcc that need to be scalarized. + EVT SrcVT = N0.getOperand(0).getValueType(); if (SrcVT == MVT::i1 || (SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16)) return SDValue(); - int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits(); + // If NumMaskElts == 0, the comparison is larger than select result. The + // largest real NEON comparison is 64-bits per lane, which means the result is + // at most 32-bits and an illegal vector. Just bail out for now. + unsigned NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits(); if (!ResVT.isVector() || NumMaskElts == 0) return SDValue(); + // Avoid creating vectors with excessive VFs before legalization. + if (DCI.isBeforeLegalize() && NumMaskElts != ResVT.getVectorNumElements()) + return SDValue(); + SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts); EVT CCVT = SrcVT.changeVectorElementTypeToInteger(); @@ -27325,8 +27758,8 @@ static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG, // ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to // `aarch64_sve_prfb_gather_uxtw_index`. SDLoc DL(N); - Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL, - MVT::i64); + Ops[1] = DAG.getTargetConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, + DL, MVT::i64); return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops); } @@ -27877,6 +28310,35 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); } +static SDValue performCTPOPCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; + if (!DCI.isBeforeLegalize()) + return SDValue(); + + // ctpop(zext(bitcast(vector_mask))) -> neg(signed_reduce_add(vector_mask)) + SDValue Mask; + if (!sd_match(N->getOperand(0), m_ZExt(m_BitCast(m_Value(Mask))))) + return SDValue(); + + EVT VT = N->getValueType(0); + EVT MaskVT = Mask.getValueType(); + + if (VT.isVector() || !MaskVT.isFixedLengthVector() || + MaskVT.getVectorElementType() != MVT::i1) + return SDValue(); + + EVT ReduceInVT = + EVT::getVectorVT(*DAG.getContext(), VT, MaskVT.getVectorElementCount()); + + SDLoc DL(N); + // Sign extend to best fit ZeroOrNegativeOneBooleanContent. + SDValue ExtMask = DAG.getNode(ISD::SIGN_EXTEND, DL, ReduceInVT, Mask); + SDValue NegPopCount = DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, ExtMask); + return DAG.getNegative(NegPopCount, DL, VT); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -28222,6 +28684,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performScalarToVectorCombine(N, DCI, DAG); case ISD::SHL: return performSHLCombine(N, DCI, DAG); + case ISD::CTPOP: + return performCTPOPCombine(N, DCI, DAG); } return SDValue(); } @@ -28568,7 +29032,8 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( if ((Index != 0) && (Index != ResEC.getKnownMinValue())) return; - unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI; + unsigned Opcode = (Index == 0) ? (unsigned)ISD::ANY_EXTEND_VECTOR_INREG + : (unsigned)AArch64ISD::UUNPKHI; EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext()); SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0)); @@ -29295,12 +29760,26 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { AI->getOperation() == AtomicRMWInst::FMinimum)) return AtomicExpansionKind::None; - // Nand is not supported in LSE. // Leave 128 bits to LLSC or CmpXChg. - if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128 && - !AI->isFloatingPointOperation()) { - if (Subtarget->hasLSE()) - return AtomicExpansionKind::None; + if (Size < 128 && !AI->isFloatingPointOperation()) { + if (Subtarget->hasLSE()) { + // Nand is not supported in LSE. + switch (AI->getOperation()) { + case AtomicRMWInst::Xchg: + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::And: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: + case AtomicRMWInst::Max: + case AtomicRMWInst::Min: + case AtomicRMWInst::UMax: + case AtomicRMWInst::UMin: + return AtomicExpansionKind::None; + default: + break; + } + } if (Subtarget->outlineAtomics()) { // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far. // Don't outline them unless @@ -29308,11 +29787,16 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf // (2) low level libgcc and compiler-rt support implemented by: // min/max outline atomics helpers - if (AI->getOperation() != AtomicRMWInst::Min && - AI->getOperation() != AtomicRMWInst::Max && - AI->getOperation() != AtomicRMWInst::UMin && - AI->getOperation() != AtomicRMWInst::UMax) { + switch (AI->getOperation()) { + case AtomicRMWInst::Xchg: + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::And: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: return AtomicExpansionKind::None; + default: + break; } } } @@ -30119,6 +30603,43 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( Store->isTruncatingStore()); } +SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + auto *Store = cast<MaskedStoreSDNode>(Op); + EVT VT = Store->getValue().getValueType(); + if (VT.isFixedLengthVector()) + return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + + if (!Store->isCompressingStore()) + return SDValue(); + + EVT MaskVT = Store->getMask().getValueType(); + EVT MaskExtVT = getPromotedVTForPredicate(MaskVT); + EVT MaskReduceVT = MaskExtVT.getScalarType(); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + + SDValue MaskExt = + DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask()); + SDValue CntActive = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt); + if (MaskReduceVT != MVT::i64) + CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive); + + SDValue CompressedValue = + DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(), + Store->getMask(), DAG.getPOISON(VT)); + SDValue CompressedMask = + DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive); + + return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue, + Store->getBasePtr(), Store->getOffset(), + CompressedMask, Store->getMemoryVT(), + Store->getMemOperand(), Store->getAddressingMode(), + Store->isTruncatingStore(), + /*isCompressing=*/false); +} + SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( SDValue Op, SelectionDAG &DAG) const { auto *Store = cast<MaskedStoreSDNode>(Op); @@ -30133,7 +30654,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( return DAG.getMaskedStore( Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(), Mask, Store->getMemoryVT(), Store->getMemOperand(), - Store->getAddressingMode(), Store->isTruncatingStore()); + Store->getAddressingMode(), Store->isTruncatingStore(), + Store->isCompressingStore()); } SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( @@ -31160,10 +31682,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2, SDValue Shuffle; if (IsSingleOp) - Shuffle = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, - DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32), - Op1, SVEMask); + Shuffle = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32), Op1, + SVEMask); else if (Subtarget.hasSVE2()) { if (!MinMaxEqual) { unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt; @@ -31182,10 +31704,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2, SVEMask = convertToScalableVector( DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask); } - Shuffle = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, - DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32), - Op1, Op2, SVEMask); + Shuffle = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32), Op1, + Op2, SVEMask); } Shuffle = convertFromScalableVector(DAG, VT, Shuffle); return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle); @@ -31267,15 +31789,23 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( } unsigned WhichResult; - if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) && + unsigned OperandOrder; + if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder) && WhichResult == 0) return convertFromScalableVector( - DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2)); + DAG, VT, + DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, + OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1)); - if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) { + if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2; - return convertFromScalableVector( - DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2)); + SDValue TRN = + DAG.getNode(Opc, DL, ContainerVT, OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1); + return convertFromScalableVector(DAG, VT, TRN); } if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0) @@ -31315,10 +31845,14 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( return convertFromScalableVector(DAG, VT, Op); } - if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) && + if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder) && WhichResult != 0) return convertFromScalableVector( - DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2)); + DAG, VT, + DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, + OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1)); if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2; @@ -31345,8 +31879,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( unsigned SegmentElts = VT.getVectorNumElements() / Segments; if (std::optional<unsigned> Lane = isDUPQMask(ShuffleMask, Segments, SegmentElts)) { - SDValue IID = - DAG.getConstant(Intrinsic::aarch64_sve_dup_laneq, DL, MVT::i64); + SDValue IID = DAG.getTargetConstant(Intrinsic::aarch64_sve_dup_laneq, + DL, MVT::i64); return convertFromScalableVector( DAG, VT, DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, @@ -31493,22 +32027,24 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( return false; } case ISD::INTRINSIC_WO_CHAIN: { - if (auto ElementSize = IsSVECntIntrinsic(Op)) { - unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); - if (!MaxSVEVectorSizeInBits) - MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; - unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize; - // The SVE count intrinsics don't support the multiplier immediate so we - // don't have to account for that here. The value returned may be slightly - // over the true required bits, as this is based on the "ALL" pattern. The - // other patterns are also exposed by these intrinsics, but they all - // return a value that's strictly less than "ALL". - unsigned RequiredBits = llvm::bit_width(MaxElements); - unsigned BitWidth = Known.Zero.getBitWidth(); - if (RequiredBits < BitWidth) - Known.Zero.setHighBits(BitWidth - RequiredBits); + std::optional<ElementCount> MaxCount = getMaxValueForSVECntIntrinsic(Op); + if (!MaxCount) return false; - } + unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); + if (!MaxSVEVectorSizeInBits) + MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; + unsigned VscaleMax = MaxSVEVectorSizeInBits / 128; + unsigned MaxValue = MaxCount->getKnownMinValue() * VscaleMax; + // The SVE count intrinsics don't support the multiplier immediate so we + // don't have to account for that here. The value returned may be slightly + // over the true required bits, as this is based on the "ALL" pattern. The + // other patterns are also exposed by these intrinsics, but they all + // return a value that's strictly less than "ALL". + unsigned RequiredBits = llvm::bit_width(MaxValue); + unsigned BitWidth = Known.Zero.getBitWidth(); + if (RequiredBits < BitWidth) + Known.Zero.setHighBits(BitWidth - RequiredBits); + return false; } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 2cb8ed2..e8c026d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -206,7 +206,7 @@ public: EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *MBB) const override; - bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, + bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override; @@ -333,6 +333,11 @@ public: return TargetLowering::shouldFormOverflowOp(Opcode, VT, true); } + // Return true if the target wants to optimize the mul overflow intrinsic + // for the given \p VT. + bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context, + EVT VT) const override; + Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr, AtomicOrdering Ord) const override; Value *emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr, @@ -609,6 +614,8 @@ private: SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const; SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; @@ -708,6 +715,7 @@ private: SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const; SDValue LowerZERO_EXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const; @@ -745,7 +753,6 @@ private: SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerXOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOOP_DEPENDENCE_MASK(SDValue Op, SelectionDAG &DAG) const; SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const; @@ -756,6 +763,7 @@ private: SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const; diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 58a53af..4d2e740 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -415,6 +415,12 @@ def CmpBranchUImm6Operand_64b let WantsParent = true; } +def CmpBranchBExtOperand + : ComplexPattern<i32, 2, "SelectCmpBranchExtOperand<true>", []> {} + +def CmpBranchHExtOperand + : ComplexPattern<i32, 2, "SelectCmpBranchExtOperand<false>", []> {} + def UImm6Plus1Operand : AsmOperandClass { let Name = "UImm6P1"; let DiagnosticType = "InvalidImm1_64"; @@ -1712,28 +1718,6 @@ class RtSystemI<bit L, dag oops, dag iops, string asm, string operands, let Inst{4-0} = Rt; } -// System instructions for transactional memory extension -class TMBaseSystemI<bit L, bits<4> CRm, bits<3> op2, dag oops, dag iops, - string asm, string operands, list<dag> pattern> - : BaseSystemI<L, oops, iops, asm, operands, pattern>, - Sched<[WriteSys]> { - let Inst{20-12} = 0b000110011; - let Inst{11-8} = CRm; - let Inst{7-5} = op2; - let DecoderMethod = ""; - - let mayLoad = 1; - let mayStore = 1; -} - -// System instructions for transactional memory - single input operand -class TMSystemI<bits<4> CRm, string asm, list<dag> pattern> - : TMBaseSystemI<0b1, CRm, 0b011, - (outs GPR64:$Rt), (ins), asm, "\t$Rt", pattern> { - bits<5> Rt; - let Inst{4-0} = Rt; -} - // System instructions that pass a register argument // This class assumes the register is for input rather than output. class RegInputSystemI<bits<4> CRm, bits<3> Op2, string asm, @@ -1744,23 +1728,6 @@ class RegInputSystemI<bits<4> CRm, bits<3> Op2, string asm, let Inst{7-5} = Op2; } -// System instructions for transactional memory - no operand -class TMSystemINoOperand<bits<4> CRm, string asm, list<dag> pattern> - : TMBaseSystemI<0b0, CRm, 0b011, (outs), (ins), asm, "", pattern> { - let Inst{4-0} = 0b11111; -} - -// System instructions for exit from transactions -class TMSystemException<bits<3> op1, string asm, list<dag> pattern> - : I<(outs), (ins timm64_0_65535:$imm), asm, "\t$imm", "", pattern>, - Sched<[WriteSys]> { - bits<16> imm; - let Inst{31-24} = 0b11010100; - let Inst{23-21} = op1; - let Inst{20-5} = imm; - let Inst{4-0} = 0b00000; -} - class APASI : SimpleSystemI<0, (ins GPR64:$Xt), "apas", "\t$Xt">, Sched<[]> { bits<5> Xt; let Inst{20-5} = 0b0111001110000000; @@ -1909,6 +1876,21 @@ def CMHPriorityHint_op : Operand<i32> { }]; } +def TIndexHintOperand : AsmOperandClass { + let Name = "TIndexHint"; + let ParserMethod = "tryParseTIndexHint"; +} + +def TIndexhint_op : Operand<i32> { + let ParserMatchClass = TIndexHintOperand; + let PrintMethod = "printTIndexHintOp"; + let MCOperandPredicate = [{ + if (!MCOp.isImm()) + return false; + return AArch64TIndexHint::lookupTIndexByEncoding(MCOp.getImm()) != nullptr; + }]; +} + class MRSI : RtSystemI<1, (outs GPR64:$Rt), (ins mrs_sysreg_op:$systemreg), "mrs", "\t$Rt, $systemreg"> { bits<16> systemreg; @@ -2365,6 +2347,7 @@ class BImm<bit op, dag iops, string asm, list<dag> pattern> let Inst{25-0} = addr; let DecoderMethod = "DecodeUnconditionalBranch"; + let supportsDeactivationSymbol = true; } class BranchImm<bit op, string asm, list<dag> pattern> @@ -2422,6 +2405,7 @@ class SignAuthOneData<bits<3> opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = Rn; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthZero<bits<3> opcode_prefix, bits<2> opcode, string asm, @@ -2435,6 +2419,7 @@ class SignAuthZero<bits<3> opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = 0b11111; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthTwoOperand<bits<4> opc, string asm, @@ -7715,16 +7700,21 @@ multiclass SIMDThreeScalarD<bit U, bits<5> opc, string asm, } multiclass SIMDThreeScalarBHSD<bit U, bits<5> opc, string asm, - SDPatternOperator OpNode, SDPatternOperator SatOp> { + SDPatternOperator OpNode, SDPatternOperator G_OpNode, SDPatternOperator SatOp> { def v1i64 : BaseSIMDThreeScalar<U, 0b111, opc, FPR64, asm, [(set (v1i64 FPR64:$Rd), (SatOp (v1i64 FPR64:$Rn), (v1i64 FPR64:$Rm)))]>; def v1i32 : BaseSIMDThreeScalar<U, 0b101, opc, FPR32, asm, []>; def v1i16 : BaseSIMDThreeScalar<U, 0b011, opc, FPR16, asm, []>; def v1i8 : BaseSIMDThreeScalar<U, 0b001, opc, FPR8 , asm, []>; - def : Pat<(i64 (OpNode (i64 FPR64:$Rn), (i64 FPR64:$Rm))), + def : Pat<(i64 (G_OpNode (i64 FPR64:$Rn), (i64 FPR64:$Rm))), (!cast<Instruction>(NAME#"v1i64") FPR64:$Rn, FPR64:$Rm)>; - def : Pat<(i32 (OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm))), + def : Pat<(i32 (G_OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm))), + (!cast<Instruction>(NAME#"v1i32") FPR32:$Rn, FPR32:$Rm)>; + + def : Pat<(f64 (OpNode FPR64:$Rn, FPR64:$Rm)), + (!cast<Instruction>(NAME#"v1i64") FPR64:$Rn, FPR64:$Rm)>; + def : Pat<(f32 (OpNode FPR32:$Rn, FPR32:$Rm)), (!cast<Instruction>(NAME#"v1i32") FPR32:$Rn, FPR32:$Rm)>; } @@ -7810,7 +7800,7 @@ multiclass SIMDThreeScalarMixedHS<bit U, bits<5> opc, string asm, def i32 : BaseSIMDThreeScalarMixed<U, 0b10, opc, (outs FPR64:$Rd), (ins FPR32:$Rn, FPR32:$Rm), asm, "", - [(set (i64 FPR64:$Rd), (OpNode (i32 FPR32:$Rn), (i32 FPR32:$Rm)))]>; + [(set (f64 FPR64:$Rd), (OpNode FPR32:$Rn, FPR32:$Rm))]>; } let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in @@ -9815,7 +9805,8 @@ multiclass SIMDIndexedLongSD<bit U, bits<4> opc, string asm, multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm, SDPatternOperator VecAcc, - SDPatternOperator ScalAcc> { + SDPatternOperator ScalAcc, + SDPatternOperator G_ScalAcc> { def v4i16_indexed : BaseSIMDIndexedTied<0, U, 0, 0b01, opc, V128, V64, V128_lo, VectorIndexH, @@ -9884,7 +9875,7 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm, let Inst{20} = idx{0}; } - def : Pat<(i32 (ScalAcc (i32 FPR32Op:$Rd), + def : Pat<(i32 (G_ScalAcc (i32 FPR32Op:$Rd), (i32 (vector_extract (v4i32 (int_aarch64_neon_sqdmull (v4i16 V64:$Rn), @@ -9896,7 +9887,19 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm, (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rm, dsub), (i64 0))>; - def : Pat<(i32 (ScalAcc (i32 FPR32Op:$Rd), + def : Pat<(f32 (ScalAcc FPR32Op:$Rd, + (bitconvert (i32 (vector_extract + (v4i32 (int_aarch64_neon_sqdmull + (v4i16 V64:$Rn), + (v4i16 V64:$Rm))), + (i64 0)))))), + (!cast<Instruction>(NAME # v1i32_indexed) + FPR32Op:$Rd, + (f16 (EXTRACT_SUBREG V64:$Rn, hsub)), + (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rm, dsub), + (i64 0))>; + + def : Pat<(i32 (G_ScalAcc (i32 FPR32Op:$Rd), (i32 (vector_extract (v4i32 (int_aarch64_neon_sqdmull (v4i16 V64:$Rn), @@ -9909,15 +9912,27 @@ multiclass SIMDIndexedLongSQDMLXSDTied<bit U, bits<4> opc, string asm, V128_lo:$Rm, VectorIndexH:$idx)>; + def : Pat<(f32 (ScalAcc FPR32Op:$Rd, + (bitconvert (i32 (vector_extract + (v4i32 (int_aarch64_neon_sqdmull + (v4i16 V64:$Rn), + (dup_v8i16 (v8i16 V128_lo:$Rm), + VectorIndexH:$idx))), + (i64 0)))))), + (!cast<Instruction>(NAME # v1i32_indexed) + FPR32Op:$Rd, + (f16 (EXTRACT_SUBREG V64:$Rn, hsub)), + V128_lo:$Rm, + VectorIndexH:$idx)>; + def v1i64_indexed : BaseSIMDIndexedTied<1, U, 1, 0b10, opc, FPR64Op, FPR32Op, V128, VectorIndexS, asm, ".s", "", "", ".s", - [(set (i64 FPR64Op:$dst), - (ScalAcc (i64 FPR64Op:$Rd), - (i64 (int_aarch64_neon_sqdmulls_scalar - (i32 FPR32Op:$Rn), - (i32 (vector_extract (v4i32 V128:$Rm), - VectorIndexS:$idx))))))]> { + [(set (f64 FPR64Op:$dst), + (ScalAcc FPR64Op:$Rd, + (AArch64sqdmull FPR32Op:$Rn, + (bitconvert (i32 (vector_extract (v4i32 V128:$Rm), + VectorIndexS:$idx))))))]> { bits<2> idx; let Inst{11} = idx{1}; @@ -12588,12 +12603,10 @@ class MOPSMemoryCopy<bits<2> opcode, bits<2> op1, bits<2> op2, string asm> class MOPSMemoryMove<bits<2> opcode, bits<2> op1, bits<2> op2, string asm> : MOPSMemoryCopyMoveBase<1, opcode, op1, op2, asm>; -class MOPSMemorySetBase<bit isTagging, bits<2> opcode, bit op1, bit op2, - string asm> - : I<(outs GPR64common:$Rd_wb, GPR64:$Rn_wb), - (ins GPR64common:$Rd, GPR64:$Rn, GPR64:$Rm), - asm, "\t[$Rd]!, $Rn!, $Rm", - "$Rd = $Rd_wb,$Rn = $Rn_wb", []>, +class MOPSMemorySetBase<dag ins, string operands, bit isTagging, bits<2> opcode, + bit op1, bit op2, bit op3, string asm> + : I<(outs GPR64common:$Rd_wb, GPR64:$Rn_wb), ins, + asm, operands, "$Rd = $Rd_wb,$Rn = $Rn_wb", []>, Sched<[]> { bits<5> Rd; bits<5> Rn; @@ -12605,20 +12618,34 @@ class MOPSMemorySetBase<bit isTagging, bits<2> opcode, bit op1, bit op2, let Inst{15-14} = opcode; let Inst{13} = op2; let Inst{12} = op1; - let Inst{11-10} = 0b01; + let Inst{11} = 0b0; + let Inst{10} = op3; let Inst{9-5} = Rn; let Inst{4-0} = Rd; - let DecoderMethod = "DecodeSETMemOpInstruction"; let mayLoad = 0; let mayStore = 1; } -class MOPSMemorySet<bits<2> opcode, bit op1, bit op2, string asm> - : MOPSMemorySetBase<0, opcode, op1, op2, asm>; +class MOPSMemorySet<bits<2> opcode, bit op1, bit op2, bit op3, string asm> + : MOPSMemorySetBase<(ins GPR64common:$Rd, GPR64:$Rn, GPR64:$Rm), + "\t[$Rd]!, $Rn!, $Rm", 0, opcode, op1, op2, op3, asm> { + let DecoderMethod = "DecodeSETMemOpInstruction"; +} + +class MOPSMemorySetTagging<bits<2> opcode, bit op1, bit op2, bit op3, string asm> + : MOPSMemorySetBase<(ins GPR64common:$Rd, GPR64:$Rn, GPR64:$Rm), + "\t[$Rd]!, $Rn!, $Rm", 1, opcode, op1, op2, op3, asm> { + let DecoderMethod = "DecodeSETMemOpInstruction"; +} -class MOPSMemorySetTagging<bits<2> opcode, bit op1, bit op2, string asm> - : MOPSMemorySetBase<1, opcode, op1, op2, asm>; +class MOPSGoMemorySetTagging<bits<2> opcode, bit op1, bit op2, bit op3, string asm> + : MOPSMemorySetBase<(ins GPR64common:$Rd, GPR64:$Rn), + "\t[$Rd]!, $Rn!", 1, opcode, op1, op2, op3, asm> { + // No `Rm` operand, as all bits must be set to 1 + let Inst{20-16} = 0b11111; + let DecoderMethod = "DecodeSETMemGoOpInstruction"; +} multiclass MOPSMemoryCopyInsns<bits<2> opcode, string asm> { def "" : MOPSMemoryCopy<opcode, 0b00, 0b00, asm>; @@ -12659,17 +12686,27 @@ multiclass MOPSMemoryMoveInsns<bits<2> opcode, string asm> { } multiclass MOPSMemorySetInsns<bits<2> opcode, string asm> { - def "" : MOPSMemorySet<opcode, 0, 0, asm>; - def T : MOPSMemorySet<opcode, 1, 0, asm # "t">; - def N : MOPSMemorySet<opcode, 0, 1, asm # "n">; - def TN : MOPSMemorySet<opcode, 1, 1, asm # "tn">; + def "" : MOPSMemorySet<opcode, 0, 0, 1, asm>; + def T : MOPSMemorySet<opcode, 1, 0, 1, asm # "t">; + def N : MOPSMemorySet<opcode, 0, 1, 1, asm # "n">; + def TN : MOPSMemorySet<opcode, 1, 1, 1, asm # "tn">; } multiclass MOPSMemorySetTaggingInsns<bits<2> opcode, string asm> { - def "" : MOPSMemorySetTagging<opcode, 0, 0, asm>; - def T : MOPSMemorySetTagging<opcode, 1, 0, asm # "t">; - def N : MOPSMemorySetTagging<opcode, 0, 1, asm # "n">; - def TN : MOPSMemorySetTagging<opcode, 1, 1, asm # "tn">; + def "" : MOPSMemorySetTagging<opcode, 0, 0, 1, asm>; + def T : MOPSMemorySetTagging<opcode, 1, 0, 1, asm # "t">; + def N : MOPSMemorySetTagging<opcode, 0, 1, 1, asm # "n">; + def TN : MOPSMemorySetTagging<opcode, 1, 1, 1, asm # "tn">; +} + +//---------------------------------------------------------------------------- +// MOPS Granule Only - FEAT_MOPS_GO +//---------------------------------------------------------------------------- +multiclass MOPSGoMemorySetTaggingInsns<bits<2> opcode, string asm> { + def "" : MOPSGoMemorySetTagging<opcode, 0, 0, 0, asm>; + def T : MOPSGoMemorySetTagging<opcode, 1, 0, 0, asm # "t">; + def N : MOPSGoMemorySetTagging<opcode, 0, 1, 0, asm # "n">; + def TN : MOPSGoMemorySetTagging<opcode, 1, 1, 0, asm # "tn">; } //---------------------------------------------------------------------------- @@ -13198,8 +13235,22 @@ multiclass CmpBranchRegisterAlias<string mnemonic, string insn> { } class CmpBranchRegisterPseudo<RegisterClass regtype> - : Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, regtype:$Rm, am_brcmpcond:$Target), []>, - Sched<[WriteBr]> { + : Pseudo<(outs), + (ins ccode:$Cond, regtype:$Rt, regtype:$Rm, am_brcmpcond:$Target), + []>, + Sched<[WriteBr]> { + let isBranch = 1; + let isTerminator = 1; +} + +// Cmpbr pseudo instruction, encoding potentially folded zero-, sign-extension, +// assertzext and/or assersext. +class CmpBranchExtRegisterPseudo + : Pseudo<(outs), + (ins ccode:$Cond, GPR32:$Rt, GPR32:$Rm, am_brcmpcond:$Target, + simm8_32b:$ExtRt, simm8_32b:$ExtRm), + []>, + Sched<[WriteBr]> { let isBranch = 1; let isTerminator = 1; } @@ -13292,18 +13343,24 @@ multiclass AtomicFPStore<bit R, bits<3> op0, string asm> { def H : BaseAtomicFPStore<FPR16, 0b01, R, op0, asm>; } -class BaseSIMDThreeSameVectorFP8MatrixMul<string asm, bits<2> size, string kind> +class BaseSIMDThreeSameVectorFP8MatrixMul<string asm, bits<2> size, string kind, list<dag> pattern> : BaseSIMDThreeSameVectorTied<1, 1, {size, 0}, 0b11101, - V128, asm, ".16b", []> { + V128, asm, ".16b", pattern> { let AsmString = !strconcat(asm, "{\t$Rd", kind, ", $Rn.16b, $Rm.16b", "|", kind, "\t$Rd, $Rn, $Rm}"); } -multiclass SIMDThreeSameVectorFP8MatrixMul<string asm>{ - def v8f16: BaseSIMDThreeSameVectorFP8MatrixMul<asm, 0b00, ".8h">{ +multiclass SIMDThreeSameVectorFP8MatrixMul<string asm, SDPatternOperator OpNode>{ + def v8f16: BaseSIMDThreeSameVectorFP8MatrixMul<asm, 0b00, ".8h", + [(set (v8f16 V128:$dst), (OpNode (v8f16 V128:$Rd), + (v16i8 V128:$Rn), + (v16i8 V128:$Rm)))]> { let Predicates = [HasNEON, HasF8F16MM]; } - def v4f32: BaseSIMDThreeSameVectorFP8MatrixMul<asm, 0b10, ".4s">{ + def v4f32: BaseSIMDThreeSameVectorFP8MatrixMul<asm, 0b10, ".4s", + [(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd), + (v16i8 V128:$Rn), + (v16i8 V128:$Rm)))]> { let Predicates = [HasNEON, HasF8F32MM]; } } @@ -13338,3 +13395,84 @@ class STCPHInst<string asm> : I< let Inst{7-5} = 0b100; let Inst{4-0} = 0b11111; } + +//--- +// Permission Overlays Extension 2 (FEAT_S1POE2) +//--- + +class TCHANGERegInst<string asm, bit isB> : I< + (outs GPR64:$Xd), + (ins GPR64:$Xn, TIndexhint_op:$nb), + asm, "\t$Xd, $Xn, $nb", "", []>, Sched<[]> { + bits<5> Xd; + bits<5> Xn; + bits<1> nb; + let Inst{31-19} = 0b1101010110000; + let Inst{18} = isB; + let Inst{17} = nb; + let Inst{16-10} = 0b0000000; + let Inst{9-5} = Xn; + let Inst{4-0} = Xd; +} + +class TCHANGEImmInst<string asm, bit isB> : I< + (outs GPR64:$Xd), + (ins imm0_127:$imm, TIndexhint_op:$nb), + asm, "\t$Xd, $imm, $nb", "", []>, Sched<[]> { + bits<5> Xd; + bits<7> imm; + bits<1> nb; + let Inst{31-19} = 0b1101010110010; + let Inst{18} = isB; + let Inst{17} = nb; + let Inst{16-12} = 0b00000; + let Inst{11-5} = imm; + let Inst{4-0} = Xd; +} + +class TENTERInst<string asm> : I< + (outs), + (ins imm0_127:$imm, TIndexhint_op:$nb), + asm, "\t$imm, $nb", "", []>, Sched<[]> { + bits<7> imm; + bits<1> nb; + let Inst{31-18} = 0b11010100111000; + let Inst{17} = nb; + let Inst{16-12} = 0b00000; + let Inst{11-5} = imm; + let Inst{4-0} = 0b00000; +} + +class TEXITInst<string asm> : I< + (outs), + (ins TIndexhint_op:$nb), + asm, "\t$nb", "", []>, Sched<[]> { + bits<1> nb; + let Inst{31-11} = 0b110101101111111100000; + let Inst{10} = nb; + let Inst{9-0} = 0b1111100000; +} + + +multiclass TCHANGEReg<string asm , bit isB> { + def NAME : TCHANGERegInst<asm, isB>; + def : InstAlias<asm # "\t$Xd, $Xn", + (!cast<Instruction>(NAME) GPR64:$Xd, GPR64:$Xn, 0), 1>; +} + +multiclass TCHANGEImm<string asm, bit isB> { + def NAME : TCHANGEImmInst<asm, isB>; + def : InstAlias<asm # "\t$Xd, $imm", + (!cast<Instruction>(NAME) GPR64:$Xd, imm0_127:$imm, 0), 1>; +} + +multiclass TENTER<string asm> { + def NAME : TENTERInst<asm>; + def : InstAlias<asm # "\t$imm", + (!cast<Instruction>(NAME) imm0_127:$imm, 0), 1>; +} + +multiclass TEXIT<string asm> { + def NAME : TEXITInst<asm>; + def : InstAlias<asm, (!cast<Instruction>(NAME) 0), 1>; +} diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td index 30b7b03..7d99786 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td +++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td @@ -149,6 +149,13 @@ def G_VLSHR : AArch64GenericInstruction { let hasSideEffects = 0; } +// Float truncation using round to odd +def G_FPTRUNC_ODD : AArch64GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$src); + let hasSideEffects = false; +} + // Represents an integer to FP conversion on the FPR bank. def G_SITOF : AArch64GenericInstruction { let OutOperandList = (outs type0:$dst); @@ -197,6 +204,12 @@ def G_SMULL : AArch64GenericInstruction { let hasSideEffects = 0; } +def G_PMULL : AArch64GenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$src1, type1:$src2); + let hasSideEffects = 0; +} + def G_UADDLP : AArch64GenericInstruction { let OutOperandList = (outs type0:$dst); let InOperandList = (ins type0:$src1); @@ -273,6 +286,7 @@ def : GINodeEquiv<G_FCMGT, AArch64fcmgt>; def : GINodeEquiv<G_BSP, AArch64bsp>; +def : GINodeEquiv<G_PMULL, AArch64pmull>; def : GINodeEquiv<G_UMULL, AArch64umull>; def : GINodeEquiv<G_SMULL, AArch64smull>; @@ -290,6 +304,8 @@ def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>; def : GINodeEquiv<G_AARCH64_PREFETCH, AArch64Prefetch>; +def : GINodeEquiv<G_FPTRUNC_ODD, AArch64fcvtxn_n>; + // These are patterns that we only use for GlobalISel via the importer. def : Pat<(f32 (fadd (vector_extract (v2f32 FPR64:$Rn), (i64 0)), (vector_extract (v2f32 FPR64:$Rn), (i64 1)))), diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 457e540..f82180f 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -91,7 +91,7 @@ static cl::opt<unsigned> GatherOptSearchLimit( "machine-combiner gather pattern optimization")); AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI) - : AArch64GenInstrInfo(STI, AArch64::ADJCALLSTACKDOWN, + : AArch64GenInstrInfo(STI, RI, AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP, AArch64::CATCHRET), RI(STI.getTargetTriple(), STI.getHwMode()), Subtarget(STI) {} @@ -122,7 +122,7 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const { NumBytes = Desc.getSize() ? Desc.getSize() : 4; const auto *MFI = MF->getInfo<AArch64FunctionInfo>(); - if (!MFI->shouldSignReturnAddress(MF)) + if (!MFI->shouldSignReturnAddress(*MF)) return NumBytes; const auto &STI = MF->getSubtarget<AArch64Subtarget>(); @@ -241,6 +241,17 @@ static void parseCondBranch(MachineInstr *LastInst, MachineBasicBlock *&Target, Cond.push_back(LastInst->getOperand(1)); Cond.push_back(LastInst->getOperand(2)); break; + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: + Target = LastInst->getOperand(3).getMBB(); + Cond.push_back(MachineOperand::CreateImm(-1)); // -1 + Cond.push_back(MachineOperand::CreateImm(LastInst->getOpcode())); // Opc + Cond.push_back(LastInst->getOperand(0)); // Cond + Cond.push_back(LastInst->getOperand(1)); // Op0 + Cond.push_back(LastInst->getOperand(2)); // Op1 + Cond.push_back(LastInst->getOperand(4)); // Ext0 + Cond.push_back(LastInst->getOperand(5)); // Ext1 + break; } } @@ -264,6 +275,8 @@ static unsigned getBranchDisplacementBits(unsigned Opc) { return BCCDisplacementBits; case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: return CBDisplacementBits; @@ -298,6 +311,8 @@ AArch64InstrInfo::getBranchDestBlock(const MachineInstr &MI) const { return MI.getOperand(1).getMBB(); case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: return MI.getOperand(3).getMBB(); @@ -580,9 +595,11 @@ bool AArch64InstrInfo::reverseBranchCondition( Cond[1].setImm(AArch64::TBZX); break; - // Cond is { -1, Opcode, CC, Op0, Op1 } + // Cond is { -1, Opcode, CC, Op0, Op1, ... } case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: { // Pseudos using standard 4bit Arm condition codes @@ -654,6 +671,12 @@ void AArch64InstrInfo::instantiateCondBranch( MIB.add(Cond[4]); MIB.addMBB(TBB); + + // cb[b,h] + if (Cond.size() > 5) { + MIB.addImm(Cond[5].getImm()); + MIB.addImm(Cond[6].getImm()); + } } } @@ -685,6 +708,53 @@ unsigned AArch64InstrInfo::insertBranch( return 2; } +bool llvm::optimizeTerminators(MachineBasicBlock *MBB, + const TargetInstrInfo &TII) { + for (MachineInstr &MI : MBB->terminators()) { + unsigned Opc = MI.getOpcode(); + switch (Opc) { + case AArch64::CBZW: + case AArch64::CBZX: + case AArch64::TBZW: + case AArch64::TBZX: + // CBZ/TBZ with WZR/XZR -> unconditional B + if (MI.getOperand(0).getReg() == AArch64::WZR || + MI.getOperand(0).getReg() == AArch64::XZR) { + DEBUG_WITH_TYPE("optimizeTerminators", + dbgs() << "Removing always taken branch: " << MI); + MachineBasicBlock *Target = TII.getBranchDestBlock(MI); + SmallVector<MachineBasicBlock *> Succs(MBB->successors()); + for (auto *S : Succs) + if (S != Target) + MBB->removeSuccessor(S); + DebugLoc DL = MI.getDebugLoc(); + while (MBB->rbegin() != &MI) + MBB->rbegin()->eraseFromParent(); + MI.eraseFromParent(); + BuildMI(MBB, DL, TII.get(AArch64::B)).addMBB(Target); + return true; + } + break; + case AArch64::CBNZW: + case AArch64::CBNZX: + case AArch64::TBNZW: + case AArch64::TBNZX: + // CBNZ/TBNZ with WZR/XZR -> never taken, remove branch and successor + if (MI.getOperand(0).getReg() == AArch64::WZR || + MI.getOperand(0).getReg() == AArch64::XZR) { + DEBUG_WITH_TYPE("optimizeTerminators", + dbgs() << "Removing never taken branch: " << MI); + MachineBasicBlock *Target = TII.getBranchDestBlock(MI); + MI.getParent()->removeSuccessor(Target); + MI.eraseFromParent(); + return true; + } + break; + } + } + return false; +} + // Find the original register that VReg is copied from. static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) { while (Register::isVirtualRegister(VReg)) { @@ -931,44 +1001,122 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB, // We must insert a cmp, that is a subs // 0 1 2 3 4 // Cond is { -1, Opcode, CC, Op0, Op1 } - unsigned SUBSOpC, SUBSDestReg; + + unsigned SubsOpc, SubsDestReg; bool IsImm = false; CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm()); switch (Cond[1].getImm()) { default: llvm_unreachable("Unknown branch opcode in Cond"); case AArch64::CBWPri: - SUBSOpC = AArch64::SUBSWri; - SUBSDestReg = AArch64::WZR; + SubsOpc = AArch64::SUBSWri; + SubsDestReg = AArch64::WZR; IsImm = true; break; case AArch64::CBXPri: - SUBSOpC = AArch64::SUBSXri; - SUBSDestReg = AArch64::XZR; + SubsOpc = AArch64::SUBSXri; + SubsDestReg = AArch64::XZR; IsImm = true; break; case AArch64::CBWPrr: - SUBSOpC = AArch64::SUBSWrr; - SUBSDestReg = AArch64::WZR; + SubsOpc = AArch64::SUBSWrr; + SubsDestReg = AArch64::WZR; IsImm = false; break; case AArch64::CBXPrr: - SUBSOpC = AArch64::SUBSXrr; - SUBSDestReg = AArch64::XZR; + SubsOpc = AArch64::SUBSXrr; + SubsDestReg = AArch64::XZR; IsImm = false; break; } if (IsImm) - BuildMI(MBB, I, DL, get(SUBSOpC), SUBSDestReg) + BuildMI(MBB, I, DL, get(SubsOpc), SubsDestReg) .addReg(Cond[3].getReg()) .addImm(Cond[4].getImm()) .addImm(0); else - BuildMI(MBB, I, DL, get(SUBSOpC), SUBSDestReg) + BuildMI(MBB, I, DL, get(SubsOpc), SubsDestReg) .addReg(Cond[3].getReg()) .addReg(Cond[4].getReg()); - } + } break; + case 7: { // cb[b,h] + // We must insert a cmp, that is a subs, but also zero- or sign-extensions + // that have been folded. For the first operand we codegen an explicit + // extension, for the second operand we fold the extension into cmp. + // 0 1 2 3 4 5 6 + // Cond is { -1, Opcode, CC, Op0, Op1, Ext0, Ext1 } + + // We need a new register for the now explicitly extended register + Register Reg = Cond[4].getReg(); + if (Cond[5].getImm() != AArch64_AM::InvalidShiftExtend) { + unsigned ExtOpc; + unsigned ExtBits; + AArch64_AM::ShiftExtendType ExtendType = + AArch64_AM::getExtendType(Cond[5].getImm()); + switch (ExtendType) { + default: + llvm_unreachable("Unknown shift-extend for CB instruction"); + case AArch64_AM::SXTB: + assert( + Cond[1].getImm() == AArch64::CBBAssertExt && + "Unexpected compare-and-branch instruction for SXTB shift-extend"); + ExtOpc = AArch64::SBFMWri; + ExtBits = AArch64_AM::encodeLogicalImmediate(0xff, 32); + break; + case AArch64_AM::SXTH: + assert( + Cond[1].getImm() == AArch64::CBHAssertExt && + "Unexpected compare-and-branch instruction for SXTH shift-extend"); + ExtOpc = AArch64::SBFMWri; + ExtBits = AArch64_AM::encodeLogicalImmediate(0xffff, 32); + break; + case AArch64_AM::UXTB: + assert( + Cond[1].getImm() == AArch64::CBBAssertExt && + "Unexpected compare-and-branch instruction for UXTB shift-extend"); + ExtOpc = AArch64::ANDWri; + ExtBits = AArch64_AM::encodeLogicalImmediate(0xff, 32); + break; + case AArch64_AM::UXTH: + assert( + Cond[1].getImm() == AArch64::CBHAssertExt && + "Unexpected compare-and-branch instruction for UXTH shift-extend"); + ExtOpc = AArch64::ANDWri; + ExtBits = AArch64_AM::encodeLogicalImmediate(0xffff, 32); + break; + } + + // Build the explicit extension of the first operand + Reg = MRI.createVirtualRegister(&AArch64::GPR32spRegClass); + MachineInstrBuilder MBBI = + BuildMI(MBB, I, DL, get(ExtOpc), Reg).addReg(Cond[4].getReg()); + if (ExtOpc != AArch64::ANDWri) + MBBI.addImm(0); + MBBI.addImm(ExtBits); + } + + // Now, subs with an extended second operand + if (Cond[6].getImm() != AArch64_AM::InvalidShiftExtend) { + AArch64_AM::ShiftExtendType ExtendType = + AArch64_AM::getExtendType(Cond[6].getImm()); + MRI.constrainRegClass(Reg, MRI.getRegClass(Cond[3].getReg())); + MRI.constrainRegClass(Cond[3].getReg(), &AArch64::GPR32spRegClass); + BuildMI(MBB, I, DL, get(AArch64::SUBSWrx), AArch64::WZR) + .addReg(Cond[3].getReg()) + .addReg(Reg) + .addImm(AArch64_AM::getArithExtendImm(ExtendType, 0)); + } // If no extension is needed, just a regular subs + else { + MRI.constrainRegClass(Reg, MRI.getRegClass(Cond[3].getReg())); + MRI.constrainRegClass(Cond[3].getReg(), &AArch64::GPR32spRegClass); + BuildMI(MBB, I, DL, get(AArch64::SUBSWrr), AArch64::WZR) + .addReg(Cond[3].getReg()) + .addReg(Reg); + } + + CC = static_cast<AArch64CC::CondCode>(Cond[2].getImm()); + } break; } unsigned Opc = 0; @@ -1043,6 +1191,28 @@ static bool isCheapImmediate(const MachineInstr &MI, unsigned BitSize) { return Is.size() <= 2; } +// Check if a COPY instruction is cheap. +static bool isCheapCopy(const MachineInstr &MI, const AArch64RegisterInfo &RI) { + assert(MI.isCopy() && "Expected COPY instruction"); + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + + // Cross-bank copies (e.g., between GPR and FPR) are expensive on AArch64, + // typically requiring an FMOV instruction with a 2-6 cycle latency. + auto GetRegClass = [&](Register Reg) -> const TargetRegisterClass * { + if (Reg.isVirtual()) + return MRI.getRegClass(Reg); + if (Reg.isPhysical()) + return RI.getMinimalPhysRegClass(Reg); + return nullptr; + }; + const TargetRegisterClass *DstRC = GetRegClass(MI.getOperand(0).getReg()); + const TargetRegisterClass *SrcRC = GetRegClass(MI.getOperand(1).getReg()); + if (DstRC && SrcRC && !RI.getCommonSubClass(DstRC, SrcRC)) + return false; + + return MI.isAsCheapAsAMove(); +} + // FIXME: this implementation should be micro-architecture dependent, so a // micro-architecture target hook should be introduced here in future. bool AArch64InstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const { @@ -1056,6 +1226,9 @@ bool AArch64InstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const { default: return MI.isAsCheapAsAMove(); + case TargetOpcode::COPY: + return isCheapCopy(MI, RI); + case AArch64::ADDWrs: case AArch64::ADDXrs: case AArch64::SUBWrs: @@ -1217,6 +1390,8 @@ bool AArch64InstrInfo::isSEHInstruction(const MachineInstr &MI) { case AArch64::SEH_EpilogStart: case AArch64::SEH_EpilogEnd: case AArch64::SEH_PACSignLR: + case AArch64::SEH_SaveAnyRegI: + case AArch64::SEH_SaveAnyRegIP: case AArch64::SEH_SaveAnyRegQP: case AArch64::SEH_SaveAnyRegQPX: case AArch64::SEH_AllocZ: @@ -1774,10 +1949,24 @@ static unsigned sForm(MachineInstr &Instr) { case AArch64::ADDSWri: case AArch64::ADDSXrr: case AArch64::ADDSXri: + case AArch64::ADDSWrx: + case AArch64::ADDSXrx: case AArch64::SUBSWrr: case AArch64::SUBSWri: + case AArch64::SUBSWrx: case AArch64::SUBSXrr: case AArch64::SUBSXri: + case AArch64::SUBSXrx: + case AArch64::ANDSWri: + case AArch64::ANDSWrr: + case AArch64::ANDSWrs: + case AArch64::ANDSXri: + case AArch64::ANDSXrr: + case AArch64::ANDSXrs: + case AArch64::BICSWrr: + case AArch64::BICSXrr: + case AArch64::BICSWrs: + case AArch64::BICSXrs: return Instr.getOpcode(); case AArch64::ADDWrr: @@ -1788,6 +1977,10 @@ static unsigned sForm(MachineInstr &Instr) { return AArch64::ADDSXrr; case AArch64::ADDXri: return AArch64::ADDSXri; + case AArch64::ADDWrx: + return AArch64::ADDSWrx; + case AArch64::ADDXrx: + return AArch64::ADDSXrx; case AArch64::ADCWr: return AArch64::ADCSWr; case AArch64::ADCXr: @@ -1800,6 +1993,10 @@ static unsigned sForm(MachineInstr &Instr) { return AArch64::SUBSXrr; case AArch64::SUBXri: return AArch64::SUBSXri; + case AArch64::SUBWrx: + return AArch64::SUBSWrx; + case AArch64::SUBXrx: + return AArch64::SUBSXrx; case AArch64::SBCWr: return AArch64::SBCSWr; case AArch64::SBCXr: @@ -1808,6 +2005,22 @@ static unsigned sForm(MachineInstr &Instr) { return AArch64::ANDSWri; case AArch64::ANDXri: return AArch64::ANDSXri; + case AArch64::ANDWrr: + return AArch64::ANDSWrr; + case AArch64::ANDWrs: + return AArch64::ANDSWrs; + case AArch64::ANDXrr: + return AArch64::ANDSXrr; + case AArch64::ANDXrs: + return AArch64::ANDSXrs; + case AArch64::BICWrr: + return AArch64::BICSWrr; + case AArch64::BICXrr: + return AArch64::BICSXrr; + case AArch64::BICWrs: + return AArch64::BICSWrs; + case AArch64::BICXrs: + return AArch64::BICSXrs; } } @@ -1945,6 +2158,25 @@ static bool isSUBSRegImm(unsigned Opcode) { return Opcode == AArch64::SUBSWri || Opcode == AArch64::SUBSXri; } +static bool isANDOpcode(MachineInstr &MI) { + unsigned Opc = sForm(MI); + switch (Opc) { + case AArch64::ANDSWri: + case AArch64::ANDSWrr: + case AArch64::ANDSWrs: + case AArch64::ANDSXri: + case AArch64::ANDSXrr: + case AArch64::ANDSXrs: + case AArch64::BICSWrr: + case AArch64::BICSXrr: + case AArch64::BICSWrs: + case AArch64::BICSXrs: + return true; + default: + return false; + } +} + /// Check if CmpInstr can be substituted by MI. /// /// CmpInstr can be substituted: @@ -1982,7 +2214,8 @@ static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr, // 1) MI and CmpInstr set N and V to the same value. // 2) If MI is add/sub with no-signed-wrap, it produces a poison value when // signed overflow occurs, so CmpInstr could still be simplified away. - if (NZVCUsed->V && !MI.getFlag(MachineInstr::NoSWrap)) + // Note that Ands and Bics instructions always clear the V flag. + if (NZVCUsed->V && !MI.getFlag(MachineInstr::NoSWrap) && !isANDOpcode(MI)) return false; AccessKind AccessToCheck = AK_Write; @@ -2392,11 +2625,10 @@ bool AArch64InstrInfo::isFPRCopy(const MachineInstr &MI) { return false; } -Register AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI, - int &FrameIndex) const { - switch (MI.getOpcode()) { +static bool isFrameLoadOpcode(int Opcode) { + switch (Opcode) { default: - break; + return false; case AArch64::LDRWui: case AArch64::LDRXui: case AArch64::LDRBui: @@ -2405,22 +2637,27 @@ Register AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI, case AArch64::LDRDui: case AArch64::LDRQui: case AArch64::LDR_PXI: - if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && - MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { - FrameIndex = MI.getOperand(1).getIndex(); - return MI.getOperand(0).getReg(); - } - break; + return true; } +} + +Register AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI, + int &FrameIndex) const { + if (!isFrameLoadOpcode(MI.getOpcode())) + return Register(); - return 0; + if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && + MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { + FrameIndex = MI.getOperand(1).getIndex(); + return MI.getOperand(0).getReg(); + } + return Register(); } -Register AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI, - int &FrameIndex) const { - switch (MI.getOpcode()) { +static bool isFrameStoreOpcode(int Opcode) { + switch (Opcode) { default: - break; + return false; case AArch64::STRWui: case AArch64::STRXui: case AArch64::STRBui: @@ -2429,14 +2666,63 @@ Register AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI, case AArch64::STRDui: case AArch64::STRQui: case AArch64::STR_PXI: - if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && - MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { - FrameIndex = MI.getOperand(1).getIndex(); - return MI.getOperand(0).getReg(); - } - break; + return true; + } +} + +Register AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI, + int &FrameIndex) const { + if (!isFrameStoreOpcode(MI.getOpcode())) + return Register(); + + if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && + MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { + FrameIndex = MI.getOperand(1).getIndex(); + return MI.getOperand(0).getReg(); } - return 0; + return Register(); +} + +Register AArch64InstrInfo::isStoreToStackSlotPostFE(const MachineInstr &MI, + int &FrameIndex) const { + if (!isFrameStoreOpcode(MI.getOpcode())) + return Register(); + + if (Register Reg = isStoreToStackSlot(MI, FrameIndex)) + return Reg; + + SmallVector<const MachineMemOperand *, 1> Accesses; + if (hasStoreToStackSlot(MI, Accesses)) { + if (Accesses.size() > 1) + return Register(); + + FrameIndex = + cast<FixedStackPseudoSourceValue>(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return MI.getOperand(0).getReg(); + } + return Register(); +} + +Register AArch64InstrInfo::isLoadFromStackSlotPostFE(const MachineInstr &MI, + int &FrameIndex) const { + if (!isFrameLoadOpcode(MI.getOpcode())) + return Register(); + + if (Register Reg = isLoadFromStackSlot(MI, FrameIndex)) + return Reg; + + SmallVector<const MachineMemOperand *, 1> Accesses; + if (hasLoadFromStackSlot(MI, Accesses)) { + if (Accesses.size() > 1) + return Register(); + + FrameIndex = + cast<FixedStackPseudoSourceValue>(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return MI.getOperand(0).getReg(); + } + return Register(); } /// Check all MachineMemOperands for a hint to suppress pairing. @@ -5616,7 +5902,6 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, Register SrcReg, bool isKill, int FI, const TargetRegisterClass *RC, - const TargetRegisterInfo *TRI, Register VReg, MachineInstr::MIFlag Flags) const { MachineFunction &MF = *MBB.getParent(); @@ -5630,7 +5915,7 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB, bool Offset = true; MCRegister PNRReg = MCRegister::NoRegister; unsigned StackID = TargetStackID::Default; - switch (TRI->getSpillSize(*RC)) { + switch (RI.getSpillSize(*RC)) { case 1: if (AArch64::FPR8RegClass.hasSubClassEq(RC)) Opc = AArch64::STRBui; @@ -5793,10 +6078,12 @@ static void loadRegPairFromStackSlot(const TargetRegisterInfo &TRI, .addMemOperand(MMO); } -void AArch64InstrInfo::loadRegFromStackSlot( - MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, Register DestReg, - int FI, const TargetRegisterClass *RC, const TargetRegisterInfo *TRI, - Register VReg, MachineInstr::MIFlag Flags) const { +void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + Register DestReg, int FI, + const TargetRegisterClass *RC, + Register VReg, + MachineInstr::MIFlag Flags) const { MachineFunction &MF = *MBB.getParent(); MachineFrameInfo &MFI = MF.getFrameInfo(); MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(MF, FI); @@ -5808,7 +6095,7 @@ void AArch64InstrInfo::loadRegFromStackSlot( bool Offset = true; unsigned StackID = TargetStackID::Default; Register PNRReg = MCRegister::NoRegister; - switch (TRI->getSpillSize(*RC)) { + switch (TRI.getSpillSize(*RC)) { case 1: if (AArch64::FPR8RegClass.hasSubClassEq(RC)) Opc = AArch64::LDRBui; @@ -6444,10 +6731,10 @@ MachineInstr *AArch64InstrInfo::foldMemoryOperandImpl( "Mismatched register size in non subreg COPY"); if (IsSpill) storeRegToStackSlot(MBB, InsertPt, SrcReg, SrcMO.isKill(), FrameIndex, - getRegClass(SrcReg), &TRI, Register()); + getRegClass(SrcReg), Register()); else loadRegFromStackSlot(MBB, InsertPt, DstReg, FrameIndex, - getRegClass(DstReg), &TRI, Register()); + getRegClass(DstReg), Register()); return &*--InsertPt; } @@ -6465,8 +6752,7 @@ MachineInstr *AArch64InstrInfo::foldMemoryOperandImpl( assert(SrcMO.getSubReg() == 0 && "Unexpected subreg on physical register"); storeRegToStackSlot(MBB, InsertPt, AArch64::XZR, SrcMO.isKill(), - FrameIndex, &AArch64::GPR64RegClass, &TRI, - Register()); + FrameIndex, &AArch64::GPR64RegClass, Register()); return &*--InsertPt; } @@ -6500,7 +6786,7 @@ MachineInstr *AArch64InstrInfo::foldMemoryOperandImpl( assert(TRI.getRegSizeInBits(*getRegClass(SrcReg)) == TRI.getRegSizeInBits(*FillRC) && "Mismatched regclass size on folded subreg COPY"); - loadRegFromStackSlot(MBB, InsertPt, DstReg, FrameIndex, FillRC, &TRI, + loadRegFromStackSlot(MBB, InsertPt, DstReg, FrameIndex, FillRC, Register()); MachineInstr &LoadMI = *--InsertPt; MachineOperand &LoadDst = LoadMI.getOperand(0); @@ -6662,12 +6948,10 @@ bool llvm::rewriteAArch64FrameIndex(MachineInstr &MI, unsigned FrameRegIdx, void AArch64InstrInfo::insertNoop(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI) const { DebugLoc DL; - BuildMI(MBB, MI, DL, get(AArch64::HINT)).addImm(0); + BuildMI(MBB, MI, DL, get(AArch64::NOP)); } -MCInst AArch64InstrInfo::getNop() const { - return MCInstBuilder(AArch64::HINT).addImm(0); -} +MCInst AArch64InstrInfo::getNop() const { return MCInstBuilder(AArch64::NOP); } // AArch64 supports MachineCombiner. bool AArch64InstrInfo::useMachineCombiner() const { return true; } @@ -9259,6 +9543,8 @@ bool AArch64InstrInfo::optimizeCondBranch(MachineInstr &MI) const { case AArch64::Bcc: case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: return false; @@ -9555,8 +9841,8 @@ outliningCandidatesSigningScopeConsensus(const outliner::Candidate &a, const auto &MFIa = a.getMF()->getInfo<AArch64FunctionInfo>(); const auto &MFIb = b.getMF()->getInfo<AArch64FunctionInfo>(); - return MFIa->shouldSignReturnAddress(false) == MFIb->shouldSignReturnAddress(false) && - MFIa->shouldSignReturnAddress(true) == MFIb->shouldSignReturnAddress(true); + return MFIa->getSignReturnAddressCondition() == + MFIb->getSignReturnAddressCondition(); } static bool @@ -9588,6 +9874,27 @@ AArch64InstrInfo::getOutliningCandidateInfo( unsigned NumBytesToCreateFrame = 0; + // Avoid splitting ADRP ADD/LDR pair into outlined functions. + // These instructions are fused together by the scheduler. + // Any candidate where ADRP is the last instruction should be rejected + // as that will lead to splitting ADRP pair. + MachineInstr &LastMI = RepeatedSequenceLocs[0].back(); + MachineInstr &FirstMI = RepeatedSequenceLocs[0].front(); + if (LastMI.getOpcode() == AArch64::ADRP && + (LastMI.getOperand(1).getTargetFlags() & AArch64II::MO_PAGE) != 0 && + (LastMI.getOperand(1).getTargetFlags() & AArch64II::MO_GOT) != 0) { + return std::nullopt; + } + + // Similarly any candidate where the first instruction is ADD/LDR with a + // page offset should be rejected to avoid ADRP splitting. + if ((FirstMI.getOpcode() == AArch64::ADDXri || + FirstMI.getOpcode() == AArch64::LDRXui) && + (FirstMI.getOperand(2).getTargetFlags() & AArch64II::MO_PAGEOFF) != 0 && + (FirstMI.getOperand(2).getTargetFlags() & AArch64II::MO_GOT) != 0) { + return std::nullopt; + } + // We only allow outlining for functions having exactly matching return // address signing attributes, i.e., all share the same value for the // attribute "sign-return-address" and all share the same type of key they @@ -9626,10 +9933,11 @@ AArch64InstrInfo::getOutliningCandidateInfo( // Performing a tail call may require extra checks when PAuth is enabled. // If PAuth is disabled, set it to zero for uniformity. unsigned NumBytesToCheckLRInTCEpilogue = 0; - if (RepeatedSequenceLocs[0] - .getMF() - ->getInfo<AArch64FunctionInfo>() - ->shouldSignReturnAddress(true)) { + const auto RASignCondition = RepeatedSequenceLocs[0] + .getMF() + ->getInfo<AArch64FunctionInfo>() + ->getSignReturnAddressCondition(); + if (RASignCondition != SignReturnAddress::None) { // One PAC and one AUT instructions NumBytesToCreateFrame += 8; @@ -10433,7 +10741,9 @@ void AArch64InstrInfo::buildOutlinedFrame( Et = MBB.insert(Et, LDRXpost); } - bool ShouldSignReturnAddr = FI->shouldSignReturnAddress(!IsLeafFunction); + auto RASignCondition = FI->getSignReturnAddressCondition(); + bool ShouldSignReturnAddr = AArch64FunctionInfo::shouldSignReturnAddress( + RASignCondition, !IsLeafFunction); // If this is a tail call outlined function, then there's already a return. if (OF.FrameConstructionID == MachineOutlinerTailCall || @@ -10994,8 +11304,6 @@ static Register cloneInstr(const MachineInstr *MI, unsigned ReplaceOprNum, MachineBasicBlock::iterator InsertTo) { MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); const TargetInstrInfo *TII = MBB.getParent()->getSubtarget().getInstrInfo(); - const TargetRegisterInfo *TRI = - MBB.getParent()->getSubtarget().getRegisterInfo(); MachineInstr *NewMI = MBB.getParent()->CloneMachineInstr(MI); Register Result = 0; for (unsigned I = 0; I < NewMI->getNumOperands(); ++I) { @@ -11004,8 +11312,7 @@ static Register cloneInstr(const MachineInstr *MI, unsigned ReplaceOprNum, MRI.getRegClass(NewMI->getOperand(0).getReg())); NewMI->getOperand(I).setReg(Result); } else if (I == ReplaceOprNum) { - MRI.constrainRegClass(ReplaceReg, - TII->getRegClass(NewMI->getDesc(), I, TRI)); + MRI.constrainRegClass(ReplaceReg, TII->getRegClass(NewMI->getDesc(), I)); NewMI->getOperand(I).setReg(ReplaceReg); } } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index 179574a..d237721 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -205,6 +205,15 @@ public: Register isStoreToStackSlot(const MachineInstr &MI, int &FrameIndex) const override; + /// Check for post-frame ptr elimination stack locations as well. This uses a + /// heuristic so it isn't reliable for correctness. + Register isStoreToStackSlotPostFE(const MachineInstr &MI, + int &FrameIndex) const override; + /// Check for post-frame ptr elimination stack locations as well. This uses a + /// heuristic so it isn't reliable for correctness. + Register isLoadFromStackSlotPostFE(const MachineInstr &MI, + int &FrameIndex) const override; + /// Does this instruction set its full destination register to zero? static bool isGPRZero(const MachineInstr &MI); @@ -353,14 +362,13 @@ public: void storeRegToStackSlot( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, Register SrcReg, - bool isKill, int FrameIndex, const TargetRegisterClass *RC, - const TargetRegisterInfo *TRI, Register VReg, + bool isKill, int FrameIndex, const TargetRegisterClass *RC, Register VReg, MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override; void loadRegFromStackSlot( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, Register DestReg, int FrameIndex, const TargetRegisterClass *RC, - const TargetRegisterInfo *TRI, Register VReg, + Register VReg, MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override; // This tells target independent code that it is okay to pass instructions @@ -697,6 +705,8 @@ int isAArch64FrameOffsetLegal(const MachineInstr &MI, StackOffset &Offset, unsigned *OutUnscaledOp = nullptr, int64_t *EmittableOffset = nullptr); +bool optimizeTerminators(MachineBasicBlock *MBB, const TargetInstrInfo &TII); + static inline bool isUncondBranchOpcode(int Opc) { return Opc == AArch64::B; } static inline bool isCondBranchOpcode(int Opc) { @@ -712,6 +722,8 @@ static inline bool isCondBranchOpcode(int Opc) { case AArch64::TBNZX: case AArch64::CBWPri: case AArch64::CBXPri: + case AArch64::CBBAssertExt: + case AArch64::CBHAssertExt: case AArch64::CBWPrr: case AArch64::CBXPrr: return true; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index b9e299e..7ee094a 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -346,10 +346,10 @@ def HasCCDP : Predicate<"Subtarget->hasCCDP()">, AssemblerPredicateWithAll<(all_of FeatureCacheDeepPersist), "ccdp">; def HasBTI : Predicate<"Subtarget->hasBTI()">, AssemblerPredicateWithAll<(all_of FeatureBranchTargetId), "bti">; +def HasBTIE : Predicate<"Subtarget->hasBTIE()">, + AssemblerPredicateWithAll<(all_of FeatureBTIE), "btie">; def HasMTE : Predicate<"Subtarget->hasMTE()">, AssemblerPredicateWithAll<(all_of FeatureMTE), "mte">; -def HasTME : Predicate<"Subtarget->hasTME()">, - AssemblerPredicateWithAll<(all_of FeatureTME), "tme">; def HasETE : Predicate<"Subtarget->hasETE()">, AssemblerPredicateWithAll<(all_of FeatureETE), "ete">; def HasTRBE : Predicate<"Subtarget->hasTRBE()">, @@ -405,6 +405,12 @@ def HasMTETC : Predicate<"Subtarget->hasMTETC()">, AssemblerPredicateWithAll<(all_of FeatureMTETC), "mtetc">; def HasGCIE : Predicate<"Subtarget->hasGCIE()">, AssemblerPredicateWithAll<(all_of FeatureGCIE), "gcie">; +def HasMOPS_GO : Predicate<"Subtarget->hasMOPS_GO()">, + AssemblerPredicateWithAll<(all_of FeatureMOPS_GO), "mops-go">; +def HasS1POE2 : Predicate<"Subtarget->hasS1POE2()">, + AssemblerPredicateWithAll<(all_of FeatureS1POE2), "poe2">; +def HasTEV : Predicate<"Subtarget->hasTEV()">, + AssemblerPredicateWithAll<(all_of FeatureTEV), "tev">; def IsLE : Predicate<"Subtarget->isLittleEndian()">; def IsBE : Predicate<"!Subtarget->isLittleEndian()">; def IsWindows : Predicate<"Subtarget->isTargetWindows()">; @@ -639,29 +645,34 @@ def nontrunc_masked_store : (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() && cast<MaskedStoreSDNode>(N)->isUnindexed() && - !cast<MaskedStoreSDNode>(N)->isNonTemporal(); + !cast<MaskedStoreSDNode>(N)->isNonTemporal() && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; // truncating masked store fragments. def trunc_masked_store : PatFrag<(ops node:$val, node:$ptr, node:$pred), (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return cast<MaskedStoreSDNode>(N)->isTruncatingStore() && - cast<MaskedStoreSDNode>(N)->isUnindexed(); + cast<MaskedStoreSDNode>(N)->isUnindexed() && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; def trunc_masked_store_i8 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8; + return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; def trunc_masked_store_i16 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16; + return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; def trunc_masked_store_i32 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32; + return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; def non_temporal_store : @@ -669,7 +680,8 @@ def non_temporal_store : (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() && cast<MaskedStoreSDNode>(N)->isUnindexed() && - cast<MaskedStoreSDNode>(N)->isNonTemporal(); + cast<MaskedStoreSDNode>(N)->isNonTemporal() && + !cast<MaskedStoreSDNode>(N)->isCompressingStore(); }]>; multiclass masked_gather_scatter<PatFrags GatherScatterOp> { @@ -1012,6 +1024,18 @@ def AArch64fcvtnu_half : SDNode<"AArch64ISD::FCVTNU_HALF", SDTFPExtendOp>; def AArch64fcvtps_half : SDNode<"AArch64ISD::FCVTPS_HALF", SDTFPExtendOp>; def AArch64fcvtpu_half : SDNode<"AArch64ISD::FCVTPU_HALF", SDTFPExtendOp>; +def AArch64sqadd: SDNode<"AArch64ISD::SQADD", SDTFPBinOp>; +def AArch64sqrshl: SDNode<"AArch64ISD::SQRSHL", SDTFPBinOp>; +def AArch64sqshl: SDNode<"AArch64ISD::SQSHL", SDTFPBinOp>; +def AArch64sqsub: SDNode<"AArch64ISD::SQSUB", SDTFPBinOp>; +def AArch64uqadd: SDNode<"AArch64ISD::UQADD", SDTFPBinOp>; +def AArch64uqrshl: SDNode<"AArch64ISD::UQRSHL", SDTFPBinOp>; +def AArch64uqshl: SDNode<"AArch64ISD::UQSHL", SDTFPBinOp>; +def AArch64uqsub: SDNode<"AArch64ISD::UQSUB", SDTFPBinOp>; +def AArch64sqdmull: SDNode<"AArch64ISD::SQDMULL", + SDTypeProfile<1, 2, [ SDTCisSameAs<1, 2>, + SDTCisFP<0>, SDTCisFP<1>]>>; + //def Aarch64softf32tobf16v8: SDNode<"AArch64ISD::", SDTFPRoundOp>; // Vector immediate ops @@ -1034,11 +1058,11 @@ def AArch64uitof: SDNode<"AArch64ISD::UITOF", SDT_AArch64ITOF>; // offset of a variable into X0, using the TLSDesc model. def AArch64tlsdesc_callseq : SDNode<"AArch64ISD::TLSDESC_CALLSEQ", SDT_AArch64TLSDescCallSeq, - [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>; + [SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>; def AArch64tlsdesc_auth_callseq : SDNode<"AArch64ISD::TLSDESC_AUTH_CALLSEQ", SDT_AArch64TLSDescCallSeq, - [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>; + [SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>; def AArch64WrapperLarge : SDNode<"AArch64ISD::WrapperLarge", SDT_AArch64WrapperLarge>; @@ -1178,7 +1202,7 @@ def AArch64msrr : SDNode<"AArch64ISD::MSRR", SDTCisVT<2, i64>]>, [SDNPHasChain]>; -def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisInt<2>]>; +def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisVT<2, i32>]>; // Vector narrowing shift by immediate (bottom) def AArch64rshrnb : SDNode<"AArch64ISD::RSHRNB_I", SD_AArch64rshrnb>; def AArch64rshrnb_pf : PatFrags<(ops node:$rs, node:$i), @@ -1520,7 +1544,6 @@ let hasSideEffects = 1, isCodeGenOnly = 1, isTerminator = 1, isBarrier = 1 in { //===----------------------------------------------------------------------===// def HINT : HintI<"hint">; -def : InstAlias<"nop", (HINT 0b000)>; def : InstAlias<"yield",(HINT 0b001)>; def : InstAlias<"wfe", (HINT 0b010)>; def : InstAlias<"wfi", (HINT 0b011)>; @@ -1530,6 +1553,11 @@ def : InstAlias<"dgh", (HINT 0b110)>; def : InstAlias<"esb", (HINT 0b10000)>, Requires<[HasRAS]>; def : InstAlias<"csdb", (HINT 20)>; +let CRm = 0b0000, hasSideEffects = 0 in +def NOP : SystemNoOperands<0b000, "hint\t#0">; + +def : InstAlias<"nop", (NOP)>; + let Predicates = [HasPCDPHINT] in { def STSHH: STSHHI; } @@ -1540,6 +1568,7 @@ let Predicates = [HasPCDPHINT] in { // should not emit these mnemonics unless BTI is enabled. def : InstAlias<"bti", (HINT 32), 0>; def : InstAlias<"bti $op", (HINT btihint_op:$op), 0>; +def : InstAlias<"bti r", (HINT 32)>, Requires<[HasBTIE]>; def : InstAlias<"bti", (HINT 32)>, Requires<[HasBTI]>; def : InstAlias<"bti $op", (HINT btihint_op:$op)>, Requires<[HasBTI]>; @@ -1805,14 +1834,22 @@ def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v8i16>; def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v4i32>; def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v2i64>; -class EOR3_pattern<ValueType VecTy> - : Pat<(xor (xor (VecTy V128:$Vn), (VecTy V128:$Vm)), (VecTy V128:$Va)), - (EOR3 (VecTy V128:$Vn), (VecTy V128:$Vm), (VecTy V128:$Va))>; +multiclass EOR3_pattern<ValueType Vec128Ty, ValueType Vec64Ty>{ + def : Pat<(xor (xor (Vec128Ty V128:$Vn), (Vec128Ty V128:$Vm)), (Vec128Ty V128:$Va)), + (EOR3 (Vec128Ty V128:$Vn), (Vec128Ty V128:$Vm), (Vec128Ty V128:$Va))>; + def : Pat<(xor (xor (Vec64Ty V64:$Vn), (Vec64Ty V64:$Vm)), (Vec64Ty V64:$Va)), + (EXTRACT_SUBREG + (EOR3 + (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vn, dsub), + (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vm, dsub), + (INSERT_SUBREG (IMPLICIT_DEF), V64:$Va, dsub)), + dsub)>; +} -def : EOR3_pattern<v16i8>; -def : EOR3_pattern<v8i16>; -def : EOR3_pattern<v4i32>; -def : EOR3_pattern<v2i64>; +defm : EOR3_pattern<v16i8, v8i8>; +defm : EOR3_pattern<v8i16, v4i16>; +defm : EOR3_pattern<v4i32, v2i32>; +defm : EOR3_pattern<v2i64, v1i64>; class BCAX_pattern<ValueType VecTy> : Pat<(xor (VecTy V128:$Vn), (and (VecTy V128:$Vm), (vnot (VecTy V128:$Va)))), @@ -2200,6 +2237,7 @@ let Predicates = [HasPAuth] in { let Size = 12; let Defs = [X16, X17]; let usesCustomInserter = 1; + let supportsDeactivationSymbol = true; } // A standalone pattern is used, so that literal 0 can be passed as $Disc. @@ -2480,22 +2518,6 @@ def : InstAlias<"sys $op1, $Cn, $Cm, $op2", sys_cr_op:$Cm, imm0_7:$op2, XZR)>; -let Predicates = [HasTME] in { - -def TSTART : TMSystemI<0b0000, "tstart", - [(set GPR64:$Rt, (int_aarch64_tstart))]>; - -def TCOMMIT : TMSystemINoOperand<0b0000, "tcommit", [(int_aarch64_tcommit)]>; - -def TCANCEL : TMSystemException<0b011, "tcancel", - [(int_aarch64_tcancel timm64_0_65535:$imm)]>; - -def TTEST : TMSystemI<0b0001, "ttest", [(set GPR64:$Rt, (int_aarch64_ttest))]> { - let mayLoad = 0; - let mayStore = 0; -} -} // HasTME - //===----------------------------------------------------------------------===// // Move immediate instructions. //===----------------------------------------------------------------------===// @@ -2753,6 +2775,20 @@ def : Pat<(AArch64sub_flag GPR64:$Rn, neg_addsub_shifted_imm64:$imm), (ADDSXri GPR64:$Rn, neg_addsub_shifted_imm64:$imm)>; } + +def trunc_isWorthFoldingALU : PatFrag<(ops node:$src), (trunc $src)> { + let PredicateCode = [{ return isWorthFoldingALU(SDValue(N, 0)); }]; + let GISelPredicateCode = [{ return isWorthFoldingIntoExtendedReg(MI, MRI, false); }]; +} + +// Patterns for (add X, trunc(shift(Y))), for which we can generate 64bit instructions. +def : Pat<(add GPR32:$Rn, (trunc_isWorthFoldingALU arith_shifted_reg64:$Rm)), + (EXTRACT_SUBREG (ADDXrs (INSERT_SUBREG (IMPLICIT_DEF), GPR32:$Rn, sub_32), + arith_shifted_reg64:$Rm), sub_32)>; +def : Pat<(sub GPR32:$Rn, (trunc_isWorthFoldingALU arith_shifted_reg64:$Rm)), + (EXTRACT_SUBREG (SUBXrs (INSERT_SUBREG (IMPLICIT_DEF), GPR32:$Rn, sub_32), + arith_shifted_reg64:$Rm), sub_32)>; + def : InstAlias<"neg $dst, $src", (SUBWrs GPR32:$dst, WZR, (arith_shifted_reg32 GPR32:$src, 0)), 3>; @@ -4436,6 +4472,11 @@ defm PRFUM : PrefetchUnscaled<0b11, 0, 0b10, "prfum", [(AArch64Prefetch timm:$Rt, (am_unscaled64 GPR64sp:$Rn, simm9:$offset))]>; +// PRFM falls back to PRFUM for negative or unaligned offsets (not a multiple +// of 8). +def : InstAlias<"prfm $Rt, [$Rn, $offset]", + (PRFUMi prfop:$Rt, GPR64sp:$Rn, simm9_offset_fb64:$offset), 0>; + //--- // (unscaled immediate, unprivileged) defm LDTRX : LoadUnprivileged<0b11, 0, 0b01, GPR64, "ldtr">; @@ -5658,6 +5699,8 @@ let isPseudo = 1 in { def SEH_EpilogStart : Pseudo<(outs), (ins), []>, Sched<[]>; def SEH_EpilogEnd : Pseudo<(outs), (ins), []>, Sched<[]>; def SEH_PACSignLR : Pseudo<(outs), (ins), []>, Sched<[]>; + def SEH_SaveAnyRegI : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$offs), []>, Sched<[]>; + def SEH_SaveAnyRegIP : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$reg1, i32imm:$offs), []>, Sched<[]>; def SEH_SaveAnyRegQP : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$reg1, i32imm:$offs), []>, Sched<[]>; def SEH_SaveAnyRegQPX : Pseudo<(outs), (ins i32imm:$reg0, i32imm:$reg1, i32imm:$offs), []>, Sched<[]>; def SEH_AllocZ : Pseudo<(outs), (ins i32imm:$offs), []>, Sched<[]>; @@ -6406,19 +6449,19 @@ defm FCMGT : SIMDThreeScalarFPCmp<1, 1, 0b100, "fcmgt", AArch64fcmgt>; defm FMULX : SIMDFPThreeScalar<0, 0, 0b011, "fmulx", int_aarch64_neon_fmulx, HasNEONandIsStreamingSafe>; defm FRECPS : SIMDFPThreeScalar<0, 0, 0b111, "frecps", int_aarch64_neon_frecps, HasNEONandIsStreamingSafe>; defm FRSQRTS : SIMDFPThreeScalar<0, 1, 0b111, "frsqrts", int_aarch64_neon_frsqrts, HasNEONandIsStreamingSafe>; -defm SQADD : SIMDThreeScalarBHSD<0, 0b00001, "sqadd", int_aarch64_neon_sqadd, saddsat>; +defm SQADD : SIMDThreeScalarBHSD<0, 0b00001, "sqadd", AArch64sqadd, int_aarch64_neon_sqadd, saddsat>; defm SQDMULH : SIMDThreeScalarHS< 0, 0b10110, "sqdmulh", int_aarch64_neon_sqdmulh>; defm SQRDMULH : SIMDThreeScalarHS< 1, 0b10110, "sqrdmulh", int_aarch64_neon_sqrdmulh>; -defm SQRSHL : SIMDThreeScalarBHSD<0, 0b01011, "sqrshl", int_aarch64_neon_sqrshl, int_aarch64_neon_sqrshl>; -defm SQSHL : SIMDThreeScalarBHSD<0, 0b01001, "sqshl", int_aarch64_neon_sqshl, int_aarch64_neon_sqshl>; -defm SQSUB : SIMDThreeScalarBHSD<0, 0b00101, "sqsub", int_aarch64_neon_sqsub, ssubsat>; +defm SQRSHL : SIMDThreeScalarBHSD<0, 0b01011, "sqrshl", AArch64sqrshl, int_aarch64_neon_sqrshl, int_aarch64_neon_sqrshl>; +defm SQSHL : SIMDThreeScalarBHSD<0, 0b01001, "sqshl", AArch64sqshl, int_aarch64_neon_sqshl, int_aarch64_neon_sqshl>; +defm SQSUB : SIMDThreeScalarBHSD<0, 0b00101, "sqsub", AArch64sqsub, int_aarch64_neon_sqsub, ssubsat>; defm SRSHL : SIMDThreeScalarD< 0, 0b01010, "srshl", int_aarch64_neon_srshl>; defm SSHL : SIMDThreeScalarD< 0, 0b01000, "sshl", int_aarch64_neon_sshl>; defm SUB : SIMDThreeScalarD< 1, 0b10000, "sub", sub>; -defm UQADD : SIMDThreeScalarBHSD<1, 0b00001, "uqadd", int_aarch64_neon_uqadd, uaddsat>; -defm UQRSHL : SIMDThreeScalarBHSD<1, 0b01011, "uqrshl", int_aarch64_neon_uqrshl, int_aarch64_neon_uqrshl>; -defm UQSHL : SIMDThreeScalarBHSD<1, 0b01001, "uqshl", int_aarch64_neon_uqshl, int_aarch64_neon_uqshl>; -defm UQSUB : SIMDThreeScalarBHSD<1, 0b00101, "uqsub", int_aarch64_neon_uqsub, usubsat>; +defm UQADD : SIMDThreeScalarBHSD<1, 0b00001, "uqadd", AArch64uqadd, int_aarch64_neon_uqadd, uaddsat>; +defm UQRSHL : SIMDThreeScalarBHSD<1, 0b01011, "uqrshl", AArch64uqrshl, int_aarch64_neon_uqrshl, int_aarch64_neon_uqrshl>; +defm UQSHL : SIMDThreeScalarBHSD<1, 0b01001, "uqshl", AArch64uqshl, int_aarch64_neon_uqshl, int_aarch64_neon_uqshl>; +defm UQSUB : SIMDThreeScalarBHSD<1, 0b00101, "uqsub", AArch64uqsub, int_aarch64_neon_uqsub, usubsat>; defm URSHL : SIMDThreeScalarD< 1, 0b01010, "urshl", int_aarch64_neon_urshl>; defm USHL : SIMDThreeScalarD< 1, 0b01000, "ushl", int_aarch64_neon_ushl>; let Predicates = [HasRDM] in { @@ -6469,17 +6512,16 @@ def : InstAlias<"faclt $dst, $src1, $src2", // Advanced SIMD three scalar instructions (mixed operands). //===----------------------------------------------------------------------===// defm SQDMULL : SIMDThreeScalarMixedHS<0, 0b11010, "sqdmull", - int_aarch64_neon_sqdmulls_scalar>; + AArch64sqdmull>; defm SQDMLAL : SIMDThreeScalarMixedTiedHS<0, 0b10010, "sqdmlal">; defm SQDMLSL : SIMDThreeScalarMixedTiedHS<0, 0b10110, "sqdmlsl">; -def : Pat<(i64 (int_aarch64_neon_sqadd (i64 FPR64:$Rd), - (i64 (int_aarch64_neon_sqdmulls_scalar (i32 FPR32:$Rn), - (i32 FPR32:$Rm))))), +def : Pat<(f64 (AArch64sqadd FPR64:$Rd, + (AArch64sqdmull FPR32:$Rn, FPR32:$Rm))), (SQDMLALi32 FPR64:$Rd, FPR32:$Rn, FPR32:$Rm)>; -def : Pat<(i64 (int_aarch64_neon_sqsub (i64 FPR64:$Rd), - (i64 (int_aarch64_neon_sqdmulls_scalar (i32 FPR32:$Rn), - (i32 FPR32:$Rm))))), + +def : Pat<(f64 (AArch64sqsub FPR64:$Rd, + (AArch64sqdmull FPR32:$Rn, FPR32:$Rm))), (SQDMLSLi32 FPR64:$Rd, FPR32:$Rn, FPR32:$Rm)>; //===----------------------------------------------------------------------===// @@ -6799,6 +6841,49 @@ defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, ftrunc, "F defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fp_to_sint_sat_gi, fround, "FCVTAS">; defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, fround, "FCVTAU">; +let Predicates = [HasFPRCVT] in { + def : Pat<(f32 (bitconvert (i32 (any_lround f16:$Rn)))), + (FCVTASSHr f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (any_lround f16:$Rn)))), + (FCVTASDHr f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (any_llround f16:$Rn)))), + (FCVTASDHr f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (any_lround f32:$Rn)))), + (FCVTASDSr f32:$Rn)>; + def : Pat<(f32 (bitconvert (i32 (any_lround f64:$Rn)))), + (FCVTASSDr f64:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (any_llround f32:$Rn)))), + (FCVTASDSr f32:$Rn)>; +} +def : Pat<(f32 (bitconvert (i32 (any_lround f32:$Rn)))), + (FCVTASv1i32 f32:$Rn)>; +def : Pat<(f64 (bitconvert (i64 (any_lround f64:$Rn)))), + (FCVTASv1i64 f64:$Rn)>; +def : Pat<(f64 (bitconvert (i64 (any_llround f64:$Rn)))), + (FCVTASv1i64 f64:$Rn)>; + +let Predicates = [HasFPRCVT] in { + def : Pat<(f32 (bitconvert (i32 (any_lrint f16:$Rn)))), + (FCVTZSSHr (FRINTXHr f16:$Rn))>; + def : Pat<(f64 (bitconvert (i64 (any_lrint f16:$Rn)))), + (FCVTZSDHr (FRINTXHr f16:$Rn))>; + def : Pat<(f64 (bitconvert (i64 (any_llrint f16:$Rn)))), + (FCVTZSDHr (FRINTXHr f16:$Rn))>; + def : Pat<(f64 (bitconvert (i64 (any_lrint f32:$Rn)))), + (FCVTZSDSr (FRINTXSr f32:$Rn))>; + def : Pat<(f32 (bitconvert (i32 (any_lrint f64:$Rn)))), + (FCVTZSSDr (FRINTXDr f64:$Rn))>; + def : Pat<(f64 (bitconvert (i64 (any_llrint f32:$Rn)))), + (FCVTZSDSr (FRINTXSr f32:$Rn))>; +} +def : Pat<(f32 (bitconvert (i32 (any_lrint f32:$Rn)))), + (FCVTZSv1i32 (FRINTXSr f32:$Rn))>; +def : Pat<(f64 (bitconvert (i64 (any_lrint f64:$Rn)))), + (FCVTZSv1i64 (FRINTXDr f64:$Rn))>; +def : Pat<(f64 (bitconvert (i64 (any_llrint f64:$Rn)))), + (FCVTZSv1i64 (FRINTXDr f64:$Rn))>; + + // f16 -> s16 conversions let Predicates = [HasFullFP16] in { def : Pat<(i16(fp_to_sint_sat_gi f16:$Rn)), (FCVTZSv1f16 f16:$Rn)>; @@ -7008,6 +7093,19 @@ multiclass UIntToFPROLoadPat<ValueType DstTy, ValueType SrcTy, sub))>; } +let Predicates = [HasNEONandIsSME2p2StreamingSafe, HasFullFP16] in { +defm : UIntToFPROLoadPat<f16, i32, zextloadi8, + UCVTFv1i16, ro8, LDRBroW, LDRBroX, bsub>; +def : Pat <(f16 (uint_to_fp (i32 + (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (UCVTFv1i16 (INSERT_SUBREG (f16 (IMPLICIT_DEF)), + (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub))>; +def : Pat <(f16 (uint_to_fp (i32 + (zextloadi8 (am_unscaled8 GPR64sp:$Rn, simm9:$offset))))), + (UCVTFv1i16 (INSERT_SUBREG (f16 (IMPLICIT_DEF)), + (LDURBi GPR64sp:$Rn, simm9:$offset), bsub))>; +} + defm : UIntToFPROLoadPat<f32, i32, zextloadi8, UCVTFv1i32, ro8, LDRBroW, LDRBroX, bsub>; def : Pat <(f32 (uint_to_fp (i32 @@ -8348,6 +8446,7 @@ def : InstAlias<"orr.4s $Vd, $imm", (ORRv4i32 V128:$Vd, imm0_255:$imm, 0)>; } // AdvSIMD FMOV +let isReMaterializable = 1, isAsCheapAsAMove = 1 in { def FMOVv2f64_ns : SIMDModifiedImmVectorNoShift<1, 1, 0, 0b1111, V128, fpimm8, "fmov", ".2d", [(set (v2f64 V128:$Rd), (AArch64fmov imm0_255:$imm8))]>; @@ -8365,6 +8464,7 @@ def FMOVv8f16_ns : SIMDModifiedImmVectorNoShift<1, 0, 1, 0b1111, V128, fpimm8, "fmov", ".8h", [(set (v8f16 V128:$Rd), (AArch64fmov imm0_255:$imm8))]>; } // Predicates = [HasNEON, HasFullFP16] +} // AdvSIMD MOVI @@ -8692,9 +8792,9 @@ defm SMLSL : SIMDVectorIndexedLongSDTied<0, 0b0110, "smlsl", TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>; defm SMULL : SIMDVectorIndexedLongSD<0, 0b1010, "smull", AArch64smull>; defm SQDMLAL : SIMDIndexedLongSQDMLXSDTied<0, 0b0011, "sqdmlal", saddsat, - int_aarch64_neon_sqadd>; + AArch64sqadd, int_aarch64_neon_sqadd>; defm SQDMLSL : SIMDIndexedLongSQDMLXSDTied<0, 0b0111, "sqdmlsl", ssubsat, - int_aarch64_neon_sqsub>; + AArch64sqsub, int_aarch64_neon_sqsub>; defm SQRDMLAH : SIMDIndexedSQRDMLxHSDTied<1, 0b1101, "sqrdmlah", int_aarch64_neon_sqrdmlah>; defm SQRDMLSH : SIMDIndexedSQRDMLxHSDTied<1, 0b1111, "sqrdmlsh", @@ -10853,6 +10953,15 @@ let Predicates = [HasMOPS, HasMTE], Defs = [NZCV], Size = 12, mayLoad = 0, maySt } //----------------------------------------------------------------------------- +// MOPS Granule Only Protection (FEAT_MOPS_GO) + +let Predicates = [HasMOPS_GO, HasMTE] in { + defm SETGOP : MOPSGoMemorySetTaggingInsns<0b00, "setgop">; + defm SETGOM : MOPSGoMemorySetTaggingInsns<0b01, "setgom">; + defm SETGOE : MOPSGoMemorySetTaggingInsns<0b10, "setgoe">; +} + +//----------------------------------------------------------------------------- // v8.3 Pointer Authentication late patterns def : Pat<(int_ptrauth_blend GPR64:$Rd, imm64_0_65535:$imm), @@ -11303,23 +11412,37 @@ let Predicates = [HasCMPBR] in { defm : CmpBranchWRegisterAlias<"cbhlt", "CBHGT">; // Pseudos for codegen - def CBWPrr : CmpBranchRegisterPseudo<GPR32>; - def CBXPrr : CmpBranchRegisterPseudo<GPR64>; - def CBWPri : CmpBranchImmediatePseudo<GPR32, uimm6_32b>; - def CBXPri : CmpBranchImmediatePseudo<GPR64, uimm6_64b>; - - def : Pat<(AArch64CB i32:$Cond, GPR32:$Rn, CmpBranchUImm6Operand_32b:$Imm, - bb:$Target), - (CBWPri i32:$Cond, GPR32:$Rn, uimm6_32b:$Imm, - am_brcmpcond:$Target)>; - def : Pat<(AArch64CB i32:$Cond, GPR64:$Rn, CmpBranchUImm6Operand_64b:$Imm, - bb:$Target), - (CBXPri i32:$Cond, GPR64:$Rn, uimm6_64b:$Imm, - am_brcmpcond:$Target)>; - def : Pat<(AArch64CB i32:$Cond, GPR32:$Rn, GPR32:$Rt, bb:$Target), - (CBWPrr ccode:$Cond, GPR32:$Rn, GPR32:$Rt, am_brcmpcond:$Target)>; - def : Pat<(AArch64CB i32:$Cond, GPR64:$Rn, GPR64:$Rt, bb:$Target), - (CBXPrr ccode:$Cond, GPR64:$Rn, GPR64:$Rt, am_brcmpcond:$Target)>; + def CBBAssertExt : CmpBranchExtRegisterPseudo; + def CBHAssertExt : CmpBranchExtRegisterPseudo; + def CBWPrr : CmpBranchRegisterPseudo<GPR32>; + def CBXPrr : CmpBranchRegisterPseudo<GPR64>; + def CBWPri : CmpBranchImmediatePseudo<GPR32, uimm6_32b>; + def CBXPri : CmpBranchImmediatePseudo<GPR64, uimm6_64b>; + + def : Pat<(AArch64CB i32:$Cond, GPR32:$Rn, CmpBranchUImm6Operand_32b:$Imm, + bb:$Target), + (CBWPri i32:$Cond, GPR32:$Rn, uimm6_32b:$Imm, am_brcmpcond:$Target)>; + def : Pat<(AArch64CB i32:$Cond, GPR64:$Rn, CmpBranchUImm6Operand_64b:$Imm, + bb:$Target), + (CBXPri i32:$Cond, GPR64:$Rn, uimm6_64b:$Imm, am_brcmpcond:$Target)>; + def : Pat<(AArch64CB i32:$Cond, GPR32:$Rn, GPR32:$Rt, bb:$Target), + (CBWPrr ccode:$Cond, GPR32:$Rn, GPR32:$Rt, am_brcmpcond:$Target)>; + def : Pat<(AArch64CB i32:$Cond, GPR64:$Rn, GPR64:$Rt, bb:$Target), + (CBXPrr ccode:$Cond, GPR64:$Rn, GPR64:$Rt, am_brcmpcond:$Target)>; + + def : Pat<(AArch64CB i32:$Cond, + (CmpBranchBExtOperand GPR32:$Rn, simm8_32b:$ExtTypeRn), + (CmpBranchBExtOperand GPR32:$Rt, simm8_32b:$ExtTypeRt), + bb:$Target), + (CBBAssertExt ccode:$Cond, GPR32:$Rn, GPR32:$Rt, bb:$Target, + simm8_32b:$ExtTypeRn, simm8_32b:$ExtTypeRt)>; + + def : Pat<(AArch64CB i32:$Cond, + (CmpBranchHExtOperand GPR32:$Rn, simm8_32b:$ExtTypeRn), + (CmpBranchHExtOperand GPR32:$Rt, simm8_32b:$ExtTypeRt), + bb:$Target), + (CBHAssertExt ccode:$Cond, GPR32:$Rn, GPR32:$Rt, bb:$Target, + simm8_32b:$ExtTypeRn, simm8_32b:$ExtTypeRt)>; } // HasCMPBR @@ -11407,7 +11530,7 @@ let Predicates = [HasF16F32MM] in defm FMMLA : SIMDThreeSameVectorFMLAWiden<"fmmla">; let Uses = [FPMR, FPCR] in - defm FMMLA : SIMDThreeSameVectorFP8MatrixMul<"fmmla">; + defm FMMLA : SIMDThreeSameVectorFP8MatrixMul<"fmmla", int_aarch64_neon_fmmla>; //===----------------------------------------------------------------------===// // Contention Management Hints (FEAT_CMH) @@ -11418,6 +11541,26 @@ let Predicates = [HasCMH] in { def STCPH : STCPHInst<"stcph">; // Store Concurrent Priority Hint instruction } +//===----------------------------------------------------------------------===// +// Permission Overlays Extension 2 (FEAT_S1POE2) +//===----------------------------------------------------------------------===// + +let Predicates = [HasS1POE2] in { + defm TCHANGEBrr : TCHANGEReg<"tchangeb", true>; + defm TCHANGEFrr : TCHANGEReg<"tchangef", false>; + defm TCHANGEBri : TCHANGEImm<"tchangeb", true>; + defm TCHANGEFri : TCHANGEImm<"tchangef", false>; +} + +//===----------------------------------------------------------------------===// +// TIndex Exception-like Vector (FEAT_TEV) +//===----------------------------------------------------------------------===// + +let Predicates = [HasTEV] in { + defm TENTER : TENTER<"tenter">; + defm TEXIT : TEXIT<"texit">; +} + include "AArch64InstrAtomics.td" include "AArch64SVEInstrInfo.td" include "AArch64SMEInstrInfo.td" diff --git a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp index e69fa32..45599de 100644 --- a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp +++ b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp @@ -42,7 +42,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" #include <cassert> #include <cstdint> #include <functional> @@ -1386,6 +1385,25 @@ AArch64LoadStoreOpt::mergePairedInsns(MachineBasicBlock::iterator I, if (MOP.isReg() && MOP.isKill()) DefinedInBB.addReg(MOP.getReg()); + // Copy over any implicit-def operands. This is like MI.copyImplicitOps, but + // only copies implicit defs and makes sure that each operand is only added + // once in case of duplicates. + auto CopyImplicitOps = [&](MachineBasicBlock::iterator MI1, + MachineBasicBlock::iterator MI2) { + SmallSetVector<Register, 4> Ops; + for (const MachineOperand &MO : + llvm::drop_begin(MI1->operands(), MI1->getDesc().getNumOperands())) + if (MO.isReg() && MO.isImplicit() && MO.isDef()) + Ops.insert(MO.getReg()); + for (const MachineOperand &MO : + llvm::drop_begin(MI2->operands(), MI2->getDesc().getNumOperands())) + if (MO.isReg() && MO.isImplicit() && MO.isDef()) + Ops.insert(MO.getReg()); + for (auto Op : Ops) + MIB.addDef(Op, RegState::Implicit); + }; + CopyImplicitOps(I, Paired); + // Erase the old instructions. I->eraseFromParent(); Paired->eraseFromParent(); diff --git a/llvm/lib/Target/AArch64/AArch64LowerHomogeneousPrologEpilog.cpp b/llvm/lib/Target/AArch64/AArch64LowerHomogeneousPrologEpilog.cpp index d67182d..d69f12e 100644 --- a/llvm/lib/Target/AArch64/AArch64LowerHomogeneousPrologEpilog.cpp +++ b/llvm/lib/Target/AArch64/AArch64LowerHomogeneousPrologEpilog.cpp @@ -483,16 +483,17 @@ bool AArch64LowerHomogeneousPE::lowerEpilog( assert(MI.getOpcode() == AArch64::HOM_Epilog); auto Return = NextMBBI; + MachineInstr *HelperCall = nullptr; if (shouldUseFrameHelper(MBB, NextMBBI, Regs, FrameHelperType::EpilogTail)) { // When MBB ends with a return, emit a tail-call to the epilog helper auto *EpilogTailHelper = getOrCreateFrameHelper(M, MMI, Regs, FrameHelperType::EpilogTail); - BuildMI(MBB, MBBI, DL, TII->get(AArch64::TCRETURNdi)) - .addGlobalAddress(EpilogTailHelper) - .addImm(0) - .setMIFlag(MachineInstr::FrameDestroy) - .copyImplicitOps(MI) - .copyImplicitOps(*Return); + HelperCall = BuildMI(MBB, MBBI, DL, TII->get(AArch64::TCRETURNdi)) + .addGlobalAddress(EpilogTailHelper) + .addImm(0) + .setMIFlag(MachineInstr::FrameDestroy) + .copyImplicitOps(MI) + .copyImplicitOps(*Return); NextMBBI = std::next(Return); Return->removeFromParent(); } else if (shouldUseFrameHelper(MBB, NextMBBI, Regs, @@ -500,10 +501,10 @@ bool AArch64LowerHomogeneousPE::lowerEpilog( // The default epilog helper case. auto *EpilogHelper = getOrCreateFrameHelper(M, MMI, Regs, FrameHelperType::Epilog); - BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL)) - .addGlobalAddress(EpilogHelper) - .setMIFlag(MachineInstr::FrameDestroy) - .copyImplicitOps(MI); + HelperCall = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL)) + .addGlobalAddress(EpilogHelper) + .setMIFlag(MachineInstr::FrameDestroy) + .copyImplicitOps(MI); } else { // Fall back to no-helper. for (int I = 0; I < Size - 2; I += 2) @@ -512,6 +513,12 @@ bool AArch64LowerHomogeneousPE::lowerEpilog( emitLoad(MF, MBB, MBBI, *TII, Regs[Size - 2], Regs[Size - 1], Size, true); } + // Make sure all explicit definitions are preserved in the helper call; + // implicit ones are already handled by copyImplicitOps. + if (HelperCall) + for (auto &Def : MBBI->defs()) + HelperCall->addRegisterDefined(Def.getReg(), + MF.getRegInfo().getTargetRegisterInfo()); MBBI->removeFromParent(); return true; } @@ -649,7 +656,7 @@ bool AArch64LowerHomogeneousPE::runOnMBB(MachineBasicBlock &MBB) { } bool AArch64LowerHomogeneousPE::runOnMachineFunction(MachineFunction &MF) { - TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); + TII = MF.getSubtarget<AArch64Subtarget>().getInstrInfo(); bool Modified = false; for (auto &MBB : MF) diff --git a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp index 04e76c7..d25db89 100644 --- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -595,17 +595,17 @@ bool AArch64MIPeepholeOpt::splitTwoPartImm( // Determine register classes for destinations and register operands const TargetRegisterClass *FirstInstrDstRC = - TII->getRegClass(TII->get(Opcode.first), 0, TRI); + TII->getRegClass(TII->get(Opcode.first), 0); const TargetRegisterClass *FirstInstrOperandRC = - TII->getRegClass(TII->get(Opcode.first), 1, TRI); + TII->getRegClass(TII->get(Opcode.first), 1); const TargetRegisterClass *SecondInstrDstRC = (Opcode.first == Opcode.second) ? FirstInstrDstRC - : TII->getRegClass(TII->get(Opcode.second), 0, TRI); + : TII->getRegClass(TII->get(Opcode.second), 0); const TargetRegisterClass *SecondInstrOperandRC = (Opcode.first == Opcode.second) ? FirstInstrOperandRC - : TII->getRegClass(TII->get(Opcode.second), 1, TRI); + : TII->getRegClass(TII->get(Opcode.second), 1); // Get old registers destinations and new register destinations Register DstReg = MI.getOperand(0).getReg(); @@ -784,14 +784,14 @@ bool AArch64MIPeepholeOpt::visitUBFMXri(MachineInstr &MI) { } const TargetRegisterClass *DstRC64 = - TII->getRegClass(TII->get(MI.getOpcode()), 0, TRI); + TII->getRegClass(TII->get(MI.getOpcode()), 0); const TargetRegisterClass *DstRC32 = TRI->getSubRegisterClass(DstRC64, AArch64::sub_32); assert(DstRC32 && "Destination register class of UBFMXri doesn't have a " "sub_32 subregister class"); const TargetRegisterClass *SrcRC64 = - TII->getRegClass(TII->get(MI.getOpcode()), 1, TRI); + TII->getRegClass(TII->get(MI.getOpcode()), 1); const TargetRegisterClass *SrcRC32 = TRI->getSubRegisterClass(SrcRC64, AArch64::sub_32); assert(SrcRC32 && "Source register class of UBFMXri doesn't have a sub_32 " diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp index 343fd81..baeab6a 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp @@ -16,6 +16,7 @@ #include "AArch64MachineFunctionInfo.h" #include "AArch64InstrInfo.h" #include "AArch64Subtarget.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -63,24 +64,21 @@ void AArch64FunctionInfo::initializeBaseYamlFields( setHasStreamingModeChanges(*YamlMFI.HasStreamingModeChanges); } -static std::pair<bool, bool> GetSignReturnAddress(const Function &F) { +static SignReturnAddress GetSignReturnAddress(const Function &F) { if (F.hasFnAttribute("ptrauth-returns")) - return {true, false}; // non-leaf + return SignReturnAddress::NonLeaf; + // The function should be signed in the following situations: // - sign-return-address=all // - sign-return-address=non-leaf and the functions spills the LR if (!F.hasFnAttribute("sign-return-address")) - return {false, false}; + return SignReturnAddress::None; StringRef Scope = F.getFnAttribute("sign-return-address").getValueAsString(); - if (Scope == "none") - return {false, false}; - - if (Scope == "all") - return {true, true}; - - assert(Scope == "non-leaf"); - return {true, false}; + return StringSwitch<SignReturnAddress>(Scope) + .Case("none", SignReturnAddress::None) + .Case("non-leaf", SignReturnAddress::NonLeaf) + .Case("all", SignReturnAddress::All); } static bool ShouldSignWithBKey(const Function &F, const AArch64Subtarget &STI) { @@ -116,7 +114,7 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F, // HasRedZone here. if (F.hasFnAttribute(Attribute::NoRedZone)) HasRedZone = false; - std::tie(SignReturnAddress, SignReturnAddressAll) = GetSignReturnAddress(F); + SignCondition = GetSignReturnAddress(F); SignWithBKey = ShouldSignWithBKey(F, *STI); HasELFSignedGOT = hasELFSignedGOTHelper(F, STI); // TODO: skip functions that have no instrumented allocas for optimization @@ -169,23 +167,28 @@ MachineFunctionInfo *AArch64FunctionInfo::clone( return DestMF.cloneInfo<AArch64FunctionInfo>(*this); } -bool AArch64FunctionInfo::shouldSignReturnAddress(bool SpillsLR) const { - if (!SignReturnAddress) - return false; - if (SignReturnAddressAll) - return true; - return SpillsLR; -} - static bool isLRSpilled(const MachineFunction &MF) { return llvm::any_of( MF.getFrameInfo().getCalleeSavedInfo(), [](const auto &Info) { return Info.getReg() == AArch64::LR; }); } +bool AArch64FunctionInfo::shouldSignReturnAddress(SignReturnAddress Condition, + bool IsLRSpilled) { + switch (Condition) { + case SignReturnAddress::None: + return false; + case SignReturnAddress::NonLeaf: + return IsLRSpilled; + case SignReturnAddress::All: + return true; + } + llvm_unreachable("Unknown SignReturnAddress enum"); +} + bool AArch64FunctionInfo::shouldSignReturnAddress( const MachineFunction &MF) const { - return shouldSignReturnAddress(isLRSpilled(MF)); + return shouldSignReturnAddress(SignCondition, isLRSpilled(MF)); } bool AArch64FunctionInfo::needsShadowCallStackPrologueEpilogue( diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index f680a5e..00e0c25 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -13,8 +13,8 @@ #ifndef LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H #define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -42,6 +42,15 @@ struct TPIDR2Object { unsigned Uses = 0; }; +/// Condition of signing the return address in a function. +/// +/// Corresponds to possible values of "sign-return-address" function attribute. +enum class SignReturnAddress { + None, + NonLeaf, + All, +}; + /// AArch64FunctionInfo - This class is derived from MachineFunctionInfo and /// contains private AArch64-specific information for each MachineFunction. class AArch64FunctionInfo final : public MachineFunctionInfo { @@ -170,13 +179,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { // CalleeSavedStackSize) to the address of the frame record. int CalleeSaveBaseToFrameRecordOffset = 0; - /// SignReturnAddress is true if PAC-RET is enabled for the function with - /// defaults being sign non-leaf functions only, with the B key. - bool SignReturnAddress = false; - - /// SignReturnAddressAll modifies the default PAC-RET mode to signing leaf - /// functions as well. - bool SignReturnAddressAll = false; + /// SignCondition controls when PAC-RET protection should be used. + SignReturnAddress SignCondition = SignReturnAddress::None; /// SignWithBKey modifies the default PAC-RET mode to signing with the B key. bool SignWithBKey = false; @@ -591,8 +595,14 @@ public: CalleeSaveBaseToFrameRecordOffset = Offset; } + static bool shouldSignReturnAddress(SignReturnAddress Condition, + bool IsLRSpilled); + bool shouldSignReturnAddress(const MachineFunction &MF) const; - bool shouldSignReturnAddress(bool SpillsLR) const; + + SignReturnAddress getSignReturnAddressCondition() const { + return SignCondition; + } bool needsShadowCallStackPrologueEpilogue(MachineFunction &MF) const; diff --git a/llvm/lib/Target/AArch64/AArch64MacroFusion.h b/llvm/lib/Target/AArch64/AArch64MacroFusion.h index 62da054..7682dbf 100644 --- a/llvm/lib/Target/AArch64/AArch64MacroFusion.h +++ b/llvm/lib/Target/AArch64/AArch64MacroFusion.h @@ -23,6 +23,6 @@ namespace llvm { /// to AArch64TargetMachine::createMachineScheduler() to have an effect. std::unique_ptr<ScheduleDAGMutation> createAArch64MacroFusionDAGMutation(); -} // llvm +} // namespace llvm #endif // LLVM_LIB_TARGET_AARCH64_AARCH64MACROFUSION_H diff --git a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h index f7beca1..c7d6b31 100644 --- a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h +++ b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h @@ -6622,35 +6622,52 @@ inline unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) { } /// Return true for zip1 or zip2 masks of the form: -/// <0, 8, 1, 9, 2, 10, 3, 11> or -/// <4, 12, 5, 13, 6, 14, 7, 15> +/// <0, 8, 1, 9, 2, 10, 3, 11> (WhichResultOut = 0, OperandOrderOut = 0) or +/// <4, 12, 5, 13, 6, 14, 7, 15> (WhichResultOut = 1, OperandOrderOut = 0) or +/// <8, 0, 9, 1, 10, 2, 11, 3> (WhichResultOut = 0, OperandOrderOut = 1) or +/// <12, 4, 13, 5, 14, 6, 15, 7> (WhichResultOut = 1, OperandOrderOut = 1) inline bool isZIPMask(ArrayRef<int> M, unsigned NumElts, - unsigned &WhichResultOut) { + unsigned &WhichResultOut, unsigned &OperandOrderOut) { if (NumElts % 2 != 0) return false; - // Check the first non-undef element for which half to use. - unsigned WhichResult = 2; - for (unsigned i = 0; i != NumElts / 2; i++) { - if (M[i * 2] >= 0) { - WhichResult = ((unsigned)M[i * 2] == i ? 0 : 1); - break; - } else if (M[i * 2 + 1] >= 0) { - WhichResult = ((unsigned)M[i * 2 + 1] == NumElts + i ? 0 : 1); - break; - } - } - if (WhichResult == 2) - return false; + // "Variant" refers to the distinction bwetween zip1 and zip2, while + // "Order" refers to sequence of input registers (matching vs flipped). + bool Variant0Order0 = true; // WhichResultOut = 0, OperandOrderOut = 0 + bool Variant1Order0 = true; // WhichResultOut = 1, OperandOrderOut = 0 + bool Variant0Order1 = true; // WhichResultOut = 0, OperandOrderOut = 1 + bool Variant1Order1 = true; // WhichResultOut = 1, OperandOrderOut = 1 // Check all elements match. - unsigned Idx = WhichResult * NumElts / 2; for (unsigned i = 0; i != NumElts; i += 2) { - if ((M[i] >= 0 && (unsigned)M[i] != Idx) || - (M[i + 1] >= 0 && (unsigned)M[i + 1] != Idx + NumElts)) - return false; - Idx += 1; + if (M[i] >= 0) { + unsigned EvenElt = (unsigned)M[i]; + if (EvenElt != i / 2) + Variant0Order0 = false; + if (EvenElt != NumElts / 2 + i / 2) + Variant1Order0 = false; + if (EvenElt != NumElts + i / 2) + Variant0Order1 = false; + if (EvenElt != NumElts + NumElts / 2 + i / 2) + Variant1Order1 = false; + } + if (M[i + 1] >= 0) { + unsigned OddElt = (unsigned)M[i + 1]; + if (OddElt != NumElts + i / 2) + Variant0Order0 = false; + if (OddElt != NumElts + NumElts / 2 + i / 2) + Variant1Order0 = false; + if (OddElt != i / 2) + Variant0Order1 = false; + if (OddElt != NumElts / 2 + i / 2) + Variant1Order1 = false; + } } - WhichResultOut = WhichResult; + + if (Variant0Order0 + Variant1Order0 + Variant0Order1 + Variant1Order1 != 1) + return false; + + WhichResultOut = (Variant0Order0 || Variant0Order1) ? 0 : 1; + OperandOrderOut = (Variant0Order0 || Variant1Order0) ? 0 : 1; return true; } @@ -6682,18 +6699,53 @@ inline bool isUZPMask(ArrayRef<int> M, unsigned NumElts, } /// Return true for trn1 or trn2 masks of the form: -/// <0, 8, 2, 10, 4, 12, 6, 14> or -/// <1, 9, 3, 11, 5, 13, 7, 15> +/// <0, 8, 2, 10, 4, 12, 6, 14> (WhichResultOut = 0, OperandOrderOut = 0) or +/// <1, 9, 3, 11, 5, 13, 7, 15> (WhichResultOut = 1, OperandOrderOut = 0) or +/// <8, 0, 10, 2, 12, 4, 14, 6> (WhichResultOut = 0, OperandOrderOut = 1) or +/// <9, 1, 11, 3, 13, 5, 15, 7> (WhichResultOut = 1, OperandOrderOut = 1) or inline bool isTRNMask(ArrayRef<int> M, unsigned NumElts, - unsigned &WhichResult) { + unsigned &WhichResultOut, unsigned &OperandOrderOut) { if (NumElts % 2 != 0) return false; - WhichResult = (M[0] == 0 ? 0 : 1); - for (unsigned i = 0; i < NumElts; i += 2) { - if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) || - (M[i + 1] >= 0 && (unsigned)M[i + 1] != i + NumElts + WhichResult)) - return false; + + // "Result" corresponds to "WhichResultOut", selecting between trn1 and trn2. + // "Order" corresponds to "OperandOrderOut", selecting the order of operands + // for the instruction (flipped or not). + bool Result0Order0 = true; // WhichResultOut = 0, OperandOrderOut = 0 + bool Result1Order0 = true; // WhichResultOut = 1, OperandOrderOut = 0 + bool Result0Order1 = true; // WhichResultOut = 0, OperandOrderOut = 1 + bool Result1Order1 = true; // WhichResultOut = 1, OperandOrderOut = 1 + // Check all elements match. + for (unsigned i = 0; i != NumElts; i += 2) { + if (M[i] >= 0) { + unsigned EvenElt = (unsigned)M[i]; + if (EvenElt != i) + Result0Order0 = false; + if (EvenElt != i + 1) + Result1Order0 = false; + if (EvenElt != NumElts + i) + Result0Order1 = false; + if (EvenElt != NumElts + i + 1) + Result1Order1 = false; + } + if (M[i + 1] >= 0) { + unsigned OddElt = (unsigned)M[i + 1]; + if (OddElt != NumElts + i) + Result0Order0 = false; + if (OddElt != NumElts + i + 1) + Result1Order0 = false; + if (OddElt != i) + Result0Order1 = false; + if (OddElt != i + 1) + Result1Order1 = false; + } } + + if (Result0Order0 + Result1Order0 + Result0Order1 + Result1Order1 != 1) + return false; + + WhichResultOut = (Result0Order0 || Result0Order1) ? 0 : 1; + OperandOrderOut = (Result0Order0 || Result1Order0) ? 0 : 1; return true; } diff --git a/llvm/lib/Target/AArch64/AArch64Processors.td b/llvm/lib/Target/AArch64/AArch64Processors.td index 81f5d07..120415f 100644 --- a/llvm/lib/Target/AArch64/AArch64Processors.td +++ b/llvm/lib/Target/AArch64/AArch64Processors.td @@ -593,6 +593,7 @@ def TuneNeoverseN2 : SubtargetFeature<"neoversen2", "ARMProcFamily", "NeoverseN2 FeatureALULSLFast, FeaturePostRAScheduler, FeatureEnableSelectOptimize, + FeatureDisableMaximizeScalableBandwidth, FeaturePredictableSelectIsExpensive]>; def TuneNeoverseN3 : SubtargetFeature<"neoversen3", "ARMProcFamily", "NeoverseN3", @@ -626,6 +627,7 @@ def TuneNeoverseV1 : SubtargetFeature<"neoversev1", "ARMProcFamily", "NeoverseV1 FeaturePostRAScheduler, FeatureEnableSelectOptimize, FeaturePredictableSelectIsExpensive, + FeatureDisableMaximizeScalableBandwidth, FeatureNoSVEFPLD1R]>; def TuneNeoverseV2 : SubtargetFeature<"neoversev2", "ARMProcFamily", "NeoverseV2", @@ -1272,11 +1274,11 @@ def : ProcessorModel<"cortex-x2", NeoverseV2Model, ProcessorFeatures.X2, [TuneX2]>; def : ProcessorModel<"cortex-x3", NeoverseV2Model, ProcessorFeatures.X3, [TuneX3]>; -def : ProcessorModel<"cortex-x4", NeoverseV2Model, ProcessorFeatures.X4, +def : ProcessorModel<"cortex-x4", NeoverseV3Model, ProcessorFeatures.X4, [TuneX4]>; -def : ProcessorModel<"cortex-x925", NeoverseV2Model, ProcessorFeatures.X925, +def : ProcessorModel<"cortex-x925", NeoverseV3Model, ProcessorFeatures.X925, [TuneX925]>; -def : ProcessorModel<"gb10", NeoverseV2Model, ProcessorFeatures.GB10, +def : ProcessorModel<"gb10", NeoverseV3Model, ProcessorFeatures.GB10, [TuneX925]>; def : ProcessorModel<"grace", NeoverseV2Model, ProcessorFeatures.Grace, [TuneNeoverseV2]>; @@ -1295,9 +1297,9 @@ def : ProcessorModel<"neoverse-v1", NeoverseV1Model, ProcessorFeatures.NeoverseV1, [TuneNeoverseV1]>; def : ProcessorModel<"neoverse-v2", NeoverseV2Model, ProcessorFeatures.NeoverseV2, [TuneNeoverseV2]>; -def : ProcessorModel<"neoverse-v3", NeoverseV2Model, +def : ProcessorModel<"neoverse-v3", NeoverseV3Model, ProcessorFeatures.NeoverseV3, [TuneNeoverseV3]>; -def : ProcessorModel<"neoverse-v3ae", NeoverseV2Model, +def : ProcessorModel<"neoverse-v3ae", NeoverseV3AEModel, ProcessorFeatures.NeoverseV3AE, [TuneNeoverseV3AE]>; def : ProcessorModel<"exynos-m3", ExynosM3Model, ProcessorFeatures.ExynosM3, [TuneExynosM3]>; diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp index 7e03b97..965585f 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp @@ -253,6 +253,8 @@ static void fixupSEHOpcode(MachineBasicBlock::iterator MBBI, case AArch64::SEH_SaveReg: case AArch64::SEH_SaveFRegP: case AArch64::SEH_SaveFReg: + case AArch64::SEH_SaveAnyRegI: + case AArch64::SEH_SaveAnyRegIP: case AArch64::SEH_SaveAnyRegQP: case AArch64::SEH_SaveAnyRegQPX: ImmOpnd = &MBBI->getOperand(ImmIdx); @@ -370,6 +372,22 @@ SVEFrameSizes AArch64PrologueEpilogueCommon::getSVEStackFrameSizes() const { {ZPRCalleeSavesSize, PPRLocalsSize + ZPRLocalsSize}}; } +SVEStackAllocations AArch64PrologueEpilogueCommon::getSVEStackAllocations( + SVEFrameSizes const &SVE) { + StackOffset AfterZPRs = SVE.ZPR.LocalsSize; + StackOffset BeforePPRs = SVE.ZPR.CalleeSavesSize + SVE.PPR.CalleeSavesSize; + StackOffset AfterPPRs = {}; + if (SVELayout == SVEStackLayout::Split) { + BeforePPRs = SVE.PPR.CalleeSavesSize; + // If there are no ZPR CSRs, place all local allocations after the ZPRs. + if (SVE.ZPR.CalleeSavesSize) + AfterPPRs += SVE.PPR.LocalsSize + SVE.ZPR.CalleeSavesSize; + else + AfterZPRs += SVE.PPR.LocalsSize; // Group allocation of locals. + } + return {BeforePPRs, AfterPPRs, AfterZPRs}; +} + struct SVEPartitions { struct { MachineBasicBlock::iterator Begin, End; @@ -687,16 +705,19 @@ void AArch64PrologueEmitter::emitPrologue() { // All of the remaining stack allocations are for locals. determineLocalsStackSize(NumBytes, PrologueSaveSize); + auto [PPR, ZPR] = getSVEStackFrameSizes(); + SVEStackAllocations SVEAllocs = getSVEStackAllocations({PPR, ZPR}); + MachineBasicBlock::iterator FirstGPRSaveI = PrologueBeginI; if (SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord) { + assert(!SVEAllocs.AfterPPRs && + "unexpected SVE allocs after PPRs with CalleeSavesAboveFrameRecord"); // If we're doing SVE saves first, we need to immediately allocate space // for fixed objects, then space for the SVE callee saves. // // Windows unwind requires that the scalable size is a multiple of 16; // that's handled when the callee-saved size is computed. - auto SaveSize = - StackOffset::getScalable(AFI->getSVECalleeSavedStackSize()) + - StackOffset::getFixed(FixedObject); + auto SaveSize = SVEAllocs.BeforePPRs + StackOffset::getFixed(FixedObject); allocateStackSpace(PrologueBeginI, 0, SaveSize, false, StackOffset{}, /*FollowupAllocs=*/true); NumBytes -= FixedObject; @@ -764,12 +785,11 @@ void AArch64PrologueEmitter::emitPrologue() { if (AFL.windowsRequiresStackProbe(MF, NumBytes + RealignmentPadding)) emitWindowsStackProbe(AfterGPRSavesI, DL, NumBytes, RealignmentPadding); - auto [PPR, ZPR] = getSVEStackFrameSizes(); - StackOffset SVECalleeSavesSize = ZPR.CalleeSavesSize + PPR.CalleeSavesSize; StackOffset NonSVELocalsSize = StackOffset::getFixed(NumBytes); + SVEAllocs.AfterZPRs += NonSVELocalsSize; + StackOffset CFAOffset = StackOffset::getFixed(MFI.getStackSize()) - NonSVELocalsSize; - MachineBasicBlock::iterator AfterSVESavesI = AfterGPRSavesI; // Allocate space for the callee saves and PPR locals (if any). if (SVELayout != SVEStackLayout::CalleeSavesAboveFrameRecord) { @@ -780,31 +800,23 @@ void AArch64PrologueEmitter::emitPrologue() { if (EmitAsyncCFI) emitCalleeSavedSVELocations(AfterSVESavesI); - StackOffset AllocateBeforePPRs = SVECalleeSavesSize; - StackOffset AllocateAfterPPRs = PPR.LocalsSize; - if (SVELayout == SVEStackLayout::Split) { - AllocateBeforePPRs = PPR.CalleeSavesSize; - AllocateAfterPPRs = PPR.LocalsSize + ZPR.CalleeSavesSize; - } - allocateStackSpace(PPRRange.Begin, 0, AllocateBeforePPRs, + allocateStackSpace(PPRRange.Begin, 0, SVEAllocs.BeforePPRs, EmitAsyncCFI && !HasFP, CFAOffset, - MFI.hasVarSizedObjects() || AllocateAfterPPRs || - ZPR.LocalsSize || NonSVELocalsSize); - CFAOffset += AllocateBeforePPRs; + MFI.hasVarSizedObjects() || SVEAllocs.AfterPPRs || + SVEAllocs.AfterZPRs); + CFAOffset += SVEAllocs.BeforePPRs; assert(PPRRange.End == ZPRRange.Begin && "Expected ZPR callee saves after PPR locals"); - allocateStackSpace(PPRRange.End, RealignmentPadding, AllocateAfterPPRs, + allocateStackSpace(PPRRange.End, 0, SVEAllocs.AfterPPRs, EmitAsyncCFI && !HasFP, CFAOffset, - MFI.hasVarSizedObjects() || ZPR.LocalsSize || - NonSVELocalsSize); - CFAOffset += AllocateAfterPPRs; + MFI.hasVarSizedObjects() || SVEAllocs.AfterZPRs); + CFAOffset += SVEAllocs.AfterPPRs; } else { assert(SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord); - // Note: With CalleeSavesAboveFrameRecord, the SVE CS have already been - // allocated (and separate PPR locals are not supported, all SVE locals, - // both PPR and ZPR, are within the ZPR locals area). - assert(!PPR.LocalsSize && "Unexpected PPR locals!"); - CFAOffset += SVECalleeSavesSize; + // Note: With CalleeSavesAboveFrameRecord, the SVE CS (BeforePPRs) have + // already been allocated. PPR locals (included in AfterPPRs) are not + // supported (note: this is asserted above). + CFAOffset += SVEAllocs.BeforePPRs; } // Allocate space for the rest of the frame including ZPR locals. Align the @@ -815,9 +827,9 @@ void AArch64PrologueEmitter::emitPrologue() { // FIXME: in the case of dynamic re-alignment, NumBytes doesn't have the // correct value here, as NumBytes also includes padding bytes, which // shouldn't be counted here. - allocateStackSpace( - AfterSVESavesI, RealignmentPadding, ZPR.LocalsSize + NonSVELocalsSize, - EmitAsyncCFI && !HasFP, CFAOffset, MFI.hasVarSizedObjects()); + allocateStackSpace(AfterSVESavesI, RealignmentPadding, SVEAllocs.AfterZPRs, + EmitAsyncCFI && !HasFP, CFAOffset, + MFI.hasVarSizedObjects()); } // If we need a base pointer, set it up here. It's whatever the value of the @@ -1308,6 +1320,26 @@ AArch64EpilogueEmitter::AArch64EpilogueEmitter(MachineFunction &MF, SEHEpilogueStartI = MBB.end(); } +void AArch64EpilogueEmitter::moveSPBelowFP(MachineBasicBlock::iterator MBBI, + StackOffset Offset) { + // Other combinations could be supported, but are not currently needed. + assert(Offset.getScalable() < 0 && Offset.getFixed() <= 0 && + "expected negative offset (with optional fixed portion)"); + Register Base = AArch64::FP; + if (int64_t FixedOffset = Offset.getFixed()) { + // If we have a negative fixed offset, we need to first subtract it in a + // temporary register first (to avoid briefly deallocating the scalable + // portion of the offset). + Base = MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + emitFrameOffset(MBB, MBBI, DL, Base, AArch64::FP, + StackOffset::getFixed(FixedOffset), TII, + MachineInstr::FrameDestroy); + } + emitFrameOffset(MBB, MBBI, DL, AArch64::SP, Base, + StackOffset::getScalable(Offset.getScalable()), TII, + MachineInstr::FrameDestroy); +} + void AArch64EpilogueEmitter::emitEpilogue() { MachineBasicBlock::iterator EpilogueEndI = MBB.getLastNonDebugInstr(); if (MBB.end() != EpilogueEndI) { @@ -1408,6 +1440,7 @@ void AArch64EpilogueEmitter::emitEpilogue() { AfterCSRPopSize += ProloguePopSize; } } + // Move past the restores of the callee-saved registers. // If we plan on combining the sp bump of the local stack size and the callee // save stack size, we might need to adjust the CSR save and restore offsets. @@ -1472,27 +1505,25 @@ void AArch64EpilogueEmitter::emitEpilogue() { assert(NumBytes >= 0 && "Negative stack allocation size!?"); StackOffset SVECalleeSavesSize = ZPR.CalleeSavesSize + PPR.CalleeSavesSize; - StackOffset SVEStackSize = - SVECalleeSavesSize + PPR.LocalsSize + ZPR.LocalsSize; - MachineBasicBlock::iterator RestoreBegin = ZPRRange.Begin; - MachineBasicBlock::iterator RestoreEnd = PPRRange.End; + SVEStackAllocations SVEAllocs = getSVEStackAllocations({PPR, ZPR}); // Deallocate the SVE area. if (SVELayout == SVEStackLayout::CalleeSavesAboveFrameRecord) { - StackOffset SVELocalsSize = ZPR.LocalsSize + PPR.LocalsSize; + assert(!SVEAllocs.AfterPPRs && + "unexpected SVE allocs after PPRs with CalleeSavesAboveFrameRecord"); // If the callee-save area is before FP, restoring the FP implicitly - // deallocates non-callee-save SVE allocations. Otherwise, deallocate them + // deallocates non-callee-save SVE allocations. Otherwise, deallocate them // explicitly. if (!AFI->isStackRealigned() && !MFI.hasVarSizedObjects()) { emitFrameOffset(MBB, FirstGPRRestoreI, DL, AArch64::SP, AArch64::SP, - SVELocalsSize, TII, MachineInstr::FrameDestroy, false, - NeedsWinCFI, &HasWinCFI); + SVEAllocs.AfterZPRs, TII, MachineInstr::FrameDestroy, + false, NeedsWinCFI, &HasWinCFI); } // Deallocate callee-save SVE registers. - emitFrameOffset(MBB, RestoreEnd, DL, AArch64::SP, AArch64::SP, - SVECalleeSavesSize, TII, MachineInstr::FrameDestroy, false, - NeedsWinCFI, &HasWinCFI); + emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, + SVEAllocs.BeforePPRs, TII, MachineInstr::FrameDestroy, + false, NeedsWinCFI, &HasWinCFI); } else if (AFI->hasSVEStackSize()) { // If we have stack realignment or variable-sized objects we must use the FP // to restore SVE callee saves (as there is an unknown amount of @@ -1501,69 +1532,53 @@ void AArch64EpilogueEmitter::emitEpilogue() { (AFI->isStackRealigned() || MFI.hasVarSizedObjects()) ? AArch64::FP : AArch64::SP; if (SVECalleeSavesSize && BaseForSVEDealloc == AArch64::FP) { - // TODO: Support stack realigment and variable-sized objects. - assert( - SVELayout != SVEStackLayout::Split && - "unexpected stack realignment or variable sized objects with split " - "SVE stack objects"); - - Register CalleeSaveBase = AArch64::FP; - if (int64_t CalleeSaveBaseOffset = - AFI->getCalleeSaveBaseToFrameRecordOffset()) { - // If we have have an non-zero offset to the non-SVE CS base we need to - // compute the base address by subtracting the offest in a temporary - // register first (to avoid briefly deallocating the SVE CS). - CalleeSaveBase = MBB.getParent()->getRegInfo().createVirtualRegister( - &AArch64::GPR64RegClass); - emitFrameOffset(MBB, RestoreBegin, DL, CalleeSaveBase, AArch64::FP, - StackOffset::getFixed(-CalleeSaveBaseOffset), TII, - MachineInstr::FrameDestroy); + if (ZPR.CalleeSavesSize || SVELayout != SVEStackLayout::Split) { + // The offset from the frame-pointer to the start of the ZPR saves. + StackOffset FPOffsetZPR = + -SVECalleeSavesSize - PPR.LocalsSize - + StackOffset::getFixed(AFI->getCalleeSaveBaseToFrameRecordOffset()); + // Deallocate the stack space space by moving the SP to the start of the + // ZPR/PPR callee-save area. + moveSPBelowFP(ZPRRange.Begin, FPOffsetZPR); } - // The code below will deallocate the stack space space by moving the SP - // to the start of the SVE callee-save area. - emitFrameOffset(MBB, RestoreBegin, DL, AArch64::SP, CalleeSaveBase, - -SVECalleeSavesSize, TII, MachineInstr::FrameDestroy); - } else if (BaseForSVEDealloc == AArch64::SP) { - auto CFAOffset = - SVEStackSize + StackOffset::getFixed(NumBytes + PrologueSaveSize); - - if (SVECalleeSavesSize) { - // Deallocate the non-SVE locals first before we can deallocate (and - // restore callee saves) from the SVE area. - auto NonSVELocals = StackOffset::getFixed(NumBytes); - emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, - NonSVELocals, TII, MachineInstr::FrameDestroy, false, - NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, CFAOffset); - CFAOffset -= NonSVELocals; - NumBytes = 0; + // With split SVE, the predicates are stored in a separate area above the + // ZPR saves, so we must adjust the stack to the start of the PPRs. + if (PPR.CalleeSavesSize && SVELayout == SVEStackLayout::Split) { + // The offset from the frame-pointer to the start of the PPR saves. + StackOffset FPOffsetPPR = -PPR.CalleeSavesSize; + // Move to the start of the PPR area. + assert(!FPOffsetPPR.getFixed() && "expected only scalable offset"); + emitFrameOffset(MBB, ZPRRange.End, DL, AArch64::SP, AArch64::FP, + FPOffsetPPR, TII, MachineInstr::FrameDestroy); } - - if (ZPR.LocalsSize) { - emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, - ZPR.LocalsSize, TII, MachineInstr::FrameDestroy, false, - NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, CFAOffset); - CFAOffset -= ZPR.LocalsSize; - } - - StackOffset SVECalleeSavesToDealloc = SVECalleeSavesSize; - if (SVELayout == SVEStackLayout::Split && - (PPR.LocalsSize || ZPR.CalleeSavesSize)) { - assert(PPRRange.Begin == ZPRRange.End && - "Expected PPR restores after ZPR"); - emitFrameOffset(MBB, PPRRange.Begin, DL, AArch64::SP, AArch64::SP, - PPR.LocalsSize + ZPR.CalleeSavesSize, TII, - MachineInstr::FrameDestroy, false, NeedsWinCFI, - &HasWinCFI, EmitCFI && !HasFP, CFAOffset); - CFAOffset -= PPR.LocalsSize + ZPR.CalleeSavesSize; - SVECalleeSavesToDealloc -= ZPR.CalleeSavesSize; + } else if (BaseForSVEDealloc == AArch64::SP) { + auto NonSVELocals = StackOffset::getFixed(NumBytes); + auto CFAOffset = NonSVELocals + StackOffset::getFixed(PrologueSaveSize) + + SVEAllocs.totalSize(); + + if (SVECalleeSavesSize || SVELayout == SVEStackLayout::Split) { + // Deallocate non-SVE locals now. This is needed to reach the SVE callee + // saves, but may also allow combining stack hazard bumps for split SVE. + SVEAllocs.AfterZPRs += NonSVELocals; + NumBytes -= NonSVELocals.getFixed(); } - - // If split SVE is on, this dealloc PPRs, otherwise, deallocs ZPRs + PPRs: - if (SVECalleeSavesToDealloc) - emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, - SVECalleeSavesToDealloc, TII, - MachineInstr::FrameDestroy, false, NeedsWinCFI, - &HasWinCFI, EmitCFI && !HasFP, CFAOffset); + // To deallocate the SVE stack adjust by the allocations in reverse. + emitFrameOffset(MBB, ZPRRange.Begin, DL, AArch64::SP, AArch64::SP, + SVEAllocs.AfterZPRs, TII, MachineInstr::FrameDestroy, + false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, + CFAOffset); + CFAOffset -= SVEAllocs.AfterZPRs; + assert(PPRRange.Begin == ZPRRange.End && + "Expected PPR restores after ZPR"); + emitFrameOffset(MBB, PPRRange.Begin, DL, AArch64::SP, AArch64::SP, + SVEAllocs.AfterPPRs, TII, MachineInstr::FrameDestroy, + false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, + CFAOffset); + CFAOffset -= SVEAllocs.AfterPPRs; + emitFrameOffset(MBB, PPRRange.End, DL, AArch64::SP, AArch64::SP, + SVEAllocs.BeforePPRs, TII, MachineInstr::FrameDestroy, + false, NeedsWinCFI, &HasWinCFI, EmitCFI && !HasFP, + CFAOffset); } if (EmitCFI) diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h index bccadda..7f297b5 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.h @@ -33,6 +33,11 @@ struct SVEFrameSizes { } PPR, ZPR; }; +struct SVEStackAllocations { + StackOffset BeforePPRs, AfterPPRs, AfterZPRs; + StackOffset totalSize() const { return BeforePPRs + AfterPPRs + AfterZPRs; } +}; + class AArch64PrologueEpilogueCommon { public: AArch64PrologueEpilogueCommon(MachineFunction &MF, MachineBasicBlock &MBB, @@ -66,6 +71,7 @@ protected: bool shouldCombineCSRLocalStackBump(uint64_t StackBumpBytes) const; SVEFrameSizes getSVEStackFrameSizes() const; + SVEStackAllocations getSVEStackAllocations(SVEFrameSizes const &); MachineFunction &MF; MachineBasicBlock &MBB; @@ -174,6 +180,10 @@ public: private: bool shouldCombineCSRLocalStackBump(uint64_t StackBumpBytes) const; + /// A helper for moving the SP to a negative offset from the FP, without + /// deallocating any stack in the range FP to FP + Offset. + void moveSPBelowFP(MachineBasicBlock::iterator MBBI, StackOffset Offset); + void emitSwiftAsyncContextFramePointer(MachineBasicBlock::iterator MBBI, const DebugLoc &DL) const; diff --git a/llvm/lib/Target/AArch64/AArch64RedundantCondBranchPass.cpp b/llvm/lib/Target/AArch64/AArch64RedundantCondBranchPass.cpp new file mode 100644 index 0000000..1a5a9f0 --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64RedundantCondBranchPass.cpp @@ -0,0 +1,63 @@ +//=- AArch64RedundantCondBranch.cpp - Remove redundant conditional branches -=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Late in the pipeline, especially with zero phi operands propagated after tail +// duplications, we can end up with CBZ/CBNZ/TBZ/TBNZ with a zero register. This +// simple pass looks at the terminators to a block, removing the redundant +// instructions where necessary. +// +//===----------------------------------------------------------------------===// + +#include "AArch64.h" +#include "AArch64InstrInfo.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-redundantcondbranch" + +namespace { +class AArch64RedundantCondBranch : public MachineFunctionPass { +public: + static char ID; + AArch64RedundantCondBranch() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + + MachineFunctionProperties getRequiredProperties() const override { + return MachineFunctionProperties().setNoVRegs(); + } + StringRef getPassName() const override { + return "AArch64 Redundant Conditional Branch Elimination"; + } +}; +char AArch64RedundantCondBranch::ID = 0; +} // namespace + +INITIALIZE_PASS(AArch64RedundantCondBranch, "aarch64-redundantcondbranch", + "AArch64 Redundant Conditional Branch Elimination pass", false, + false) + +bool AArch64RedundantCondBranch::runOnMachineFunction(MachineFunction &MF) { + if (skipFunction(MF.getFunction())) + return false; + + const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); + + bool Changed = false; + for (MachineBasicBlock &MBB : MF) + Changed |= optimizeTerminators(&MBB, TII); + return Changed; +} + +FunctionPass *llvm::createAArch64RedundantCondBranchPass() { + return new AArch64RedundantCondBranch(); +} diff --git a/llvm/lib/Target/AArch64/AArch64RedundantCopyElimination.cpp b/llvm/lib/Target/AArch64/AArch64RedundantCopyElimination.cpp index 84015e5..9dc721e 100644 --- a/llvm/lib/Target/AArch64/AArch64RedundantCopyElimination.cpp +++ b/llvm/lib/Target/AArch64/AArch64RedundantCopyElimination.cpp @@ -50,6 +50,7 @@ // to use WZR/XZR directly in some cases. //===----------------------------------------------------------------------===// #include "AArch64.h" +#include "AArch64InstrInfo.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" @@ -475,6 +476,7 @@ bool AArch64RedundantCopyElimination::runOnMachineFunction( return false; TRI = MF.getSubtarget().getRegisterInfo(); MRI = &MF.getRegInfo(); + const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); // Resize the clobbered and used register unit trackers. We do this once per // function. @@ -484,8 +486,10 @@ bool AArch64RedundantCopyElimination::runOnMachineFunction( OptBBUsedRegs.init(*TRI); bool Changed = false; - for (MachineBasicBlock &MBB : MF) + for (MachineBasicBlock &MBB : MF) { + Changed |= optimizeTerminators(&MBB, TII); Changed |= optimizeBlock(&MBB); + } return Changed; } diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp index 79975b0..ab1df70 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -15,10 +15,10 @@ #include "AArch64FrameLowering.h" #include "AArch64InstrInfo.h" #include "AArch64MachineFunctionInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "MCTargetDesc/AArch64InstPrinter.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/BitVector.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/CodeGen/LiveRegMatrix.h" @@ -71,142 +71,126 @@ bool AArch64RegisterInfo::regNeedsCFI(MCRegister Reg, const MCPhysReg * AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "Invalid MachineFunction pointer."); + auto &AFI = *MF->getInfo<AArch64FunctionInfo>(); + const auto &F = MF->getFunction(); + const auto *TLI = MF->getSubtarget<AArch64Subtarget>().getTargetLowering(); + const bool Darwin = MF->getSubtarget<AArch64Subtarget>().isTargetDarwin(); + const bool Windows = MF->getSubtarget<AArch64Subtarget>().isTargetWindows(); + + if (TLI->supportSwiftError() && + F.getAttributes().hasAttrSomewhere(Attribute::SwiftError)) { + if (Darwin) + return CSR_Darwin_AArch64_AAPCS_SwiftError_SaveList; + if (Windows) + return CSR_Win_AArch64_AAPCS_SwiftError_SaveList; + return CSR_AArch64_AAPCS_SwiftError_SaveList; + } - if (MF->getFunction().getCallingConv() == CallingConv::GHC) + switch (F.getCallingConv()) { + case CallingConv::GHC: // GHC set of callee saved regs is empty as all those regs are // used for passing STG regs around return CSR_AArch64_NoRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::PreserveNone) + + case CallingConv::PreserveNone: + // FIXME: Windows likely need this to be altered for properly unwinding. return CSR_AArch64_NoneRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::AnyReg) + + case CallingConv::AnyReg: return CSR_AArch64_AllRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::ARM64EC_Thunk_X64) + case CallingConv::ARM64EC_Thunk_X64: return CSR_Win_AArch64_Arm64EC_Thunk_SaveList; - // Darwin has its own CSR_AArch64_AAPCS_SaveList, which means most CSR save - // lists depending on that will need to have their Darwin variant as well. - if (MF->getSubtarget<AArch64Subtarget>().isTargetDarwin()) - return getDarwinCalleeSavedRegs(MF); + case CallingConv::PreserveMost: + if (Darwin) + return CSR_Darwin_AArch64_RT_MostRegs_SaveList; + if (Windows) + return CSR_Win_AArch64_RT_MostRegs_SaveList; + return CSR_AArch64_RT_MostRegs_SaveList; + + case CallingConv::PreserveAll: + if (Darwin) + return CSR_Darwin_AArch64_RT_AllRegs_SaveList; + if (Windows) + return CSR_Win_AArch64_RT_AllRegs_SaveList; + return CSR_AArch64_RT_AllRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::CFGuard_Check) + case CallingConv::CFGuard_Check: + if (Darwin) + report_fatal_error( + "Calling convention CFGuard_Check is unsupported on Darwin."); return CSR_Win_AArch64_CFGuard_Check_SaveList; - if (MF->getSubtarget<AArch64Subtarget>().isTargetWindows()) { - if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering() - ->supportSwiftError() && - MF->getFunction().getAttributes().hasAttrSomewhere( - Attribute::SwiftError)) - return CSR_Win_AArch64_AAPCS_SwiftError_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::SwiftTail) + + case CallingConv::SwiftTail: + if (Darwin) + return CSR_Darwin_AArch64_AAPCS_SwiftTail_SaveList; + if (Windows) return CSR_Win_AArch64_AAPCS_SwiftTail_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::AArch64_VectorCall) + return CSR_AArch64_AAPCS_SwiftTail_SaveList; + + case CallingConv::AArch64_VectorCall: + if (Darwin) + return CSR_Darwin_AArch64_AAVPCS_SaveList; + if (Windows) return CSR_Win_AArch64_AAVPCS_SaveList; - if (AFI.hasSVE_AAPCS(*MF)) - return CSR_Win_AArch64_SVE_AAPCS_SaveList; - return CSR_Win_AArch64_AAPCS_SaveList; - } - if (MF->getFunction().getCallingConv() == CallingConv::AArch64_VectorCall) return CSR_AArch64_AAVPCS_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) + + case CallingConv::AArch64_SVE_VectorCall: + if (Darwin) + report_fatal_error( + "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (Windows) + return CSR_Win_AArch64_SVE_AAPCS_SaveList; return CSR_AArch64_SVE_AAPCS_SaveList; - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: report_fatal_error( "Calling convention " "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is only " "supported to improve calls to SME ACLE save/restore/disable-za " "functions, and is not intended to be used beyond that scope."); - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) + + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1: report_fatal_error( "Calling convention " "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 is " "only supported to improve calls to SME ACLE __arm_get_current_vg " "function, and is not intended to be used beyond that scope."); - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: report_fatal_error( "Calling convention " "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " "only supported to improve calls to SME ACLE __arm_sme_state " "and is not intended to be used beyond that scope."); - if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering() - ->supportSwiftError() && - MF->getFunction().getAttributes().hasAttrSomewhere( - Attribute::SwiftError)) - return CSR_AArch64_AAPCS_SwiftError_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::SwiftTail) - return CSR_AArch64_AAPCS_SwiftTail_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::PreserveMost) - return CSR_AArch64_RT_MostRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::PreserveAll) - return CSR_AArch64_RT_AllRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::Win64) - // This is for OSes other than Windows; Windows is a separate case further - // above. + + case CallingConv::Win64: + if (Darwin) + return CSR_Darwin_AArch64_AAPCS_Win64_SaveList; + if (Windows) + return CSR_Win_AArch64_AAPCS_SaveList; return CSR_AArch64_AAPCS_X18_SaveList; - if (AFI.hasSVE_AAPCS(*MF)) - return CSR_AArch64_SVE_AAPCS_SaveList; - return CSR_AArch64_AAPCS_SaveList; -} -const MCPhysReg * -AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const { - assert(MF && "Invalid MachineFunction pointer."); - assert(MF->getSubtarget<AArch64Subtarget>().isTargetDarwin() && - "Invalid subtarget for getDarwinCalleeSavedRegs"); - auto &AFI = *MF->getInfo<AArch64FunctionInfo>(); + case CallingConv::CXX_FAST_TLS: + if (Darwin) + return AFI.isSplitCSR() ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList + : CSR_Darwin_AArch64_CXX_TLS_SaveList; + // FIXME: this likely should be a `report_fatal_error` condition, however, + // that would be a departure from the previously implemented behaviour. + LLVM_FALLTHROUGH; - if (MF->getFunction().getCallingConv() == CallingConv::CFGuard_Check) - report_fatal_error( - "Calling convention CFGuard_Check is unsupported on Darwin."); - if (MF->getFunction().getCallingConv() == CallingConv::AArch64_VectorCall) - return CSR_Darwin_AArch64_AAVPCS_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) - report_fatal_error( - "Calling convention SVE_VectorCall is unsupported on Darwin."); - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) - report_fatal_error( - "Calling convention " - "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " - "only supported to improve calls to SME ACLE save/restore/disable-za " - "functions, and is not intended to be used beyond that scope."); - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) - report_fatal_error( - "Calling convention " - "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 is " - "only supported to improve calls to SME ACLE __arm_get_current_vg " - "function, and is not intended to be used beyond that scope."); - if (MF->getFunction().getCallingConv() == - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) - report_fatal_error( - "Calling convention " - "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " - "only supported to improve calls to SME ACLE __arm_sme_state " - "and is not intended to be used beyond that scope."); - if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) - return MF->getInfo<AArch64FunctionInfo>()->isSplitCSR() - ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList - : CSR_Darwin_AArch64_CXX_TLS_SaveList; - if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering() - ->supportSwiftError() && - MF->getFunction().getAttributes().hasAttrSomewhere( - Attribute::SwiftError)) - return CSR_Darwin_AArch64_AAPCS_SwiftError_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::SwiftTail) - return CSR_Darwin_AArch64_AAPCS_SwiftTail_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::PreserveMost) - return CSR_Darwin_AArch64_RT_MostRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::PreserveAll) - return CSR_Darwin_AArch64_RT_AllRegs_SaveList; - if (MF->getFunction().getCallingConv() == CallingConv::Win64) - return CSR_Darwin_AArch64_AAPCS_Win64_SaveList; - if (AFI.hasSVE_AAPCS(*MF)) - return CSR_Darwin_AArch64_SVE_AAPCS_SaveList; - return CSR_Darwin_AArch64_AAPCS_SaveList; + default: + if (Darwin) + return AFI.hasSVE_AAPCS(*MF) ? CSR_Darwin_AArch64_SVE_AAPCS_SaveList + : CSR_Darwin_AArch64_AAPCS_SaveList; + if (Windows) + return AFI.hasSVE_AAPCS(*MF) ? CSR_Win_AArch64_SVE_AAPCS_SaveList + : CSR_Win_AArch64_AAPCS_SaveList; + return AFI.hasSVE_AAPCS(*MF) ? CSR_AArch64_SVE_AAPCS_SaveList + : CSR_AArch64_AAPCS_SaveList; + } } const MCPhysReg *AArch64RegisterInfo::getCalleeSavedRegsViaCopy( @@ -620,7 +604,7 @@ AArch64RegisterInfo::getCrossCopyRegClass(const TargetRegisterClass *RC) const { return RC; } -unsigned AArch64RegisterInfo::getBaseRegister() const { return AArch64::X19; } +MCRegister AArch64RegisterInfo::getBaseRegister() const { return AArch64::X19; } bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const { const MachineFrameInfo &MFI = MF.getFrameInfo(); @@ -891,7 +875,7 @@ AArch64RegisterInfo::materializeFrameBaseRegister(MachineBasicBlock *MBB, const MCInstrDesc &MCID = TII->get(AArch64::ADDXri); MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo(); Register BaseReg = MRI.createVirtualRegister(&AArch64::GPR64spRegClass); - MRI.constrainRegClass(BaseReg, TII->getRegClass(MCID, 0, this)); + MRI.constrainRegClass(BaseReg, TII->getRegClass(MCID, 0)); unsigned Shifter = AArch64_AM::getShifterImm(AArch64_AM::LSL, 0); BuildMI(*MBB, Ins, DL, MCID, BaseReg) @@ -1117,24 +1101,89 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC, } } -// FORM_TRANSPOSED_REG_TUPLE nodes are created to improve register allocation -// where a consecutive multi-vector tuple is constructed from the same indices -// of multiple strided loads. This may still result in unnecessary copies -// between the loads and the tuple. Here we try to return a hint to assign the -// contiguous ZPRMulReg starting at the same register as the first operand of -// the pseudo, which should be a subregister of the first strided load. +// We add regalloc hints for different cases: +// * Choosing a better destination operand for predicated SVE instructions +// where the inactive lanes are undef, by choosing a register that is not +// unique to the other operands of the instruction. +// +// * Improve register allocation for SME multi-vector instructions where we can +// benefit from the strided- and contiguous register multi-vector tuples. // -// For example, if the first strided load has been assigned $z16_z20_z24_z28 -// and the operands of the pseudo are each accessing subregister zsub2, we -// should look through through Order to find a contiguous register which -// begins with $z24 (i.e. $z24_z25_z26_z27). +// Here FORM_TRANSPOSED_REG_TUPLE nodes are created to improve register +// allocation where a consecutive multi-vector tuple is constructed from the +// same indices of multiple strided loads. This may still result in +// unnecessary copies between the loads and the tuple. Here we try to return a +// hint to assign the contiguous ZPRMulReg starting at the same register as +// the first operand of the pseudo, which should be a subregister of the first +// strided load. // +// For example, if the first strided load has been assigned $z16_z20_z24_z28 +// and the operands of the pseudo are each accessing subregister zsub2, we +// should look through through Order to find a contiguous register which +// begins with $z24 (i.e. $z24_z25_z26_z27). bool AArch64RegisterInfo::getRegAllocationHints( Register VirtReg, ArrayRef<MCPhysReg> Order, SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF, const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const { - auto &ST = MF.getSubtarget<AArch64Subtarget>(); + const AArch64InstrInfo *TII = + MF.getSubtarget<AArch64Subtarget>().getInstrInfo(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + + // For predicated SVE instructions where the inactive lanes are undef, + // pick a destination register that is not unique to avoid introducing + // a movprfx. + const TargetRegisterClass *RegRC = MRI.getRegClass(VirtReg); + if (AArch64::ZPRRegClass.hasSubClassEq(RegRC)) { + bool ConsiderOnlyHints = TargetRegisterInfo::getRegAllocationHints( + VirtReg, Order, Hints, MF, VRM); + + for (const MachineOperand &DefOp : MRI.def_operands(VirtReg)) { + const MachineInstr &Def = *DefOp.getParent(); + if (DefOp.isImplicit() || + (TII->get(Def.getOpcode()).TSFlags & AArch64::FalseLanesMask) != + AArch64::FalseLanesUndef) + continue; + + unsigned InstFlags = + TII->get(AArch64::getSVEPseudoMap(Def.getOpcode())).TSFlags; + + for (MCPhysReg R : Order) { + auto AddHintIfSuitable = [&](MCPhysReg R, + const MachineOperand &MO) -> bool { + // R is a suitable register hint if R can reuse one of the other + // source operands. + if (VRM->getPhys(MO.getReg()) != R) + return false; + Hints.push_back(R); + return true; + }; + + switch (InstFlags & AArch64::DestructiveInstTypeMask) { + default: + break; + case AArch64::DestructiveTernaryCommWithRev: + AddHintIfSuitable(R, Def.getOperand(2)) || + AddHintIfSuitable(R, Def.getOperand(3)) || + AddHintIfSuitable(R, Def.getOperand(4)); + break; + case AArch64::DestructiveBinaryComm: + case AArch64::DestructiveBinaryCommWithRev: + AddHintIfSuitable(R, Def.getOperand(2)) || + AddHintIfSuitable(R, Def.getOperand(3)); + break; + case AArch64::DestructiveBinary: + case AArch64::DestructiveBinaryImm: + AddHintIfSuitable(R, Def.getOperand(2)); + break; + } + } + } + + if (Hints.size()) + return ConsiderOnlyHints; + } + if (!ST.hasSME() || !ST.isStreaming()) return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM); @@ -1147,8 +1196,7 @@ bool AArch64RegisterInfo::getRegAllocationHints( // FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy // instructions over reducing the number of clobbered callee-save registers, // so we add the strided registers as a hint. - const MachineRegisterInfo &MRI = MF.getRegInfo(); - unsigned RegID = MRI.getRegClass(VirtReg)->getID(); + unsigned RegID = RegRC->getID(); if (RegID == AArch64::ZPR2StridedOrContiguousRegClassID || RegID == AArch64::ZPR4StridedOrContiguousRegClassID) { diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h index 47d76f3..89d1802 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h @@ -46,7 +46,6 @@ public: /// Code Generation virtual methods... const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override; - const MCPhysReg *getDarwinCalleeSavedRegs(const MachineFunction *MF) const; const MCPhysReg * getCalleeSavedRegsViaCopy(const MachineFunction *MF) const; const uint32_t *getCallPreservedMask(const MachineFunction &MF, @@ -124,7 +123,7 @@ public: bool requiresVirtualBaseRegisters(const MachineFunction &MF) const override; bool hasBasePointer(const MachineFunction &MF) const; - unsigned getBaseRegister() const; + MCRegister getBaseRegister() const; bool isArgumentRegister(const MachineFunction &MF, MCRegister Reg) const override; diff --git a/llvm/lib/Target/AArch64/AArch64SIMDInstrOpt.cpp b/llvm/lib/Target/AArch64/AArch64SIMDInstrOpt.cpp index d695f26..b4a4f4c 100644 --- a/llvm/lib/Target/AArch64/AArch64SIMDInstrOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64SIMDInstrOpt.cpp @@ -33,6 +33,7 @@ //===----------------------------------------------------------------------===// #include "AArch64InstrInfo.h" +#include "AArch64Subtarget.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -49,8 +50,8 @@ #include "llvm/MC/MCInstrDesc.h" #include "llvm/MC/MCSchedule.h" #include "llvm/Pass.h" -#include <unordered_map> #include <map> +#include <unordered_map> using namespace llvm; @@ -67,7 +68,7 @@ namespace { struct AArch64SIMDInstrOpt : public MachineFunctionPass { static char ID; - const TargetInstrInfo *TII; + const AArch64InstrInfo *TII; MachineRegisterInfo *MRI; TargetSchedModel SchedModel; @@ -694,13 +695,9 @@ bool AArch64SIMDInstrOpt::runOnMachineFunction(MachineFunction &MF) { if (skipFunction(MF.getFunction())) return false; - TII = MF.getSubtarget().getInstrInfo(); MRI = &MF.getRegInfo(); - const TargetSubtargetInfo &ST = MF.getSubtarget(); - const AArch64InstrInfo *AAII = - static_cast<const AArch64InstrInfo *>(ST.getInstrInfo()); - if (!AAII) - return false; + const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>(); + TII = ST.getInstrInfo(); SchedModel.init(&ST); if (!SchedModel.hasInstrSchedModel()) return false; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/AArch64SMEAttributes.cpp index 085c8588..085c8588 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/AArch64SMEAttributes.cpp diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/AArch64SMEAttributes.h index 28c397e..28c397e 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/AArch64SMEAttributes.h diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 752b1858..b099f15 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -102,25 +102,32 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)), let hasSideEffects = 1, isMeta = 1 in { def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; + def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>; } def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>; def CommitZASavePseudo : Pseudo<(outs), - (ins GPR64:$tpidr2_el0, i1imm:$zero_za, i64imm:$commit_routine, variable_ops), []>, + (ins GPR64:$tpidr2_el0, i1imm:$zero_za, i1imm:$zero_zt0, + i64imm:$commit_routine, variable_ops), []>, Sched<[]>; def AArch64_inout_za_use : SDNode<"AArch64ISD::INOUT_ZA_USE", SDTypeProfile<0, 0,[]>, - [SDNPHasChain, SDNPInGlue]>; + [SDNPHasChain, SDNPInGlue, SDNPOutGlue]>; def : Pat<(AArch64_inout_za_use), (InOutZAUsePseudo)>; def AArch64_requires_za_save : SDNode<"AArch64ISD::REQUIRES_ZA_SAVE", SDTypeProfile<0, 0,[]>, - [SDNPHasChain, SDNPInGlue]>; + [SDNPHasChain, SDNPInGlue, SDNPOutGlue]>; def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>; +def AArch64_requires_zt0_save + : SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>, + [SDNPHasChain, SDNPInGlue, SDNPOutGlue]>; +def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>; + def AArch64_sme_state_alloc : SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>, [SDNPHasChain]>; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 3b268dc..c923b6e 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -341,11 +341,11 @@ def AArch64urhadd : PatFrags<(ops node:$pg, node:$op1, node:$op2), def AArch64saba : PatFrags<(ops node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_saba node:$op1, node:$op2, node:$op3), - (add node:$op1, (AArch64sabd_p (SVEAllActive), node:$op2, node:$op3))]>; + (add node:$op1, (AArch64sabd_p (SVEAnyPredicate), node:$op2, node:$op3))]>; def AArch64uaba : PatFrags<(ops node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_uaba node:$op1, node:$op2, node:$op3), - (add node:$op1, (AArch64uabd_p (SVEAllActive), node:$op2, node:$op3))]>; + (add node:$op1, (AArch64uabd_p (SVEAnyPredicate), node:$op2, node:$op3))]>; def AArch64usra : PatFrags<(ops node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_usra node:$op1, node:$op2, node:$op3), @@ -375,6 +375,11 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm), node:$Zm) ]>; +def AArch64fdot : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm), + [(int_aarch64_sve_fdot_x2 node:$Zd, node:$Zn, node:$Zm), + (partial_reduce_fmla node:$Zd, node:$Zn, node:$Zm) + ]>; + def SDT_AArch64FCVT : SDTypeProfile<1, 3, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3> @@ -457,6 +462,7 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx), def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))), + (AArch64fma_p node:$pg, node:$zn, (AArch64fneg_mt node:$pg, node:$zm, (undef)), (AArch64fneg_mt node:$pg, node:$za, (undef))), (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>; def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), @@ -984,7 +990,7 @@ let Predicates = [HasSVE_or_SME] in { (DUP_ZR_D (MOVi64imm (bitcast_fpimm_to_i64 f64:$val)))>; // Duplicate FP immediate into all vector elements - let AddedComplexity = 2 in { + let AddedComplexity = 3 in { def : Pat<(nxv8f16 (splat_vector fpimm16:$imm8)), (FDUP_ZI_H fpimm16:$imm8)>; def : Pat<(nxv4f16 (splat_vector fpimm16:$imm8)), @@ -2578,6 +2584,11 @@ let Predicates = [HasBF16, HasSVE_or_SME] in { defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>; defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>; + def : Pat<(nxv4f32 (AArch64fmla_p (SVEAllActive), nxv4f32:$acc, + (nxv4f32 (AArch64fcvte_mt (SVEAllActive), nxv4bf16:$Zn, (undef))), + (nxv4f32 (AArch64fcvte_mt (SVEAllActive), nxv4bf16:$Zm, (undef))))), + (BFMLALB_ZZZ nxv4f32:$acc, ZPR:$Zn, ZPR:$Zm)>; + defm BFCVT_ZPmZ : sve_bfloat_convert<"bfcvt", int_aarch64_sve_fcvt_bf16f32_v2, AArch64fcvtr_mt>; defm BFCVTNT_ZPmZ : sve_bfloat_convert_top<"bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32_v2>; } // End HasBF16, HasSVE_or_SME @@ -3592,6 +3603,18 @@ let Predicates = [HasSVE_or_SME] in { def : Pat<(sext (i32 (vector_extract nxv4i32:$vec, VectorIndexS:$index))), (SMOVvi32to64 (v4i32 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexS:$index)>; + + // Extracts of ``unsigned'' i8 or i16 elements lead to the zero-extend being + // transformed to an AND mask. The mask is redundant since UMOV already zeroes + // the high bits of the destination register. + def : Pat<(i32 (and (vector_extract nxv16i8:$vec, VectorIndexB:$index), 0xff)), + (UMOVvi8 (v16i8 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexB:$index)>; + def : Pat<(i32 (and (vector_extract nxv8i16:$vec, VectorIndexH:$index), 0xffff)), + (UMOVvi16 (v8i16 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexH:$index)>; + def : Pat<(i64 (and (i64 (anyext (i32 (vector_extract nxv16i8:$vec, VectorIndexB:$index)))), (i64 0xff))), + (SUBREG_TO_REG (i64 0), (i32 (UMOVvi8 (v16i8 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexB:$index)), sub_32)>; + def : Pat<(i64 (and (i64 (anyext (i32 (vector_extract nxv8i16:$vec, VectorIndexH:$index)))), (i64 0xffff))), + (SUBREG_TO_REG (i64 0), (i32 (UMOVvi16 (v8i16 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexH:$index)), sub_32)>; } // End HasNEON // Extract first element from vector. @@ -3684,7 +3707,7 @@ let Predicates = [HasSVE, HasMatMulFP32] in { } // End HasSVE, HasMatMulFP32 let Predicates = [HasSVE_F16F32MM] in { - def FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b001, "fmmla", ZPR32, ZPR16>; + defm FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b001, "fmmla", ZPR32, ZPR16, int_aarch64_sve_fmmla, nxv4f32, nxv8f16>; } // End HasSVE_F16F32MM let Predicates = [HasSVE, HasMatMulFP64] in { @@ -3863,12 +3886,12 @@ let Predicates = [HasSVE2_or_SME] in { defm SQRSHLR_ZPmZ : sve2_int_arith_pred<0b011100, "sqrshlr", null_frag, "SQRSHLR_ZPZZ", DestructiveBinaryCommWithRev, "SQRSHL_ZPmZ", /*isReverseInstr*/ 1>; defm UQRSHLR_ZPmZ : sve2_int_arith_pred<0b011110, "uqrshlr", null_frag, "UQRSHLR_ZPZZ", DestructiveBinaryCommWithRev, "UQRSHL_ZPmZ", /*isReverseInstr*/ 1>; - defm SRSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_srshl>; - defm URSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_urshl>; - defm SQSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_sqshl>; - defm UQSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_uqshl>; - defm SQRSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_sqrshl>; - defm UQRSHL_ZPZZ : sve_int_bin_pred_all_active_bhsd<int_aarch64_sve_uqrshl>; + defm SRSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_srshl_u>; + defm URSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_urshl_u>; + defm SQSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_sqshl_u>; + defm UQSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_uqshl_u>; + defm SQRSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_sqrshl_u>; + defm UQRSHL_ZPZZ : sve_int_bin_pred_bhsd<int_aarch64_sve_uqrshl_u>; } // End HasSVE2_or_SME let Predicates = [HasSVE2_or_SME, UseExperimentalZeroingPseudos] in { @@ -3887,6 +3910,9 @@ let Predicates = [HasSVE2_or_SME] in { defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", AArch64urshri_p>; defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left< 0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>; + defm SQSHL_ZPZI : sve_int_shift_pred_bhsd<int_aarch64_sve_sqshl_u, SVEShiftImmL8, SVEShiftImmL16, SVEShiftImmL32, SVEShiftImmL64>; + defm UQSHL_ZPZI : sve_int_shift_pred_bhsd<int_aarch64_sve_uqshl_u, SVEShiftImmL8, SVEShiftImmL16, SVEShiftImmL32, SVEShiftImmL64>; + // SVE2 integer add/subtract long defm SADDLB_ZZZ : sve2_wide_int_arith_long<0b00000, "saddlb", int_aarch64_sve_saddlb>; defm SADDLT_ZZZ : sve2_wide_int_arith_long<0b00001, "saddlt", int_aarch64_sve_saddlt>; @@ -4251,7 +4277,7 @@ defm PSEL_PPPRI : sve2_int_perm_sel_p<"psel", int_aarch64_sve_psel>; let Predicates = [HasSVE2p1_or_SME2] in { defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>; -defm FDOT_ZZZ_S : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>; +defm FDOT_ZZZ_S : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, AArch64fdot>; defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>; defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>; @@ -4744,11 +4770,11 @@ defm FMLALLTT_ZZZ : sve2_fp8_mla<0b011, ZPR32, "fmlalltt", nxv4f32, int_aarch64_ } // End HasSSVE_FP8FMA let Predicates = [HasSVE2, HasF8F32MM] in { - def FMMLA_ZZZ_BtoS : sve2_fp8_mmla<0b0, ZPR32, "fmmla">; + defm FMMLA_ZZZ_BtoS : sve2_fp8_fmmla<0b0, ZPR32, "fmmla", nxv4f32>; } let Predicates = [HasSVE2, HasF8F16MM] in { - def FMMLA_ZZZ_BtoH : sve2_fp8_mmla<0b1, ZPR16, "fmmla">; + defm FMMLA_ZZZ_BtoH : sve2_fp8_fmmla<0b1, ZPR16, "fmmla", nxv8f16>; } let Predicates = [HasSSVE_FP8DOT2] in { diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN1.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN1.td index 50142af..80e5bff 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN1.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN1.td @@ -286,9 +286,6 @@ def : SchedAlias<WriteBrReg, N1Write_1c_1B>; // Branch and link, register def : InstRW<[N1Write_1c_1B_1I], (instrs BL, BLR)>; -// Compare and branch -def : InstRW<[N1Write_1c_1B], (instregex "^[CT]BN?Z[XW]$")>; - // Arithmetic and Logical Instructions // ----------------------------------------------------------------------------- diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td index 50f1011..a02130f 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td @@ -72,6 +72,10 @@ def : WriteRes<WriteLDHi, []> { let Latency = 4; } // Define customized scheduler read/write types specific to the Neoverse N2. //===----------------------------------------------------------------------===// + +// Define generic 0 micro-op types +def N2Write_0c : SchedWriteRes<[]> { let Latency = 0; } + // Define generic 1 micro-op types def N2Write_1c_1B : SchedWriteRes<[N2UnitB]> { let Latency = 1; } @@ -646,6 +650,21 @@ def N2Write_11c_9L01_9S_9V : SchedWriteRes<[N2UnitL01, N2UnitL01, N2UnitL01, } //===----------------------------------------------------------------------===// +// Define predicate-controlled types + +def N2Write_0or1c_1I : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N2Write_0c]>, + SchedVar<NoSchedPred, [N2Write_1c_1I]>]>; + +def N2Write_0or2c_1V : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N2Write_0c]>, + SchedVar<NoSchedPred, [N2Write_2c_1V]>]>; + +def N2Write_0or3c_1M0 : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N2Write_0c]>, + SchedVar<NoSchedPred, [N2Write_3c_1M0]>]>; + +//===----------------------------------------------------------------------===// // Define types for arithmetic and logical ops with short shifts def N2Write_Arith : SchedWriteVariant<[ SchedVar<IsCheapLSL, [N2Write_1c_1I]>, @@ -680,6 +699,7 @@ def : InstRW<[N2Write_1c_1B_1S], (instrs BL, BLR)>; // ALU, basic // ALU, basic, flagset def : SchedAlias<WriteI, N2Write_1c_1I>; +def : InstRW<[N2Write_0or1c_1I], (instregex "^MOVZ[WX]i$")>; // ALU, extend and shift def : SchedAlias<WriteIEReg, N2Write_2c_1M>; @@ -691,7 +711,8 @@ def : SchedAlias<WriteISReg, N2Write_Arith>; // Logical, shift, no flagset def : InstRW<[N2Write_1c_1I], - (instregex "^(AND|BIC|EON|EOR|ORN|ORR)[WX]rs$")>; + (instregex "^(AND|BIC|EON|EOR|ORN)[WX]rs$")>; +def : InstRW<[N2Write_0or1c_1I], (instregex "^ORR[WX]rs$")>; // Logical, shift, flagset def : InstRW<[N2Write_Logical], (instregex "^(AND|BIC)S[WX]rs$")>; @@ -882,8 +903,7 @@ def : SchedAlias<WriteFImm, N2Write_2c_1V>; def : InstRW<[N2Write_2c_1V], (instrs FMOVHr, FMOVSr, FMOVDr)>; // FP transfer, from gen to low half of vec reg -def : InstRW<[N2Write_3c_1M0], (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr, - FMOVHWr, FMOVHXr, FMOVSWr, FMOVDXr)>; +def : InstRW<[N2Write_0or3c_1M0], (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr)>; // FP transfer, from gen to high half of vec reg def : InstRW<[N2Write_5c_1M0_1V], (instrs FMOVXDHighr)>; @@ -1225,6 +1245,8 @@ def : InstRW<[N2Write_3c_1V0], (instrs BFCVT)>; // ASIMD unzip/zip // Handled by SchedAlias<WriteV[dq], ...> +def : InstRW<[N2Write_0or2c_1V], (instrs MOVID, MOVIv2d_ns)>; + // ASIMD duplicate, gen reg def : InstRW<[N2Write_3c_1M0], (instregex "^DUPv.+gpr")>; diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td index 411b372..22e6d11 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td @@ -49,6 +49,12 @@ def N3UnitM : ProcResGroup<[N3UnitM0, N3UnitM1]>; def N3UnitL : ProcResGroup<[N3UnitL01, N3UnitL2]>; def N3UnitI : ProcResGroup<[N3UnitS, N3UnitM0, N3UnitM1]>; +// Group required for modelling SVE gather loads throughput +def N3UnitVL : ProcResGroup<[N3UnitL01, N3UnitV0, N3UnitV1]>; +// Unused group to fix: "error: proc resource group overlaps with N3UnitVL but +// no supergroup contains both." +def : ProcResGroup<[N3UnitL01, N3UnitL2, N3UnitV0, N3UnitV1]>; + //===----------------------------------------------------------------------===// def : ReadAdvance<ReadI, 0>; @@ -75,7 +81,7 @@ def : WriteRes<WriteHint, []> { let Latency = 1; } def N3Write_0c : SchedWriteRes<[]> { let Latency = 0; - let NumMicroOps = 0; + let NumMicroOps = 1; } def N3Write_4c : SchedWriteRes<[]> { @@ -321,6 +327,12 @@ def N3Write_6c_2I_2L : SchedWriteRes<[N3UnitI, N3UnitI, N3UnitL, N3UnitL]> { let NumMicroOps = 4; } +def N3Write_6c_2L01_2V : SchedWriteRes<[N3UnitVL]> { + let Latency = 6; + let NumMicroOps = 4; + let ReleaseAtCycles = [5]; +} + def N3Write_6c_4V0 : SchedWriteRes<[N3UnitV0, N3UnitV0, N3UnitV0, N3UnitV0]> { let Latency = 6; let NumMicroOps = 4; @@ -553,6 +565,126 @@ def N3Write_16c_16V0 : SchedWriteRes<[N3UnitV0, N3UnitV0, N3UnitV0, N3UnitV0, let NumMicroOps = 16; } + +//===----------------------------------------------------------------------===// +// Define predicate-controlled types + +def N3Write_0or1c_1I : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N3Write_0c]>, + SchedVar<NoSchedPred, [N3Write_1c_1I]>]>; + +def N3Write_0or2c_1V : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N3Write_0c]>, + SchedVar<NoSchedPred, [N3Write_2c_1V]>]>; + +def N3Write_0or2c_1M : SchedWriteVariant<[ + SchedVar<NeoverseAllActivePredicate, [N3Write_0c]>, + SchedVar<NoSchedPred, [N3Write_2c_1M]>]>; + +def N3Write_0or3c_1M0 : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [N3Write_0c]>, + SchedVar<NoSchedPred, [N3Write_3c_1M0]>]>; +//===----------------------------------------------------------------------===// +// Define forwarded types +// NOTE: SOG, p. 19, n. 2: Accumulator forwarding is not supported for +// consumers of 64 bit multiply high operations? + +def N3Wr_FMA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_FMA : SchedReadAdvance<2, [WriteFMul, N3Wr_FMA]>; + +def N3Wr_VMA : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_VMA : SchedReadAdvance<3, [N3Wr_VMA]>; + +def N3Wr_VMAL : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_VMAL : SchedReadAdvance<3, [N3Wr_VMAL]>; + +def N3Wr_VMAH : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_VMAH : SchedReadAdvance<2, [N3Wr_VMAH]>; + +def N3Wr_VMASL : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_VMASL : SchedReadAdvance<2, [N3Wr_VMASL]>; + +def N3Wr_ADA : SchedWriteRes<[N3UnitV1]> { let Latency = 4; } +def N3Rd_ADA : SchedReadAdvance<3, [N3Wr_ADA]>; + +def N3Wr_VDOT : SchedWriteRes<[N3UnitV]> { let Latency = 3; } +def N3Rd_VDOT : SchedReadAdvance<2, [N3Wr_VDOT]>; + +def N3Wr_VMMA : SchedWriteRes<[N3UnitV]> { let Latency = 3; } +def N3Rd_VMMA : SchedReadAdvance<2, [N3Wr_VMMA]>; + +def N3Wr_FCMA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_FCMA : SchedReadAdvance<2, [N3Wr_FCMA]>; + +def N3Wr_FPM : SchedWriteRes<[N3UnitV]> { let Latency = 3; } +def N3Wr_FPMA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_FPMA : SchedReadAdvance<2, [N3Wr_FPM, N3Wr_FPMA]>; + +def N3Wr_FPMAL : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_FPMAL : SchedReadAdvance<2, [N3Wr_FPMAL]>; + +def N3Wr_BFD : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_BFD : SchedReadAdvance<2, [N3Wr_BFD]>; + +def N3Wr_BFMMA : SchedWriteRes<[N3UnitV]> { let Latency = 5; } +def N3Rd_BFMMA : SchedReadAdvance<2, [N3Wr_BFMMA]>; + +def N3Wr_BFMLA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_BFMLA : SchedReadAdvance<2, [N3Wr_BFMLA]>; + +def N3Wr_CRC : SchedWriteRes<[N3UnitM0]> { let Latency = 2; } +def N3Rd_CRC : SchedReadAdvance<1, [N3Wr_CRC]>; + +def N3Wr_ZA : SchedWriteRes<[N3UnitV1]> { let Latency = 4; } +def N3Rd_ZA : SchedReadAdvance<3, [N3Wr_ZA]>; +def N3Wr_ZPA : SchedWriteRes<[N3UnitV1]> { let Latency = 4; } +def N3Rd_ZPA : SchedReadAdvance<3, [N3Wr_ZPA]>; +def N3Wr_ZSA : SchedWriteRes<[N3UnitV1]> { let Latency = 4; } +def N3Rd_ZSA : SchedReadAdvance<3, [N3Wr_ZSA]>; + +def N3Wr_ZDOTB : SchedWriteRes<[N3UnitV]> { let Latency = 3; } +def N3Rd_ZDOTB : SchedReadAdvance<2, [N3Wr_ZDOTB]>; +def N3Wr_ZDOTH : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_ZDOTH : SchedReadAdvance<3, [N3Wr_ZDOTH]>; + +def N3Wr_ZCMABHS : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_ZCMABHS : SchedReadAdvance<3, [N3Wr_ZCMABHS]>; +def N3Wr_ZCMAD : SchedWriteRes<[N3UnitV0, N3UnitV0]> { let Latency = 5; } +def N3Rd_ZCMAD : SchedReadAdvance<2, [N3Wr_ZCMAD]>; + +def N3Wr_ZMMA : SchedWriteRes<[N3UnitV]> { let Latency = 3; } +def N3Rd_ZMMA : SchedReadAdvance<2, [N3Wr_ZMMA]>; + +def N3Wr_ZMABHS : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_ZMABHS : SchedReadAdvance<3, [N3Wr_ZMABHS]>; +def N3Wr_ZMAD : SchedWriteRes<[N3UnitV0, N3UnitV0]> { let Latency = 5; } +def N3Rd_ZMAD : SchedReadAdvance<2, [N3Wr_ZMAD]>; + +def N3Wr_ZMAL : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Rd_ZMAL : SchedReadAdvance<3, [N3Wr_ZMAL]>; + +def N3Wr_ZMASQL : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Wr_ZMASQBHS : SchedWriteRes<[N3UnitV0]> { let Latency = 4; } +def N3Wr_ZMASQD : SchedWriteRes<[N3UnitV0, N3UnitV0]> { let Latency = 5; } +def N3Rd_ZMASQ : SchedReadAdvance<2, [N3Wr_ZMASQL, N3Wr_ZMASQBHS, + N3Wr_ZMASQD]>; + +def N3Wr_ZFCMA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_ZFCMA : SchedReadAdvance<2, [N3Wr_ZFCMA]>; + +def N3Wr_ZFMA : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_ZFMA : SchedReadAdvance<2, [N3Wr_ZFMA]>; + +def N3Wr_ZFMAL : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_ZFMAL : SchedReadAdvance<2, [N3Wr_ZFMAL]>; + +def N3Wr_ZBFDOT : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_ZBFDOT : SchedReadAdvance<2, [N3Wr_ZBFDOT]>; +def N3Wr_ZBFMMA : SchedWriteRes<[N3UnitV]> { let Latency = 5; } +def N3Rd_ZBFMMA : SchedReadAdvance<2, [N3Wr_ZBFMMA]>; +def N3Wr_ZBFMAL : SchedWriteRes<[N3UnitV]> { let Latency = 4; } +def N3Rd_ZBFMAL : SchedReadAdvance<2, [N3Wr_ZBFMAL]>; + // Miscellaneous // ----------------------------------------------------------------------------- @@ -581,6 +713,7 @@ def : InstRW<[N3Write_1c_1B_1S], (instrs BL, BLR)>; // Conditional compare // Conditional select def : SchedAlias<WriteI, N3Write_1c_1I>; +def : InstRW<[N3Write_0or1c_1I], (instregex "^MOVZ[WX]i$")>; // ALU, extend and shift def : SchedAlias<WriteIEReg, N3Write_2c_1M>; @@ -610,7 +743,8 @@ def : InstRW<[N3Write_1c_1I], (instrs GMI, SUBP, SUBPS)>; // Logical, shift, no flagset def : InstRW<[N3Write_1c_1I], - (instregex "^(AND|BIC|EON|EOR|ORN|ORR)[WX]rs$")>; + (instregex "^(AND|BIC|EON|EOR|ORN)[WX]rs$")>; +def : InstRW<[N3Write_0or1c_1I], (instregex "^ORR[WX]rs$")>; // Logical, shift, flagset def : InstRW<[N3Write_2c_1M], (instregex "^(AND|BIC)S[WX]rs$")>; @@ -832,10 +966,11 @@ def : SchedAlias<WriteFDiv , N3Write_7c_1V0>; def : InstRW<[N3Write_12c_1V0], (instrs FDIVDrr, FSQRTDr)>; // FP multiply -def : SchedAlias<WriteFMul, N3Write_3c_1V>; +def : WriteRes<WriteFMul, [N3UnitV]> { let Latency = 3; } // FP multiply accumulate -def : InstRW<[N3Write_4c_1V], (instregex "^(FMADD|FMSUB|FNMADD|FNMSUB)[DHS]rrr$")>; +def : InstRW<[N3Wr_FMA, ReadDefault, ReadDefault, N3Rd_FMA], + (instregex "^(FMADD|FMSUB|FNMADD|FNMSUB)[DHS]rrr$")>; // FP round to integral def : InstRW<[N3Write_3c_1V0], (instregex "^FRINT([AIMNPXZ]|32X|64X|32Z|64Z)[DHS]r$")>; @@ -855,10 +990,11 @@ def : SchedAlias<WriteFCvt, N3Write_3c_1V0>; def : SchedAlias<WriteFImm, N3Write_2c_1V>; // FP move, register -def : InstRW<[N3Write_2c_1V], (instrs FMOVHr, FMOVSr, FMOVDr)>; +def : InstRW<[N3Write_2c_1V], (instrs FMOVHr)>; +def : InstRW<[N3Write_0c], (instrs FMOVSr, FMOVDr)>; // FP transfer, from gen to low half of vec reg -def : InstRW<[N3Write_3c_1M0], (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr)>; +def : InstRW<[N3Write_0or3c_1M0], (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr)>; // FP transfer, from gen to high half of vec reg def : InstRW<[N3Write_5c_1M0_1V], (instrs FMOVXDHighr)>; @@ -962,6 +1098,8 @@ def : InstRW<[WriteAdr, N3Write_2c_1L01_1V_1I], (instregex "^STP[SDQ](post|pre)$ // ASIMD compare // ASIMD logical // ASIMD max/min, basic and pair-wise +def : InstRW<[N3Write_0or2c_1V], (instrs ORRv16i8, ORRv8i8)>; + def : SchedAlias<WriteVd, N3Write_2c_1V>; def : SchedAlias<WriteVq, N3Write_2c_1V>; @@ -969,9 +1107,9 @@ def : SchedAlias<WriteVq, N3Write_2c_1V>; // ASIMD absolute diff accum long // ASIMD pairwise add and accumulate long // ASIMD shift accumulate -def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]ABAL?v", +def : InstRW<[N3Wr_ADA, N3Rd_ADA], (instregex "^[SU]ABAL?v", "^[SU]ADALPv", - "^[SU]R?SRAv")>; + "^[SU]R?SRA(v|d)")>; // ASIMD arith, reduce, 4H/4S def : InstRW<[N3Write_3c_1V1], (instregex "^[SU]?ADDL?Vv4i(16|32)v$")>; @@ -984,10 +1122,11 @@ def : InstRW<[N3Write_6c_2V1], (instregex "^[SU]?ADDL?Vv16i8v$")>; // ASIMD dot product // ASIMD dot product using signed and unsigned integers -def : InstRW<[N3Write_3c_1V], (instregex "^([SU]|SU|US)DOT(lane)?(v8|v16)i8$")>; +def : InstRW<[N3Wr_VDOT, N3Rd_VDOT], + (instregex "^([SU]|SU|US)DOT(lane)?(v8|v16)i8$")>; // ASIMD matrix multiply-accumulate -def : InstRW<[N3Write_3c_1V], (instrs SMMLA, UMMLA, USMMLA)>; +def : InstRW<[N3Wr_VMMA, N3Rd_VMMA], (instrs SMMLA, UMMLA, USMMLA)>; // ASIMD max/min, reduce, 4H/4S def : InstRW<[N3Write_3c_1V1], (instregex "^[SU](MAX|MIN)Vv4i(16|32)v$")>; @@ -1002,39 +1141,39 @@ def : InstRW<[N3Write_6c_2V1], (instregex "[SU](MAX|MIN)Vv16i8v$")>; def : InstRW<[N3Write_4c_1V0], (instregex "^MULv", "^SQ(R)?DMULHv")>; // ASIMD multiply accumulate -def : InstRW<[N3Write_4c_1V0], (instregex "^MLAv", "^MLSv")>; +def : InstRW<[N3Wr_VMA, N3Rd_VMA], (instregex "^MLAv", "^MLSv")>; // ASIMD multiply accumulate high -def : InstRW<[N3Write_4c_1V0], (instregex "^SQRDMLAHv", "^SQRDMLSHv")>; +def : InstRW<[N3Wr_VMAH, N3Rd_VMAH], (instregex "^SQRDMLAHv", "^SQRDMLSHv")>; // ASIMD multiply accumulate long -def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]MLALv", "^[SU]MLSLv")>; +def : InstRW<[N3Wr_VMAL, N3Rd_VMAL], (instregex "^[SU]MLALv", "^[SU]MLSLv")>; // ASIMD multiply accumulate saturating long -def : InstRW<[N3Write_4c_1V0], (instregex "^SQDMLALv", "^SQDMLSLv")>; +def : InstRW<[N3Wr_VMASL, N3Rd_VMASL], (instregex "^SQDMLAL(v|i16|i32)", "^SQDMLSL(v|i16|i32)")>; // ASIMD multiply/multiply long (8x8) polynomial, D-form // ASIMD multiply/multiply long (8x8) polynomial, Q-form def : InstRW<[N3Write_2c_1V0], (instregex "^PMULL?(v8i8|v16i8)$")>; // ASIMD multiply long -def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]MULLv", "^SQDMULLv")>; +def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]MULLv", "^SQDMULL(v|i16|i32)")>; // ASIMD shift by immed, basic -def : InstRW<[N3Write_2c_1V1], (instregex "^SHLv", "^SHLLv", "^SHRNv", - "^SSHLLv", "^SSHRv", "^USHLLv", - "^USHRv")>; +def : InstRW<[N3Write_2c_1V1], (instregex "^SHL(v|d)", "^SHLLv", "^SHRNv", + "^SSHLLv", "^SSHR(v|d)", "^USHLLv", + "^USHR(v|d)")>; // ASIMD shift by immed and insert, basic -def : InstRW<[N3Write_2c_1V1], (instregex "^SLIv", "^SRIv")>; +def : InstRW<[N3Write_2c_1V1], (instregex "^SLI(v|d)", "^SRI(v|d)")>; // ASIMD shift by immed, complex def : InstRW<[N3Write_4c_1V1], - (instregex "^RSHRNv", "^SQRSHRNv", "^SQRSHRUNv", + (instregex "^RSHRNv", "^SQRSHRN[vbhs]", "^SQRSHRUN[vbhs]", "^(SQSHLU?|UQSHL)[bhsd]$", "^(SQSHLU?|UQSHL)(v8i8|v16i8|v4i16|v8i16|v2i32|v4i32|v2i64)_shift$", - "^SQSHRNv", "^SQSHRUNv", "^SRSHRv", "^UQRSHRNv", - "^UQSHRNv", "^URSHRv")>; + "^SQSHRN[vbhs]", "^SQSHRUN[vbhs]", "^SRSHR(v|d)", + "^UQRSHRN[vbhs]", "^UQSHRN[vbhs]","^URSHR(v|d)")>; // ASIMD shift by register, basic def : InstRW<[N3Write_2c_1V1], (instregex "^[SU]SHLv")>; @@ -1058,7 +1197,7 @@ def : InstRW<[N3Write_4c_1V1], def : InstRW<[N3Write_3c_1V], (instregex "^FCADDv")>; // ASIMD FP complex multiply add -def : InstRW<[N3Write_4c_1V], (instregex "^FCMLAv")>; +def : InstRW<[N3Wr_FCMA, N3Rd_FCMA], (instregex "^FCMLAv")>; // ASIMD FP convert, long (F16 to F32) def : InstRW<[N3Write_4c_2V0], (instregex "^FCVTL(v4|v8)i16")>; @@ -1070,16 +1209,16 @@ def : InstRW<[N3Write_3c_1V0], (instregex "^FCVTL(v2|v4)i32")>; def : InstRW<[N3Write_4c_2V0], (instregex "^FCVTN(v4|v8)i16")>; // ASIMD FP convert, narrow (F64 to F32) -def : InstRW<[N3Write_3c_1V0], (instregex "^FCVTN(v2|v4)i32", +def : InstRW<[N3Write_3c_1V0], (instregex "^FCVTN(v2|v4)i32", "^FCVTXNv1i64", "^FCVTXN(v2|v4)f32")>; // ASIMD FP convert, other, D-form F32 and Q-form F64 -def : InstRW<[N3Write_3c_1V0], (instregex "^[FSU]CVT[AMNPZ][SU]v2f(32|64)$", - "^[SU]CVTFv2f(32|64)$")>; +def : InstRW<[N3Write_3c_1V0], (instregex "^[FSU]CVT[AMNPZ][SU](v2f(32|64)|s|d|v1i32|v1i64|v2i32_shift|v2i64_shift)$", + "^[SU]CVTF(v2f(32|64)|s|d|v1i32|v1i64|v2i32_shift|v2i64_shift)$")>; // ASIMD FP convert, other, D-form F16 and Q-form F32 -def : InstRW<[N3Write_4c_2V0], (instregex "^[FSU]CVT[AMNPZ][SU]v4f(16|32)$", - "^[SU]CVTFv4f(16|32)$")>; +def : InstRW<[N3Write_4c_2V0], (instregex "^[FSU]CVT[AMNPZ][SU](v4f(16|32)|v4i(16|32)_shift)$", + "^[SU]CVTF(v4f(16|32)|v4i(16|32)_shift)$")>; // ASIMD FP convert, other, Q-form F16 def : InstRW<[N3Write_6c_4V0], (instregex "^[FSU]CVT[AMNPZ][SU]v8f16$", @@ -1114,13 +1253,13 @@ def : InstRW<[N3Write_4c_2V], (instregex "^(FMAX|FMIN)(NM)?Vv4(i16|i32)v$")>; def : InstRW<[N3Write_6c_3V], (instregex "^(FMAX|FMIN)(NM)?Vv8i16v$")>; // ASIMD FP multiply -def : InstRW<[N3Write_3c_1V], (instregex "^FMULv", "^FMULXv")>; +def : InstRW<[N3Wr_FPM], (instregex "^FMULv", "^FMULX(v|32|64)")>; // ASIMD FP multiply accumulate -def : InstRW<[N3Write_4c_1V], (instregex "^FMLAv", "^FMLSv")>; +def : InstRW<[N3Wr_FPMA, N3Rd_FPMA], (instregex "^FMLAv", "^FMLSv")>; // ASIMD FP multiply accumulate long -def : InstRW<[N3Write_4c_1V], (instregex "^FMLALv", "^FMLSLv")>; +def : InstRW<[N3Wr_FPMAL, N3Rd_FPMAL], (instregex "^FMLALv", "^FMLSLv")>; // ASIMD FP round, D-form F32 and Q-form F64 def : InstRW<[N3Write_3c_1V0], @@ -1157,13 +1296,14 @@ def : InstRW<[N3Write_13c_2V0], (instrs FSQRTv2f64)>; def : InstRW<[N3Write_4c_2V0], (instrs BFCVTN, BFCVTN2)>; // ASIMD dot product -def : InstRW<[N3Write_4c_1V], (instrs BFDOTv4bf16, BFDOTv8bf16)>; +def : InstRW<[N3Wr_BFD, N3Rd_BFD], (instrs BFDOTv4bf16, BFDOTv8bf16)>; // ASIMD matrix multiply accumulate -def : InstRW<[N3Write_5c_1V], (instrs BFMMLA)>; +def : InstRW<[N3Wr_BFMMA, N3Rd_BFMMA], (instrs BFMMLA)>; // ASIMD multiply accumulate long -def : InstRW<[N3Write_4c_1V], (instrs BFMLALB, BFMLALBIdx, BFMLALT, BFMLALTIdx)>; +def : InstRW<[N3Wr_BFMLA, N3Rd_BFMLA], + (instrs BFMLALB, BFMLALBIdx, BFMLALT, BFMLALTIdx)>; // Scalar convert, F32 to BF16 def : InstRW<[N3Write_3c_1V0], (instrs BFCVT)>; @@ -1186,6 +1326,7 @@ def : InstRW<[N3Write_3c_1V0], (instrs BFCVT)>; // ASIMD transpose // ASIMD unzip/zip // Covered by WriteV[dq] +def : InstRW<[N3Write_0or2c_1V], (instrs MOVID, MOVIv2d_ns)>; // ASIMD duplicate, gen reg def : InstRW<[N3Write_3c_1M0], (instregex "^DUPv.+gpr")>; @@ -1201,9 +1342,9 @@ def : InstRW<[N3Write_4c_2V0], (instrs URECPEv4i32, URSQRTEv4i32)>; // ASIMD reciprocal and square root estimate, D-form F32 and scalar forms def : InstRW<[N3Write_3c_1V0], (instrs FRECPEv1f16, FRECPEv1i32, - FRECPEv1i64, FRECPEv2f32, + FRECPEv1i64, FRECPEv2f32, FRECPEv2f64, FRSQRTEv1f16, FRSQRTEv1i32, - FRSQRTEv1i64, FRSQRTEv2f32)>; + FRSQRTEv1i64, FRSQRTEv2f32, FRSQRTEv2f64)>; // ASIMD reciprocal and square root estimate, D-form F16 and Q-form F32 def : InstRW<[N3Write_4c_2V0], (instrs FRECPEv4f16, FRECPEv4f32, @@ -1216,7 +1357,7 @@ def : InstRW<[N3Write_6c_4V0], (instrs FRECPEv8f16, FRSQRTEv8f16)>; def : InstRW<[N3Write_3c_1V0], (instregex "^FRECPXv")>; // ASIMD reciprocal step -def : InstRW<[N3Write_4c_1V], (instregex "^FRECPSv", "^FRSQRTSv")>; +def : InstRW<[N3Write_4c_1V], (instregex "^FRECPS(v|32|64)", "^FRSQRTS(v|32|64)")>; // ASIMD table lookup, 3 table regs def : InstRW<[N3Write_4c_2V], (instrs TBLv8i8Three, TBLv16i8Three)>; @@ -1502,7 +1643,7 @@ def : InstRW<[N3Write_4c_1V0], (instrs SM4E, SM4ENCKEY)>; // ----------------------------------------------------------------------------- // CRC checksum ops -def : InstRW<[N3Write_2c_1M0], (instregex "^CRC32")>; +def : InstRW<[N3Wr_CRC, N3Rd_CRC], (instregex "^CRC32")>; // SVE Predicate instructions // ----------------------------------------------------------------------------- @@ -1564,10 +1705,11 @@ def : InstRW<[N3Write_2c_1M], (instregex "^REV_PP_[BHSD]")>; def : InstRW<[N3Write_1c_1M], (instrs SEL_PPPP)>; // Predicate set -def : InstRW<[N3Write_2c_1M], (instregex "^PFALSE", "^PTRUE_[BHSD]")>; +def : InstRW<[N3Write_0c], (instrs PFALSE)>; +def : InstRW<[N3Write_0or2c_1M], (instregex "^PTRUE_[BHSD]")>; // Predicate set/initialize, set flags -def : InstRW<[N3Write_2c_1M], (instregex "^PTRUES_[BHSD]")>; +def : InstRW<[N3Write_0or2c_1M], (instregex "^PTRUES_[BHSD]")>; // Predicate find first/next def : InstRW<[N3Write_2c_1M], (instregex "^PFIRST_B$", "^PNEXT_[BHSD]$")>; @@ -1592,10 +1734,10 @@ def : InstRW<[N3Write_2c_1V], (instregex "^[SU]ABD_ZPmZ_[BHSD]", "^[SU]ABD_ZPZZ_[BHSD]")>; // Arithmetic, absolute diff accum -def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]ABA_ZZZ_[BHSD]$")>; +def : InstRW<[N3Wr_ZA, N3Rd_ZA], (instregex "^[SU]ABA_ZZZ_[BHSD]$")>; // Arithmetic, absolute diff accum long -def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]ABAL[TB]_ZZZ_[HSD]$")>; +def : InstRW<[N3Wr_ZA, N3Rd_ZA], (instregex "^[SU]ABAL[TB]_ZZZ_[HSD]$")>; // Arithmetic, absolute diff long def : InstRW<[N3Write_2c_1V], (instregex "^[SU]ABDL[TB]_ZZZ_[HSD]$")>; @@ -1629,7 +1771,8 @@ def : InstRW<[N3Write_2c_1V], (instregex "^(AD|SB)CL[BT]_ZZZ_[SD]$")>; def : InstRW<[N3Write_2c_1V], (instregex "^ADDP_ZPmZ_[BHSD]$")>; // Arithmetic, pairwise add and accum long -def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]ADALP_ZPmZ_[HSD]$")>; +def : InstRW<[N3Wr_ZPA, ReadDefault, N3Rd_ZPA], + (instregex "^[SU]ADALP_ZPmZ_[HSD]$")>; // Arithmetic, shift def : InstRW<[N3Write_2c_1V1], @@ -1642,7 +1785,7 @@ def : InstRW<[N3Write_2c_1V1], "^(ASRR|LSLR|LSRR)_ZPmZ_[BHSD]")>; // Arithmetic, shift and accumulate -def : InstRW<[N3Write_4c_1V1], +def : InstRW<[N3Wr_ZSA, N3Rd_ZSA], (instregex "^(SRSRA|SSRA|URSRA|USRA)_ZZI_[BHSD]$")>; // Arithmetic, shift by immediate @@ -1688,16 +1831,17 @@ def : InstRW<[N3Write_2c_1V], def : InstRW<[N3Write_2c_1V], (instregex "^(SQ)?CADD_ZZI_[BHSD]$")>; // Complex dot product 8-bit element -def : InstRW<[N3Write_3c_1V], (instrs CDOT_ZZZ_S, CDOT_ZZZI_S)>; +def : InstRW<[N3Wr_ZDOTB, N3Rd_ZDOTB], (instrs CDOT_ZZZ_S, CDOT_ZZZI_S)>; // Complex dot product 16-bit element -def : InstRW<[N3Write_4c_1V0], (instrs CDOT_ZZZ_D, CDOT_ZZZI_D)>; +def : InstRW<[N3Wr_ZDOTH, N3Rd_ZDOTH], (instrs CDOT_ZZZ_D, CDOT_ZZZI_D)>; // Complex multiply-add B, H, S element size -def : InstRW<[N3Write_4c_1V0], (instregex "^CMLA_ZZZ_[BHS]$", "^CMLA_ZZZI_[HS]$")>; +def : InstRW<[N3Wr_ZCMABHS, N3Rd_ZCMABHS], + (instregex "^CMLA_ZZZ_[BHS]$", "^CMLA_ZZZI_[HS]$")>; // Complex multiply-add D element size -def : InstRW<[N3Write_5c_2V0], (instrs CMLA_ZZZ_D)>; +def : InstRW<[N3Wr_ZCMAD, N3Rd_ZCMAD], (instrs CMLA_ZZZ_D)>; // Conditional extract operations, scalar form def : InstRW<[N3Write_8c_1M0_1V], (instregex "^CLAST[AB]_RPZ_[BHSD]$")>; @@ -1736,13 +1880,14 @@ def : InstRW<[N3Write_16c_16V0], (instregex "^[SU]DIVR?_ZPmZ_D", "^[SU]DIV_ZPZZ_D")>; // Dot product, 8 bit -def : InstRW<[N3Write_3c_1V], (instregex "^[SU]DOT_ZZZI?_BtoS$")>; +def : InstRW<[N3Wr_ZDOTB, N3Rd_ZDOTB], (instregex "^[SU]DOT_ZZZI?_BtoS$")>; // Dot product, 8 bit, using signed and unsigned integers -def : InstRW<[N3Write_3c_1V], (instrs SUDOT_ZZZI, USDOT_ZZZI, USDOT_ZZZ)>; +def : InstRW<[N3Wr_ZDOTB, N3Rd_ZDOTB], + (instrs SUDOT_ZZZI, USDOT_ZZZI, USDOT_ZZZ)>; // Dot product, 16 bit -def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]DOT_ZZZI?_HtoD$")>; +def : InstRW<[N3Wr_ZDOTH, N3Rd_ZDOTH], (instregex "^[SU]DOT_ZZZI?_HtoD$")>; // Duplicate, immediate and indexed form def : InstRW<[N3Write_2c_1V], (instregex "^DUP_ZI_[BHSD]$", @@ -1790,10 +1935,11 @@ def : InstRW<[N3Write_5c_1M0_1V], (instregex "^INDEX_(IR|RI|RR)_D$")>; // Logical def : InstRW<[N3Write_2c_1V], (instregex "^(AND|EOR|ORR)_ZI", - "^(AND|BIC|EOR|ORR)_ZZZ", + "^(AND|BIC|EOR)_ZZZ", "^EOR(BT|TB)_ZZZ_[BHSD]", "^(AND|BIC|EOR|NOT|ORR)_(ZPmZ|ZPZZ)_[BHSD]", "^NOT_ZPmZ_[BHSD]")>; +def : InstRW<[N3Write_0or2c_1V], (instrs ORR_ZZZ)>; // Max/min, basic and pairwise def : InstRW<[N3Write_2c_1V], (instregex "^[SU](MAX|MIN)_ZI_[BHSD]", @@ -1804,7 +1950,7 @@ def : InstRW<[N3Write_2c_1V], (instregex "^[SU](MAX|MIN)_ZI_[BHSD]", def : InstRW<[N3Write_2c_1V], (instregex "^N?MATCH_PPzZZ_[BH]$")>; // Matrix multiply-accumulate -def : InstRW<[N3Write_3c_1V], (instrs SMMLA_ZZZ, UMMLA_ZZZ, USMMLA_ZZZ)>; +def : InstRW<[N3Wr_ZMMA, N3Rd_ZMMA], (instrs SMMLA_ZZZ, UMMLA_ZZZ, USMMLA_ZZZ)>; // Move prefix def : InstRW<[N3Write_2c_1V], (instregex "^MOVPRFX_ZP[mz]Z_[BHSD]$", @@ -1827,20 +1973,22 @@ def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]MULL[BT]_ZZZI_[SD]$", "^[SU]MULL[BT]_ZZZ_[HSD]$")>; // Multiply accumulate, B, H, S element size -def : InstRW<[N3Write_4c_1V0], (instregex "^ML[AS]_ZZZI_[BHS]$", - "^(ML[AS]|MAD|MSB)_(ZPmZZ|ZPZZZ)_[BHS]")>; +def : InstRW<[N3Wr_ZMABHS, ReadDefault, N3Rd_ZMABHS], + (instregex "^ML[AS]_ZZZI_[BHS]$", + "^(ML[AS]|MAD|MSB)_(ZPmZZ|ZPZZZ)_[BHS]")>; // Multiply accumulate, D element size -def : InstRW<[N3Write_5c_2V0], (instregex "^ML[AS]_ZZZI_D$", +def : InstRW<[N3Wr_ZMAD, ReadDefault, N3Rd_ZMAD], (instregex "^ML[AS]_ZZZI_D$", "^(ML[AS]|MAD|MSB)_(ZPmZZ|ZPZZZ)_D")>; // Multiply accumulate long -def : InstRW<[N3Write_4c_1V0], (instregex "^[SU]ML[AS]L[BT]_ZZZ_[HSD]$", +def : InstRW<[N3Wr_ZMAL, N3Rd_ZMAL], (instregex "^[SU]ML[AS]L[BT]_ZZZ_[HSD]$", "^[SU]ML[AS]L[BT]_ZZZI_[SD]$")>; // Multiply accumulate saturating doubling long regular -def : InstRW<[N3Write_4c_1V0], (instregex "^SQDML[AS](LB|LT|LBT)_ZZZ_[HSD]$", - "^SQDML[AS](LB|LT)_ZZZI_[SD]$")>; +def : InstRW<[N3Wr_ZMASQL, N3Rd_ZMASQ], + (instregex "^SQDML[AS](LB|LT|LBT)_ZZZ_[HSD]$", + "^SQDML[AS](LB|LT)_ZZZI_[SD]$")>; // Multiply saturating doubling high, B, H, S element size def : InstRW<[N3Write_4c_1V0], (instregex "^SQDMULH_ZZZ_[BHS]$", @@ -1854,13 +2002,13 @@ def : InstRW<[N3Write_4c_1V0], (instregex "^SQDMULL[BT]_ZZZ_[HSD]$", "^SQDMULL[BT]_ZZZI_[SD]$")>; // Multiply saturating rounding doubling regular/complex accumulate, B, H, S element size -def : InstRW<[N3Write_4c_1V0], (instregex "^SQRDML[AS]H_ZZZ_[BHS]$", +def : InstRW<[N3Wr_ZMASQBHS, N3Rd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZ_[BHS]$", "^SQRDCMLAH_ZZZ_[BHS]$", "^SQRDML[AS]H_ZZZI_[HS]$", "^SQRDCMLAH_ZZZI_[HS]$")>; // Multiply saturating rounding doubling regular/complex accumulate, D element size -def : InstRW<[N3Write_5c_2V0], (instregex "^SQRDML[AS]H_ZZZI?_D$", +def : InstRW<[N3Wr_ZMASQD, N3Rd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZI?_D$", "^SQRDCMLAH_ZZZ_D$")>; // Multiply saturating rounding doubling regular/complex, B, H, S element size @@ -1926,7 +2074,6 @@ def : InstRW<[N3Write_2c_1V], (instregex "^FAB[SD]_ZPmZ_[HSD]", // Floating point arithmetic def : InstRW<[N3Write_2c_1V], (instregex "^F(ADD|SUB)_(ZPm[IZ]|ZZZ)_[HSD]", "^F(ADD|SUB)_ZPZ[IZ]_[HSD]", - "^FADDP_ZPmZZ_[HSD]", "^FNEG_ZPmZ_[HSD]", "^FSUBR_ZPm[IZ]_[HSD]", "^FSUBR_(ZPZI|ZPZZ)_[HSD]")>; @@ -1949,8 +2096,9 @@ def : InstRW<[N3Write_2c_1V], (instregex "^FAC(GE|GT)_PPzZZ_[HSD]$", def : InstRW<[N3Write_3c_1V], (instregex "^FCADD_ZPmZ_[HSD]$")>; // Floating point complex multiply add -def : InstRW<[N3Write_4c_1V], (instregex "^FCMLA_ZPmZZ_[HSD]$", - "^FCMLA_ZZZI_[HS]$")>; +def : InstRW<[N3Wr_ZFCMA, ReadDefault, N3Rd_ZFCMA], + (instregex "^FCMLA_ZPmZZ_[HSD]")>; +def : InstRW<[N3Wr_ZFCMA, N3Rd_ZFCMA], (instregex "^FCMLA_ZZZI_[HS]")>; // Floating point convert, long or narrow (F16 to F32 or F32 to F16) def : InstRW<[N3Write_4c_2V0], (instregex "^FCVT_ZPmZ_(HtoS|StoH)", @@ -2001,7 +2149,8 @@ def : InstRW<[N3Write_10c_4V0], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_S")>; def : InstRW<[N3Write_13c_2V0], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_D")>; // Floating point arith, min/max pairwise -def : InstRW<[N3Write_3c_1V], (instregex "^F(MAX|MIN)(NM)?P_ZPmZZ_[HSD]")>; +def : InstRW<[N3Write_3c_1V], (instregex "^FADDP_ZPmZZ_[HSD]", + "^F(MAX|MIN)(NM)?P_ZPmZZ_[HSD]")>; // Floating point min/max def : InstRW<[N3Write_2c_1V], (instregex "^F(MAX|MIN)(NM)?_ZPm[IZ]_[HSD]", @@ -2014,12 +2163,15 @@ def : InstRW<[N3Write_3c_1V], (instregex "^(FSCALE|FMULX)_ZPmZ_[HSD]", "^FMUL_ZPZ[IZ]_[HSD]")>; // Floating point multiply accumulate -def : InstRW<[N3Write_4c_1V], (instregex "^F(N?M(AD|SB)|N?ML[AS])_ZPmZZ_[HSD]$", - "^FN?ML[AS]_ZPZZZ_[HSD]", - "^FML[AS]_ZZZI_[HSD]$")>; +def : InstRW<[N3Wr_ZFMA, ReadDefault, N3Rd_ZFMA], + (instregex "^FN?ML[AS]_ZPmZZ_[HSD]", + "^FN?(MAD|MSB)_ZPmZZ_[HSD]")>; +def : InstRW<[N3Wr_ZFMA, N3Rd_ZFMA], + (instregex "^FML[AS]_ZZZI_[HSD]", + "^FN?ML[AS]_ZPZZZ_[HSD]")>; // Floating point multiply add/sub accumulate long -def : InstRW<[N3Write_4c_1V], (instregex "^FML[AS]L[BT]_ZZZI?_SHH$")>; +def : InstRW<[N3Wr_ZFMAL, N3Rd_ZFMAL], (instregex "^FML[AS]L[BT]_ZZZI?_SHH$")>; // Floating point reciprocal estimate, F16 def : InstRW<[N3Write_6c_4V0], (instregex "^FR(ECP|SQRT)E_ZZ_H", "^FRECPX_ZPmZ_H")>; @@ -2079,13 +2231,13 @@ def : InstRW<[N3Write_3c_1V], (instregex "^FTS(MUL|SEL)_ZZZ_[HSD]$")>; def : InstRW<[N3Write_4c_2V0], (instrs BFCVT_ZPmZ, BFCVTNT_ZPmZ)>; // Dot product -def : InstRW<[N3Write_4c_1V], (instrs BFDOT_ZZI, BFDOT_ZZZ)>; +def : InstRW<[N3Wr_ZBFDOT, N3Rd_ZBFDOT], (instrs BFDOT_ZZI, BFDOT_ZZZ)>; // Matrix multiply accumulate -def : InstRW<[N3Write_5c_1V], (instrs BFMMLA_ZZZ_HtoS)>; +def : InstRW<[N3Wr_ZBFMMA, N3Rd_ZBFMMA], (instrs BFMMLA_ZZZ_HtoS)>; // Multiply accumulate long -def : InstRW<[N3Write_4c_1V], (instregex "^BFMLAL[BT]_ZZZ(I)?$")>; +def : InstRW<[N3Wr_ZBFMAL, N3Rd_ZBFMAL], (instregex "^BFMLAL[BT]_ZZZ(I)?$")>; // SVE Load instructions // ----------------------------------------------------------------------------- @@ -2130,8 +2282,8 @@ def : InstRW<[N3Write_7c_4L], (instregex "^LDNT1[BHW]_ZZR_S$", "^LDNT1S[BH]_ZZR_S$")>; // Non temporal gather load, vector + scalar 64-bit element size -def : InstRW<[N3Write_6c_2L], (instregex "^LDNT1S?[BHW]_ZZR_D$")>; -def : InstRW<[N3Write_6c_2L], (instrs LDNT1D_ZZR_D)>; +def : InstRW<[N3Write_6c_2L01_2V], (instregex "^LDNT1S?[BHW]_ZZR_D$")>; +def : InstRW<[N3Write_6c_2L01_2V], (instrs LDNT1D_ZZR_D)>; // Contiguous first faulting load, scalar + scalar def : InstRW<[N3Write_6c_1L], (instregex "^LDFF1[BHWD]$", @@ -2180,11 +2332,11 @@ def : InstRW<[N3Write_7c_4L], (instregex "^GLD(FF)?1S?[BH]_S_IMM$", "^GLD(FF)?1W_IMM$")>; // Gather load, vector + imm, 64-bit element size -def : InstRW<[N3Write_6c_2L], (instregex "^GLD(FF)?1S?[BHW]_D_IMM$", +def : InstRW<[N3Write_6c_2L01_2V], (instregex "^GLD(FF)?1S?[BHW]_D_IMM$", "^GLD(FF)?1D_IMM$")>; // Gather load, 64-bit element size -def : InstRW<[N3Write_6c_2L], +def : InstRW<[N3Write_6c_2L01_2V], (instregex "^GLD(FF)?1S?[BHW]_D_[SU]XTW(_SCALED)?$", "^GLD(FF)?1S?[BHW]_D(_SCALED)?$", "^GLD(FF)?1D_[SU]XTW(_SCALED)?$", diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td index 3cbfc59..ac5e889 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td @@ -94,6 +94,7 @@ def : WriteRes<WriteHint, []> { let Latency = 1; } let Latency = 0, NumMicroOps = 0 in def V1Write_0c_0Z : SchedWriteRes<[]>; +def V1Write_0c : SchedWriteRes<[]> { let Latency = 0; } //===----------------------------------------------------------------------===// // Define generic 1 micro-op types @@ -473,6 +474,17 @@ def V1Write_11c_9L01_9S_9V : SchedWriteRes<[V1UnitL01, V1UnitL01, V1UnitL01, V1UnitV, V1UnitV, V1UnitV]>; //===----------------------------------------------------------------------===// +// Define predicate-controlled types + +def V1Write_0or1c_1I : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V1Write_0c]>, + SchedVar<NoSchedPred, [V1Write_1c_1I]>]>; + +def V1Write_0or3c_1M0 : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V1Write_0c]>, + SchedVar<NoSchedPred, [V1Write_3c_1M0]>]>; + +//===----------------------------------------------------------------------===// // Define forwarded types // NOTE: SOG, p. 20, n. 2: Accumulator forwarding is not supported for @@ -580,9 +592,6 @@ def : SchedAlias<WriteBrReg, V1Write_1c_1B>; // Branch and link, register def : InstRW<[V1Write_1c_1B_1S], (instrs BL, BLR)>; -// Compare and branch -def : InstRW<[V1Write_1c_1B], (instregex "^[CT]BN?Z[XW]$")>; - // Arithmetic and Logical Instructions // ----------------------------------------------------------------------------- @@ -603,6 +612,7 @@ def : InstRW<[V1Write_1c_1I_1Flg], "^(ADC|SBC)S[WX]r$", "^ANDS[WX]ri$", "^(AND|BIC)S[WX]rr$")>; +def : InstRW<[V1Write_0or1c_1I], (instregex "^MOVZ[WX]i$")>; // ALU, extend and shift def : SchedAlias<WriteIEReg, V1Write_2c_1M>; @@ -623,7 +633,8 @@ def : InstRW<[V1WriteISRegS], (instregex "^(ADD|SUB)S(([WX]r[sx])|Xrx64)$")>; // Logical, shift, no flagset -def : InstRW<[V1Write_1c_1I], (instregex "^(AND|BIC|EON|EOR|ORN|ORR)[WX]rs$")>; +def : InstRW<[V1Write_1c_1I], (instregex "^(AND|BIC|EON|EOR|ORN)[WX]rs$")>; +def : InstRW<[V1Write_0or1c_1I], (instregex "^ORR[WX]rs$")>; // Logical, shift, flagset def : InstRW<[V1Write_2c_1M_1Flg], (instregex "^(AND|BIC)S[WX]rs$")>; diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3.td new file mode 100644 index 0000000..e23576a --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3.td @@ -0,0 +1,2777 @@ +//=- AArch64SchedNeoverseV3.td - NeoverseV3 Scheduling Defs --*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the scheduling model for the Arm Neoverse V3 processors. +// All information is taken from the V3 Software Optimization guide: +// +// https://developer.arm.com/documentation/109678/300/?lang=en +// +//===----------------------------------------------------------------------===// + +def NeoverseV3Model : SchedMachineModel { + let IssueWidth = 10; // Expect best value to be slightly higher than V2 + let MicroOpBufferSize = 320; // Entries in micro-op re-order buffer. NOTE: Copied from Neoverse-V2 + let LoadLatency = 4; // Optimistic load latency. + let MispredictPenalty = 10; // Extra cycles for mispredicted branch. NOTE: Copied from N2. + let LoopMicroOpBufferSize = 16; // NOTE: Copied from Cortex-A57. + let CompleteModel = 1; + + list<Predicate> UnsupportedFeatures = !listconcat(SMEUnsupported.F, + [HasSVE2p1, HasSVEB16B16, + HasCPA, HasCSSC]); +} + +//===----------------------------------------------------------------------===// +// Define each kind of processor resource and number available on Neoverse V3. +// Instructions are first fetched and then decoded into internal macro-ops +// (MOPs). From there, the MOPs proceed through register renaming and dispatch +// stages. A MOP can be split into two micro-ops further down the pipeline +// after the decode stage. Once dispatched, micro-ops wait for their operands +// and issue out-of-order to one of twenty-one issue pipelines. Each issue +// pipeline can accept one micro-op per cycle. + +let SchedModel = NeoverseV3Model in { + +// Define the (21) issue ports. +def V3UnitB : ProcResource<3>; // Branch 0/1/2 +def V3UnitS0 : ProcResource<1>; // Integer single-cycle 0 +def V3UnitS1 : ProcResource<1>; // Integer single-cycle 1 +def V3UnitS2 : ProcResource<1>; // Integer single-cycle 2 +def V3UnitS3 : ProcResource<1>; // Integer single-cycle 3 +def V3UnitS4 : ProcResource<1>; // Integer single-cycle 4 +def V3UnitS5 : ProcResource<1>; // Integer single-cycle 5 +def V3UnitM0 : ProcResource<1>; // Integer single/multicycle 0 +def V3UnitM1 : ProcResource<1>; // Integer single/multicycle 1 +def V3UnitV0 : ProcResource<1>; // FP/ASIMD 0 +def V3UnitV1 : ProcResource<1>; // FP/ASIMD 1 +def V3UnitV2 : ProcResource<1>; // FP/ASIMD 2 +def V3UnitV3 : ProcResource<1>; // FP/ASIMD 3 +def V3UnitLS0 : ProcResource<1>; // Load/Store 0 +def V3UnitL12 : ProcResource<2>; // Load 1/2 +def V3UnitST1 : ProcResource<1>; // Store 1 +def V3UnitD : ProcResource<2>; // Store data 0/1 +def V3UnitFlg : ProcResource<4>; // Flags + +def V3UnitS : ProcResGroup<[V3UnitS0, V3UnitS1, V3UnitS2, V3UnitS3, V3UnitS4, V3UnitS5]>; // Integer single-cycle 0/1/2/3/4/5 +def V3UnitI : ProcResGroup<[V3UnitS0, V3UnitS1, V3UnitS2, V3UnitS3, V3UnitS4, V3UnitS5, V3UnitM0, V3UnitM1]>; // Integer single-cycle 0/1/2/3/4/5 and single/multicycle 0/1 +def V3UnitM : ProcResGroup<[V3UnitM0, V3UnitM1]>; // Integer single/multicycle 0/1 +def V3UnitLSA : ProcResGroup<[V3UnitLS0, V3UnitL12, V3UnitST1]>; // Supergroup of L+SA +def V3UnitL : ProcResGroup<[V3UnitLS0, V3UnitL12]>; // Load/Store 0 and Load 1/2 +def V3UnitSA : ProcResGroup<[V3UnitLS0, V3UnitST1]>; // Load/Store 0 and Store 1 +def V3UnitV : ProcResGroup<[V3UnitV0, V3UnitV1, V3UnitV2, V3UnitV3]>; // FP/ASIMD 0/1/2/3 +def V3UnitV01 : ProcResGroup<[V3UnitV0, V3UnitV1]>; // FP/ASIMD 0/1 +def V3UnitV02 : ProcResGroup<[V3UnitV0, V3UnitV2]>; // FP/ASIMD 0/2 +def V3UnitV13 : ProcResGroup<[V3UnitV1, V3UnitV3]>; // FP/ASIMD 1/3 + +// Define commonly used read types. + +// No forwarding is provided for these types. +def : ReadAdvance<ReadI, 0>; +def : ReadAdvance<ReadISReg, 0>; +def : ReadAdvance<ReadIEReg, 0>; +def : ReadAdvance<ReadIM, 0>; +def : ReadAdvance<ReadIMA, 0>; +def : ReadAdvance<ReadID, 0>; +def : ReadAdvance<ReadExtrHi, 0>; +def : ReadAdvance<ReadAdrBase, 0>; +def : ReadAdvance<ReadST, 0>; +def : ReadAdvance<ReadVLD, 0>; + +// NOTE: Copied from N2. +def : WriteRes<WriteAtomic, []> { let Unsupported = 1; } +def : WriteRes<WriteBarrier, []> { let Latency = 1; } +def : WriteRes<WriteHint, []> { let Latency = 1; } +def : WriteRes<WriteLDHi, []> { let Latency = 4; } + +//===----------------------------------------------------------------------===// +// Define customized scheduler read/write types specific to the Neoverse V3. + +//===----------------------------------------------------------------------===// + +// Define generic 0 micro-op types +def V3Write_0c : SchedWriteRes<[]> { let Latency = 0; } + +// Define generic 1 micro-op types + +def V3Write_1c_1B : SchedWriteRes<[V3UnitB]> { let Latency = 1; } +def V3Write_1c_1F_1Flg : SchedWriteRes<[V3UnitI, V3UnitFlg]> { let Latency = 1; } +def V3Write_1c_1I : SchedWriteRes<[V3UnitI]> { let Latency = 1; } +def V3Write_1c_1M : SchedWriteRes<[V3UnitM]> { let Latency = 1; } +def V3Write_1c_1SA : SchedWriteRes<[V3UnitSA]> { let Latency = 1; } +def V3Write_2c_1M : SchedWriteRes<[V3UnitM]> { let Latency = 2; } +def V3Write_2c_1M_1Flg : SchedWriteRes<[V3UnitM, V3UnitFlg]> { let Latency = 2; } +def V3Write_3c_1M : SchedWriteRes<[V3UnitM]> { let Latency = 3; } +def V3Write_2c_1M0 : SchedWriteRes<[V3UnitM0]> { let Latency = 2; } +def V3Write_3c_1M0 : SchedWriteRes<[V3UnitM0]> { let Latency = 3; } +def V3Write_4c_1M0 : SchedWriteRes<[V3UnitM0]> { let Latency = 4; } +def V3Write_12c_1M0 : SchedWriteRes<[V3UnitM0]> { let Latency = 12; + let ReleaseAtCycles = [12]; } +def V3Write_20c_1M0 : SchedWriteRes<[V3UnitM0]> { let Latency = 20; + let ReleaseAtCycles = [20]; } +def V3Write_4c_1L : SchedWriteRes<[V3UnitL]> { let Latency = 4; } +def V3Write_6c_1L : SchedWriteRes<[V3UnitL]> { let Latency = 6; } +def V3Write_2c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 2; } +def V3Write_2c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 2; } +def V3Write_3c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Write_3c_1V01 : SchedWriteRes<[V3UnitV01]> { let Latency = 3; + let ReleaseAtCycles = [2]; } +def V3Write_4c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Write_5c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Write_6c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 6; } +def V3Write_12c_1V : SchedWriteRes<[V3UnitV]> { let Latency = 12; } +def V3Write_3c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 3; } +def V3Write_3c_1V02 : SchedWriteRes<[V3UnitV02]> { let Latency = 3; } +def V3Write_4c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 4; } +def V3Write_4c_1V02 : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Write_9c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 9; } +def V3Write_10c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 10; } +def V3Write_8c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 8; } +def V3Write_12c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 12; + let ReleaseAtCycles = [11]; } +def V3Write_13c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 13; } +def V3Write_15c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 15; } +def V3Write_13c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 13; } +def V3Write_16c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 16; } +def V3Write_16c_1V02 : SchedWriteRes<[V3UnitV02]> { let Latency = 16; + let ReleaseAtCycles = [8]; } +def V3Write_20c_1V0 : SchedWriteRes<[V3UnitV0]> { let Latency = 20; + let ReleaseAtCycles = [20]; } +def V3Write_2c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 2; } +def V3Write_2c_1V13 : SchedWriteRes<[V3UnitV13]> { let Latency = 2; } +def V3Write_3c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 3; } +def V3Write_3c_1V13 : SchedWriteRes<[V3UnitV13]> { let Latency = 3; } +def V3Write_4c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 4; } +def V3Write_6c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 6; } +def V3Write_10c_1V1 : SchedWriteRes<[V3UnitV1]> { let Latency = 10; } +def V3Write_6c_1SA : SchedWriteRes<[V3UnitSA]> { let Latency = 6; } + +//===----------------------------------------------------------------------===// +// Define generic 2 micro-op types + +def V3Write_1c_1B_1S : SchedWriteRes<[V3UnitB, V3UnitS]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3Write_6c_1M0_1B : SchedWriteRes<[V3UnitM0, V3UnitB]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_9c_1M0_1L : SchedWriteRes<[V3UnitM0, V3UnitL]> { + let Latency = 9; + let NumMicroOps = 2; +} + +def V3Write_3c_1I_1M : SchedWriteRes<[V3UnitI, V3UnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3Write_1c_2M : SchedWriteRes<[V3UnitM, V3UnitM]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3Write_3c_2M : SchedWriteRes<[V3UnitM, V3UnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3Write_4c_2M : SchedWriteRes<[V3UnitM, V3UnitM]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_5c_1L_1I : SchedWriteRes<[V3UnitL, V3UnitI]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_6c_1I_1L : SchedWriteRes<[V3UnitI, V3UnitL]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_7c_1I_1L : SchedWriteRes<[V3UnitI, V3UnitL]> { + let Latency = 7; + let NumMicroOps = 2; +} + +def V3Write_1c_1SA_1D : SchedWriteRes<[V3UnitSA, V3UnitD]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3Write_5c_1M0_1V : SchedWriteRes<[V3UnitM0, V3UnitV]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_2c_1SA_1V01 : SchedWriteRes<[V3UnitSA, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3Write_2c_2V01 : SchedWriteRes<[V3UnitV01, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3Write_4c_1SA_1V01 : SchedWriteRes<[V3UnitSA, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_5c_1V13_1V : SchedWriteRes<[V3UnitV13, V3UnitV]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_4c_2V0 : SchedWriteRes<[V3UnitV0, V3UnitV0]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_4c_2V02 : SchedWriteRes<[V3UnitV02, V3UnitV02]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_4c_2V : SchedWriteRes<[V3UnitV, V3UnitV]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_6c_2V : SchedWriteRes<[V3UnitV, V3UnitV]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_6c_2L : SchedWriteRes<[V3UnitL, V3UnitL]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_8c_1L_1V : SchedWriteRes<[V3UnitL, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 2; +} + +def V3Write_4c_1SA_1V : SchedWriteRes<[V3UnitSA, V3UnitV]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_3c_1M0_1M : SchedWriteRes<[V3UnitM0, V3UnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3Write_4c_1M0_1M : SchedWriteRes<[V3UnitM0, V3UnitM]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3Write_1c_1M0_1M : SchedWriteRes<[V3UnitM0, V3UnitM]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3Write_2c_1M0_1M : SchedWriteRes<[V3UnitM0, V3UnitM]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3Write_6c_2V1 : SchedWriteRes<[V3UnitV1, V3UnitV1]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_5c_2V0 : SchedWriteRes<[V3UnitV0, V3UnitV0]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_5c_2V02 : SchedWriteRes<[V3UnitV02, V3UnitV02]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_5c_1V1_1M0 : SchedWriteRes<[V3UnitV1, V3UnitM0]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3Write_6c_1V1_1M0 : SchedWriteRes<[V3UnitV1, V3UnitM0]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_7c_1M0_1V02 : SchedWriteRes<[V3UnitM0, V3UnitV02]> { + let Latency = 7; + let NumMicroOps = 2; +} + +def V3Write_2c_1V0_1M : SchedWriteRes<[V3UnitV0, V3UnitM]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3Write_3c_1V0_1M : SchedWriteRes<[V3UnitV0, V3UnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3Write_6c_1V_1V13 : SchedWriteRes<[V3UnitV, V3UnitV13]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_6c_1L_1M : SchedWriteRes<[V3UnitL, V3UnitM]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_6c_1L_1I : SchedWriteRes<[V3UnitL, V3UnitI]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_6c_2V13 : SchedWriteRes<[V3UnitV13, V3UnitV13]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3Write_8c_1M0_1V01 : SchedWriteRes<[V3UnitM0, V3UnitV01]> { + let Latency = 8; + let NumMicroOps = 2; +} + +//===----------------------------------------------------------------------===// +// Define generic 3 micro-op types + +def V3Write_1c_1SA_1D_1I : SchedWriteRes<[V3UnitSA, V3UnitD, V3UnitI]> { + let Latency = 1; + let NumMicroOps = 3; +} + +def V3Write_2c_1SA_1V01_1I : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitI]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3Write_2c_1SA_2V01 : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3Write_4c_1SA_2V01 : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 3; +} + +def V3Write_9c_1L_2V : SchedWriteRes<[V3UnitL, V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 3; +} + +def V3Write_4c_3V : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 4; + let NumMicroOps = 3; +} + +def V3Write_7c_1M_1M0_1V : SchedWriteRes<[V3UnitM, V3UnitM0, V3UnitV]> { + let Latency = 7; + let NumMicroOps = 3; +} + +def V3Write_2c_1SA_1I_1V01 : SchedWriteRes<[V3UnitSA, V3UnitI, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3Write_6c_3L : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL]> { + let Latency = 6; + let NumMicroOps = 3; +} + +def V3Write_6c_3V : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 6; + let NumMicroOps = 3; +} + +def V3Write_8c_1L_2V : SchedWriteRes<[V3UnitL, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 3; +} + +//===----------------------------------------------------------------------===// +// Define generic 4 micro-op types + +def V3Write_2c_1SA_2V01_1I : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitI]> { + let Latency = 2; + let NumMicroOps = 4; +} + +def V3Write_2c_2SA_2V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 4; +} + +def V3Write_4c_2SA_2V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 4; +} + +def V3Write_5c_1I_3L : SchedWriteRes<[V3UnitI, V3UnitL, V3UnitL, V3UnitL]> { + let Latency = 5; + let NumMicroOps = 4; +} + +def V3Write_6c_4V0 : SchedWriteRes<[V3UnitV0, V3UnitV0, V3UnitV0, V3UnitV0]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3Write_8c_4V : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3Write_6c_2V_2V13 : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV13, + V3UnitV13]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3Write_8c_2V_2V13 : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV13, + V3UnitV13]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3Write_6c_4V02 : SchedWriteRes<[V3UnitV02, V3UnitV02, V3UnitV02, + V3UnitV02]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3Write_6c_4V : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3Write_8c_2L_2V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3Write_9c_2L_2V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 4; +} + +def V3Write_2c_2SA_2V : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitV, + V3UnitV]> { + let Latency = 2; + let NumMicroOps = 4; +} + +def V3Write_4c_2SA_2V : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitV, + V3UnitV]> { + let Latency = 4; + let NumMicroOps = 4; +} + +def V3Write_8c_2M0_2V02 : SchedWriteRes<[V3UnitM0, V3UnitM0, V3UnitV02, + V3UnitV02]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3Write_8c_2V_2V1 : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV1, + V3UnitV1]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3Write_4c_2M0_2M : SchedWriteRes<[V3UnitM0, V3UnitM0, V3UnitM, + V3UnitM]> { + let Latency = 4; + let NumMicroOps = 4; +} + +def V3Write_5c_2M0_2M : SchedWriteRes<[V3UnitM0, V3UnitM0, V3UnitM, + V3UnitM]> { + let Latency = 5; + let NumMicroOps = 4; +} + +def V3Write_6c_2I_2L : SchedWriteRes<[V3UnitI, V3UnitI, V3UnitL, V3UnitL]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3Write_7c_4L : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, V3UnitL]> { + let Latency = 7; + let NumMicroOps = 4; +} + +def V3Write_6c_1SA_3V01 : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitV01]> { + let Latency = 6; + let NumMicroOps = 4; +} + +//===----------------------------------------------------------------------===// +// Define generic 5 micro-op types + +def V3Write_2c_1SA_2V01_2I : SchedWriteRes<[V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitI, V3UnitI]> { + let Latency = 2; + let NumMicroOps = 5; +} + +def V3Write_8c_2L_3V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 8; + let NumMicroOps = 5; +} + +def V3Write_9c_1L_4V : SchedWriteRes<[V3UnitL, V3UnitV, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 9; + let NumMicroOps = 5; +} + +def V3Write_10c_1L_4V : SchedWriteRes<[V3UnitL, V3UnitV, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 10; + let NumMicroOps = 5; +} + +def V3Write_6c_5V : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 6; + let NumMicroOps = 5; +} + +//===----------------------------------------------------------------------===// +// Define generic 6 micro-op types + +def V3Write_8c_3L_3V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 6; +} + +def V3Write_9c_3L_3V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3Write_9c_2L_4V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitV, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3Write_9c_2L_2V_2I : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitV, + V3UnitV, V3UnitI, V3UnitI]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3Write_9c_2V_4V13 : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV13, + V3UnitV13, V3UnitV13, V3UnitV13]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3Write_2c_3SA_3V : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 2; + let NumMicroOps = 6; +} + +def V3Write_4c_2SA_4V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 6; +} + +def V3Write_5c_2SA_4V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 5; + let NumMicroOps = 6; +} + +def V3Write_2c_3SA_3V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 6; +} + +def V3Write_4c_2SA_2I_2V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitI, + V3UnitI, V3UnitV01, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 6; +} + +//===----------------------------------------------------------------------===// +// Define generic 7 micro-op types + +def V3Write_8c_3L_4V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitV, V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 7; +} + +//===----------------------------------------------------------------------===// +// Define generic 8 micro-op types + +def V3Write_2c_4SA_4V : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitV, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 2; + let NumMicroOps = 8; +} + +def V3Write_2c_4SA_4V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01]> { + let Latency = 2; + let NumMicroOps = 8; +} + +def V3Write_6c_2SA_6V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01]> { + let Latency = 6; + let NumMicroOps = 8; +} + +def V3Write_8c_4L_4V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, V3UnitL, + V3UnitV, V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 8; + let NumMicroOps = 8; +} + +//===----------------------------------------------------------------------===// +// Define generic 9 micro-op types + +def V3Write_6c_3SA_6V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 6; + let NumMicroOps = 9; +} + +def V3Write_10c_1L_8V : SchedWriteRes<[V3UnitL, V3UnitV, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV, V3UnitV, + V3UnitV]> { + let Latency = 10; + let NumMicroOps = 9; +} + +def V3Write_10c_3V_3L_3I : SchedWriteRes<[V3UnitV, V3UnitV, V3UnitV, + V3UnitL, V3UnitL, V3UnitL, + V3UnitI, V3UnitI, V3UnitI]> { + let Latency = 10; + let NumMicroOps = 9; +} + +//===----------------------------------------------------------------------===// +// Define generic 10 micro-op types + +def V3Write_9c_6L_4V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, V3UnitL, + V3UnitL, V3UnitL, V3UnitV, V3UnitV, + V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 10; +} + +//===----------------------------------------------------------------------===// +// Define generic 12 micro-op types + +def V3Write_5c_4SA_8V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 5; + let NumMicroOps = 12; +} + +def V3Write_9c_4L_8V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitL, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 9; + let NumMicroOps = 12; +} + +def V3Write_10c_4L_8V : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitL, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV]> { + let Latency = 10; + let NumMicroOps = 12; +} + +def V3Write_4c_6SA_6V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 4; + let NumMicroOps = 12; +} + +//===----------------------------------------------------------------------===// +// Define generic 16 micro-op types + +def V3Write_7c_4SA_12V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01]> { + let Latency = 7; + let NumMicroOps = 16; +} + +def V3Write_10c_4L_8V_4I : SchedWriteRes<[V3UnitL, V3UnitL, V3UnitL, + V3UnitL, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV, + V3UnitV, V3UnitV, V3UnitV, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI]> { + let Latency = 10; + let NumMicroOps = 16; +} + +//===----------------------------------------------------------------------===// +// Define generic 18 micro-op types + +def V3Write_7c_9SA_9V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01]> { + let Latency = 7; + let NumMicroOps = 18; +} + +//===----------------------------------------------------------------------===// +// Define generic 27 micro-op types + +def V3Write_7c_9SA_9I_9V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01]> { + let Latency = 7; + let NumMicroOps = 27; +} + +//===----------------------------------------------------------------------===// +// Define generic 36 micro-op types + +def V3Write_11c_18SA_18V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, V3UnitSA, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01]> { + let Latency = 11; + let NumMicroOps = 36; +} + +//===----------------------------------------------------------------------===// +// Define generic 54 micro-op types + +def V3Write_11c_18SA_18I_18V01 : SchedWriteRes<[V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitSA, V3UnitSA, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitI, V3UnitI, V3UnitI, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01, + V3UnitV01, V3UnitV01]> { + let Latency = 11; + let NumMicroOps = 54; +} + +//===----------------------------------------------------------------------===// +// Define predicate-controlled types + +def V3Write_ArithI : SchedWriteVariant<[ + SchedVar<IsCheapLSL, [V3Write_1c_1I]>, + SchedVar<NoSchedPred, [V3Write_2c_1M]>]>; + +def V3Write_ArithF : SchedWriteVariant<[ + SchedVar<IsCheapLSL, [V3Write_1c_1F_1Flg]>, + SchedVar<NoSchedPred, [V3Write_2c_1M_1Flg]>]>; + +def V3Write_Logical : SchedWriteVariant<[ + SchedVar<NeoverseNoLSL, [V3Write_1c_1F_1Flg]>, + SchedVar<NoSchedPred, [V3Write_2c_1M_1Flg]>]>; + +def V3Write_Extr : SchedWriteVariant<[ + SchedVar<IsRORImmIdiomPred, [V3Write_1c_1I]>, + SchedVar<NoSchedPred, [V3Write_3c_1I_1M]>]>; + +def V3Write_LdrHQ : SchedWriteVariant<[ + SchedVar<NeoverseHQForm, [V3Write_7c_1I_1L]>, + SchedVar<NoSchedPred, [V3Write_6c_1L]>]>; + +def V3Write_StrHQ : SchedWriteVariant<[ + SchedVar<NeoverseHQForm, [V3Write_2c_1SA_1V01_1I]>, + SchedVar<NoSchedPred, [V3Write_2c_1SA_1V01]>]>; + +def V3Write_0or1c_1I : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3Write_0c]>, + SchedVar<NoSchedPred, [V3Write_1c_1I]>]>; + +def V3Write_0or2c_1V : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3Write_0c]>, + SchedVar<NoSchedPred, [V3Write_2c_1V]>]>; + +def V3Write_0or3c_1M0 : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3Write_0c]>, + SchedVar<NoSchedPred, [V3Write_3c_1M0]>]>; + +def V3Write_2or3c_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3Write_3c_1M]>, + SchedVar<NoSchedPred, [V3Write_2c_1M]>]>; + +def V3Write_1or2c_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3Write_2c_1M]>, + SchedVar<NoSchedPred, [V3Write_1c_1M]>]>; + +def V3Write_3or4c_1M0_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3Write_4c_1M0_1M]>, + SchedVar<NoSchedPred, [V3Write_3c_1M0_1M]>]>; + +def V3Write_2or3c_1V0 : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3Write_3c_1V0]>, + SchedVar<NoSchedPred, [V3Write_2c_1V0]>]>; + +def V3Write_2or3c_1V0_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3Write_3c_1V0_1M]>, + SchedVar<NoSchedPred, [V3Write_2c_1V0_1M]>]>; + +def V3Write_IncDec : SchedWriteVariant<[ + SchedVar<NeoverseCheapIncDec, [V3Write_1c_1I]>, + SchedVar<NoSchedPred, [V3Write_2c_1M]>]>; + +//===----------------------------------------------------------------------===// +// Define forwarded types + +// NOTE: SOG, p. 16, n. 2: Accumulator forwarding is not supported for +// consumers of 64 bit multiply high operations? +def V3Wr_IM : SchedWriteRes<[V3UnitM]> { let Latency = 2; } + +def V3Wr_FMA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_FMA : SchedReadAdvance<2, [WriteFMul, V3Wr_FMA]>; + +def V3Wr_VA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VA : SchedReadAdvance<3, [V3Wr_VA]>; + +def V3Wr_VDOT : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Rd_VDOT : SchedReadAdvance<2, [V3Wr_VDOT]>; + +def V3Wr_VMMA : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Rd_VMMA : SchedReadAdvance<2, [V3Wr_VMMA]>; + +def V3Wr_VMA : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Rd_VMA : SchedReadAdvance<3, [V3Wr_VMA]>; + +def V3Wr_VMAH : SchedWriteRes<[V3UnitV02, V3UnitV02]> { let Latency = 4; } +def V3Rd_VMAH : SchedReadAdvance<2, [V3Wr_VMAH]>; + +def V3Wr_VMAL : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Rd_VMAL : SchedReadAdvance<3, [V3Wr_VMAL]>; + +def V3Wr_VPA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VPA : SchedReadAdvance<3, [V3Wr_VPA]>; + +def V3Wr_VSA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VSA : SchedReadAdvance<3, [V3Wr_VSA]>; + +def V3Wr_VFCMA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VFCMA : SchedReadAdvance<2, [V3Wr_VFCMA]>; + +def V3Wr_VFM : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Wr_VFMA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VFMA : SchedReadAdvance<2, [V3Wr_VFM, V3Wr_VFMA]>; + +def V3Wr_VFMAL : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_VFMAL : SchedReadAdvance<2, [V3Wr_VFMAL]>; + +def V3Wr_VBFDOT : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Rd_VBFDOT : SchedReadAdvance<2, [V3Wr_VBFDOT]>; +def V3Wr_VBFMMA : SchedWriteRes<[V3UnitV]> { let Latency = 6; } +def V3Rd_VBFMMA : SchedReadAdvance<2, [V3Wr_VBFMMA]>; +def V3Wr_VBFMAL : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Rd_VBFMAL : SchedReadAdvance<3, [V3Wr_VBFMAL]>; + +def V3Wr_CRC : SchedWriteRes<[V3UnitM0]> { let Latency = 2; } +def V3Rd_CRC : SchedReadAdvance<1, [V3Wr_CRC]>; + +def V3Wr_ZA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_ZA : SchedReadAdvance<3, [V3Wr_ZA]>; +def V3Wr_ZPA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_ZPA : SchedReadAdvance<3, [V3Wr_ZPA]>; +def V3Wr_ZSA : SchedWriteRes<[V3UnitV13]> { let Latency = 4; } +def V3Rd_ZSA : SchedReadAdvance<3, [V3Wr_ZSA]>; + +def V3Wr_ZDOTB : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Rd_ZDOTB : SchedReadAdvance<2, [V3Wr_ZDOTB]>; +def V3Wr_ZDOTH : SchedWriteRes<[V3UnitV02]> { let Latency = 3; } +def V3Rd_ZDOTH : SchedReadAdvance<2, [V3Wr_ZDOTH]>; + +// NOTE: SOG p. 43: Complex multiply-add B, H, S element size: How to reduce +// throughput to 1 in case of forwarding? +def V3Wr_ZCMABHS : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Rd_ZCMABHS : SchedReadAdvance<3, [V3Wr_ZCMABHS]>; +def V3Wr_ZCMAD : SchedWriteRes<[V3UnitV02, V3UnitV02]> { let Latency = 5; } +def V3Rd_ZCMAD : SchedReadAdvance<2, [V3Wr_ZCMAD]>; + +def V3Wr_ZMMA : SchedWriteRes<[V3UnitV]> { let Latency = 3; } +def V3Rd_ZMMA : SchedReadAdvance<2, [V3Wr_ZMMA]>; + +def V3Wr_ZMABHS : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Rd_ZMABHS : SchedReadAdvance<3, [V3Wr_ZMABHS]>; +def V3Wr_ZMAD : SchedWriteRes<[V3UnitV02, V3UnitV02]> { let Latency = 5; } +def V3Rd_ZMAD : SchedReadAdvance<2, [V3Wr_ZMAD]>; + +def V3Wr_ZMAL : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Rd_ZMAL : SchedReadAdvance<3, [V3Wr_ZMAL]>; + +def V3Wr_ZMASQL : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Wr_ZMASQBHS : SchedWriteRes<[V3UnitV02]> { let Latency = 4; } +def V3Wr_ZMASQD : SchedWriteRes<[V3UnitV02, V3UnitV02]> { let Latency = 5; } +def V3Rd_ZMASQ : SchedReadAdvance<2, [V3Wr_ZMASQL, V3Wr_ZMASQBHS, + V3Wr_ZMASQD]>; + +def V3Wr_ZFCMA : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Rd_ZFCMA : SchedReadAdvance<3, [V3Wr_ZFCMA]>; + +def V3Wr_ZFMA : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_ZFMA : SchedReadAdvance<2, [V3Wr_ZFMA]>; + +def V3Wr_ZFMAL : SchedWriteRes<[V3UnitV]> { let Latency = 4; } +def V3Rd_ZFMAL : SchedReadAdvance<2, [V3Wr_ZFMAL]>; + +def V3Wr_ZBFDOT : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Rd_ZBFDOT : SchedReadAdvance<2, [V3Wr_ZBFDOT]>; +def V3Wr_ZBFMMA : SchedWriteRes<[V3UnitV]> { let Latency = 6; } +def V3Rd_ZBFMMA : SchedReadAdvance<2, [V3Wr_ZBFMMA]>; +def V3Wr_ZBFMAL : SchedWriteRes<[V3UnitV]> { let Latency = 5; } +def V3Rd_ZBFMAL : SchedReadAdvance<3, [V3Wr_ZBFMAL]>; + +//===----------------------------------------------------------------------===// +// Define types with long resource cycles (rc) + +def V3Write_6c_1V1_5rc : SchedWriteRes<[V3UnitV1]> { let Latency = 6; let ReleaseAtCycles = [ 5]; } +def V3Write_9c_1V1_2rc : SchedWriteRes<[V3UnitV1]> { let Latency = 9; let ReleaseAtCycles = [ 2]; } +def V3Write_9c_1V1_4rc : SchedWriteRes<[V3UnitV1]> { let Latency = 9; let ReleaseAtCycles = [ 4]; } +def V3Write_10c_1V1_9rc : SchedWriteRes<[V3UnitV1]> { let Latency = 10; let ReleaseAtCycles = [ 9]; } +def V3Write_11c_1V1_4rc : SchedWriteRes<[V3UnitV1]> { let Latency = 11; let ReleaseAtCycles = [ 4]; } +def V3Write_13c_1V1_8rc : SchedWriteRes<[V3UnitV1]> { let Latency = 13; let ReleaseAtCycles = [8]; } +def V3Write_14c_1V1_2rc : SchedWriteRes<[V3UnitV1]> { let Latency = 14; let ReleaseAtCycles = [2]; } + +// Miscellaneous +// ----------------------------------------------------------------------------- + +def : InstRW<[WriteI], (instrs COPY)>; + +// §3.3 Branch instructions +// ----------------------------------------------------------------------------- + +// Branch, immed +// Compare and branch +def : SchedAlias<WriteBr, V3Write_1c_1B>; + +// Branch, register +def : SchedAlias<WriteBrReg, V3Write_1c_1B>; + +// Branch and link, immed +// Branch and link, register +def : InstRW<[V3Write_1c_1B_1S], (instrs BL, BLR)>; + +// §3.4 Arithmetic and Logical Instructions +// ----------------------------------------------------------------------------- + +// ALU, basic +def : SchedAlias<WriteI, V3Write_1c_1I>; + +// ALU, basic, flagset +def : InstRW<[V3Write_1c_1F_1Flg], + (instregex "^(ADD|SUB)S[WX]r[ir]$", + "^(ADC|SBC)S[WX]r$", + "^ANDS[WX]ri$", + "^(AND|BIC)S[WX]rr$")>; +def : InstRW<[V3Write_0or1c_1I], (instregex "^MOVZ[WX]i$")>; + +// ALU, extend and shift +def : SchedAlias<WriteIEReg, V3Write_2c_1M>; + +// Arithmetic, LSL shift, shift <= 4 +// Arithmetic, flagset, LSL shift, shift <= 4 +// Arithmetic, LSR/ASR/ROR shift or LSL shift > 4 +def : SchedAlias<WriteISReg, V3Write_ArithI>; +def : InstRW<[V3Write_ArithF], + (instregex "^(ADD|SUB)S[WX]rs$")>; + +// Arithmetic, immediate to logical address tag +def : InstRW<[V3Write_2c_1M], (instrs ADDG, SUBG)>; + +// Conditional compare +def : InstRW<[V3Write_1c_1F_1Flg], (instregex "^CCM[NP][WX][ir]")>; + +// Convert floating-point condition flags +// Flag manipulation instructions +def : WriteRes<WriteSys, []> { let Latency = 1; } + +// Insert Random Tags +def : InstRW<[V3Write_2c_1M], (instrs IRG, IRGstack)>; + +// Insert Tag Mask +// Subtract Pointer +def : InstRW<[V3Write_1c_1I], (instrs GMI, SUBP)>; + +// Subtract Pointer, flagset +def : InstRW<[V3Write_1c_1F_1Flg], (instrs SUBPS)>; + +// Logical, shift, no flagset +def : InstRW<[V3Write_1c_1I], (instregex "^(AND|BIC|EON|EOR|ORN)[WX]rs$")>; +def : InstRW<[V3Write_0or1c_1I], (instregex "^ORR[WX]rs$")>; + +// Logical, shift, flagset +def : InstRW<[V3Write_Logical], (instregex "^(AND|BIC)S[WX]rs$")>; + +// Move and shift instructions +// ----------------------------------------------------------------------------- + +def : SchedAlias<WriteImm, V3Write_1c_1I>; + +// §3.5 Divide and multiply instructions +// ----------------------------------------------------------------------------- + +// SDIV, UDIV +def : SchedAlias<WriteID32, V3Write_12c_1M0>; +def : SchedAlias<WriteID64, V3Write_20c_1M0>; + +def : SchedAlias<WriteIM32, V3Write_2c_1M>; +def : SchedAlias<WriteIM64, V3Write_2c_1M>; + +// Multiply +// Multiply accumulate, W-form +// Multiply accumulate, X-form +def : InstRW<[V3Wr_IM], (instregex "^M(ADD|SUB)[WX]rrr$")>; + +// Multiply accumulate long +// Multiply long +def : InstRW<[V3Wr_IM], (instregex "^(S|U)M(ADD|SUB)Lrrr$")>; + +// Multiply high +def : InstRW<[V3Write_3c_1M], (instrs SMULHrr, UMULHrr)>; + +// §3.6 Pointer Authentication Instructions (v8.3 PAC) +// ----------------------------------------------------------------------------- + +// Authenticate data address +// Authenticate instruction address +// Compute pointer authentication code for data address +// Compute pointer authentication code, using generic key +// Compute pointer authentication code for instruction address +def : InstRW<[V3Write_4c_1M0], (instregex "^AUT", "^PAC")>; + +// Branch and link, register, with pointer authentication +// Branch, register, with pointer authentication +// Branch, return, with pointer authentication +def : InstRW<[V3Write_6c_1M0_1B], (instrs BLRAA, BLRAAZ, BLRAB, BLRABZ, BRAA, + BRAAZ, BRAB, BRABZ, RETAA, RETAB, + ERETAA, ERETAB)>; + + +// Load register, with pointer authentication +def : InstRW<[V3Write_9c_1M0_1L], (instregex "^LDRA[AB](indexed|writeback)")>; + +// Strip pointer authentication code +def : InstRW<[V3Write_2c_1M0], (instrs XPACD, XPACI, XPACLRI)>; + +// §3.7 Miscellaneous data-processing instructions +// ----------------------------------------------------------------------------- + +// Address generation +def : InstRW<[V3Write_1c_1I], (instrs ADR, ADRP)>; + +// Bitfield extract, one reg +// Bitfield extract, two regs +def : SchedAlias<WriteExtr, V3Write_Extr>; +def : InstRW<[V3Write_Extr], (instrs EXTRWrri, EXTRXrri)>; + +// Bitfield move, basic +def : SchedAlias<WriteIS, V3Write_1c_1I>; + +// Bitfield move, insert +def : InstRW<[V3Write_2c_1M], (instregex "^BFM[WX]ri$")>; + +// §3.8 Load instructions +// ----------------------------------------------------------------------------- + +// NOTE: SOG p. 19: Throughput of LDN?P X-form should be 2, but reported as 3. + +def : SchedAlias<WriteLD, V3Write_4c_1L>; +def : SchedAlias<WriteLDIdx, V3Write_4c_1L>; + +// Load register, literal +def : InstRW<[V3Write_5c_1L_1I], (instrs LDRWl, LDRXl, LDRSWl, PRFMl)>; + +// Load pair, signed immed offset, signed words +def : InstRW<[V3Write_5c_1I_3L, WriteLDHi], (instrs LDPSWi)>; + +// Load pair, immed post-index or immed pre-index, signed words +def : InstRW<[WriteAdr, V3Write_5c_1I_3L, WriteLDHi], + (instregex "^LDPSW(post|pre)$")>; + +// §3.9 Store instructions +// ----------------------------------------------------------------------------- + +// NOTE: SOG, p. 20: Unsure if STRH uses pipeline I. + +def : SchedAlias<WriteST, V3Write_1c_1SA_1D>; +def : SchedAlias<WriteSTIdx, V3Write_1c_1SA_1D>; +def : SchedAlias<WriteSTP, V3Write_1c_1SA_1D>; +def : SchedAlias<WriteAdr, V3Write_1c_1I>; + +// §3.10 Tag load instructions +// ----------------------------------------------------------------------------- + +// Load allocation tag +// Load multiple allocation tags +def : InstRW<[V3Write_4c_1L], (instrs LDG, LDGM)>; + +// §3.11 Tag store instructions +// ----------------------------------------------------------------------------- + +// Store allocation tags to one or two granules, post-index +// Store allocation tags to one or two granules, pre-index +// Store allocation tag to one or two granules, zeroing, post-index +// Store Allocation Tag to one or two granules, zeroing, pre-index +// Store allocation tag and reg pair to memory, post-Index +// Store allocation tag and reg pair to memory, pre-Index +def : InstRW<[V3Write_1c_1SA_1D_1I], (instrs STGPreIndex, STGPostIndex, + ST2GPreIndex, ST2GPostIndex, + STZGPreIndex, STZGPostIndex, + STZ2GPreIndex, STZ2GPostIndex, + STGPpre, STGPpost)>; + +// Store allocation tags to one or two granules, signed offset +// Store allocation tag to two granules, zeroing, signed offset +// Store allocation tag and reg pair to memory, signed offset +// Store multiple allocation tags +def : InstRW<[V3Write_1c_1SA_1D], (instrs STGi, ST2Gi, STZGi, + STZ2Gi, STGPi, STGM, STZGM)>; + +// §3.12 FP data processing instructions +// ----------------------------------------------------------------------------- + +// FP absolute value +// FP arithmetic +// FP min/max +// FP negate +// FP select +def : SchedAlias<WriteF, V3Write_2c_1V>; + +// FP compare +def : SchedAlias<WriteFCmp, V3Write_2c_1V0>; + +// FP divide, square root +def : SchedAlias<WriteFDiv, V3Write_6c_1V1>; + +// FP divide, H-form +def : InstRW<[V3Write_6c_1V1], (instrs FDIVHrr)>; +// FP divide, S-form +def : InstRW<[V3Write_8c_1V1], (instrs FDIVSrr)>; +// FP divide, D-form +def : InstRW<[V3Write_13c_1V1], (instrs FDIVDrr)>; + +// FP square root, H-form +def : InstRW<[V3Write_6c_1V1], (instrs FSQRTHr)>; +// FP square root, S-form +def : InstRW<[V3Write_8c_1V1], (instrs FSQRTSr)>; +// FP square root, D-form +def : InstRW<[V3Write_13c_1V1], (instrs FSQRTDr)>; + +// FP multiply +def : WriteRes<WriteFMul, [V3UnitV]> { let Latency = 3; } + +// FP multiply accumulate +def : InstRW<[V3Wr_FMA, ReadDefault, ReadDefault, V3Rd_FMA], + (instregex "^FN?M(ADD|SUB)[HSD]rrr$")>; + +// FP round to integral +def : InstRW<[V3Write_3c_1V02], (instregex "^FRINT[AIMNPXZ][HSD]r$", + "^FRINT(32|64)[XZ][SD]r$")>; + +// §3.13 FP miscellaneous instructions +// ----------------------------------------------------------------------------- + +// FP convert, from gen to vec reg +def : InstRW<[V3Write_3c_1M0], (instregex "^[SU]CVTF[SU][WX][HSD]ri$")>; + +// FP convert, from vec to gen reg +def : InstRW<[V3Write_3c_1V01], + (instregex "^FCVT[AMNPZ][SU][SU][WX][HSD]ri?$")>; + +// FP convert, Javascript from vec to gen reg +def : SchedAlias<WriteFCvt, V3Write_3c_1V0>; + +// FP convert, from vec to vec reg +def : InstRW<[V3Write_3c_1V02], (instrs FCVTSHr, FCVTDHr, FCVTHSr, FCVTDSr, + FCVTHDr, FCVTSDr, FCVTXNv1i64)>; + +// FP move, immed +// FP move, register +def : SchedAlias<WriteFImm, V3Write_2c_1V>; + +// FP transfer, from gen to low half of vec reg +def : InstRW<[V3Write_0or3c_1M0], + (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr)>; + +// FP transfer, from gen to high half of vec reg +def : InstRW<[V3Write_5c_1M0_1V], (instrs FMOVXDHighr)>; + +// FP transfer, from vec to gen reg +def : SchedAlias<WriteFCopy, V3Write_2c_2V01>; + +// §3.14 FP load instructions +// ----------------------------------------------------------------------------- + +// Load vector reg, literal, S/D/Q forms +def : InstRW<[V3Write_7c_1I_1L], (instregex "^LDR[SDQ]l$")>; + +// Load vector reg, unscaled immed +def : InstRW<[V3Write_6c_1L], (instregex "^LDUR[BHSDQ]i$")>; + +// Load vector reg, immed post-index +// Load vector reg, immed pre-index +def : InstRW<[WriteAdr, V3Write_6c_1I_1L], + (instregex "^LDR[BHSDQ](pre|post)$")>; + +// Load vector reg, unsigned immed +def : InstRW<[V3Write_6c_1L], (instregex "^LDR[BHSDQ]ui$")>; + +// Load vector reg, register offset, basic +// Load vector reg, register offset, scale, S/D-form +// Load vector reg, register offset, scale, H/Q-form +// Load vector reg, register offset, extend +// Load vector reg, register offset, extend, scale, S/D-form +// Load vector reg, register offset, extend, scale, H/Q-form +def : InstRW<[V3Write_LdrHQ, ReadAdrBase], (instregex "^LDR[BHSDQ]ro[WX]$")>; + +// Load vector pair, immed offset, S/D-form +def : InstRW<[V3Write_6c_1L, WriteLDHi], (instregex "^LDN?P[SD]i$")>; + +// Load vector pair, immed offset, Q-form +def : InstRW<[V3Write_6c_2L, WriteLDHi], (instrs LDPQi, LDNPQi)>; + +// Load vector pair, immed post-index, S/D-form +// Load vector pair, immed pre-index, S/D-form +def : InstRW<[WriteAdr, V3Write_6c_1I_1L, WriteLDHi], + (instregex "^LDP[SD](pre|post)$")>; + +// Load vector pair, immed post-index, Q-form +// Load vector pair, immed pre-index, Q-form +def : InstRW<[WriteAdr, V3Write_6c_2I_2L, WriteLDHi], (instrs LDPQpost, + LDPQpre)>; + +// §3.15 FP store instructions +// ----------------------------------------------------------------------------- + +// Store vector reg, unscaled immed, B/H/S/D-form +// Store vector reg, unscaled immed, Q-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^STUR[BHSDQ]i$")>; + +// Store vector reg, immed post-index, B/H/S/D-form +// Store vector reg, immed post-index, Q-form +// Store vector reg, immed pre-index, B/H/S/D-form +// Store vector reg, immed pre-index, Q-form +def : InstRW<[WriteAdr, V3Write_2c_1SA_1V01_1I], + (instregex "^STR[BHSDQ](pre|post)$")>; + +// Store vector reg, unsigned immed, B/H/S/D-form +// Store vector reg, unsigned immed, Q-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^STR[BHSDQ]ui$")>; + +// Store vector reg, register offset, basic, B/H/S/D-form +// Store vector reg, register offset, basic, Q-form +// Store vector reg, register offset, scale, H-form +// Store vector reg, register offset, scale, S/D-form +// Store vector reg, register offset, scale, Q-form +// Store vector reg, register offset, extend, B/H/S/D-form +// Store vector reg, register offset, extend, Q-form +// Store vector reg, register offset, extend, scale, H-form +// Store vector reg, register offset, extend, scale, S/D-form +// Store vector reg, register offset, extend, scale, Q-form +def : InstRW<[V3Write_StrHQ, ReadAdrBase], + (instregex "^STR[BHSDQ]ro[WX]$")>; + +// Store vector pair, immed offset, S-form +// Store vector pair, immed offset, D-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^STN?P[SD]i$")>; + +// Store vector pair, immed offset, Q-form +def : InstRW<[V3Write_2c_1SA_2V01], (instrs STPQi, STNPQi)>; + +// Store vector pair, immed post-index, S-form +// Store vector pair, immed post-index, D-form +// Store vector pair, immed pre-index, S-form +// Store vector pair, immed pre-index, D-form +def : InstRW<[WriteAdr, V3Write_2c_1SA_1V01_1I], + (instregex "^STP[SD](pre|post)$")>; + +// Store vector pair, immed post-index, Q-form +def : InstRW<[V3Write_2c_1SA_2V01_1I], (instrs STPQpost)>; + +// Store vector pair, immed pre-index, Q-form +def : InstRW<[V3Write_2c_1SA_2V01_2I], (instrs STPQpre)>; + +// §3.16 ASIMD integer instructions +// ----------------------------------------------------------------------------- + +// ASIMD absolute diff +// ASIMD absolute diff long +// ASIMD arith, basic +// ASIMD arith, complex +// ASIMD arith, pair-wise +// ASIMD compare +// ASIMD logical +// ASIMD max/min, basic and pair-wise +def : SchedAlias<WriteVd, V3Write_2c_1V>; +def : SchedAlias<WriteVq, V3Write_2c_1V>; + +// ASIMD absolute diff accum +// ASIMD absolute diff accum long +def : InstRW<[V3Wr_VA, V3Rd_VA], (instregex "^[SU]ABAL?v")>; + +// ASIMD arith, reduce, 4H/4S +def : InstRW<[V3Write_3c_1V13], (instregex "^(ADDV|[SU]ADDLV)v4(i16|i32)v$")>; + +// ASIMD arith, reduce, 8B/8H +def : InstRW<[V3Write_5c_1V13_1V], + (instregex "^(ADDV|[SU]ADDLV)v8(i8|i16)v$")>; + +// ASIMD arith, reduce, 16B +def : InstRW<[V3Write_6c_2V13], (instregex "^(ADDV|[SU]ADDLV)v16i8v$")>; + +// ASIMD dot product +// ASIMD dot product using signed and unsigned integers +def : InstRW<[V3Wr_VDOT, V3Rd_VDOT], + (instregex "^([SU]|SU|US)DOT(lane)?(v8|v16)i8$")>; + +// ASIMD matrix multiply-accumulate +def : InstRW<[V3Wr_VMMA, V3Rd_VMMA], (instrs SMMLA, UMMLA, USMMLA)>; + +// ASIMD max/min, reduce, 4H/4S +def : InstRW<[V3Write_3c_1V13], (instregex "^[SU](MAX|MIN)Vv4i16v$", + "^[SU](MAX|MIN)Vv4i32v$")>; + +// ASIMD max/min, reduce, 8B/8H +def : InstRW<[V3Write_5c_1V13_1V], (instregex "^[SU](MAX|MIN)Vv8i8v$", + "^[SU](MAX|MIN)Vv8i16v$")>; + +// ASIMD max/min, reduce, 16B +def : InstRW<[V3Write_6c_2V13], (instregex "[SU](MAX|MIN)Vv16i8v$")>; + +// ASIMD multiply +def : InstRW<[V3Write_4c_1V02], (instregex "^MULv", "^SQ(R)?DMULHv")>; + +// ASIMD multiply accumulate +def : InstRW<[V3Wr_VMA, V3Rd_VMA], (instregex "^MLAv", "^MLSv")>; + +// ASIMD multiply accumulate high +def : InstRW<[V3Wr_VMAH, V3Rd_VMAH], (instregex "^SQRDMLAHv", "^SQRDMLSHv")>; + +// ASIMD multiply accumulate long +def : InstRW<[V3Wr_VMAL, V3Rd_VMAL], (instregex "^[SU]MLALv", "^[SU]MLSLv")>; + +// ASIMD multiply accumulate saturating long +def : InstRW<[V3Write_4c_1V02], (instregex "^SQDML[AS]L[iv]")>; + +// ASIMD multiply/multiply long (8x8) polynomial, D-form +// ASIMD multiply/multiply long (8x8) polynomial, Q-form +def : InstRW<[V3Write_3c_1V], (instregex "^PMULL?(v8i8|v16i8)$")>; + +// ASIMD multiply long +def : InstRW<[V3Write_3c_1V02], (instregex "^[SU]MULLv", "^SQDMULL[iv]")>; + +// ASIMD pairwise add and accumulate long +def : InstRW<[V3Wr_VPA, V3Rd_VPA], (instregex "^[SU]ADALPv")>; + +// ASIMD shift accumulate +def : InstRW<[V3Wr_VSA, V3Rd_VSA], (instregex "^[SU]SRA[dv]", "^[SU]RSRA[dv]")>; + +// ASIMD shift by immed, basic +def : InstRW<[V3Write_2c_1V], (instregex "^SHL[dv]", "^SHLLv", "^SHRNv", + "^SSHLLv", "^SSHR[dv]", "^USHLLv", + "^USHR[dv]")>; + +// ASIMD shift by immed and insert, basic +def : InstRW<[V3Write_2c_1V], (instregex "^SLI[dv]", "^SRI[dv]")>; + +// ASIMD shift by immed, complex +def : InstRW<[V3Write_4c_1V], + (instregex "^RSHRNv", "^SQRSHRU?N[bhsv]", "^(SQSHLU?|UQSHL)[bhsd]$", + "^(SQSHLU?|UQSHL)(v8i8|v16i8|v4i16|v8i16|v2i32|v4i32|v2i64)_shift$", + "^SQSHRU?N[bhsv]", "^SRSHR[dv]", "^UQRSHRN[bhsv]", + "^UQSHRN[bhsv]", "^URSHR[dv]")>; + +// ASIMD shift by register, basic +def : InstRW<[V3Write_2c_1V], (instregex "^[SU]SHLv")>; + +// ASIMD shift by register, complex +def : InstRW<[V3Write_4c_1V], + (instregex "^[SU]RSHLv", "^[SU]QRSHLv", + "^[SU]QSHL(v1i8|v1i16|v1i32|v1i64|v8i8|v16i8|v4i16|v8i16|v2i32|v4i32|v2i64)$")>; + +// §3.17 ASIMD floating-point instructions +// ----------------------------------------------------------------------------- + +// ASIMD FP absolute value/difference +// ASIMD FP arith, normal +// ASIMD FP compare +// ASIMD FP complex add +// ASIMD FP max/min, normal +// ASIMD FP max/min, pairwise +// ASIMD FP negate +// Handled by SchedAlias<WriteV[dq], ...> + +// ASIMD FP complex multiply add +def : InstRW<[V3Wr_VFCMA, V3Rd_VFCMA], (instregex "^FCMLAv")>; + +// ASIMD FP convert, long (F16 to F32) +def : InstRW<[V3Write_4c_2V02], (instregex "^FCVTL(v4|v8)i16")>; + +// ASIMD FP convert, long (F32 to F64) +def : InstRW<[V3Write_3c_1V02], (instregex "^FCVTL(v2|v4)i32")>; + +// ASIMD FP convert, narrow (F32 to F16) +def : InstRW<[V3Write_4c_2V02], (instregex "^FCVTN(v4|v8)i16")>; + +// ASIMD FP convert, narrow (F64 to F32) +def : InstRW<[V3Write_3c_1V02], (instregex "^FCVTN(v2|v4)i32", + "^FCVTXN(v2|v4)f32")>; + +// ASIMD FP convert, other, D-form F32 and Q-form F64 +def : InstRW<[V3Write_3c_1V02], (instregex "^FCVT[AMNPZ][SU]v2f(32|64)$", + "^FCVT[AMNPZ][SU]v2i(32|64)_shift$", + "^FCVT[AMNPZ][SU]v1i64$", + "^FCVTZ[SU]d$", + "^[SU]CVTFv2f(32|64)$", + "^[SU]CVTFv2i(32|64)_shift$", + "^[SU]CVTFv1i64$", + "^[SU]CVTFd$")>; + +// ASIMD FP convert, other, D-form F16 and Q-form F32 +def : InstRW<[V3Write_4c_2V02], (instregex "^FCVT[AMNPZ][SU]v4f(16|32)$", + "^FCVT[AMNPZ][SU]v4i(16|32)_shift$", + "^FCVT[AMNPZ][SU]v1i32$", + "^FCVTZ[SU]s$", + "^[SU]CVTFv4f(16|32)$", + "^[SU]CVTFv4i(16|32)_shift$", + "^[SU]CVTFv1i32$", + "^[SU]CVTFs$")>; + +// ASIMD FP convert, other, Q-form F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FCVT[AMNPZ][SU]v8f16$", + "^FCVT[AMNPZ][SU]v8i16_shift$", + "^FCVT[AMNPZ][SU]v1f16$", + "^FCVTZ[SU]h$", + "^[SU]CVTFv8f16$", + "^[SU]CVTFv8i16_shift$", + "^[SU]CVTFv1i16$", + "^[SU]CVTFh$")>; + +// ASIMD FP divide, D-form, F16 +def : InstRW<[V3Write_9c_1V1_4rc], (instrs FDIVv4f16)>; + +// ASIMD FP divide, D-form, F32 +def : InstRW<[V3Write_9c_1V1_2rc], (instrs FDIVv2f32)>; + +// ASIMD FP divide, Q-form, F16 +def : InstRW<[V3Write_13c_1V1_8rc], (instrs FDIVv8f16)>; + +// ASIMD FP divide, Q-form, F32 +def : InstRW<[V3Write_11c_1V1_4rc], (instrs FDIVv4f32)>; + +// ASIMD FP divide, Q-form, F64 +def : InstRW<[V3Write_14c_1V1_2rc], (instrs FDIVv2f64)>; + +// ASIMD FP max/min, reduce, F32 and D-form F16 +def : InstRW<[V3Write_4c_2V], (instregex "^(FMAX|FMIN)(NM)?Vv4(i16|i32)v$")>; + +// ASIMD FP max/min, reduce, Q-form F16 +def : InstRW<[V3Write_6c_3V], (instregex "^(FMAX|FMIN)(NM)?Vv8i16v$")>; + +// ASIMD FP multiply +def : InstRW<[V3Wr_VFM], (instregex "^FMULv", "^FMULXv")>; + +// ASIMD FP multiply accumulate +def : InstRW<[V3Wr_VFMA, V3Rd_VFMA], (instregex "^FMLAv", "^FMLSv")>; + +// ASIMD FP multiply accumulate long +def : InstRW<[V3Wr_VFMAL, V3Rd_VFMAL], (instregex "^FML[AS]L2?(lane)?v")>; + +// ASIMD FP round, D-form F32 and Q-form F64 +def : InstRW<[V3Write_3c_1V02], + (instregex "^FRINT[AIMNPXZ]v2f(32|64)$", + "^FRINT(32|64)[XZ]v2f(32|64)$")>; + +// ASIMD FP round, D-form F16 and Q-form F32 +def : InstRW<[V3Write_4c_2V02], + (instregex "^FRINT[AIMNPXZ]v4f(16|32)$", + "^FRINT(32|64)[XZ]v4f32$")>; + +// ASIMD FP round, Q-form F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FRINT[AIMNPXZ]v8f16$")>; + +// ASIMD FP square root, D-form, F16 +def : InstRW<[V3Write_9c_1V1_4rc], (instrs FSQRTv4f16)>; + +// ASIMD FP square root, D-form, F32 +def : InstRW<[V3Write_9c_1V1_2rc], (instrs FSQRTv2f32)>; + +// ASIMD FP square root, Q-form, F16 +def : InstRW<[V3Write_13c_1V1_8rc], (instrs FSQRTv8f16)>; + +// ASIMD FP square root, Q-form, F32 +def : InstRW<[V3Write_11c_1V1_4rc], (instrs FSQRTv4f32)>; + +// ASIMD FP square root, Q-form, F64 +def : InstRW<[V3Write_14c_1V1_2rc], (instrs FSQRTv2f64)>; + +// §3.18 ASIMD BFloat16 (BF16) instructions +// ----------------------------------------------------------------------------- + +// ASIMD convert, F32 to BF16 +def : InstRW<[V3Write_4c_2V02], (instrs BFCVTN, BFCVTN2)>; + +// ASIMD dot product +def : InstRW<[V3Wr_VBFDOT, V3Rd_VBFDOT], (instrs BFDOTv4bf16, BFDOTv8bf16)>; + +// ASIMD matrix multiply accumulate +def : InstRW<[V3Wr_VBFMMA, V3Rd_VBFMMA], (instrs BFMMLA)>; + +// ASIMD multiply accumulate long +def : InstRW<[V3Wr_VBFMAL, V3Rd_VBFMAL], (instrs BFMLALB, BFMLALBIdx, BFMLALT, + BFMLALTIdx)>; + +// Scalar convert, F32 to BF16 +def : InstRW<[V3Write_3c_1V02], (instrs BFCVT)>; + +// §3.19 ASIMD miscellaneous instructions +// ----------------------------------------------------------------------------- + +// ASIMD bit reverse +// ASIMD bitwise insert +// ASIMD count +// ASIMD duplicate, element +// ASIMD extract +// ASIMD extract narrow +// ASIMD insert, element to element +// ASIMD move, FP immed +// ASIMD move, integer immed +// ASIMD reverse +// ASIMD table lookup extension, 1 table reg +// ASIMD transpose +// ASIMD unzip/zip +// Handled by SchedAlias<WriteV[dq], ...> +def : InstRW<[V3Write_0or2c_1V], (instrs MOVID, MOVIv2d_ns)>; + +// ASIMD duplicate, gen reg +def : InstRW<[V3Write_3c_1M0], (instregex "^DUPv.+gpr")>; + +// ASIMD extract narrow, saturating +def : InstRW<[V3Write_4c_1V], (instregex "^[SU]QXTNv", "^SQXTUNv")>; + +// ASIMD reciprocal and square root estimate, D-form U32 +def : InstRW<[V3Write_3c_1V02], (instrs URECPEv2i32, URSQRTEv2i32)>; + +// ASIMD reciprocal and square root estimate, Q-form U32 +def : InstRW<[V3Write_4c_2V02], (instrs URECPEv4i32, URSQRTEv4i32)>; + +// ASIMD reciprocal and square root estimate, D-form F32 and scalar forms +def : InstRW<[V3Write_3c_1V02], (instrs FRECPEv1f16, FRECPEv1i32, + FRECPEv1i64, FRECPEv2f32, + FRSQRTEv1f16, FRSQRTEv1i32, + FRSQRTEv1i64, FRSQRTEv2f32)>; + +// ASIMD reciprocal and square root estimate, D-form F16 and Q-form F32 +def : InstRW<[V3Write_4c_2V02], (instrs FRECPEv4f16, FRECPEv4f32, + FRSQRTEv4f16, FRSQRTEv4f32)>; + +// ASIMD reciprocal and square root estimate, Q-form F16 +def : InstRW<[V3Write_6c_4V02], (instrs FRECPEv8f16, FRSQRTEv8f16)>; + +// ASIMD reciprocal exponent +def : InstRW<[V3Write_3c_1V02], (instregex "^FRECPXv")>; + +// ASIMD reciprocal step +def : InstRW<[V3Write_4c_1V], (instregex "^FRECPS(32|64|v)", + "^FRSQRTS(32|64|v)")>; + +// ASIMD table lookup, 1 or 2 table regs +def : InstRW<[V3Write_2c_1V], (instrs TBLv8i8One, TBLv16i8One, + TBLv8i8Two, TBLv16i8Two)>; + +// ASIMD table lookup, 3 table regs +def : InstRW<[V3Write_4c_2V], (instrs TBLv8i8Three, TBLv16i8Three)>; + +// ASIMD table lookup, 4 table regs +def : InstRW<[V3Write_4c_3V], (instrs TBLv8i8Four, TBLv16i8Four)>; + +// ASIMD table lookup extension, 2 table reg +def : InstRW<[V3Write_4c_2V], (instrs TBXv8i8Two, TBXv16i8Two)>; + +// ASIMD table lookup extension, 3 table reg +def : InstRW<[V3Write_6c_3V], (instrs TBXv8i8Three, TBXv16i8Three)>; + +// ASIMD table lookup extension, 4 table reg +def : InstRW<[V3Write_6c_5V], (instrs TBXv8i8Four, TBXv16i8Four)>; + +// ASIMD transfer, element to gen reg +def : InstRW<[V3Write_2c_2V01], (instregex "^[SU]MOVv")>; + +// ASIMD transfer, gen reg to element +def : InstRW<[V3Write_5c_1M0_1V], (instregex "^INSvi(8|16|32|64)gpr$")>; + +// §3.20 ASIMD load instructions +// ----------------------------------------------------------------------------- + +// ASIMD load, 1 element, multiple, 1 reg, D-form +def : InstRW<[V3Write_6c_1L], (instregex "^LD1Onev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_1L], + (instregex "^LD1Onev(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 1 reg, Q-form +def : InstRW<[V3Write_6c_1L], (instregex "^LD1Onev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_1L], + (instregex "^LD1Onev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 2 reg, D-form +def : InstRW<[V3Write_6c_2L], (instregex "^LD1Twov(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_2L], + (instregex "^LD1Twov(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 2 reg, Q-form +def : InstRW<[V3Write_6c_2L], (instregex "^LD1Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_2L], + (instregex "^LD1Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 3 reg, D-form +def : InstRW<[V3Write_6c_3L], (instregex "^LD1Threev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_3L], + (instregex "^LD1Threev(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 3 reg, Q-form +def : InstRW<[V3Write_6c_3L], (instregex "^LD1Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_3L], + (instregex "^LD1Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 4 reg, D-form +def : InstRW<[V3Write_7c_4L], (instregex "^LD1Fourv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_7c_4L], + (instregex "^LD1Fourv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 4 reg, Q-form +def : InstRW<[V3Write_7c_4L], (instregex "^LD1Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_7c_4L], + (instregex "^LD1Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, one lane, B/H/S +// ASIMD load, 1 element, one lane, D +def : InstRW<[V3Write_8c_1L_1V], (instregex "LD1i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_1V], (instregex "LD1i(8|16|32|64)_POST$")>; + +// ASIMD load, 1 element, all lanes, D-form, B/H/S +// ASIMD load, 1 element, all lanes, D-form, D +def : InstRW<[V3Write_8c_1L_1V], (instregex "LD1Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_1V], (instregex "LD1Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, all lanes, Q-form +def : InstRW<[V3Write_8c_1L_1V], (instregex "LD1Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_1V], (instregex "LD1Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 2 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_8c_1L_2V], (instregex "LD2Twov(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_2V], (instregex "LD2Twov(8b|4h|2s)_POST$")>; + +// ASIMD load, 2 element, multiple, Q-form, B/H/S +// ASIMD load, 2 element, multiple, Q-form, D +def : InstRW<[V3Write_8c_2L_2V], (instregex "LD2Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_2L_2V], (instregex "LD2Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 2 element, one lane, B/H +// ASIMD load, 2 element, one lane, S +// ASIMD load, 2 element, one lane, D +def : InstRW<[V3Write_8c_1L_2V], (instregex "LD2i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_2V], (instregex "LD2i(8|16|32|64)_POST$")>; + +// ASIMD load, 2 element, all lanes, D-form, B/H/S +// ASIMD load, 2 element, all lanes, D-form, D +def : InstRW<[V3Write_8c_1L_2V], (instregex "LD2Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_2V], (instregex "LD2Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 2 element, all lanes, Q-form +def : InstRW<[V3Write_8c_1L_2V], (instregex "LD2Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_1L_2V], (instregex "LD2Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 3 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_8c_2L_3V], (instregex "LD3Threev(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_8c_2L_3V], (instregex "LD3Threev(8b|4h|2s)_POST$")>; + +// ASIMD load, 3 element, multiple, Q-form, B/H/S +// ASIMD load, 3 element, multiple, Q-form, D +def : InstRW<[V3Write_8c_3L_3V], (instregex "LD3Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_3L_3V], (instregex "LD3Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 3 element, one lane, B/H +// ASIMD load, 3 element, one lane, S +// ASIMD load, 3 element, one lane, D +def : InstRW<[V3Write_8c_2L_3V], (instregex "LD3i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_8c_2L_3V], (instregex "LD3i(8|16|32|64)_POST$")>; + +// ASIMD load, 3 element, all lanes, D-form, B/H/S +// ASIMD load, 3 element, all lanes, D-form, D +def : InstRW<[V3Write_8c_2L_3V], (instregex "LD3Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_2L_3V], (instregex "LD3Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 3 element, all lanes, Q-form, B/H/S +// ASIMD load, 3 element, all lanes, Q-form, D +def : InstRW<[V3Write_8c_3L_3V], (instregex "LD3Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_3L_3V], (instregex "LD3Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 4 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_8c_3L_4V], (instregex "LD4Fourv(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_8c_3L_4V], (instregex "LD4Fourv(8b|4h|2s)_POST$")>; + +// ASIMD load, 4 element, multiple, Q-form, B/H/S +// ASIMD load, 4 element, multiple, Q-form, D +def : InstRW<[V3Write_9c_6L_4V], (instregex "LD4Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_9c_6L_4V], (instregex "LD4Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 4 element, one lane, B/H +// ASIMD load, 4 element, one lane, S +// ASIMD load, 4 element, one lane, D +def : InstRW<[V3Write_8c_3L_4V], (instregex "LD4i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_8c_3L_4V], (instregex "LD4i(8|16|32|64)_POST$")>; + +// ASIMD load, 4 element, all lanes, D-form, B/H/S +// ASIMD load, 4 element, all lanes, D-form, D +def : InstRW<[V3Write_8c_3L_4V], (instregex "LD4Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_3L_4V], (instregex "LD4Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 4 element, all lanes, Q-form, B/H/S +// ASIMD load, 4 element, all lanes, Q-form, D +def : InstRW<[V3Write_8c_4L_4V], (instregex "LD4Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_8c_4L_4V], (instregex "LD4Rv(16b|8h|4s|2d)_POST$")>; + +// §3.21 ASIMD store instructions +// ----------------------------------------------------------------------------- + +// ASIMD store, 1 element, multiple, 1 reg, D-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "ST1Onev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_1SA_1V01], (instregex "ST1Onev(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 1 reg, Q-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "ST1Onev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_1SA_1V01], (instregex "ST1Onev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 2 reg, D-form +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "ST1Twov(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_1SA_1V01], (instregex "ST1Twov(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 2 reg, Q-form +def : InstRW<[V3Write_2c_2SA_2V01], (instregex "ST1Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_2SA_2V01], (instregex "ST1Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 3 reg, D-form +def : InstRW<[V3Write_2c_2SA_2V01], (instregex "ST1Threev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_2SA_2V01], (instregex "ST1Threev(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 3 reg, Q-form +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "ST1Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_3SA_3V01], (instregex "ST1Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 4 reg, D-form +def : InstRW<[V3Write_2c_2SA_2V01], (instregex "ST1Fourv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_2SA_2V01], (instregex "ST1Fourv(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 4 reg, Q-form +def : InstRW<[V3Write_2c_4SA_4V01], (instregex "ST1Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_2c_4SA_4V01], (instregex "ST1Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, one lane, B/H/S +// ASIMD store, 1 element, one lane, D +def : InstRW<[V3Write_4c_1SA_2V01], (instregex "ST1i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_4c_1SA_2V01], (instregex "ST1i(8|16|32|64)_POST$")>; + +// ASIMD store, 2 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_4c_1SA_2V01], (instregex "ST2Twov(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_4c_1SA_2V01], (instregex "ST2Twov(8b|4h|2s)_POST$")>; + +// ASIMD store, 2 element, multiple, Q-form, B/H/S +// ASIMD store, 2 element, multiple, Q-form, D +def : InstRW<[V3Write_4c_2SA_4V01], (instregex "ST2Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_4c_2SA_4V01], (instregex "ST2Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 2 element, one lane, B/H/S +// ASIMD store, 2 element, one lane, D +def : InstRW<[V3Write_4c_1SA_2V01], (instregex "ST2i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_4c_1SA_2V01], (instregex "ST2i(8|16|32|64)_POST$")>; + +// ASIMD store, 3 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_5c_2SA_4V01], (instregex "ST3Threev(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_5c_2SA_4V01], (instregex "ST3Threev(8b|4h|2s)_POST$")>; + +// ASIMD store, 3 element, multiple, Q-form, B/H/S +// ASIMD store, 3 element, multiple, Q-form, D +def : InstRW<[V3Write_6c_3SA_6V01], (instregex "ST3Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3Write_6c_3SA_6V01], (instregex "ST3Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 3 element, one lane, B/H +// ASIMD store, 3 element, one lane, S +// ASIMD store, 3 element, one lane, D +def : InstRW<[V3Write_5c_2SA_4V01], (instregex "ST3i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3Write_5c_2SA_4V01], (instregex "ST3i(8|16|32|64)_POST$")>; + +// ASIMD store, 4 element, multiple, D-form, B/H/S +def : InstRW<[V3Write_6c_2SA_6V01], (instregex "ST4Fourv(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3Write_6c_2SA_6V01], (instregex "ST4Fourv(8b|4h|2s)_POST$")>; + +// ASIMD store, 4 element, multiple, Q-form, B/H/S +def : InstRW<[V3Write_7c_4SA_12V01], (instregex "ST4Fourv(16b|8h|4s)$")>; +def : InstRW<[WriteAdr, V3Write_7c_4SA_12V01], (instregex "ST4Fourv(16b|8h|4s)_POST$")>; + +// ASIMD store, 4 element, multiple, Q-form, D +def : InstRW<[V3Write_5c_4SA_8V01], (instregex "ST4Fourv(2d)$")>; +def : InstRW<[WriteAdr, V3Write_5c_4SA_8V01], (instregex "ST4Fourv(2d)_POST$")>; + +// ASIMD store, 4 element, one lane, B/H/S +def : InstRW<[V3Write_6c_1SA_3V01], (instregex "ST4i(8|16|32)$")>; +def : InstRW<[WriteAdr, V3Write_6c_1SA_3V01], (instregex "ST4i(8|16|32)_POST$")>; + +// ASIMD store, 4 element, one lane, D +def : InstRW<[V3Write_4c_2SA_4V01], (instregex "ST4i(64)$")>; +def : InstRW<[WriteAdr, V3Write_4c_2SA_4V01], (instregex "ST4i(64)_POST$")>; + +// §3.22 Cryptography extensions +// ----------------------------------------------------------------------------- + +// Crypto AES ops +def : InstRW<[V3Write_2c_1V], (instregex "^AES[DE]rr$", "^AESI?MCrr")>; + +// Crypto polynomial (64x64) multiply long +def : InstRW<[V3Write_2c_1V], (instrs PMULLv1i64, PMULLv2i64)>; + +// Crypto SHA1 hash acceleration op +// Crypto SHA1 schedule acceleration ops +def : InstRW<[V3Write_2c_1V0], (instregex "^SHA1(H|SU0|SU1)")>; + +// Crypto SHA1 hash acceleration ops +// Crypto SHA256 hash acceleration ops +def : InstRW<[V3Write_4c_1V0], (instregex "^SHA1[CMP]", "^SHA256H2?")>; + +// Crypto SHA256 schedule acceleration ops +def : InstRW<[V3Write_2c_1V0], (instregex "^SHA256SU[01]")>; + +// Crypto SHA512 hash acceleration ops +def : InstRW<[V3Write_2c_1V0], (instregex "^SHA512(H|H2|SU0|SU1)")>; + +// Crypto SHA3 ops +def : InstRW<[V3Write_2c_1V], (instrs BCAX, EOR3, RAX1, XAR)>; + +// Crypto SM3 ops +def : InstRW<[V3Write_2c_1V0], (instregex "^SM3PARTW[12]$", "^SM3SS1$", + "^SM3TT[12][AB]$")>; + +// Crypto SM4 ops +def : InstRW<[V3Write_4c_1V0], (instrs SM4E, SM4ENCKEY)>; + +// §3.23 CRC +// ----------------------------------------------------------------------------- + +def : InstRW<[V3Wr_CRC, V3Rd_CRC], (instregex "^CRC32")>; + +// §3.24 SVE Predicate instructions +// ----------------------------------------------------------------------------- + +// Loop control, based on predicate +def : InstRW<[V3Write_2or3c_1M], (instrs BRKA_PPmP, BRKA_PPzP, + BRKB_PPmP, BRKB_PPzP)>; + +// Loop control, based on predicate and flag setting +def : InstRW<[V3Write_2or3c_1M], (instrs BRKAS_PPzP, BRKBS_PPzP)>; + +// Loop control, propagating +def : InstRW<[V3Write_2or3c_1M], (instrs BRKN_PPzP, BRKPA_PPzPP, + BRKPB_PPzPP)>; + +// Loop control, propagating and flag setting +def : InstRW<[V3Write_2or3c_1M], (instrs BRKNS_PPzP, BRKPAS_PPzPP, + BRKPBS_PPzPP)>; + +// Loop control, based on GPR +def : InstRW<[V3Write_3c_2M], + (instregex "^WHILE(GE|GT|HI|HS|LE|LO|LS|LT)_P(WW|XX)_[BHSD]")>; +def : InstRW<[V3Write_3c_2M], (instregex "^WHILE(RW|WR)_PXX_[BHSD]")>; + +// Loop terminate +def : InstRW<[V3Write_1c_2M], (instregex "^CTERM(EQ|NE)_(WW|XX)")>; + +// Predicate counting scalar +def : InstRW<[V3Write_2c_1M], (instrs ADDPL_XXI, ADDVL_XXI, RDVLI_XI)>; +def : InstRW<[V3Write_2c_1M], + (instregex "^(CNT|SQDEC|SQINC|UQDEC|UQINC)[BHWD]_XPiI", + "^SQ(DEC|INC)[BHWD]_XPiWdI", + "^UQ(DEC|INC)[BHWD]_WPiI")>; + +// Predicate counting scalar, ALL, {1,2,4} +def : InstRW<[V3Write_IncDec], (instregex "^(DEC|INC)[BHWD]_XPiI")>; + +// Predicate counting scalar, active predicate +def : InstRW<[V3Write_2c_1M], + (instregex "^CNTP_XPP_[BHSD]", + "^(DEC|INC|SQDEC|SQINC|UQDEC|UQINC)P_XP_[BHSD]", + "^(UQDEC|UQINC)P_WP_[BHSD]", + "^(SQDEC|SQINC)P_XPWd_[BHSD]")>; + +// Predicate counting vector, active predicate +def : InstRW<[V3Write_7c_1M_1M0_1V], + (instregex "^(DEC|INC|SQDEC|SQINC|UQDEC|UQINC)P_ZP_[HSD]")>; + +// Predicate logical +def : InstRW<[V3Write_1or2c_1M], + (instregex "^(AND|BIC|EOR|NAND|NOR|ORN|ORR)_PPzPP")>; + +// Predicate logical, flag setting +def : InstRW<[V3Write_1or2c_1M], + (instregex "^(ANDS|BICS|EORS|NANDS|NORS|ORNS|ORRS)_PPzPP")>; + +// Predicate reverse +def : InstRW<[V3Write_2c_1M], (instregex "^REV_PP_[BHSD]")>; + +// Predicate select +def : InstRW<[V3Write_1c_1M], (instrs SEL_PPPP)>; + +// Predicate set +def : InstRW<[V3Write_2c_1M], (instregex "^PFALSE", "^PTRUE_[BHSD]")>; + +// Predicate set/initialize, set flags +def : InstRW<[V3Write_2c_1M], (instregex "^PTRUES_[BHSD]")>; + +// Predicate find first/next +def : InstRW<[V3Write_2c_1M], (instregex "^PFIRST_B", "^PNEXT_[BHSD]")>; + +// Predicate test +def : InstRW<[V3Write_1c_1M], (instrs PTEST_PP)>; + +// Predicate transpose +def : InstRW<[V3Write_2c_1M], (instregex "^TRN[12]_PPP_[BHSD]")>; + +// Predicate unpack and widen +def : InstRW<[V3Write_2c_1M], (instrs PUNPKHI_PP, PUNPKLO_PP)>; + +// Predicate zip/unzip +def : InstRW<[V3Write_2c_1M], (instregex "^(ZIP|UZP)[12]_PPP_[BHSD]")>; + +// §3.25 SVE integer instructions +// ----------------------------------------------------------------------------- + +// Arithmetic, absolute diff +def : InstRW<[V3Write_2c_1V], (instregex "^[SU]ABD_ZPmZ_[BHSD]", + "^[SU]ABD_ZPZZ_[BHSD]")>; + +// Arithmetic, absolute diff accum +def : InstRW<[V3Wr_ZA, V3Rd_ZA], (instregex "^[SU]ABA_ZZZ_[BHSD]")>; + +// Arithmetic, absolute diff accum long +def : InstRW<[V3Wr_ZA, V3Rd_ZA], (instregex "^[SU]ABAL[TB]_ZZZ_[HSD]")>; + +// Arithmetic, absolute diff long +def : InstRW<[V3Write_2c_1V], (instregex "^[SU]ABDL[TB]_ZZZ_[HSD]")>; + +// Arithmetic, basic +def : InstRW<[V3Write_2c_1V], + (instregex "^(ABS|ADD|CNOT|NEG|SUB|SUBR)_ZPmZ_[BHSD]", + "^(ADD|SUB)_ZZZ_[BHSD]", + "^(ADD|SUB|SUBR)_ZPZZ_[BHSD]", + "^(ADD|SUB|SUBR)_ZI_[BHSD]", + "^ADR_[SU]XTW_ZZZ_D_[0123]", + "^ADR_LSL_ZZZ_[SD]_[0123]", + "^[SU](ADD|SUB)[LW][BT]_ZZZ_[HSD]", + "^SADDLBT_ZZZ_[HSD]", + "^[SU]H(ADD|SUB|SUBR)_ZPmZ_[BHSD]", + "^SSUBL(BT|TB)_ZZZ_[HSD]")>; + +// Arithmetic, complex +def : InstRW<[V3Write_2c_1V], + (instregex "^R?(ADD|SUB)HN[BT]_ZZZ_[BHS]", + "^SQ(ABS|ADD|NEG|SUB|SUBR)_ZPmZ_[BHSD]", + "^[SU]Q(ADD|SUB)_ZZZ_[BHSD]", + "^[SU]Q(ADD|SUB)_ZI_[BHSD]", + "^(SRH|SUQ|UQ|USQ|URH)ADD_ZPmZ_[BHSD]", + "^(UQSUB|UQSUBR)_ZPmZ_[BHSD]")>; + +// Arithmetic, large integer +def : InstRW<[V3Write_2c_1V], (instregex "^(AD|SB)CL[BT]_ZZZ_[SD]")>; + +// Arithmetic, pairwise add +def : InstRW<[V3Write_2c_1V], (instregex "^ADDP_ZPmZ_[BHSD]")>; + +// Arithmetic, pairwise add and accum long +def : InstRW<[V3Wr_ZPA, ReadDefault, V3Rd_ZPA], + (instregex "^[SU]ADALP_ZPmZ_[HSD]")>; + +// Arithmetic, shift +def : InstRW<[V3Write_2c_1V13], + (instregex "^(ASR|LSL|LSR)_WIDE_ZPmZ_[BHS]", + "^(ASR|LSL|LSR)_WIDE_ZZZ_[BHS]", + "^(ASR|LSL|LSR)_ZPmI_[BHSD]", + "^(ASR|LSL|LSR)_ZPmZ_[BHSD]", + "^(ASR|LSL|LSR)_ZZI_[BHSD]", + "^(ASR|LSL|LSR)_ZPZ[IZ]_[BHSD]", + "^(ASRR|LSLR|LSRR)_ZPmZ_[BHSD]")>; + +// Arithmetic, shift and accumulate +def : InstRW<[V3Wr_ZSA, V3Rd_ZSA], (instregex "^[SU]R?SRA_ZZI_[BHSD]")>; + +// Arithmetic, shift by immediate +def : InstRW<[V3Write_2c_1V], (instregex "^SHRN[BT]_ZZI_[BHS]", + "^[SU]SHLL[BT]_ZZI_[HSD]")>; + +// Arithmetic, shift by immediate and insert +def : InstRW<[V3Write_2c_1V], (instregex "^(SLI|SRI)_ZZI_[BHSD]")>; + +// Arithmetic, shift complex +def : InstRW<[V3Write_4c_1V], + (instregex "^(SQ)?RSHRU?N[BT]_ZZI_[BHS]", + "^(SQRSHL|SQRSHLR|SQSHL|SQSHLR|UQRSHL|UQRSHLR|UQSHL|UQSHLR)_ZPmZ_[BHSD]", + "^[SU]QR?SHL_ZPZZ_[BHSD]", + "^(SQSHL|SQSHLU|UQSHL)_(ZPmI|ZPZI)_[BHSD]", + "^SQSHRU?N[BT]_ZZI_[BHS]", + "^UQR?SHRN[BT]_ZZI_[BHS]")>; + +// Arithmetic, shift right for divide +def : InstRW<[V3Write_4c_1V], (instregex "^ASRD_(ZPmI|ZPZI)_[BHSD]")>; + +// Arithmetic, shift rounding +def : InstRW<[V3Write_4c_1V], (instregex "^[SU]RSHLR?_ZPmZ_[BHSD]", + "^[SU]RSHL_ZPZZ_[BHSD]", + "^[SU]RSHR_(ZPmI|ZPZI)_[BHSD]")>; + +// Bit manipulation +def : InstRW<[V3Write_6c_2V1], (instregex "^(BDEP|BEXT|BGRP)_ZZZ_[BHSD]")>; + +// Bitwise select +def : InstRW<[V3Write_2c_1V], (instregex "^(BSL|BSL1N|BSL2N|NBSL)_ZZZZ")>; + +// Count/reverse bits +def : InstRW<[V3Write_2c_1V], (instregex "^(CLS|CLZ|CNT|RBIT)_ZPmZ_[BHSD]")>; + +// Broadcast logical bitmask immediate to vector +def : InstRW<[V3Write_2c_1V], (instrs DUPM_ZI)>; + +// Compare and set flags +def : InstRW<[V3Write_2or3c_1V0], + (instregex "^CMP(EQ|GE|GT|HI|HS|LE|LO|LS|LT|NE)_PPzZ[IZ]_[BHSD]", + "^CMP(EQ|GE|GT|HI|HS|LE|LO|LS|LT|NE)_WIDE_PPzZZ_[BHS]")>; + +// Complex add +def : InstRW<[V3Write_2c_1V], (instregex "^(SQ)?CADD_ZZI_[BHSD]")>; + +// Complex dot product 8-bit element +def : InstRW<[V3Wr_ZDOTB, V3Rd_ZDOTB], (instrs CDOT_ZZZ_S, CDOT_ZZZI_S)>; + +// Complex dot product 16-bit element +def : InstRW<[V3Wr_ZDOTH, V3Rd_ZDOTH], (instrs CDOT_ZZZ_D, CDOT_ZZZI_D)>; + +// Complex multiply-add B, H, S element size +def : InstRW<[V3Wr_ZCMABHS, V3Rd_ZCMABHS], (instregex "^CMLA_ZZZ_[BHS]", + "^CMLA_ZZZI_[HS]")>; + +// Complex multiply-add D element size +def : InstRW<[V3Wr_ZCMAD, V3Rd_ZCMAD], (instrs CMLA_ZZZ_D)>; + +// Conditional extract operations, scalar form +def : InstRW<[V3Write_8c_1M0_1V01], (instregex "^CLAST[AB]_RPZ_[BHSD]")>; + +// Conditional extract operations, SIMD&FP scalar and vector forms +def : InstRW<[V3Write_3c_1V1], (instregex "^CLAST[AB]_[VZ]PZ_[BHSD]", + "^COMPACT_ZPZ_[SD]", + "^SPLICE_ZPZZ?_[BHSD]")>; + +// Convert to floating point, 64b to float or convert to double +def : InstRW<[V3Write_3c_1V02], (instregex "^[SU]CVTF_ZPmZ_Dto[HSD]", + "^[SU]CVTF_ZPmZ_StoD")>; + +// Convert to floating point, 32b to single or half +def : InstRW<[V3Write_4c_2V02], (instregex "^[SU]CVTF_ZPmZ_Sto[HS]")>; + +// Convert to floating point, 16b to half +def : InstRW<[V3Write_6c_4V02], (instregex "^[SU]CVTF_ZPmZ_HtoH")>; + +// Copy, scalar +def : InstRW<[V3Write_5c_1M0_1V], (instregex "^CPY_ZPmR_[BHSD]")>; + +// Copy, scalar SIMD&FP or imm +def : InstRW<[V3Write_2c_1V], (instregex "^CPY_ZPm[IV]_[BHSD]", + "^CPY_ZPzI_[BHSD]")>; + +// Divides, 32 bit +def : InstRW<[V3Write_12c_1V0], (instregex "^[SU]DIVR?_ZPmZ_S", + "^[SU]DIV_ZPZZ_S")>; + +// Divides, 64 bit +def : InstRW<[V3Write_20c_1V0], (instregex "^[SU]DIVR?_ZPmZ_D", + "^[SU]DIV_ZPZZ_D")>; + +// Dot product, 8 bit +def : InstRW<[V3Wr_ZDOTB, V3Rd_ZDOTB], (instregex "^[SU]DOT_ZZZI?_BtoS")>; + +// Dot product, 8 bit, using signed and unsigned integers +def : InstRW<[V3Wr_ZDOTB, V3Rd_ZDOTB], (instrs SUDOT_ZZZI, USDOT_ZZZI, USDOT_ZZZ)>; + +// Dot product, 16 bit +def : InstRW<[V3Wr_ZDOTH, V3Rd_ZDOTH], (instregex "^[SU]DOT_ZZZI?_HtoD")>; + +// Duplicate, immediate and indexed form +def : InstRW<[V3Write_2c_1V], (instregex "^DUP_ZI_[BHSD]", + "^DUP_ZZI_[BHSDQ]")>; + +// Duplicate, scalar form +def : InstRW<[V3Write_3c_1M0], (instregex "^DUP_ZR_[BHSD]")>; + +// Extend, sign or zero +def : InstRW<[V3Write_2c_1V], (instregex "^[SU]XTB_ZPmZ_[HSD]", + "^[SU]XTH_ZPmZ_[SD]", + "^[SU]XTW_ZPmZ_[D]")>; + +// Extract +def : InstRW<[V3Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; + +// Extract narrow saturating +def : InstRW<[V3Write_4c_1V], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]", + "^SQXTUN[BT]_ZZ_[BHS]")>; + +// Extract operation, SIMD and FP scalar form +def : InstRW<[V3Write_3c_1V1], (instregex "^LAST[AB]_VPZ_[BHSD]")>; + +// Extract operation, scalar +def : InstRW<[V3Write_6c_1V1_1M0], (instregex "^LAST[AB]_RPZ_[BHSD]")>; + +// Histogram operations +def : InstRW<[V3Write_2c_1V], (instregex "^HISTCNT_ZPzZZ_[SD]", + "^HISTSEG_ZZZ")>; + +// Horizontal operations, B, H, S form, immediate operands only +def : InstRW<[V3Write_4c_1V02], (instregex "^INDEX_II_[BHS]")>; + +// Horizontal operations, B, H, S form, scalar, immediate operands/ scalar +// operands only / immediate, scalar operands +def : InstRW<[V3Write_7c_1M0_1V02], (instregex "^INDEX_(IR|RI|RR)_[BHS]")>; + +// Horizontal operations, D form, immediate operands only +def : InstRW<[V3Write_5c_2V02], (instrs INDEX_II_D)>; + +// Horizontal operations, D form, scalar, immediate operands)/ scalar operands +// only / immediate, scalar operands +def : InstRW<[V3Write_8c_2M0_2V02], (instregex "^INDEX_(IR|RI|RR)_D")>; + +// insert operation, SIMD and FP scalar form +def : InstRW<[V3Write_2c_1V], (instregex "^INSR_ZV_[BHSD]")>; + +// insert operation, scalar +def : InstRW<[V3Write_5c_1V1_1M0], (instregex "^INSR_ZR_[BHSD]")>; + +// Logical +def : InstRW<[V3Write_2c_1V], + (instregex "^(AND|EOR|ORR)_ZI", + "^(AND|BIC|EOR|ORR)_ZZZ", + "^EOR(BT|TB)_ZZZ_[BHSD]", + "^(AND|BIC|EOR|NOT|ORR)_(ZPmZ|ZPZZ)_[BHSD]", + "^NOT_ZPmZ_[BHSD]")>; + +// Max/min, basic and pairwise +def : InstRW<[V3Write_2c_1V], (instregex "^[SU](MAX|MIN)_ZI_[BHSD]", + "^[SU](MAX|MIN)P?_ZPmZ_[BHSD]", + "^[SU](MAX|MIN)_ZPZZ_[BHSD]")>; + +// Matching operations +// FIXME: SOG p. 44, n. 5: If the consuming instruction has a flag source, the +// latency for this instruction is 4 cycles. +def : InstRW<[V3Write_2or3c_1V0_1M], (instregex "^N?MATCH_PPzZZ_[BH]")>; + +// Matrix multiply-accumulate +def : InstRW<[V3Wr_ZMMA, V3Rd_ZMMA], (instrs SMMLA_ZZZ, UMMLA_ZZZ, USMMLA_ZZZ)>; + +// Move prefix +def : InstRW<[V3Write_2c_1V], (instregex "^MOVPRFX_ZP[mz]Z_[BHSD]", + "^MOVPRFX_ZZ")>; + +// Multiply, B, H, S element size +def : InstRW<[V3Write_4c_1V02], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_[BHS]", + "^MUL_ZPZZ_[BHS]", + "^[SU]MULH_(ZPmZ|ZZZ)_[BHS]", + "^[SU]MULH_ZPZZ_[BHS]")>; + +// Multiply, D element size +def : InstRW<[V3Write_5c_2V02], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_D", + "^MUL_ZPZZ_D", + "^[SU]MULH_(ZPmZ|ZZZ)_D", + "^[SU]MULH_ZPZZ_D")>; + +// Multiply long +def : InstRW<[V3Write_4c_1V02], (instregex "^[SU]MULL[BT]_ZZZI_[SD]", + "^[SU]MULL[BT]_ZZZ_[HSD]")>; + +// Multiply accumulate, B, H, S element size +def : InstRW<[V3Wr_ZMABHS, V3Rd_ZMABHS], + (instregex "^ML[AS]_ZZZI_[HS]", "^ML[AS]_ZPZZZ_[BHS]")>; +def : InstRW<[V3Wr_ZMABHS, ReadDefault, V3Rd_ZMABHS], + (instregex "^(ML[AS]|MAD|MSB)_ZPmZZ_[BHS]")>; + +// Multiply accumulate, D element size +def : InstRW<[V3Wr_ZMAD, V3Rd_ZMAD], + (instregex "^ML[AS]_ZZZI_D", "^ML[AS]_ZPZZZ_D")>; +def : InstRW<[V3Wr_ZMAD, ReadDefault, V3Rd_ZMAD], + (instregex "^(ML[AS]|MAD|MSB)_ZPmZZ_D")>; + +// Multiply accumulate long +def : InstRW<[V3Wr_ZMAL, V3Rd_ZMAL], (instregex "^[SU]ML[AS]L[BT]_ZZZ_[HSD]", + "^[SU]ML[AS]L[BT]_ZZZI_[SD]")>; + +// Multiply accumulate saturating doubling long regular +def : InstRW<[V3Wr_ZMASQL, V3Rd_ZMASQ], + (instregex "^SQDML[AS]L(B|T|BT)_ZZZ_[HSD]", + "^SQDML[AS]L[BT]_ZZZI_[SD]")>; + +// Multiply saturating doubling high, B, H, S element size +def : InstRW<[V3Write_4c_1V02], (instregex "^SQDMULH_ZZZ_[BHS]", + "^SQDMULH_ZZZI_[HS]")>; + +// Multiply saturating doubling high, D element size +def : InstRW<[V3Write_5c_2V02], (instrs SQDMULH_ZZZ_D, SQDMULH_ZZZI_D)>; + +// Multiply saturating doubling long +def : InstRW<[V3Write_4c_1V02], (instregex "^SQDMULL[BT]_ZZZ_[HSD]", + "^SQDMULL[BT]_ZZZI_[SD]")>; + +// Multiply saturating rounding doubling regular/complex accumulate, B, H, S +// element size +def : InstRW<[V3Wr_ZMASQBHS, V3Rd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZ_[BHS]", + "^SQRDCMLAH_ZZZ_[BHS]", + "^SQRDML[AS]H_ZZZI_[HS]", + "^SQRDCMLAH_ZZZI_[HS]")>; + +// Multiply saturating rounding doubling regular/complex accumulate, D element +// size +def : InstRW<[V3Wr_ZMASQD, V3Rd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZI?_D", + "^SQRDCMLAH_ZZZ_D")>; + +// Multiply saturating rounding doubling regular/complex, B, H, S element size +def : InstRW<[V3Write_4c_1V02], (instregex "^SQRDMULH_ZZZ_[BHS]", + "^SQRDMULH_ZZZI_[HS]")>; + +// Multiply saturating rounding doubling regular/complex, D element size +def : InstRW<[V3Write_5c_2V02], (instregex "^SQRDMULH_ZZZI?_D")>; + +// Multiply/multiply long, (8x8) polynomial +def : InstRW<[V3Write_2c_1V], (instregex "^PMUL_ZZZ_B", + "^PMULL[BT]_ZZZ_[HDQ]")>; + +// Predicate counting vector +def : InstRW<[V3Write_2c_1V], (instregex "^([SU]Q)?(DEC|INC)[HWD]_ZPiI")>; + +// Reciprocal estimate +def : InstRW<[V3Write_4c_2V02], (instregex "^URECPE_ZPmZ_S", "^URSQRTE_ZPmZ_S")>; + +// Reduction, arithmetic, B form +def : InstRW<[V3Write_9c_2V_4V13], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_B")>; + +// Reduction, arithmetic, H form +def : InstRW<[V3Write_8c_2V_2V13], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_H")>; + +// Reduction, arithmetic, S form +def : InstRW<[V3Write_6c_2V_2V13], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_S")>; + +// Reduction, arithmetic, D form +def : InstRW<[V3Write_4c_2V], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_D")>; + +// Reduction, logical +def : InstRW<[V3Write_6c_1V_1V13], (instregex "^(AND|EOR|OR)V_VPZ_[BHSD]")>; + +// Reverse, vector +def : InstRW<[V3Write_2c_1V], (instregex "^REV_ZZ_[BHSD]", + "^REVB_ZPmZ_[HSD]", + "^REVH_ZPmZ_[SD]", + "^REVW_ZPmZ_D")>; + +// Select, vector form +def : InstRW<[V3Write_2c_1V], (instregex "^SEL_ZPZZ_[BHSD]")>; + +// Table lookup +def : InstRW<[V3Write_2c_1V], (instregex "^TBL_ZZZZ?_[BHSD]")>; + +// Table lookup extension +def : InstRW<[V3Write_2c_1V], (instregex "^TBX_ZZZ_[BHSD]")>; + +// Transpose, vector form +def : InstRW<[V3Write_2c_1V], (instregex "^TRN[12]_ZZZ_[BHSDQ]")>; + +// Unpack and extend +def : InstRW<[V3Write_2c_1V], (instregex "^[SU]UNPK(HI|LO)_ZZ_[HSD]")>; + +// Zip/unzip +def : InstRW<[V3Write_2c_1V], (instregex "^(UZP|ZIP)[12]_ZZZ_[BHSDQ]")>; + +// §3.26 SVE floating-point instructions +// ----------------------------------------------------------------------------- + +// Floating point absolute value/difference +def : InstRW<[V3Write_2c_1V], (instregex "^FAB[SD]_ZPmZ_[HSD]", + "^FABD_ZPZZ_[HSD]", + "^FABS_ZPmZ_[HSD]")>; + +// Floating point arithmetic +def : InstRW<[V3Write_2c_1V], (instregex "^F(ADD|SUB)_(ZPm[IZ]|ZZZ)_[HSD]", + "^F(ADD|SUB)_ZPZ[IZ]_[HSD]", + "^FADDP_ZPmZZ_[HSD]", + "^FNEG_ZPmZ_[HSD]", + "^FSUBR_ZPm[IZ]_[HSD]", + "^FSUBR_(ZPZI|ZPZZ)_[HSD]")>; + +// Floating point associative add, F16 +def : InstRW<[V3Write_10c_1V1_9rc], (instrs FADDA_VPZ_H)>; + +// Floating point associative add, F32 +def : InstRW<[V3Write_6c_1V1_5rc], (instrs FADDA_VPZ_S)>; + +// Floating point associative add, F64 +def : InstRW<[V3Write_4c_1V], (instrs FADDA_VPZ_D)>; + +// Floating point compare +def : InstRW<[V3Write_2c_1V0], (instregex "^FACG[ET]_PPzZZ_[HSD]", + "^FCM(EQ|GE|GT|NE)_PPzZ[0Z]_[HSD]", + "^FCM(LE|LT)_PPzZ0_[HSD]", + "^FCMUO_PPzZZ_[HSD]")>; + +// Floating point complex add +def : InstRW<[V3Write_3c_1V], (instregex "^FCADD_ZPmZ_[HSD]")>; + +// Floating point complex multiply add +def : InstRW<[V3Wr_ZFCMA, ReadDefault, V3Rd_ZFCMA], (instregex "^FCMLA_ZPmZZ_[HSD]")>; +def : InstRW<[V3Wr_ZFCMA, V3Rd_ZFCMA], (instregex "^FCMLA_ZZZI_[HS]")>; + +// Floating point convert, long or narrow (F16 to F32 or F32 to F16) +def : InstRW<[V3Write_4c_2V02], (instregex "^FCVT_ZPmZ_(HtoS|StoH)", + "^FCVTLT_ZPmZ_HtoS", + "^FCVTNT_ZPmZ_StoH")>; + +// Floating point convert, long or narrow (F16 to F64, F32 to F64, F64 to F32 +// or F64 to F16) +def : InstRW<[V3Write_3c_1V02], (instregex "^FCVT_ZPmZ_(HtoD|StoD|DtoS|DtoH)", + "^FCVTLT_ZPmZ_StoD", + "^FCVTNT_ZPmZ_DtoS")>; + +// Floating point convert, round to odd +def : InstRW<[V3Write_3c_1V02], (instrs FCVTX_ZPmZ_DtoS, FCVTXNT_ZPmZ_DtoS)>; + +// Floating point base2 log, F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FLOGB_(ZPmZ|ZPZZ)_H")>; + +// Floating point base2 log, F32 +def : InstRW<[V3Write_4c_2V02], (instregex "^FLOGB_(ZPmZ|ZPZZ)_S")>; + +// Floating point base2 log, F64 +def : InstRW<[V3Write_3c_1V02], (instregex "^FLOGB_(ZPmZ|ZPZZ)_D")>; + +// Floating point convert to integer, F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FCVTZ[SU]_ZPmZ_HtoH")>; + +// Floating point convert to integer, F32 +def : InstRW<[V3Write_4c_2V02], (instregex "^FCVTZ[SU]_ZPmZ_(HtoS|StoS)")>; + +// Floating point convert to integer, F64 +def : InstRW<[V3Write_3c_1V02], + (instregex "^FCVTZ[SU]_ZPmZ_(HtoD|StoD|DtoS|DtoD)")>; + +// Floating point copy +def : InstRW<[V3Write_2c_1V], (instregex "^FCPY_ZPmI_[HSD]", + "^FDUP_ZI_[HSD]")>; + +// Floating point divide, F16 +def : InstRW<[V3Write_13c_1V1_8rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_H")>; + +// Floating point divide, F32 +def : InstRW<[V3Write_11c_1V1_4rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_S")>; + +// Floating point divide, F64 +def : InstRW<[V3Write_14c_1V1_2rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_D")>; + +// Floating point min/max pairwise +def : InstRW<[V3Write_2c_1V], (instregex "^F(MAX|MIN)(NM)?P_ZPmZZ_[HSD]")>; + +// Floating point min/max +def : InstRW<[V3Write_2c_1V], (instregex "^F(MAX|MIN)(NM)?_ZPm[IZ]_[HSD]", + "^F(MAX|MIN)(NM)?_ZPZ[IZ]_[HSD]")>; + +// Floating point multiply +def : InstRW<[V3Write_3c_1V], (instregex "^(FSCALE|FMULX)_ZPmZ_[HSD]", + "^FMULX_ZPZZ_[HSD]", + "^FMUL_(ZPm[IZ]|ZZZI?)_[HSD]", + "^FMUL_ZPZ[IZ]_[HSD]")>; + +// Floating point multiply accumulate +def : InstRW<[V3Wr_ZFMA, ReadDefault, V3Rd_ZFMA], + (instregex "^FN?ML[AS]_ZPmZZ_[HSD]", + "^FN?(MAD|MSB)_ZPmZZ_[HSD]")>; +def : InstRW<[V3Wr_ZFMA, V3Rd_ZFMA], + (instregex "^FML[AS]_ZZZI_[HSD]", + "^FN?ML[AS]_ZPZZZ_[HSD]")>; + +// Floating point multiply add/sub accumulate long +def : InstRW<[V3Wr_ZFMAL, V3Rd_ZFMAL], (instregex "^FML[AS]L[BT]_ZZZI?_SHH")>; + +// Floating point reciprocal estimate, F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FR(ECP|SQRT)E_ZZ_H", "^FRECPX_ZPmZ_H")>; + +// Floating point reciprocal estimate, F32 +def : InstRW<[V3Write_4c_2V02], (instregex "^FR(ECP|SQRT)E_ZZ_S", "^FRECPX_ZPmZ_S")>; + +// Floating point reciprocal estimate, F64 +def : InstRW<[V3Write_3c_1V02], (instregex "^FR(ECP|SQRT)E_ZZ_D", "^FRECPX_ZPmZ_D")>; + +// Floating point reciprocal step +def : InstRW<[V3Write_4c_1V], (instregex "^F(RECPS|RSQRTS)_ZZZ_[HSD]")>; + +// Floating point reduction, F16 +def : InstRW<[V3Write_8c_4V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_H")>; + +// Floating point reduction, F32 +def : InstRW<[V3Write_6c_3V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_S")>; + +// Floating point reduction, F64 +def : InstRW<[V3Write_4c_2V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_D")>; + +// Floating point round to integral, F16 +def : InstRW<[V3Write_6c_4V02], (instregex "^FRINT[AIMNPXZ]_ZPmZ_H")>; + +// Floating point round to integral, F32 +def : InstRW<[V3Write_4c_2V02], (instregex "^FRINT[AIMNPXZ]_ZPmZ_S")>; + +// Floating point round to integral, F64 +def : InstRW<[V3Write_3c_1V02], (instregex "^FRINT[AIMNPXZ]_ZPmZ_D")>; + +// Floating point square root, F16 +def : InstRW<[V3Write_13c_1V1_8rc], (instregex "^FSQRT_ZPmZ_H")>; + +// Floating point square root, F32 +def : InstRW<[V3Write_11c_1V1_4rc], (instregex "^FSQRT_ZPmZ_S")>; + +// Floating point square root, F64 +def : InstRW<[V3Write_14c_1V1_2rc], (instregex "^FSQRT_ZPmZ_D")>; + +// Floating point trigonometric exponentiation +def : InstRW<[V3Write_3c_1V1], (instregex "^FEXPA_ZZ_[HSD]")>; + +// Floating point trigonometric multiply add +def : InstRW<[V3Write_4c_1V], (instregex "^FTMAD_ZZI_[HSD]")>; + +// Floating point trigonometric, miscellaneous +def : InstRW<[V3Write_3c_1V], (instregex "^FTS(MUL|SEL)_ZZZ_[HSD]")>; + +// §3.27 SVE BFloat16 (BF16) instructions +// ----------------------------------------------------------------------------- + +// Convert, F32 to BF16 +def : InstRW<[V3Write_4c_1V02], (instrs BFCVT_ZPmZ, BFCVTNT_ZPmZ)>; + +// Dot product +def : InstRW<[V3Wr_ZBFDOT, V3Rd_ZBFDOT], (instrs BFDOT_ZZI, BFDOT_ZZZ)>; + +// Matrix multiply accumulate +def : InstRW<[V3Wr_ZBFMMA, V3Rd_ZBFMMA], (instrs BFMMLA_ZZZ_HtoS)>; + +// Multiply accumulate long +def : InstRW<[V3Wr_ZBFMAL, V3Rd_ZBFMAL], (instregex "^BFMLAL[BT]_ZZZI?")>; + +// §3.28 SVE Load instructions +// ----------------------------------------------------------------------------- + +// Load vector +def : InstRW<[V3Write_6c_1L], (instrs LDR_ZXI)>; + +// Load predicate +def : InstRW<[V3Write_6c_1L_1M], (instrs LDR_PXI)>; + +// Contiguous load, scalar + imm +def : InstRW<[V3Write_6c_1L], (instregex "^LD1[BHWD]_IMM$", + "^LD1S?B_[HSD]_IMM$", + "^LD1S?H_[SD]_IMM$", + "^LD1S?W_D_IMM$" )>; +// Contiguous load, scalar + scalar +def : InstRW<[V3Write_6c_1L], (instregex "^LD1[BHWD]$", + "^LD1S?B_[HSD]$", + "^LD1S?H_[SD]$", + "^LD1S?W_D$" )>; + +// Contiguous load broadcast, scalar + imm +def : InstRW<[V3Write_6c_1L], (instregex "^LD1R[BHWD]_IMM$", + "^LD1RS?B_[HSD]_IMM$", + "^LD1RS?H_[SD]_IMM$", + "^LD1RW_D_IMM$", + "^LD1RSW_IMM$", + "^LD1RQ_[BHWD]_IMM$")>; + +// Contiguous load broadcast, scalar + scalar +def : InstRW<[V3Write_6c_1L], (instregex "^LD1RQ_[BHWD]$")>; + +// Non temporal load, scalar + imm +// Non temporal load, scalar + scalar +def : InstRW<[V3Write_6c_1L], (instregex "^LDNT1[BHWD]_ZR[IR]$")>; + +// Non temporal gather load, vector + scalar 32-bit element size +def : InstRW<[V3Write_9c_2L_4V], (instregex "^LDNT1[BHW]_ZZR_S$", + "^LDNT1S[BH]_ZZR_S$")>; + +// Non temporal gather load, vector + scalar 64-bit element size +def : InstRW<[V3Write_9c_2L_2V], (instregex "^LDNT1S?[BHW]_ZZR_D$")>; +def : InstRW<[V3Write_9c_2L_2V], (instrs LDNT1D_ZZR_D)>; + +// Contiguous first faulting load, scalar + scalar +def : InstRW<[V3Write_6c_1L_1I], (instregex "^LDFF1[BHWD]$", + "^LDFF1S?B_[HSD]$", + "^LDFF1S?H_[SD]$", + "^LDFF1S?W_D$")>; + +// Contiguous non faulting load, scalar + imm +def : InstRW<[V3Write_6c_1L], (instregex "^LDNF1[BHWD]_IMM$", + "^LDNF1S?B_[HSD]_IMM$", + "^LDNF1S?H_[SD]_IMM$", + "^LDNF1S?W_D_IMM$")>; + +// Contiguous Load two structures to two vectors, scalar + imm +def : InstRW<[V3Write_8c_2L_2V], (instregex "^LD2[BHWD]_IMM$")>; + +// Contiguous Load two structures to two vectors, scalar + scalar +def : InstRW<[V3Write_9c_2L_2V_2I], (instregex "^LD2[BHWD]$")>; + +// Contiguous Load three structures to three vectors, scalar + imm +def : InstRW<[V3Write_9c_3L_3V], (instregex "^LD3[BHWD]_IMM$")>; + +// Contiguous Load three structures to three vectors, scalar + scalar +def : InstRW<[V3Write_10c_3V_3L_3I], (instregex "^LD3[BHWD]$")>; + +// Contiguous Load four structures to four vectors, scalar + imm +def : InstRW<[V3Write_9c_4L_8V], (instregex "^LD4[BHWD]_IMM$")>; + +// Contiguous Load four structures to four vectors, scalar + scalar +def : InstRW<[V3Write_10c_4L_8V_4I], (instregex "^LD4[BHWD]$")>; + +// Gather load, vector + imm, 32-bit element size +def : InstRW<[V3Write_9c_1L_4V], (instregex "^GLD(FF)?1S?[BH]_S_IMM$", + "^GLD(FF)?1W_IMM$")>; + +// Gather load, vector + imm, 64-bit element size +def : InstRW<[V3Write_9c_1L_4V], (instregex "^GLD(FF)?1S?[BHW]_D_IMM$", + "^GLD(FF)?1D_IMM$")>; + +// Gather load, 32-bit scaled offset +def : InstRW<[V3Write_10c_1L_8V], + (instregex "^GLD(FF)?1S?H_S_[SU]XTW_SCALED$", + "^GLD(FF)?1W_[SU]XTW_SCALED")>; + +// Gather load, 64-bit scaled offset +// NOTE: These instructions are not specified in the SOG. +def : InstRW<[V3Write_10c_1L_4V], + (instregex "^GLD(FF)?1S?[HW]_D_([SU]XTW_)?SCALED$", + "^GLD(FF)?1D_([SU]XTW_)?SCALED$")>; + +// Gather load, 32-bit unpacked unscaled offset +def : InstRW<[V3Write_9c_1L_4V], (instregex "^GLD(FF)?1S?[BH]_S_[SU]XTW$", + "^GLD(FF)?1W_[SU]XTW$")>; + +// Gather load, 64-bit unpacked unscaled offset +// NOTE: These instructions are not specified in the SOG. +def : InstRW<[V3Write_9c_1L_2V], + (instregex "^GLD(FF)?1S?[BHW]_D(_[SU]XTW)?$", + "^GLD(FF)?1D(_[SU]XTW)?$")>; + +// §3.29 SVE Store instructions +// ----------------------------------------------------------------------------- + +// Store from predicate reg +def : InstRW<[V3Write_1c_1SA], (instrs STR_PXI)>; + +// Store from vector reg +def : InstRW<[V3Write_2c_1SA_1V01], (instrs STR_ZXI)>; + +// Contiguous store, scalar + imm +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^ST1[BHWD]_IMM$", + "^ST1B_[HSD]_IMM$", + "^ST1H_[SD]_IMM$", + "^ST1W_D_IMM$")>; + +// Contiguous store, scalar + scalar +def : InstRW<[V3Write_2c_1SA_1I_1V01], (instregex "^ST1H(_[SD])?$")>; +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^ST1[BWD]$", + "^ST1B_[HSD]$", + "^ST1W_D$")>; + +// Contiguous store two structures from two vectors, scalar + imm +def : InstRW<[V3Write_4c_1SA_1V01], (instregex "^ST2[BHWD]_IMM$")>; + +// Contiguous store two structures from two vectors, scalar + scalar +def : InstRW<[V3Write_4c_2SA_2I_2V01], (instrs ST2H)>; +def : InstRW<[V3Write_4c_2SA_2V01], (instregex "^ST2[BWD]$")>; + +// Contiguous store three structures from three vectors, scalar + imm +def : InstRW<[V3Write_7c_9SA_9V01], (instregex "^ST3[BHWD]_IMM$")>; + +// Contiguous store three structures from three vectors, scalar + scalar +def : InstRW<[V3Write_7c_9SA_9I_9V01], (instregex "^ST3[BHWD]$")>; + +// Contiguous store four structures from four vectors, scalar + imm +def : InstRW<[V3Write_11c_18SA_18V01], (instregex "^ST4[BHWD]_IMM$")>; + +// Contiguous store four structures from four vectors, scalar + scalar +def : InstRW<[V3Write_11c_18SA_18I_18V01], (instregex "^ST4[BHWD]$")>; + +// Non temporal store, scalar + imm +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^STNT1[BHWD]_ZRI$")>; + +// Non temporal store, scalar + scalar +def : InstRW<[V3Write_2c_1SA_1I_1V01], (instrs STNT1H_ZRR)>; +def : InstRW<[V3Write_2c_1SA_1V01], (instregex "^STNT1[BWD]_ZRR$")>; + +// Scatter non temporal store, vector + scalar 32-bit element size +def : InstRW<[V3Write_4c_6SA_6V01], (instregex "^STNT1[BHW]_ZZR_S")>; + +// Scatter non temporal store, vector + scalar 64-bit element size +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^STNT1[BHWD]_ZZR_D")>; + +// Scatter store vector + imm 32-bit element size +def : InstRW<[V3Write_4c_6SA_6V01], (instregex "^SST1[BH]_S_IMM$", + "^SST1W_IMM$")>; + +// Scatter store vector + imm 64-bit element size +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^SST1[BHW]_D_IMM$", + "^SST1D_IMM$")>; + +// Scatter store, 32-bit scaled offset +def : InstRW<[V3Write_4c_6SA_6V01], + (instregex "^SST1(H_S|W)_[SU]XTW_SCALED$")>; + +// Scatter store, 32-bit unpacked unscaled offset +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^SST1[BHW]_D_[SU]XTW$", + "^SST1D_[SU]XTW$")>; + +// Scatter store, 32-bit unpacked scaled offset +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^SST1[HW]_D_[SU]XTW_SCALED$", + "^SST1D_[SU]XTW_SCALED$")>; + +// Scatter store, 32-bit unscaled offset +def : InstRW<[V3Write_4c_6SA_6V01], (instregex "^SST1[BH]_S_[SU]XTW$", + "^SST1W_[SU]XTW$")>; + +// Scatter store, 64-bit scaled offset +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^SST1[HW]_D_SCALED$", + "^SST1D_SCALED$")>; + +// Scatter store, 64-bit unscaled offset +def : InstRW<[V3Write_2c_3SA_3V01], (instregex "^SST1[BHW]_D$", + "^SST1D$")>; + +// §3.30 SVE Miscellaneous instructions +// ----------------------------------------------------------------------------- + +// Read first fault register, unpredicated +def : InstRW<[V3Write_2c_1M0], (instrs RDFFR_P)>; + +// Read first fault register, predicated +def : InstRW<[V3Write_3or4c_1M0_1M], (instrs RDFFR_PPz)>; + +// Read first fault register and set flags +def : InstRW<[V3Write_3or4c_1M0_1M], (instrs RDFFRS_PPz)>; + +// Set first fault register +// Write to first fault register +def : InstRW<[V3Write_2c_1M0], (instrs SETFFR, WRFFR)>; + +// Prefetch +// NOTE: This is not specified in the SOG. +def : InstRW<[V3Write_4c_1L], (instregex "^PRF[BHWD]")>; + +// §3.31 SVE Cryptographic instructions +// ----------------------------------------------------------------------------- + +// Crypto AES ops +def : InstRW<[V3Write_2c_1V], (instregex "^AES[DE]_ZZZ_B$", + "^AESI?MC_ZZ_B$")>; + +// Crypto SHA3 ops +def : InstRW<[V3Write_2c_1V], (instregex "^(BCAX|EOR3)_ZZZZ$", + "^RAX1_ZZZ_D$", + "^XAR_ZZZI_[BHSD]$")>; + +// Crypto SM4 ops +def : InstRW<[V3Write_4c_1V0], (instregex "^SM4E(KEY)?_ZZZ_S$")>; + +} diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3AE.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3AE.td new file mode 100644 index 0000000..0f1ec66 --- /dev/null +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV3AE.td @@ -0,0 +1,2705 @@ +//=- AArch64SchedNeoverseV3AE.td - NeoverseV3AE Scheduling Defs --*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the scheduling model for the Arm Neoverse V3AE processors. +// All information is taken from the V3AE Software Optimisation guide: +// +// https://developer.arm.com/documentation/109703/300/?lang=en +// +//===----------------------------------------------------------------------===// + +def NeoverseV3AEModel : SchedMachineModel { + let IssueWidth = 10; // Expect best value to be slightly higher than V2 + let MicroOpBufferSize = 320; // Entries in micro-op re-order buffer. NOTE: Copied from Neoverse-V2 + let LoadLatency = 4; // Optimistic load latency. + let MispredictPenalty = 10; // Extra cycles for mispredicted branch. NOTE: Copied from N2. + let LoopMicroOpBufferSize = 16; // NOTE: Copied from Cortex-A57. + let CompleteModel = 1; + + list<Predicate> UnsupportedFeatures = !listconcat(SMEUnsupported.F, + [HasSVE2p1, HasSVEB16B16, + HasCPA, HasCSSC]); +} + +//===----------------------------------------------------------------------===// +// Define each kind of processor resource and number available on Neoverse V3AE. +// Instructions are first fetched and then decoded into internal macro-ops +// (MOPs). From there, the MOPs proceed through register renaming and dispatch +// stages. A MOP can be split into two micro-ops further down the pipeline +// after the decode stage. Once dispatched, micro-ops wait for their operands +// and issue out-of-order to one of nineteen issue pipelines. Each issue +// pipeline can accept one micro-op per cycle. + +let SchedModel = NeoverseV3AEModel in { + +// Define the (19) issue ports. +def V3AEUnitB : ProcResource<3>; // Branch 0/1/2 +def V3AEUnitS0 : ProcResource<1>; // Integer single-cycle 0 +def V3AEUnitS1 : ProcResource<1>; // Integer single-cycle 1 +def V3AEUnitS2 : ProcResource<1>; // Integer single-cycle 2 +def V3AEUnitS3 : ProcResource<1>; // Integer single-cycle 3 +def V3AEUnitS4 : ProcResource<1>; // Integer single-cycle 4 +def V3AEUnitS5 : ProcResource<1>; // Integer single-cycle 5 +def V3AEUnitM0 : ProcResource<1>; // Integer single/multicycle 0 +def V3AEUnitM1 : ProcResource<1>; // Integer single/multicycle 1 +def V3AEUnitV0 : ProcResource<1>; // FP/ASIMD 0 +def V3AEUnitV1 : ProcResource<1>; // FP/ASIMD 1 +def V3AEUnitLS0 : ProcResource<1>; // Load/Store 0 +def V3AEUnitL12 : ProcResource<2>; // Load 1/2 +def V3AEUnitST1 : ProcResource<1>; // Store 1 +def V3AEUnitD : ProcResource<2>; // Store data 0/1 +def V3AEUnitFlg : ProcResource<4>; // Flags + +def V3AEUnitS : ProcResGroup<[V3AEUnitS0, V3AEUnitS1, V3AEUnitS2, V3AEUnitS3, V3AEUnitS4, V3AEUnitS5]>; // Integer single-cycle 0/1/2/3/4/5 +def V3AEUnitI : ProcResGroup<[V3AEUnitS0, V3AEUnitS1, V3AEUnitS2, V3AEUnitS3, V3AEUnitS4, V3AEUnitS5, V3AEUnitM0, V3AEUnitM1]>; // Integer single-cycle 0/1/2/3/4/5 and single/multicycle 0/1 +def V3AEUnitM : ProcResGroup<[V3AEUnitM0, V3AEUnitM1]>; // Integer single/multicycle 0/1 +def V3AEUnitLSA : ProcResGroup<[V3AEUnitLS0, V3AEUnitL12, V3AEUnitST1]>; // Supergroup of L+SA +def V3AEUnitL : ProcResGroup<[V3AEUnitLS0, V3AEUnitL12]>; // Load/Store 0 and Load 1/2 +def V3AEUnitSA : ProcResGroup<[V3AEUnitLS0, V3AEUnitST1]>; // Load/Store 0 and Store 1 +def V3AEUnitV : ProcResGroup<[V3AEUnitV0, V3AEUnitV1]>; // FP/ASIMD 0/1 + +// Define commonly used read types. + +// No forwarding is provided for these types. +def : ReadAdvance<ReadI, 0>; +def : ReadAdvance<ReadISReg, 0>; +def : ReadAdvance<ReadIEReg, 0>; +def : ReadAdvance<ReadIM, 0>; +def : ReadAdvance<ReadIMA, 0>; +def : ReadAdvance<ReadID, 0>; +def : ReadAdvance<ReadExtrHi, 0>; +def : ReadAdvance<ReadAdrBase, 0>; +def : ReadAdvance<ReadST, 0>; +def : ReadAdvance<ReadVLD, 0>; + +// NOTE: Copied from N2. +def : WriteRes<WriteAtomic, []> { let Unsupported = 1; } +def : WriteRes<WriteBarrier, []> { let Latency = 1; } +def : WriteRes<WriteHint, []> { let Latency = 1; } +def : WriteRes<WriteLDHi, []> { let Latency = 4; } + +//===----------------------------------------------------------------------===// +// Define customized scheduler read/write types specific to the Neoverse V3AE. + +//===----------------------------------------------------------------------===// + +// Define generic 0 micro-op types +def V3AEWrite_0c : SchedWriteRes<[]> { let Latency = 0; } + +// Define generic 1 micro-op types + +def V3AEWrite_1c_1B : SchedWriteRes<[V3AEUnitB]> { let Latency = 1; } +def V3AEWrite_1c_1F_1Flg : SchedWriteRes<[V3AEUnitI, V3AEUnitFlg]> { let Latency = 1; } +def V3AEWrite_1c_1I : SchedWriteRes<[V3AEUnitI]> { let Latency = 1; } +def V3AEWrite_1c_1M : SchedWriteRes<[V3AEUnitM]> { let Latency = 1; } +def V3AEWrite_1c_1SA : SchedWriteRes<[V3AEUnitSA]> { let Latency = 1; } +def V3AEWrite_2c_1M : SchedWriteRes<[V3AEUnitM]> { let Latency = 2; } +def V3AEWrite_2c_1M_1Flg : SchedWriteRes<[V3AEUnitM, V3AEUnitFlg]> { let Latency = 2; } +def V3AEWrite_3c_1M : SchedWriteRes<[V3AEUnitM]> { let Latency = 3; } +def V3AEWrite_2c_1M0 : SchedWriteRes<[V3AEUnitM0]> { let Latency = 2; } +def V3AEWrite_3c_1M0 : SchedWriteRes<[V3AEUnitM0]> { let Latency = 3; } +def V3AEWrite_4c_1M0 : SchedWriteRes<[V3AEUnitM0]> { let Latency = 4; } +def V3AEWrite_12c_1M0 : SchedWriteRes<[V3AEUnitM0]> { let Latency = 12; + let ReleaseAtCycles = [12]; } +def V3AEWrite_20c_1M0 : SchedWriteRes<[V3AEUnitM0]> { let Latency = 20; + let ReleaseAtCycles = [20]; } +def V3AEWrite_4c_1L : SchedWriteRes<[V3AEUnitL]> { let Latency = 4; } +def V3AEWrite_6c_1L : SchedWriteRes<[V3AEUnitL]> { let Latency = 6; } +def V3AEWrite_2c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 2; } +def V3AEWrite_2c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 2; } +def V3AEWrite_3c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AEWrite_4c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AEWrite_5c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AEWrite_6c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 6; } +def V3AEWrite_12c_1V : SchedWriteRes<[V3AEUnitV]> { let Latency = 12; } +def V3AEWrite_3c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 3; } +def V3AEWrite_4c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AEWrite_9c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 9; } +def V3AEWrite_10c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 10; } +def V3AEWrite_8c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 8; } +def V3AEWrite_12c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 12; + let ReleaseAtCycles = [11]; } +def V3AEWrite_13c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 13; } +def V3AEWrite_15c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 15; } +def V3AEWrite_13c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 13; + let ReleaseAtCycles = [8]; } +def V3AEWrite_16c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 16; } +def V3AEWrite_20c_1V0 : SchedWriteRes<[V3AEUnitV0]> { let Latency = 20; + let ReleaseAtCycles = [20]; } +def V3AEWrite_2c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 2; } +def V3AEWrite_3c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 3; } +def V3AEWrite_4c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 4; } +def V3AEWrite_6c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 6; } +def V3AEWrite_10c_1V1 : SchedWriteRes<[V3AEUnitV1]> { let Latency = 10; } +def V3AEWrite_6c_1SA : SchedWriteRes<[V3AEUnitSA]> { let Latency = 6; } + +//===----------------------------------------------------------------------===// +// Define generic 2 micro-op types + +def V3AEWrite_1c_1B_1S : SchedWriteRes<[V3AEUnitB, V3AEUnitS]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1M0_1B : SchedWriteRes<[V3AEUnitM0, V3AEUnitB]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_9c_1M0_1L : SchedWriteRes<[V3AEUnitM0, V3AEUnitL]> { + let Latency = 9; + let NumMicroOps = 2; +} + +def V3AEWrite_3c_1I_1M : SchedWriteRes<[V3AEUnitI, V3AEUnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3AEWrite_1c_2M : SchedWriteRes<[V3AEUnitM, V3AEUnitM]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3AEWrite_3c_2M : SchedWriteRes<[V3AEUnitM, V3AEUnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3AEWrite_4c_2M : SchedWriteRes<[V3AEUnitM, V3AEUnitM]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3AEWrite_5c_1L_1I : SchedWriteRes<[V3AEUnitL, V3AEUnitI]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1I_1L : SchedWriteRes<[V3AEUnitI, V3AEUnitL]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_7c_1I_1L : SchedWriteRes<[V3AEUnitI, V3AEUnitL]> { + let Latency = 7; + let NumMicroOps = 2; +} + +def V3AEWrite_1c_1SA_1D : SchedWriteRes<[V3AEUnitSA, V3AEUnitD]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3AEWrite_5c_1M0_1V : SchedWriteRes<[V3AEUnitM0, V3AEUnitV]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3AEWrite_2c_1SA_1V : SchedWriteRes<[V3AEUnitSA, V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3AEWrite_2c_2V : SchedWriteRes<[V3AEUnitV, V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3AEWrite_5c_1V1_1V : SchedWriteRes<[V3AEUnitV1, V3AEUnitV]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3AEWrite_4c_2V0 : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3AEWrite_4c_2V : SchedWriteRes<[V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_2V : SchedWriteRes<[V3AEUnitV, V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_2L : SchedWriteRes<[V3AEUnitL, V3AEUnitL]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_8c_1L_1V : SchedWriteRes<[V3AEUnitL, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 2; +} + +def V3AEWrite_4c_1SA_1V : SchedWriteRes<[V3AEUnitSA, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3AEWrite_3c_1M0_1M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3AEWrite_4c_1M0_1M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM]> { + let Latency = 4; + let NumMicroOps = 2; +} + +def V3AEWrite_1c_1M0_1M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM]> { + let Latency = 1; + let NumMicroOps = 2; +} + +def V3AEWrite_2c_1M0_1M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_2V1 : SchedWriteRes<[V3AEUnitV1, V3AEUnitV1]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_5c_2V0 : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3AEWrite_5c_1V1_1M0 : SchedWriteRes<[V3AEUnitV1, V3AEUnitM0]> { + let Latency = 5; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1V1_1M0 : SchedWriteRes<[V3AEUnitV1, V3AEUnitM0]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_7c_1M0_1V0 : SchedWriteRes<[V3AEUnitM0, V3AEUnitV0]> { + let Latency = 7; + let NumMicroOps = 2; +} + +def V3AEWrite_2c_1V0_1M : SchedWriteRes<[V3AEUnitV0, V3AEUnitM]> { + let Latency = 2; + let NumMicroOps = 2; +} + +def V3AEWrite_3c_1V0_1M : SchedWriteRes<[V3AEUnitV0, V3AEUnitM]> { + let Latency = 3; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1V_1V1 : SchedWriteRes<[V3AEUnitV, V3AEUnitV1]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1L_1M : SchedWriteRes<[V3AEUnitL, V3AEUnitM]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_6c_1L_1I : SchedWriteRes<[V3AEUnitL, V3AEUnitI]> { + let Latency = 6; + let NumMicroOps = 2; +} + +def V3AEWrite_8c_1M0_1V : SchedWriteRes<[V3AEUnitM0, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 2; +} + +//===----------------------------------------------------------------------===// +// Define generic 3 micro-op types + +def V3AEWrite_1c_1SA_1D_1I : SchedWriteRes<[V3AEUnitSA, V3AEUnitD, V3AEUnitI]> { + let Latency = 1; + let NumMicroOps = 3; +} + +def V3AEWrite_2c_1SA_1V_1I : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitI]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3AEWrite_2c_1SA_2V : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3AEWrite_4c_1SA_2V : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 3; +} + +def V3AEWrite_9c_1L_2V : SchedWriteRes<[V3AEUnitL, V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 3; +} + +def V3AEWrite_4c_3V : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 3; +} + +def V3AEWrite_7c_1M_1M0_1V : SchedWriteRes<[V3AEUnitM, V3AEUnitM0, V3AEUnitV]> { + let Latency = 7; + let NumMicroOps = 3; +} + +def V3AEWrite_2c_1SA_1I_1V : SchedWriteRes<[V3AEUnitSA, V3AEUnitI, V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 3; +} + +def V3AEWrite_6c_3L : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL]> { + let Latency = 6; + let NumMicroOps = 3; +} + +def V3AEWrite_6c_3V : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 3; +} + +def V3AEWrite_8c_1L_2V : SchedWriteRes<[V3AEUnitL, V3AEUnitV, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 3; +} + +//===----------------------------------------------------------------------===// +// Define generic 4 micro-op types + +def V3AEWrite_2c_1SA_2V_1I : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitI]> { + let Latency = 2; + let NumMicroOps = 4; +} + +def V3AEWrite_5c_1I_3L : SchedWriteRes<[V3AEUnitI, V3AEUnitL, V3AEUnitL, V3AEUnitL]> { + let Latency = 5; + let NumMicroOps = 4; +} + +def V3AEWrite_6c_4V0 : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0, V3AEUnitV0, V3AEUnitV0]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3AEWrite_8c_4V : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3AEWrite_6c_2V_2V1 : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV1, + V3AEUnitV1]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3AEWrite_6c_4V : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3AEWrite_8c_2L_2V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitV, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3AEWrite_9c_2L_2V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 4; +} + +def V3AEWrite_2c_2SA_2V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitV, + V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 4; +} + +def V3AEWrite_4c_2SA_2V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitV, + V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 4; +} + +def V3AEWrite_8c_2M0_2V0 : SchedWriteRes<[V3AEUnitM0, V3AEUnitM0, V3AEUnitV0, + V3AEUnitV0]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3AEWrite_8c_2V_2V1 : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV1, + V3AEUnitV1]> { + let Latency = 8; + let NumMicroOps = 4; +} + +def V3AEWrite_4c_2M0_2M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM0, V3AEUnitM, + V3AEUnitM]> { + let Latency = 4; + let NumMicroOps = 4; +} + +def V3AEWrite_5c_2M0_2M : SchedWriteRes<[V3AEUnitM0, V3AEUnitM0, V3AEUnitM, + V3AEUnitM]> { + let Latency = 5; + let NumMicroOps = 4; +} + +def V3AEWrite_6c_2I_2L : SchedWriteRes<[V3AEUnitI, V3AEUnitI, V3AEUnitL, V3AEUnitL]> { + let Latency = 6; + let NumMicroOps = 4; +} + +def V3AEWrite_7c_4L : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, V3AEUnitL]> { + let Latency = 7; + let NumMicroOps = 4; +} + +def V3AEWrite_6c_1SA_3V : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 4; +} + +//===----------------------------------------------------------------------===// +// Define generic 5 micro-op types + +def V3AEWrite_2c_1SA_2V_2I : SchedWriteRes<[V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitI, V3AEUnitI]> { + let Latency = 2; + let NumMicroOps = 5; +} + +def V3AEWrite_8c_2L_3V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 5; +} + +def V3AEWrite_9c_1L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 5; +} + +def V3AEWrite_10c_1L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 10; + let NumMicroOps = 5; +} + +def V3AEWrite_6c_5V : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 5; +} + +//===----------------------------------------------------------------------===// +// Define generic 6 micro-op types + +def V3AEWrite_8c_3L_3V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 6; +} + +def V3AEWrite_9c_3L_3V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3AEWrite_9c_2L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3AEWrite_9c_2L_2V_2I : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitV, + V3AEUnitV, V3AEUnitI, V3AEUnitI]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3AEWrite_9c_2V_4V1 : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV1, + V3AEUnitV1, V3AEUnitV1, V3AEUnitV1]> { + let Latency = 9; + let NumMicroOps = 6; +} + +def V3AEWrite_2c_3SA_3V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 6; +} + +def V3AEWrite_4c_2SA_4V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 6; +} + +def V3AEWrite_5c_2SA_4V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 5; + let NumMicroOps = 6; +} + +def V3AEWrite_4c_2SA_2I_2V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitI, + V3AEUnitI, V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 6; +} + +//===----------------------------------------------------------------------===// +// Define generic 7 micro-op types + +def V3AEWrite_8c_3L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 7; +} + +//===----------------------------------------------------------------------===// +// Define generic 8 micro-op types + +def V3AEWrite_2c_4SA_4V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 2; + let NumMicroOps = 8; +} + +def V3AEWrite_4c_4SA_4V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV]> { + let Latency = 4; + let NumMicroOps = 8; +} + +def V3AEWrite_6c_2SA_6V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 8; +} + +def V3AEWrite_8c_4L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 8; + let NumMicroOps = 8; +} + +//===----------------------------------------------------------------------===// +// Define generic 9 micro-op types + +def V3AEWrite_6c_3SA_6V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 6; + let NumMicroOps = 9; +} + +def V3AEWrite_10c_1L_8V : SchedWriteRes<[V3AEUnitL, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 10; + let NumMicroOps = 9; +} + +def V3AEWrite_10c_3V_3L_3I : SchedWriteRes<[V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitI, V3AEUnitI, V3AEUnitI]> { + let Latency = 10; + let NumMicroOps = 9; +} + +//===----------------------------------------------------------------------===// +// Define generic 10 micro-op types + +def V3AEWrite_9c_6L_4V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitL, V3AEUnitL, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 10; +} + +//===----------------------------------------------------------------------===// +// Define generic 12 micro-op types + +def V3AEWrite_5c_4SA_8V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 5; + let NumMicroOps = 12; +} + +def V3AEWrite_9c_4L_8V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitL, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 9; + let NumMicroOps = 12; +} + +def V3AEWrite_10c_4L_8V : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitL, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 10; + let NumMicroOps = 12; +} + +//===----------------------------------------------------------------------===// +// Define generic 16 micro-op types + +def V3AEWrite_7c_4SA_12V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 7; + let NumMicroOps = 16; +} + +def V3AEWrite_10c_4L_8V_4I : SchedWriteRes<[V3AEUnitL, V3AEUnitL, V3AEUnitL, + V3AEUnitL, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI]> { + let Latency = 10; + let NumMicroOps = 16; +} + +//===----------------------------------------------------------------------===// +// Define generic 18 micro-op types + +def V3AEWrite_7c_9SA_9V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 7; + let NumMicroOps = 18; +} + +//===----------------------------------------------------------------------===// +// Define generic 27 micro-op types + +def V3AEWrite_7c_9SA_9I_9V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 7; + let NumMicroOps = 27; +} + +//===----------------------------------------------------------------------===// +// Define generic 36 micro-op types + +def V3AEWrite_11c_18SA_18V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, V3AEUnitSA, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV]> { + let Latency = 11; + let NumMicroOps = 36; +} + +//===----------------------------------------------------------------------===// +// Define generic 54 micro-op types + +def V3AEWrite_11c_18SA_18I_18V : SchedWriteRes<[V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitSA, V3AEUnitSA, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitI, V3AEUnitI, V3AEUnitI, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, V3AEUnitV, + V3AEUnitV, V3AEUnitV, + V3AEUnitV]> { + let Latency = 11; + let NumMicroOps = 54; +} + +//===----------------------------------------------------------------------===// +// Define predicate-controlled types + +def V3AEWrite_ArithI : SchedWriteVariant<[ + SchedVar<IsCheapLSL, [V3AEWrite_1c_1I]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1M]>]>; + +def V3AEWrite_ArithF : SchedWriteVariant<[ + SchedVar<IsCheapLSL, [V3AEWrite_1c_1F_1Flg]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1M_1Flg]>]>; + +def V3AEWrite_Logical : SchedWriteVariant<[ + SchedVar<NeoverseNoLSL, [V3AEWrite_1c_1F_1Flg]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1M_1Flg]>]>; + +def V3AEWrite_Extr : SchedWriteVariant<[ + SchedVar<IsRORImmIdiomPred, [V3AEWrite_1c_1I]>, + SchedVar<NoSchedPred, [V3AEWrite_3c_1I_1M]>]>; + +def V3AEWrite_LdrHQ : SchedWriteVariant<[ + SchedVar<NeoverseHQForm, [V3AEWrite_7c_1I_1L]>, + SchedVar<NoSchedPred, [V3AEWrite_6c_1L]>]>; + +def V3AEWrite_StrHQ : SchedWriteVariant<[ + SchedVar<NeoverseHQForm, [V3AEWrite_2c_1SA_1V_1I]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1SA_1V]>]>; + +def V3AEWrite_0or1c_1I : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3AEWrite_0c]>, + SchedVar<NoSchedPred, [V3AEWrite_1c_1I]>]>; + +def V3AEWrite_0or2c_1V : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3AEWrite_0c]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1V]>]>; + +def V3AEWrite_0or3c_1M0 : SchedWriteVariant<[ + SchedVar<NeoverseZeroMove, [V3AEWrite_0c]>, + SchedVar<NoSchedPred, [V3AEWrite_3c_1M0]>]>; + +def V3AEWrite_2or3c_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3AEWrite_3c_1M]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1M]>]>; + +def V3AEWrite_1or2c_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3AEWrite_2c_1M]>, + SchedVar<NoSchedPred, [V3AEWrite_1c_1M]>]>; + +def V3AEWrite_3or4c_1M0_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3AEWrite_4c_1M0_1M]>, + SchedVar<NoSchedPred, [V3AEWrite_3c_1M0_1M]>]>; + +def V3AEWrite_2or3c_1V0 : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3AEWrite_3c_1V0]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1V0]>]>; + +def V3AEWrite_2or3c_1V0_1M : SchedWriteVariant<[ + SchedVar<NeoversePdIsPg, [V3AEWrite_3c_1V0_1M]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1V0_1M]>]>; + +def V3AEWrite_IncDec : SchedWriteVariant<[ + SchedVar<NeoverseCheapIncDec, [V3AEWrite_1c_1I]>, + SchedVar<NoSchedPred, [V3AEWrite_2c_1M]>]>; + +//===----------------------------------------------------------------------===// +// Define forwarded types + +// NOTE: SOG, p. 16, n. 2: Accumulator forwarding is not supported for +// consumers of 64 bit multiply high operations? +def V3AEWr_IM : SchedWriteRes<[V3AEUnitM]> { let Latency = 2; } + +def V3AEWr_FMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_FMA : SchedReadAdvance<2, [WriteFMul, V3AEWr_FMA]>; + +def V3AEWr_VA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VA : SchedReadAdvance<3, [V3AEWr_VA]>; + +def V3AEWr_VDOT : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AERd_VDOT : SchedReadAdvance<2, [V3AEWr_VDOT]>; + +def V3AEWr_VMMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AERd_VMMA : SchedReadAdvance<2, [V3AEWr_VMMA]>; + +def V3AEWr_VMA : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AERd_VMA : SchedReadAdvance<3, [V3AEWr_VMA]>; + +def V3AEWr_VMAH : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { let Latency = 4; } +def V3AERd_VMAH : SchedReadAdvance<2, [V3AEWr_VMAH]>; + +def V3AEWr_VMAL : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AERd_VMAL : SchedReadAdvance<3, [V3AEWr_VMAL]>; + +def V3AEWr_VPA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VPA : SchedReadAdvance<3, [V3AEWr_VPA]>; + +def V3AEWr_VSA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VSA : SchedReadAdvance<3, [V3AEWr_VSA]>; + +def V3AEWr_VFCMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VFCMA : SchedReadAdvance<2, [V3AEWr_VFCMA]>; + +def V3AEWr_VFM : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AEWr_VFMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VFMA : SchedReadAdvance<2, [V3AEWr_VFM, V3AEWr_VFMA]>; + +def V3AEWr_VFMAL : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_VFMAL : SchedReadAdvance<2, [V3AEWr_VFMAL]>; + +def V3AEWr_VBFDOT : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AERd_VBFDOT : SchedReadAdvance<2, [V3AEWr_VBFDOT]>; +def V3AEWr_VBFMMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 6; } +def V3AERd_VBFMMA : SchedReadAdvance<2, [V3AEWr_VBFMMA]>; +def V3AEWr_VBFMAL : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AERd_VBFMAL : SchedReadAdvance<3, [V3AEWr_VBFMAL]>; + +def V3AEWr_CRC : SchedWriteRes<[V3AEUnitM0]> { let Latency = 2; } +def V3AERd_CRC : SchedReadAdvance<1, [V3AEWr_CRC]>; + +def V3AEWr_ZA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_ZA : SchedReadAdvance<3, [V3AEWr_ZA]>; +def V3AEWr_ZPA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_ZPA : SchedReadAdvance<3, [V3AEWr_ZPA]>; +def V3AEWr_ZSA : SchedWriteRes<[V3AEUnitV1]> { let Latency = 4; } +def V3AERd_ZSA : SchedReadAdvance<3, [V3AEWr_ZSA]>; + +def V3AEWr_ZDOTB : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AERd_ZDOTB : SchedReadAdvance<2, [V3AEWr_ZDOTB]>; +def V3AEWr_ZDOTH : SchedWriteRes<[V3AEUnitV0]> { let Latency = 3; } +def V3AERd_ZDOTH : SchedReadAdvance<2, [V3AEWr_ZDOTH]>; + +// NOTE: SOG p. 43: Complex multiply-add B, H, S element size: How to reduce +// throughput to 1 in case of forwarding? +def V3AEWr_ZCMABHS : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AERd_ZCMABHS : SchedReadAdvance<3, [V3AEWr_ZCMABHS]>; +def V3AEWr_ZCMAD : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { let Latency = 5; } +def V3AERd_ZCMAD : SchedReadAdvance<2, [V3AEWr_ZCMAD]>; + +def V3AEWr_ZMMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 3; } +def V3AERd_ZMMA : SchedReadAdvance<2, [V3AEWr_ZMMA]>; + +def V3AEWr_ZMABHS : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AERd_ZMABHS : SchedReadAdvance<3, [V3AEWr_ZMABHS]>; +def V3AEWr_ZMAD : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { let Latency = 5; } +def V3AERd_ZMAD : SchedReadAdvance<2, [V3AEWr_ZMAD]>; + +def V3AEWr_ZMAL : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AERd_ZMAL : SchedReadAdvance<3, [V3AEWr_ZMAL]>; + +def V3AEWr_ZMASQL : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AEWr_ZMASQBHS : SchedWriteRes<[V3AEUnitV0]> { let Latency = 4; } +def V3AEWr_ZMASQD : SchedWriteRes<[V3AEUnitV0, V3AEUnitV0]> { let Latency = 5; } +def V3AERd_ZMASQ : SchedReadAdvance<2, [V3AEWr_ZMASQL, V3AEWr_ZMASQBHS, + V3AEWr_ZMASQD]>; + +def V3AEWr_ZFCMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AERd_ZFCMA : SchedReadAdvance<3, [V3AEWr_ZFCMA]>; + +def V3AEWr_ZFMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_ZFMA : SchedReadAdvance<2, [V3AEWr_ZFMA]>; + +def V3AEWr_ZFMAL : SchedWriteRes<[V3AEUnitV]> { let Latency = 4; } +def V3AERd_ZFMAL : SchedReadAdvance<2, [V3AEWr_ZFMAL]>; + +def V3AEWr_ZBFDOT : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AERd_ZBFDOT : SchedReadAdvance<2, [V3AEWr_ZBFDOT]>; +def V3AEWr_ZBFMMA : SchedWriteRes<[V3AEUnitV]> { let Latency = 6; } +def V3AERd_ZBFMMA : SchedReadAdvance<2, [V3AEWr_ZBFMMA]>; +def V3AEWr_ZBFMAL : SchedWriteRes<[V3AEUnitV]> { let Latency = 5; } +def V3AERd_ZBFMAL : SchedReadAdvance<3, [V3AEWr_ZBFMAL]>; + +//===----------------------------------------------------------------------===// +// Define types with long resource cycles (rc) + +def V3AEWrite_6c_1V1_5rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 6; let ReleaseAtCycles = [ 5]; } +def V3AEWrite_9c_1V1_2rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 9; let ReleaseAtCycles = [ 2]; } +def V3AEWrite_9c_1V1_4rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 9; let ReleaseAtCycles = [ 4]; } +def V3AEWrite_10c_1V1_9rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 10; let ReleaseAtCycles = [ 9]; } +def V3AEWrite_11c_1V1_4rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 11; let ReleaseAtCycles = [ 4]; } +def V3AEWrite_13c_1V1_8rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 13; let ReleaseAtCycles = [8]; } +def V3AEWrite_14c_1V1_2rc : SchedWriteRes<[V3AEUnitV1]> { let Latency = 14; let ReleaseAtCycles = [2]; } + +// Miscellaneous +// ----------------------------------------------------------------------------- + +def : InstRW<[WriteI], (instrs COPY)>; + +// §3.3 Branch instructions +// ----------------------------------------------------------------------------- + +// Branch, immed +// Compare and branch +def : SchedAlias<WriteBr, V3AEWrite_1c_1B>; + +// Branch, register +def : SchedAlias<WriteBrReg, V3AEWrite_1c_1B>; + +// Branch and link, immed +// Branch and link, register +def : InstRW<[V3AEWrite_1c_1B_1S], (instrs BL, BLR)>; + +// §3.4 Arithmetic and Logical Instructions +// ----------------------------------------------------------------------------- + +// ALU, basic +def : SchedAlias<WriteI, V3AEWrite_1c_1I>; + +// ALU, basic, flagset +def : InstRW<[V3AEWrite_1c_1F_1Flg], + (instregex "^(ADD|SUB)S[WX]r[ir]$", + "^(ADC|SBC)S[WX]r$", + "^ANDS[WX]ri$", + "^(AND|BIC)S[WX]rr$")>; +def : InstRW<[V3AEWrite_0or1c_1I], (instregex "^MOVZ[WX]i$")>; + +// ALU, extend and shift +def : SchedAlias<WriteIEReg, V3AEWrite_2c_1M>; + +// Arithmetic, LSL shift, shift <= 4 +// Arithmetic, flagset, LSL shift, shift <= 4 +// Arithmetic, LSR/ASR/ROR shift or LSL shift > 4 +def : SchedAlias<WriteISReg, V3AEWrite_ArithI>; +def : InstRW<[V3AEWrite_ArithF], + (instregex "^(ADD|SUB)S[WX]rs$")>; + +// Arithmetic, immediate to logical address tag +def : InstRW<[V3AEWrite_2c_1M], (instrs ADDG, SUBG)>; + +// Conditional compare +def : InstRW<[V3AEWrite_1c_1F_1Flg], (instregex "^CCM[NP][WX][ir]")>; + +// Convert floating-point condition flags +// Flag manipulation instructions +def : WriteRes<WriteSys, []> { let Latency = 1; } + +// Insert Random Tags +def : InstRW<[V3AEWrite_2c_1M], (instrs IRG, IRGstack)>; + +// Insert Tag Mask +// Subtract Pointer +def : InstRW<[V3AEWrite_1c_1I], (instrs GMI, SUBP)>; + +// Subtract Pointer, flagset +def : InstRW<[V3AEWrite_1c_1F_1Flg], (instrs SUBPS)>; + +// Logical, shift, no flagset +def : InstRW<[V3AEWrite_1c_1I], (instregex "^(AND|BIC|EON|EOR|ORN)[WX]rs$")>; +def : InstRW<[V3AEWrite_0or1c_1I], (instregex "^ORR[WX]rs$")>; + +// Logical, shift, flagset +def : InstRW<[V3AEWrite_Logical], (instregex "^(AND|BIC)S[WX]rs$")>; + +// Move and shift instructions +// ----------------------------------------------------------------------------- + +def : SchedAlias<WriteImm, V3AEWrite_1c_1I>; + +// §3.5 Divide and multiply instructions +// ----------------------------------------------------------------------------- + +// SDIV, UDIV +def : SchedAlias<WriteID32, V3AEWrite_12c_1M0>; +def : SchedAlias<WriteID64, V3AEWrite_20c_1M0>; + +def : SchedAlias<WriteIM32, V3AEWrite_2c_1M>; +def : SchedAlias<WriteIM64, V3AEWrite_2c_1M>; + +// Multiply +// Multiply accumulate, W-form +// Multiply accumulate, X-form +def : InstRW<[V3AEWr_IM], (instregex "^M(ADD|SUB)[WX]rrr$")>; + +// Multiply accumulate long +// Multiply long +def : InstRW<[V3AEWr_IM], (instregex "^(S|U)M(ADD|SUB)Lrrr$")>; + +// Multiply high +def : InstRW<[V3AEWrite_3c_1M], (instrs SMULHrr, UMULHrr)>; + +// §3.6 Pointer Authentication Instructions (v8.3 PAC) +// ----------------------------------------------------------------------------- + +// Authenticate data address +// Authenticate instruction address +// Compute pointer authentication code for data address +// Compute pointer authentication code, using generic key +// Compute pointer authentication code for instruction address +def : InstRW<[V3AEWrite_4c_1M0], (instregex "^AUT", "^PAC")>; + +// Branch and link, register, with pointer authentication +// Branch, register, with pointer authentication +// Branch, return, with pointer authentication +def : InstRW<[V3AEWrite_6c_1M0_1B], (instrs BLRAA, BLRAAZ, BLRAB, BLRABZ, BRAA, + BRAAZ, BRAB, BRABZ, RETAA, RETAB, + ERETAA, ERETAB)>; + + +// Load register, with pointer authentication +def : InstRW<[V3AEWrite_9c_1M0_1L], (instregex "^LDRA[AB](indexed|writeback)")>; + +// Strip pointer authentication code +def : InstRW<[V3AEWrite_2c_1M0], (instrs XPACD, XPACI, XPACLRI)>; + +// §3.7 Miscellaneous data-processing instructions +// ----------------------------------------------------------------------------- + +// Address generation +def : InstRW<[V3AEWrite_1c_1I], (instrs ADR, ADRP)>; + +// Bitfield extract, one reg +// Bitfield extract, two regs +def : SchedAlias<WriteExtr, V3AEWrite_Extr>; +def : InstRW<[V3AEWrite_Extr], (instrs EXTRWrri, EXTRXrri)>; + +// Bitfield move, basic +def : SchedAlias<WriteIS, V3AEWrite_1c_1I>; + +// Bitfield move, insert +def : InstRW<[V3AEWrite_2c_1M], (instregex "^BFM[WX]ri$")>; + +// §3.8 Load instructions +// ----------------------------------------------------------------------------- + +// NOTE: SOG p. 19: Throughput of LDN?P X-form should be 2, but reported as 3. + +def : SchedAlias<WriteLD, V3AEWrite_4c_1L>; +def : SchedAlias<WriteLDIdx, V3AEWrite_4c_1L>; + +// Load register, literal +def : InstRW<[V3AEWrite_5c_1L_1I], (instrs LDRWl, LDRXl, LDRSWl, PRFMl)>; + +// Load pair, signed immed offset, signed words +def : InstRW<[V3AEWrite_5c_1I_3L, WriteLDHi], (instrs LDPSWi)>; + +// Load pair, immed post-index or immed pre-index, signed words +def : InstRW<[WriteAdr, V3AEWrite_5c_1I_3L, WriteLDHi], + (instregex "^LDPSW(post|pre)$")>; + +// §3.9 Store instructions +// ----------------------------------------------------------------------------- + +// NOTE: SOG, p. 20: Unsure if STRH uses pipeline I. + +def : SchedAlias<WriteST, V3AEWrite_1c_1SA_1D>; +def : SchedAlias<WriteSTIdx, V3AEWrite_1c_1SA_1D>; +def : SchedAlias<WriteSTP, V3AEWrite_1c_1SA_1D>; +def : SchedAlias<WriteAdr, V3AEWrite_1c_1I>; + +// §3.10 Tag load instructions +// ----------------------------------------------------------------------------- + +// Load allocation tag +// Load multiple allocation tags +def : InstRW<[V3AEWrite_4c_1L], (instrs LDG, LDGM)>; + +// §3.11 Tag store instructions +// ----------------------------------------------------------------------------- + +// Store allocation tags to one or two granules, post-index +// Store allocation tags to one or two granules, pre-index +// Store allocation tag to one or two granules, zeroing, post-index +// Store Allocation Tag to one or two granules, zeroing, pre-index +// Store allocation tag and reg pair to memory, post-Index +// Store allocation tag and reg pair to memory, pre-Index +def : InstRW<[V3AEWrite_1c_1SA_1D_1I], (instrs STGPreIndex, STGPostIndex, + ST2GPreIndex, ST2GPostIndex, + STZGPreIndex, STZGPostIndex, + STZ2GPreIndex, STZ2GPostIndex, + STGPpre, STGPpost)>; + +// Store allocation tags to one or two granules, signed offset +// Store allocation tag to two granules, zeroing, signed offset +// Store allocation tag and reg pair to memory, signed offset +// Store multiple allocation tags +def : InstRW<[V3AEWrite_1c_1SA_1D], (instrs STGi, ST2Gi, STZGi, + STZ2Gi, STGPi, STGM, STZGM)>; + +// §3.12 FP data processing instructions +// ----------------------------------------------------------------------------- + +// FP absolute value +// FP arithmetic +// FP min/max +// FP negate +// FP select +def : SchedAlias<WriteF, V3AEWrite_2c_1V>; + +// FP compare +def : SchedAlias<WriteFCmp, V3AEWrite_2c_1V0>; + +// FP divide, square root +def : SchedAlias<WriteFDiv, V3AEWrite_6c_1V1>; + +// FP divide, H-form +def : InstRW<[V3AEWrite_6c_1V1], (instrs FDIVHrr)>; +// FP divide, S-form +def : InstRW<[V3AEWrite_8c_1V1], (instrs FDIVSrr)>; +// FP divide, D-form +def : InstRW<[V3AEWrite_13c_1V1], (instrs FDIVDrr)>; + +// FP square root, H-form +def : InstRW<[V3AEWrite_6c_1V1], (instrs FSQRTHr)>; +// FP square root, S-form +def : InstRW<[V3AEWrite_8c_1V1], (instrs FSQRTSr)>; +// FP square root, D-form +def : InstRW<[V3AEWrite_13c_1V1], (instrs FSQRTDr)>; + +// FP multiply +def : WriteRes<WriteFMul, [V3AEUnitV]> { let Latency = 3; } + +// FP multiply accumulate +def : InstRW<[V3AEWr_FMA, ReadDefault, ReadDefault, V3AERd_FMA], + (instregex "^FN?M(ADD|SUB)[HSD]rrr$")>; + +// FP round to integral +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FRINT[AIMNPXZ][HSD]r$", + "^FRINT(32|64)[XZ][SD]r$")>; + +// §3.13 FP miscellaneous instructions +// ----------------------------------------------------------------------------- + +// FP convert, from gen to vec reg +def : InstRW<[V3AEWrite_3c_1M0], (instregex "^[SU]CVTF[SU][WX][HSD]ri$")>; + +// FP convert, from vec to gen reg +def : InstRW<[V3AEWrite_3c_1V0], + (instregex "^FCVT[AMNPZ][SU][SU][WX][HSD]ri?$")>; + +// FP convert, Javascript from vec to gen reg +def : SchedAlias<WriteFCvt, V3AEWrite_3c_1V0>; + +// FP convert, from vec to vec reg +def : InstRW<[V3AEWrite_3c_1V], (instrs FCVTSHr, FCVTDHr, FCVTHSr, FCVTDSr, + FCVTHDr, FCVTSDr, FCVTXNv1i64)>; + +// FP move, immed +// FP move, register +def : SchedAlias<WriteFImm, V3AEWrite_2c_1V>; + +// FP transfer, from gen to low half of vec reg +def : InstRW<[V3AEWrite_0or3c_1M0], + (instrs FMOVWHr, FMOVXHr, FMOVWSr, FMOVXDr)>; + +// FP transfer, from gen to high half of vec reg +def : InstRW<[V3AEWrite_5c_1M0_1V], (instrs FMOVXDHighr)>; + +// FP transfer, from vec to gen reg +def : SchedAlias<WriteFCopy, V3AEWrite_2c_2V>; + +// §3.14 FP load instructions +// ----------------------------------------------------------------------------- + +// Load vector reg, literal, S/D/Q forms +def : InstRW<[V3AEWrite_7c_1I_1L], (instregex "^LDR[SDQ]l$")>; + +// Load vector reg, unscaled immed +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LDUR[BHSDQ]i$")>; + +// Load vector reg, immed post-index +// Load vector reg, immed pre-index +def : InstRW<[WriteAdr, V3AEWrite_6c_1I_1L], + (instregex "^LDR[BHSDQ](pre|post)$")>; + +// Load vector reg, unsigned immed +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LDR[BHSDQ]ui$")>; + +// Load vector reg, register offset, basic +// Load vector reg, register offset, scale, S/D-form +// Load vector reg, register offset, scale, H/Q-form +// Load vector reg, register offset, extend +// Load vector reg, register offset, extend, scale, S/D-form +// Load vector reg, register offset, extend, scale, H/Q-form +def : InstRW<[V3AEWrite_LdrHQ, ReadAdrBase], (instregex "^LDR[BHSDQ]ro[WX]$")>; + +// Load vector pair, immed offset, S/D-form +def : InstRW<[V3AEWrite_6c_1L, WriteLDHi], (instregex "^LDN?P[SD]i$")>; + +// Load vector pair, immed offset, Q-form +def : InstRW<[V3AEWrite_6c_2L, WriteLDHi], (instrs LDPQi, LDNPQi)>; + +// Load vector pair, immed post-index, S/D-form +// Load vector pair, immed pre-index, S/D-form +def : InstRW<[WriteAdr, V3AEWrite_6c_1I_1L, WriteLDHi], + (instregex "^LDP[SD](pre|post)$")>; + +// Load vector pair, immed post-index, Q-form +// Load vector pair, immed pre-index, Q-form +def : InstRW<[WriteAdr, V3AEWrite_6c_2I_2L, WriteLDHi], (instrs LDPQpost, + LDPQpre)>; + +// §3.15 FP store instructions +// ----------------------------------------------------------------------------- + +// Store vector reg, unscaled immed, B/H/S/D-form +// Store vector reg, unscaled immed, Q-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^STUR[BHSDQ]i$")>; + +// Store vector reg, immed post-index, B/H/S/D-form +// Store vector reg, immed post-index, Q-form +// Store vector reg, immed pre-index, B/H/S/D-form +// Store vector reg, immed pre-index, Q-form +def : InstRW<[WriteAdr, V3AEWrite_2c_1SA_1V_1I], + (instregex "^STR[BHSDQ](pre|post)$")>; + +// Store vector reg, unsigned immed, B/H/S/D-form +// Store vector reg, unsigned immed, Q-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^STR[BHSDQ]ui$")>; + +// Store vector reg, register offset, basic, B/H/S/D-form +// Store vector reg, register offset, basic, Q-form +// Store vector reg, register offset, scale, H-form +// Store vector reg, register offset, scale, S/D-form +// Store vector reg, register offset, scale, Q-form +// Store vector reg, register offset, extend, B/H/S/D-form +// Store vector reg, register offset, extend, Q-form +// Store vector reg, register offset, extend, scale, H-form +// Store vector reg, register offset, extend, scale, S/D-form +// Store vector reg, register offset, extend, scale, Q-form +def : InstRW<[V3AEWrite_StrHQ, ReadAdrBase], + (instregex "^STR[BHSDQ]ro[WX]$")>; + +// Store vector pair, immed offset, S-form +// Store vector pair, immed offset, D-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^STN?P[SD]i$")>; + +// Store vector pair, immed offset, Q-form +def : InstRW<[V3AEWrite_2c_1SA_2V], (instrs STPQi, STNPQi)>; + +// Store vector pair, immed post-index, S-form +// Store vector pair, immed post-index, D-form +// Store vector pair, immed pre-index, S-form +// Store vector pair, immed pre-index, D-form +def : InstRW<[WriteAdr, V3AEWrite_2c_1SA_1V_1I], + (instregex "^STP[SD](pre|post)$")>; + +// Store vector pair, immed post-index, Q-form +def : InstRW<[V3AEWrite_2c_1SA_2V_1I], (instrs STPQpost)>; + +// Store vector pair, immed pre-index, Q-form +def : InstRW<[V3AEWrite_2c_1SA_2V_2I], (instrs STPQpre)>; + +// §3.16 ASIMD integer instructions +// ----------------------------------------------------------------------------- + +// ASIMD absolute diff +// ASIMD absolute diff long +// ASIMD arith, basic +// ASIMD arith, complex +// ASIMD arith, pair-wise +// ASIMD compare +// ASIMD logical +// ASIMD max/min, basic and pair-wise +def : SchedAlias<WriteVd, V3AEWrite_2c_1V>; +def : SchedAlias<WriteVq, V3AEWrite_2c_1V>; + +// ASIMD absolute diff accum +// ASIMD absolute diff accum long +def : InstRW<[V3AEWr_VA, V3AERd_VA], (instregex "^[SU]ABAL?v")>; + +// ASIMD arith, reduce, 4H/4S +def : InstRW<[V3AEWrite_3c_1V1], (instregex "^(ADDV|[SU]ADDLV)v4(i16|i32)v$")>; + +// ASIMD arith, reduce, 8B/8H +def : InstRW<[V3AEWrite_5c_1V1_1V], + (instregex "^(ADDV|[SU]ADDLV)v8(i8|i16)v$")>; + +// ASIMD arith, reduce, 16B +def : InstRW<[V3AEWrite_6c_2V1], (instregex "^(ADDV|[SU]ADDLV)v16i8v$")>; + +// ASIMD dot product +// ASIMD dot product using signed and unsigned integers +def : InstRW<[V3AEWr_VDOT, V3AERd_VDOT], + (instregex "^([SU]|SU|US)DOT(lane)?(v8|v16)i8$")>; + +// ASIMD matrix multiply-accumulate +def : InstRW<[V3AEWr_VMMA, V3AERd_VMMA], (instrs SMMLA, UMMLA, USMMLA)>; + +// ASIMD max/min, reduce, 4H/4S +def : InstRW<[V3AEWrite_3c_1V1], (instregex "^[SU](MAX|MIN)Vv4i16v$", + "^[SU](MAX|MIN)Vv4i32v$")>; + +// ASIMD max/min, reduce, 8B/8H +def : InstRW<[V3AEWrite_5c_1V1_1V], (instregex "^[SU](MAX|MIN)Vv8i8v$", + "^[SU](MAX|MIN)Vv8i16v$")>; + +// ASIMD max/min, reduce, 16B +def : InstRW<[V3AEWrite_6c_2V1], (instregex "[SU](MAX|MIN)Vv16i8v$")>; + +// ASIMD multiply +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^MULv", "^SQ(R)?DMULHv")>; + +// ASIMD multiply accumulate +def : InstRW<[V3AEWr_VMA, V3AERd_VMA], (instregex "^MLAv", "^MLSv")>; + +// ASIMD multiply accumulate high +def : InstRW<[V3AEWr_VMAH, V3AERd_VMAH], (instregex "^SQRDMLAHv", "^SQRDMLSHv")>; + +// ASIMD multiply accumulate long +def : InstRW<[V3AEWr_VMAL, V3AERd_VMAL], (instregex "^[SU]MLALv", "^[SU]MLSLv")>; + +// ASIMD multiply accumulate saturating long +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SQDML[AS]L[iv]")>; + +// ASIMD multiply/multiply long (8x8) polynomial, D-form +// ASIMD multiply/multiply long (8x8) polynomial, Q-form +def : InstRW<[V3AEWrite_3c_1V], (instregex "^PMULL?(v8i8|v16i8)$")>; + +// ASIMD multiply long +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^[SU]MULLv", "^SQDMULL[iv]")>; + +// ASIMD pairwise add and accumulate long +def : InstRW<[V3AEWr_VPA, V3AERd_VPA], (instregex "^[SU]ADALPv")>; + +// ASIMD shift accumulate +def : InstRW<[V3AEWr_VSA, V3AERd_VSA], (instregex "^[SU]SRA[dv]", "^[SU]RSRA[dv]")>; + +// ASIMD shift by immed, basic +def : InstRW<[V3AEWrite_2c_1V], (instregex "^SHL[dv]", "^SHLLv", "^SHRNv", + "^SSHLLv", "^SSHR[dv]", "^USHLLv", + "^USHR[dv]")>; + +// ASIMD shift by immed and insert, basic +def : InstRW<[V3AEWrite_2c_1V], (instregex "^SLI[dv]", "^SRI[dv]")>; + +// ASIMD shift by immed, complex +def : InstRW<[V3AEWrite_4c_1V], + (instregex "^RSHRNv", "^SQRSHRU?N[bhsv]", "^(SQSHLU?|UQSHL)[bhsd]$", + "^(SQSHLU?|UQSHL)(v8i8|v16i8|v4i16|v8i16|v2i32|v4i32|v2i64)_shift$", + "^SQSHRU?N[bhsv]", "^SRSHR[dv]", "^UQRSHRN[bhsv]", + "^UQSHRN[bhsv]", "^URSHR[dv]")>; + +// ASIMD shift by register, basic +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU]SHLv")>; + +// ASIMD shift by register, complex +def : InstRW<[V3AEWrite_4c_1V], + (instregex "^[SU]RSHLv", "^[SU]QRSHLv", + "^[SU]QSHL(v1i8|v1i16|v1i32|v1i64|v8i8|v16i8|v4i16|v8i16|v2i32|v4i32|v2i64)$")>; + +// §3.17 ASIMD floating-point instructions +// ----------------------------------------------------------------------------- + +// ASIMD FP absolute value/difference +// ASIMD FP arith, normal +// ASIMD FP compare +// ASIMD FP complex add +// ASIMD FP max/min, normal +// ASIMD FP max/min, pairwise +// ASIMD FP negate +// Handled by SchedAlias<WriteV[dq], ...> + +// ASIMD FP complex multiply add +def : InstRW<[V3AEWr_VFCMA, V3AERd_VFCMA], (instregex "^FCMLAv")>; + +// ASIMD FP convert, long (F16 to F32) +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FCVTL(v4|v8)i16")>; + +// ASIMD FP convert, long (F32 to F64) +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FCVTL(v2|v4)i32")>; + +// ASIMD FP convert, narrow (F32 to F16) +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FCVTN(v4|v8)i16")>; + +// ASIMD FP convert, narrow (F64 to F32) +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FCVTN(v2|v4)i32", + "^FCVTXN(v2|v4)f32")>; + +// ASIMD FP convert, other, D-form F32 and Q-form F64 +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FCVT[AMNPZ][SU]v2f(32|64)$", + "^FCVT[AMNPZ][SU]v2i(32|64)_shift$", + "^FCVT[AMNPZ][SU]v1i64$", + "^FCVTZ[SU]d$", + "^[SU]CVTFv2f(32|64)$", + "^[SU]CVTFv2i(32|64)_shift$", + "^[SU]CVTFv1i64$", + "^[SU]CVTFd$")>; + +// ASIMD FP convert, other, D-form F16 and Q-form F32 +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FCVT[AMNPZ][SU]v4f(16|32)$", + "^FCVT[AMNPZ][SU]v4i(16|32)_shift$", + "^FCVT[AMNPZ][SU]v1i32$", + "^FCVTZ[SU]s$", + "^[SU]CVTFv4f(16|32)$", + "^[SU]CVTFv4i(16|32)_shift$", + "^[SU]CVTFv1i32$", + "^[SU]CVTFs$")>; + +// ASIMD FP convert, other, Q-form F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FCVT[AMNPZ][SU]v8f16$", + "^FCVT[AMNPZ][SU]v8i16_shift$", + "^FCVT[AMNPZ][SU]v1f16$", + "^FCVTZ[SU]h$", + "^[SU]CVTFv8f16$", + "^[SU]CVTFv8i16_shift$", + "^[SU]CVTFv1i16$", + "^[SU]CVTFh$")>; + +// ASIMD FP divide, D-form, F16 +def : InstRW<[V3AEWrite_9c_1V1_4rc], (instrs FDIVv4f16)>; + +// ASIMD FP divide, D-form, F32 +def : InstRW<[V3AEWrite_9c_1V1_2rc], (instrs FDIVv2f32)>; + +// ASIMD FP divide, Q-form, F16 +def : InstRW<[V3AEWrite_13c_1V1_8rc], (instrs FDIVv8f16)>; + +// ASIMD FP divide, Q-form, F32 +def : InstRW<[V3AEWrite_11c_1V1_4rc], (instrs FDIVv4f32)>; + +// ASIMD FP divide, Q-form, F64 +def : InstRW<[V3AEWrite_14c_1V1_2rc], (instrs FDIVv2f64)>; + +// ASIMD FP max/min, reduce, F32 and D-form F16 +def : InstRW<[V3AEWrite_4c_2V], (instregex "^(FMAX|FMIN)(NM)?Vv4(i16|i32)v$")>; + +// ASIMD FP max/min, reduce, Q-form F16 +def : InstRW<[V3AEWrite_6c_3V], (instregex "^(FMAX|FMIN)(NM)?Vv8i16v$")>; + +// ASIMD FP multiply +def : InstRW<[V3AEWr_VFM], (instregex "^FMULv", "^FMULXv")>; + +// ASIMD FP multiply accumulate +def : InstRW<[V3AEWr_VFMA, V3AERd_VFMA], (instregex "^FMLAv", "^FMLSv")>; + +// ASIMD FP multiply accumulate long +def : InstRW<[V3AEWr_VFMAL, V3AERd_VFMAL], (instregex "^FML[AS]L2?(lane)?v")>; + +// ASIMD FP round, D-form F32 and Q-form F64 +def : InstRW<[V3AEWrite_3c_1V0], + (instregex "^FRINT[AIMNPXZ]v2f(32|64)$", + "^FRINT(32|64)[XZ]v2f(32|64)$")>; + +// ASIMD FP round, D-form F16 and Q-form F32 +def : InstRW<[V3AEWrite_4c_2V0], + (instregex "^FRINT[AIMNPXZ]v4f(16|32)$", + "^FRINT(32|64)[XZ]v4f32$")>; + +// ASIMD FP round, Q-form F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FRINT[AIMNPXZ]v8f16$")>; + +// ASIMD FP square root, D-form, F16 +def : InstRW<[V3AEWrite_9c_1V1_4rc], (instrs FSQRTv4f16)>; + +// ASIMD FP square root, D-form, F32 +def : InstRW<[V3AEWrite_9c_1V1_2rc], (instrs FSQRTv2f32)>; + +// ASIMD FP square root, Q-form, F16 +def : InstRW<[V3AEWrite_13c_1V1_8rc], (instrs FSQRTv8f16)>; + +// ASIMD FP square root, Q-form, F32 +def : InstRW<[V3AEWrite_11c_1V1_4rc], (instrs FSQRTv4f32)>; + +// ASIMD FP square root, Q-form, F64 +def : InstRW<[V3AEWrite_14c_1V1_2rc], (instrs FSQRTv2f64)>; + +// §3.18 ASIMD BFloat16 (BF16) instructions +// ----------------------------------------------------------------------------- + +// ASIMD convert, F32 to BF16 +def : InstRW<[V3AEWrite_4c_2V0], (instrs BFCVTN, BFCVTN2)>; + +// ASIMD dot product +def : InstRW<[V3AEWr_VBFDOT, V3AERd_VBFDOT], (instrs BFDOTv4bf16, BFDOTv8bf16)>; + +// ASIMD matrix multiply accumulate +def : InstRW<[V3AEWr_VBFMMA, V3AERd_VBFMMA], (instrs BFMMLA)>; + +// ASIMD multiply accumulate long +def : InstRW<[V3AEWr_VBFMAL, V3AERd_VBFMAL], (instrs BFMLALB, BFMLALBIdx, BFMLALT, + BFMLALTIdx)>; + +// Scalar convert, F32 to BF16 +def : InstRW<[V3AEWrite_3c_1V0], (instrs BFCVT)>; + +// §3.19 ASIMD miscellaneous instructions +// ----------------------------------------------------------------------------- + +// ASIMD bit reverse +// ASIMD bitwise insert +// ASIMD count +// ASIMD duplicate, element +// ASIMD extract +// ASIMD extract narrow +// ASIMD insert, element to element +// ASIMD move, FP immed +// ASIMD move, integer immed +// ASIMD reverse +// ASIMD table lookup extension, 1 table reg +// ASIMD transpose +// ASIMD unzip/zip +// Handled by SchedAlias<WriteV[dq], ...> +def : InstRW<[V3AEWrite_0or2c_1V], (instrs MOVID, MOVIv2d_ns)>; + +// ASIMD duplicate, gen reg +def : InstRW<[V3AEWrite_3c_1M0], (instregex "^DUPv.+gpr")>; + +// ASIMD extract narrow, saturating +def : InstRW<[V3AEWrite_4c_1V], (instregex "^[SU]QXTNv", "^SQXTUNv")>; + +// ASIMD reciprocal and square root estimate, D-form U32 +def : InstRW<[V3AEWrite_3c_1V0], (instrs URECPEv2i32, URSQRTEv2i32)>; + +// ASIMD reciprocal and square root estimate, Q-form U32 +def : InstRW<[V3AEWrite_4c_2V0], (instrs URECPEv4i32, URSQRTEv4i32)>; + +// ASIMD reciprocal and square root estimate, D-form F32 and scalar forms +def : InstRW<[V3AEWrite_3c_1V0], (instrs FRECPEv1f16, FRECPEv1i32, + FRECPEv1i64, FRECPEv2f32, + FRSQRTEv1f16, FRSQRTEv1i32, + FRSQRTEv1i64, FRSQRTEv2f32)>; + +// ASIMD reciprocal and square root estimate, D-form F16 and Q-form F32 +def : InstRW<[V3AEWrite_4c_2V0], (instrs FRECPEv4f16, FRECPEv4f32, + FRSQRTEv4f16, FRSQRTEv4f32)>; + +// ASIMD reciprocal and square root estimate, Q-form F16 +def : InstRW<[V3AEWrite_6c_4V0], (instrs FRECPEv8f16, FRSQRTEv8f16)>; + +// ASIMD reciprocal exponent +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FRECPXv")>; + +// ASIMD reciprocal step +def : InstRW<[V3AEWrite_4c_1V], (instregex "^FRECPS(32|64|v)", + "^FRSQRTS(32|64|v)")>; + +// ASIMD table lookup, 1 or 2 table regs +def : InstRW<[V3AEWrite_2c_1V], (instrs TBLv8i8One, TBLv16i8One, + TBLv8i8Two, TBLv16i8Two)>; + +// ASIMD table lookup, 3 table regs +def : InstRW<[V3AEWrite_4c_2V], (instrs TBLv8i8Three, TBLv16i8Three)>; + +// ASIMD table lookup, 4 table regs +def : InstRW<[V3AEWrite_4c_3V], (instrs TBLv8i8Four, TBLv16i8Four)>; + +// ASIMD table lookup extension, 2 table reg +def : InstRW<[V3AEWrite_4c_2V], (instrs TBXv8i8Two, TBXv16i8Two)>; + +// ASIMD table lookup extension, 3 table reg +def : InstRW<[V3AEWrite_6c_3V], (instrs TBXv8i8Three, TBXv16i8Three)>; + +// ASIMD table lookup extension, 4 table reg +def : InstRW<[V3AEWrite_6c_5V], (instrs TBXv8i8Four, TBXv16i8Four)>; + +// ASIMD transfer, element to gen reg +def : InstRW<[V3AEWrite_2c_2V], (instregex "^[SU]MOVv")>; + +// ASIMD transfer, gen reg to element +def : InstRW<[V3AEWrite_5c_1M0_1V], (instregex "^INSvi(8|16|32|64)gpr$")>; + +// §3.20 ASIMD load instructions +// ----------------------------------------------------------------------------- + +// ASIMD load, 1 element, multiple, 1 reg, D-form +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1Onev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_1L], + (instregex "^LD1Onev(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 1 reg, Q-form +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1Onev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_1L], + (instregex "^LD1Onev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 2 reg, D-form +def : InstRW<[V3AEWrite_6c_2L], (instregex "^LD1Twov(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_2L], + (instregex "^LD1Twov(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 2 reg, Q-form +def : InstRW<[V3AEWrite_6c_2L], (instregex "^LD1Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_2L], + (instregex "^LD1Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 3 reg, D-form +def : InstRW<[V3AEWrite_6c_3L], (instregex "^LD1Threev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_3L], + (instregex "^LD1Threev(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 3 reg, Q-form +def : InstRW<[V3AEWrite_6c_3L], (instregex "^LD1Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_3L], + (instregex "^LD1Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, multiple, 4 reg, D-form +def : InstRW<[V3AEWrite_7c_4L], (instregex "^LD1Fourv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_7c_4L], + (instregex "^LD1Fourv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, multiple, 4 reg, Q-form +def : InstRW<[V3AEWrite_7c_4L], (instregex "^LD1Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_7c_4L], + (instregex "^LD1Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 1 element, one lane, B/H/S +// ASIMD load, 1 element, one lane, D +def : InstRW<[V3AEWrite_8c_1L_1V], (instregex "LD1i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_1V], (instregex "LD1i(8|16|32|64)_POST$")>; + +// ASIMD load, 1 element, all lanes, D-form, B/H/S +// ASIMD load, 1 element, all lanes, D-form, D +def : InstRW<[V3AEWrite_8c_1L_1V], (instregex "LD1Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_1V], (instregex "LD1Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 1 element, all lanes, Q-form +def : InstRW<[V3AEWrite_8c_1L_1V], (instregex "LD1Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_1V], (instregex "LD1Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 2 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_8c_1L_2V], (instregex "LD2Twov(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_2V], (instregex "LD2Twov(8b|4h|2s)_POST$")>; + +// ASIMD load, 2 element, multiple, Q-form, B/H/S +// ASIMD load, 2 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_8c_2L_2V], (instregex "LD2Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_2L_2V], (instregex "LD2Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 2 element, one lane, B/H +// ASIMD load, 2 element, one lane, S +// ASIMD load, 2 element, one lane, D +def : InstRW<[V3AEWrite_8c_1L_2V], (instregex "LD2i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_2V], (instregex "LD2i(8|16|32|64)_POST$")>; + +// ASIMD load, 2 element, all lanes, D-form, B/H/S +// ASIMD load, 2 element, all lanes, D-form, D +def : InstRW<[V3AEWrite_8c_1L_2V], (instregex "LD2Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_2V], (instregex "LD2Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 2 element, all lanes, Q-form +def : InstRW<[V3AEWrite_8c_1L_2V], (instregex "LD2Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_1L_2V], (instregex "LD2Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 3 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_8c_2L_3V], (instregex "LD3Threev(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_2L_3V], (instregex "LD3Threev(8b|4h|2s)_POST$")>; + +// ASIMD load, 3 element, multiple, Q-form, B/H/S +// ASIMD load, 3 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_8c_3L_3V], (instregex "LD3Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_3L_3V], (instregex "LD3Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 3 element, one lane, B/H +// ASIMD load, 3 element, one lane, S +// ASIMD load, 3 element, one lane, D +def : InstRW<[V3AEWrite_8c_2L_3V], (instregex "LD3i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_2L_3V], (instregex "LD3i(8|16|32|64)_POST$")>; + +// ASIMD load, 3 element, all lanes, D-form, B/H/S +// ASIMD load, 3 element, all lanes, D-form, D +def : InstRW<[V3AEWrite_8c_2L_3V], (instregex "LD3Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_2L_3V], (instregex "LD3Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 3 element, all lanes, Q-form, B/H/S +// ASIMD load, 3 element, all lanes, Q-form, D +def : InstRW<[V3AEWrite_8c_3L_3V], (instregex "LD3Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_3L_3V], (instregex "LD3Rv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 4 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_8c_3L_4V], (instregex "LD4Fourv(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_3L_4V], (instregex "LD4Fourv(8b|4h|2s)_POST$")>; + +// ASIMD load, 4 element, multiple, Q-form, B/H/S +// ASIMD load, 4 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_9c_6L_4V], (instregex "LD4Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_9c_6L_4V], (instregex "LD4Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD load, 4 element, one lane, B/H +// ASIMD load, 4 element, one lane, S +// ASIMD load, 4 element, one lane, D +def : InstRW<[V3AEWrite_8c_3L_4V], (instregex "LD4i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_3L_4V], (instregex "LD4i(8|16|32|64)_POST$")>; + +// ASIMD load, 4 element, all lanes, D-form, B/H/S +// ASIMD load, 4 element, all lanes, D-form, D +def : InstRW<[V3AEWrite_8c_3L_4V], (instregex "LD4Rv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_3L_4V], (instregex "LD4Rv(8b|4h|2s|1d)_POST$")>; + +// ASIMD load, 4 element, all lanes, Q-form, B/H/S +// ASIMD load, 4 element, all lanes, Q-form, D +def : InstRW<[V3AEWrite_8c_4L_4V], (instregex "LD4Rv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_8c_4L_4V], (instregex "LD4Rv(16b|8h|4s|2d)_POST$")>; + +// §3.21 ASIMD store instructions +// ----------------------------------------------------------------------------- + +// ASIMD store, 1 element, multiple, 1 reg, D-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "ST1Onev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_1SA_1V], (instregex "ST1Onev(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 1 reg, Q-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "ST1Onev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_1SA_1V], (instregex "ST1Onev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 2 reg, D-form +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "ST1Twov(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_1SA_1V], (instregex "ST1Twov(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 2 reg, Q-form +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "ST1Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_2SA_2V], (instregex "ST1Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 3 reg, D-form +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "ST1Threev(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_2SA_2V], (instregex "ST1Threev(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 3 reg, Q-form +def : InstRW<[V3AEWrite_2c_3SA_3V], (instregex "ST1Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_3SA_3V], (instregex "ST1Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, multiple, 4 reg, D-form +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "ST1Fourv(8b|4h|2s|1d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_2SA_2V], (instregex "ST1Fourv(8b|4h|2s|1d)_POST$")>; + +// ASIMD store, 1 element, multiple, 4 reg, Q-form +def : InstRW<[V3AEWrite_2c_4SA_4V], (instregex "ST1Fourv(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_2c_4SA_4V], (instregex "ST1Fourv(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 1 element, one lane, B/H/S +// ASIMD store, 1 element, one lane, D +def : InstRW<[V3AEWrite_4c_1SA_2V], (instregex "ST1i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_4c_1SA_2V], (instregex "ST1i(8|16|32|64)_POST$")>; + +// ASIMD store, 2 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_4c_1SA_2V], (instregex "ST2Twov(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_4c_1SA_2V], (instregex "ST2Twov(8b|4h|2s)_POST$")>; + +// ASIMD store, 2 element, multiple, Q-form, B/H/S +// ASIMD store, 2 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_4c_2SA_4V], (instregex "ST2Twov(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_4c_2SA_4V], (instregex "ST2Twov(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 2 element, one lane, B/H/S +// ASIMD store, 2 element, one lane, D +def : InstRW<[V3AEWrite_4c_1SA_2V], (instregex "ST2i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_4c_1SA_2V], (instregex "ST2i(8|16|32|64)_POST$")>; + +// ASIMD store, 3 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_5c_2SA_4V], (instregex "ST3Threev(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_5c_2SA_4V], (instregex "ST3Threev(8b|4h|2s)_POST$")>; + +// ASIMD store, 3 element, multiple, Q-form, B/H/S +// ASIMD store, 3 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_6c_3SA_6V], (instregex "ST3Threev(16b|8h|4s|2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_3SA_6V], (instregex "ST3Threev(16b|8h|4s|2d)_POST$")>; + +// ASIMD store, 3 element, one lane, B/H +// ASIMD store, 3 element, one lane, S +// ASIMD store, 3 element, one lane, D +def : InstRW<[V3AEWrite_5c_2SA_4V], (instregex "ST3i(8|16|32|64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_5c_2SA_4V], (instregex "ST3i(8|16|32|64)_POST$")>; + +// ASIMD store, 4 element, multiple, D-form, B/H/S +def : InstRW<[V3AEWrite_6c_2SA_6V], (instregex "ST4Fourv(8b|4h|2s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_2SA_6V], (instregex "ST4Fourv(8b|4h|2s)_POST$")>; + +// ASIMD store, 4 element, multiple, Q-form, B/H/S +def : InstRW<[V3AEWrite_7c_4SA_12V], (instregex "ST4Fourv(16b|8h|4s)$")>; +def : InstRW<[WriteAdr, V3AEWrite_7c_4SA_12V], (instregex "ST4Fourv(16b|8h|4s)_POST$")>; + +// ASIMD store, 4 element, multiple, Q-form, D +def : InstRW<[V3AEWrite_5c_4SA_8V], (instregex "ST4Fourv(2d)$")>; +def : InstRW<[WriteAdr, V3AEWrite_5c_4SA_8V], (instregex "ST4Fourv(2d)_POST$")>; + +// ASIMD store, 4 element, one lane, B/H/S +def : InstRW<[V3AEWrite_6c_1SA_3V], (instregex "ST4i(8|16|32)$")>; +def : InstRW<[WriteAdr, V3AEWrite_6c_1SA_3V], (instregex "ST4i(8|16|32)_POST$")>; + +// ASIMD store, 4 element, one lane, D +def : InstRW<[V3AEWrite_4c_2SA_4V], (instregex "ST4i(64)$")>; +def : InstRW<[WriteAdr, V3AEWrite_4c_2SA_4V], (instregex "ST4i(64)_POST$")>; + +// §3.22 Cryptography extensions +// ----------------------------------------------------------------------------- + +// Crypto AES ops +def : InstRW<[V3AEWrite_2c_1V], (instregex "^AES[DE]rr$", "^AESI?MCrr")>; + +// Crypto polynomial (64x64) multiply long +def : InstRW<[V3AEWrite_2c_1V], (instrs PMULLv1i64, PMULLv2i64)>; + +// Crypto SHA1 hash acceleration op +// Crypto SHA1 schedule acceleration ops +def : InstRW<[V3AEWrite_2c_1V0], (instregex "^SHA1(H|SU0|SU1)")>; + +// Crypto SHA1 hash acceleration ops +// Crypto SHA256 hash acceleration ops +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SHA1[CMP]", "^SHA256H2?")>; + +// Crypto SHA256 schedule acceleration ops +def : InstRW<[V3AEWrite_2c_1V0], (instregex "^SHA256SU[01]")>; + +// Crypto SHA512 hash acceleration ops +def : InstRW<[V3AEWrite_2c_1V0], (instregex "^SHA512(H|H2|SU0|SU1)")>; + +// Crypto SHA3 ops +def : InstRW<[V3AEWrite_2c_1V], (instrs BCAX, EOR3, RAX1, XAR)>; + +// Crypto SM3 ops +def : InstRW<[V3AEWrite_2c_1V0], (instregex "^SM3PARTW[12]$", "^SM3SS1$", + "^SM3TT[12][AB]$")>; + +// Crypto SM4 ops +def : InstRW<[V3AEWrite_4c_1V0], (instrs SM4E, SM4ENCKEY)>; + +// §3.23 CRC +// ----------------------------------------------------------------------------- + +def : InstRW<[V3AEWr_CRC, V3AERd_CRC], (instregex "^CRC32")>; + +// §3.24 SVE Predicate instructions +// ----------------------------------------------------------------------------- + +// Loop control, based on predicate +def : InstRW<[V3AEWrite_2or3c_1M], (instrs BRKA_PPmP, BRKA_PPzP, + BRKB_PPmP, BRKB_PPzP)>; + +// Loop control, based on predicate and flag setting +def : InstRW<[V3AEWrite_2or3c_1M], (instrs BRKAS_PPzP, BRKBS_PPzP)>; + +// Loop control, propagating +def : InstRW<[V3AEWrite_2or3c_1M], (instrs BRKN_PPzP, BRKPA_PPzPP, + BRKPB_PPzPP)>; + +// Loop control, propagating and flag setting +def : InstRW<[V3AEWrite_2or3c_1M], (instrs BRKNS_PPzP, BRKPAS_PPzPP, + BRKPBS_PPzPP)>; + +// Loop control, based on GPR +def : InstRW<[V3AEWrite_3c_2M], + (instregex "^WHILE(GE|GT|HI|HS|LE|LO|LS|LT)_P(WW|XX)_[BHSD]")>; +def : InstRW<[V3AEWrite_3c_2M], (instregex "^WHILE(RW|WR)_PXX_[BHSD]")>; + +// Loop terminate +def : InstRW<[V3AEWrite_1c_2M], (instregex "^CTERM(EQ|NE)_(WW|XX)")>; + +// Predicate counting scalar +def : InstRW<[V3AEWrite_2c_1M], (instrs ADDPL_XXI, ADDVL_XXI, RDVLI_XI)>; +def : InstRW<[V3AEWrite_2c_1M], + (instregex "^(CNT|SQDEC|SQINC|UQDEC|UQINC)[BHWD]_XPiI", + "^SQ(DEC|INC)[BHWD]_XPiWdI", + "^UQ(DEC|INC)[BHWD]_WPiI")>; + +// Predicate counting scalar, ALL, {1,2,4} +def : InstRW<[V3AEWrite_IncDec], (instregex "^(DEC|INC)[BHWD]_XPiI")>; + +// Predicate counting scalar, active predicate +def : InstRW<[V3AEWrite_2c_1M], + (instregex "^CNTP_XPP_[BHSD]", + "^(DEC|INC|SQDEC|SQINC|UQDEC|UQINC)P_XP_[BHSD]", + "^(UQDEC|UQINC)P_WP_[BHSD]", + "^(SQDEC|SQINC)P_XPWd_[BHSD]")>; + +// Predicate counting vector, active predicate +def : InstRW<[V3AEWrite_7c_1M_1M0_1V], + (instregex "^(DEC|INC|SQDEC|SQINC|UQDEC|UQINC)P_ZP_[HSD]")>; + +// Predicate logical +def : InstRW<[V3AEWrite_1or2c_1M], + (instregex "^(AND|BIC|EOR|NAND|NOR|ORN|ORR)_PPzPP")>; + +// Predicate logical, flag setting +def : InstRW<[V3AEWrite_1or2c_1M], + (instregex "^(ANDS|BICS|EORS|NANDS|NORS|ORNS|ORRS)_PPzPP")>; + +// Predicate reverse +def : InstRW<[V3AEWrite_2c_1M], (instregex "^REV_PP_[BHSD]")>; + +// Predicate select +def : InstRW<[V3AEWrite_1c_1M], (instrs SEL_PPPP)>; + +// Predicate set +def : InstRW<[V3AEWrite_2c_1M], (instregex "^PFALSE", "^PTRUE_[BHSD]")>; + +// Predicate set/initialize, set flags +def : InstRW<[V3AEWrite_2c_1M], (instregex "^PTRUES_[BHSD]")>; + +// Predicate find first/next +def : InstRW<[V3AEWrite_2c_1M], (instregex "^PFIRST_B", "^PNEXT_[BHSD]")>; + +// Predicate test +def : InstRW<[V3AEWrite_1c_1M], (instrs PTEST_PP)>; + +// Predicate transpose +def : InstRW<[V3AEWrite_2c_1M], (instregex "^TRN[12]_PPP_[BHSD]")>; + +// Predicate unpack and widen +def : InstRW<[V3AEWrite_2c_1M], (instrs PUNPKHI_PP, PUNPKLO_PP)>; + +// Predicate zip/unzip +def : InstRW<[V3AEWrite_2c_1M], (instregex "^(ZIP|UZP)[12]_PPP_[BHSD]")>; + +// §3.25 SVE integer instructions +// ----------------------------------------------------------------------------- + +// Arithmetic, absolute diff +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU]ABD_ZPmZ_[BHSD]", + "^[SU]ABD_ZPZZ_[BHSD]")>; + +// Arithmetic, absolute diff accum +def : InstRW<[V3AEWr_ZA, V3AERd_ZA], (instregex "^[SU]ABA_ZZZ_[BHSD]")>; + +// Arithmetic, absolute diff accum long +def : InstRW<[V3AEWr_ZA, V3AERd_ZA], (instregex "^[SU]ABAL[TB]_ZZZ_[HSD]")>; + +// Arithmetic, absolute diff long +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU]ABDL[TB]_ZZZ_[HSD]")>; + +// Arithmetic, basic +def : InstRW<[V3AEWrite_2c_1V], + (instregex "^(ABS|ADD|CNOT|NEG|SUB|SUBR)_ZPmZ_[BHSD]", + "^(ADD|SUB)_ZZZ_[BHSD]", + "^(ADD|SUB|SUBR)_ZPZZ_[BHSD]", + "^(ADD|SUB|SUBR)_ZI_[BHSD]", + "^ADR_[SU]XTW_ZZZ_D_[0123]", + "^ADR_LSL_ZZZ_[SD]_[0123]", + "^[SU](ADD|SUB)[LW][BT]_ZZZ_[HSD]", + "^SADDLBT_ZZZ_[HSD]", + "^[SU]H(ADD|SUB|SUBR)_ZPmZ_[BHSD]", + "^SSUBL(BT|TB)_ZZZ_[HSD]")>; + +// Arithmetic, complex +def : InstRW<[V3AEWrite_2c_1V], + (instregex "^R?(ADD|SUB)HN[BT]_ZZZ_[BHS]", + "^SQ(ABS|ADD|NEG|SUB|SUBR)_ZPmZ_[BHSD]", + "^[SU]Q(ADD|SUB)_ZZZ_[BHSD]", + "^[SU]Q(ADD|SUB)_ZI_[BHSD]", + "^(SRH|SUQ|UQ|USQ|URH)ADD_ZPmZ_[BHSD]", + "^(UQSUB|UQSUBR)_ZPmZ_[BHSD]")>; + +// Arithmetic, large integer +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(AD|SB)CL[BT]_ZZZ_[SD]")>; + +// Arithmetic, pairwise add +def : InstRW<[V3AEWrite_2c_1V], (instregex "^ADDP_ZPmZ_[BHSD]")>; + +// Arithmetic, pairwise add and accum long +def : InstRW<[V3AEWr_ZPA, ReadDefault, V3AERd_ZPA], + (instregex "^[SU]ADALP_ZPmZ_[HSD]")>; + +// Arithmetic, shift +def : InstRW<[V3AEWrite_2c_1V1], + (instregex "^(ASR|LSL|LSR)_WIDE_ZPmZ_[BHS]", + "^(ASR|LSL|LSR)_WIDE_ZZZ_[BHS]", + "^(ASR|LSL|LSR)_ZPmI_[BHSD]", + "^(ASR|LSL|LSR)_ZPmZ_[BHSD]", + "^(ASR|LSL|LSR)_ZZI_[BHSD]", + "^(ASR|LSL|LSR)_ZPZ[IZ]_[BHSD]", + "^(ASRR|LSLR|LSRR)_ZPmZ_[BHSD]")>; + +// Arithmetic, shift and accumulate +def : InstRW<[V3AEWr_ZSA, V3AERd_ZSA], (instregex "^[SU]R?SRA_ZZI_[BHSD]")>; + +// Arithmetic, shift by immediate +def : InstRW<[V3AEWrite_2c_1V], (instregex "^SHRN[BT]_ZZI_[BHS]", + "^[SU]SHLL[BT]_ZZI_[HSD]")>; + +// Arithmetic, shift by immediate and insert +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(SLI|SRI)_ZZI_[BHSD]")>; + +// Arithmetic, shift complex +def : InstRW<[V3AEWrite_4c_1V], + (instregex "^(SQ)?RSHRU?N[BT]_ZZI_[BHS]", + "^(SQRSHL|SQRSHLR|SQSHL|SQSHLR|UQRSHL|UQRSHLR|UQSHL|UQSHLR)_ZPmZ_[BHSD]", + "^[SU]QR?SHL_ZPZZ_[BHSD]", + "^(SQSHL|SQSHLU|UQSHL)_(ZPmI|ZPZI)_[BHSD]", + "^SQSHRU?N[BT]_ZZI_[BHS]", + "^UQR?SHRN[BT]_ZZI_[BHS]")>; + +// Arithmetic, shift right for divide +def : InstRW<[V3AEWrite_4c_1V], (instregex "^ASRD_(ZPmI|ZPZI)_[BHSD]")>; + +// Arithmetic, shift rounding +def : InstRW<[V3AEWrite_4c_1V], (instregex "^[SU]RSHLR?_ZPmZ_[BHSD]", + "^[SU]RSHL_ZPZZ_[BHSD]", + "^[SU]RSHR_(ZPmI|ZPZI)_[BHSD]")>; + +// Bit manipulation +def : InstRW<[V3AEWrite_6c_2V1], (instregex "^(BDEP|BEXT|BGRP)_ZZZ_[BHSD]")>; + +// Bitwise select +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(BSL|BSL1N|BSL2N|NBSL)_ZZZZ")>; + +// Count/reverse bits +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(CLS|CLZ|CNT|RBIT)_ZPmZ_[BHSD]")>; + +// Broadcast logical bitmask immediate to vector +def : InstRW<[V3AEWrite_2c_1V], (instrs DUPM_ZI)>; + +// Compare and set flags +def : InstRW<[V3AEWrite_2or3c_1V0], + (instregex "^CMP(EQ|GE|GT|HI|HS|LE|LO|LS|LT|NE)_PPzZ[IZ]_[BHSD]", + "^CMP(EQ|GE|GT|HI|HS|LE|LO|LS|LT|NE)_WIDE_PPzZZ_[BHS]")>; + +// Complex add +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(SQ)?CADD_ZZI_[BHSD]")>; + +// Complex dot product 8-bit element +def : InstRW<[V3AEWr_ZDOTB, V3AERd_ZDOTB], (instrs CDOT_ZZZ_S, CDOT_ZZZI_S)>; + +// Complex dot product 16-bit element +def : InstRW<[V3AEWr_ZDOTH, V3AERd_ZDOTH], (instrs CDOT_ZZZ_D, CDOT_ZZZI_D)>; + +// Complex multiply-add B, H, S element size +def : InstRW<[V3AEWr_ZCMABHS, V3AERd_ZCMABHS], (instregex "^CMLA_ZZZ_[BHS]", + "^CMLA_ZZZI_[HS]")>; + +// Complex multiply-add D element size +def : InstRW<[V3AEWr_ZCMAD, V3AERd_ZCMAD], (instrs CMLA_ZZZ_D)>; + +// Conditional extract operations, scalar form +def : InstRW<[V3AEWrite_8c_1M0_1V], (instregex "^CLAST[AB]_RPZ_[BHSD]")>; + +// Conditional extract operations, SIMD&FP scalar and vector forms +def : InstRW<[V3AEWrite_3c_1V1], (instregex "^CLAST[AB]_[VZ]PZ_[BHSD]", + "^COMPACT_ZPZ_[SD]", + "^SPLICE_ZPZZ?_[BHSD]")>; + +// Convert to floating point, 64b to float or convert to double +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^[SU]CVTF_ZPmZ_Dto[HSD]", + "^[SU]CVTF_ZPmZ_StoD")>; + +// Convert to floating point, 32b to single or half +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^[SU]CVTF_ZPmZ_Sto[HS]")>; + +// Convert to floating point, 16b to half +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^[SU]CVTF_ZPmZ_HtoH")>; + +// Copy, scalar +def : InstRW<[V3AEWrite_5c_1M0_1V], (instregex "^CPY_ZPmR_[BHSD]")>; + +// Copy, scalar SIMD&FP or imm +def : InstRW<[V3AEWrite_2c_1V], (instregex "^CPY_ZPm[IV]_[BHSD]", + "^CPY_ZPzI_[BHSD]")>; + +// Divides, 32 bit +def : InstRW<[V3AEWrite_12c_1V0], (instregex "^[SU]DIVR?_ZPmZ_S", + "^[SU]DIV_ZPZZ_S")>; + +// Divides, 64 bit +def : InstRW<[V3AEWrite_20c_1V0], (instregex "^[SU]DIVR?_ZPmZ_D", + "^[SU]DIV_ZPZZ_D")>; + +// Dot product, 8 bit +def : InstRW<[V3AEWr_ZDOTB, V3AERd_ZDOTB], (instregex "^[SU]DOT_ZZZI?_BtoS")>; + +// Dot product, 8 bit, using signed and unsigned integers +def : InstRW<[V3AEWr_ZDOTB, V3AERd_ZDOTB], (instrs SUDOT_ZZZI, USDOT_ZZZI, USDOT_ZZZ)>; + +// Dot product, 16 bit +def : InstRW<[V3AEWr_ZDOTH, V3AERd_ZDOTH], (instregex "^[SU]DOT_ZZZI?_HtoD")>; + +// Duplicate, immediate and indexed form +def : InstRW<[V3AEWrite_2c_1V], (instregex "^DUP_ZI_[BHSD]", + "^DUP_ZZI_[BHSDQ]")>; + +// Duplicate, scalar form +def : InstRW<[V3AEWrite_3c_1M0], (instregex "^DUP_ZR_[BHSD]")>; + +// Extend, sign or zero +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU]XTB_ZPmZ_[HSD]", + "^[SU]XTH_ZPmZ_[SD]", + "^[SU]XTW_ZPmZ_[D]")>; + +// Extract +def : InstRW<[V3AEWrite_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>; + +// Extract narrow saturating +def : InstRW<[V3AEWrite_4c_1V], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]", + "^SQXTUN[BT]_ZZ_[BHS]")>; + +// Extract operation, SIMD and FP scalar form +def : InstRW<[V3AEWrite_3c_1V1], (instregex "^LAST[AB]_VPZ_[BHSD]")>; + +// Extract operation, scalar +def : InstRW<[V3AEWrite_6c_1V1_1M0], (instregex "^LAST[AB]_RPZ_[BHSD]")>; + +// Histogram operations +def : InstRW<[V3AEWrite_2c_1V], (instregex "^HISTCNT_ZPzZZ_[SD]", + "^HISTSEG_ZZZ")>; + +// Horizontal operations, B, H, S form, immediate operands only +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^INDEX_II_[BHS]")>; + +// Horizontal operations, B, H, S form, scalar, immediate operands/ scalar +// operands only / immediate, scalar operands +def : InstRW<[V3AEWrite_7c_1M0_1V0], (instregex "^INDEX_(IR|RI|RR)_[BHS]")>; + +// Horizontal operations, D form, immediate operands only +def : InstRW<[V3AEWrite_5c_2V0], (instrs INDEX_II_D)>; + +// Horizontal operations, D form, scalar, immediate operands)/ scalar operands +// only / immediate, scalar operands +def : InstRW<[V3AEWrite_8c_2M0_2V0], (instregex "^INDEX_(IR|RI|RR)_D")>; + +// insert operation, SIMD and FP scalar form +def : InstRW<[V3AEWrite_2c_1V], (instregex "^INSR_ZV_[BHSD]")>; + +// insert operation, scalar +def : InstRW<[V3AEWrite_5c_1V1_1M0], (instregex "^INSR_ZR_[BHSD]")>; + +// Logical +def : InstRW<[V3AEWrite_2c_1V], + (instregex "^(AND|EOR|ORR)_ZI", + "^(AND|BIC|EOR|ORR)_ZZZ", + "^EOR(BT|TB)_ZZZ_[BHSD]", + "^(AND|BIC|EOR|NOT|ORR)_(ZPmZ|ZPZZ)_[BHSD]", + "^NOT_ZPmZ_[BHSD]")>; + +// Max/min, basic and pairwise +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU](MAX|MIN)_ZI_[BHSD]", + "^[SU](MAX|MIN)P?_ZPmZ_[BHSD]", + "^[SU](MAX|MIN)_ZPZZ_[BHSD]")>; + +// Matching operations +// FIXME: SOG p. 44, n. 5: If the consuming instruction has a flag source, the +// latency for this instruction is 4 cycles. +def : InstRW<[V3AEWrite_2or3c_1V0_1M], (instregex "^N?MATCH_PPzZZ_[BH]")>; + +// Matrix multiply-accumulate +def : InstRW<[V3AEWr_ZMMA, V3AERd_ZMMA], (instrs SMMLA_ZZZ, UMMLA_ZZZ, USMMLA_ZZZ)>; + +// Move prefix +def : InstRW<[V3AEWrite_2c_1V], (instregex "^MOVPRFX_ZP[mz]Z_[BHSD]", + "^MOVPRFX_ZZ")>; + +// Multiply, B, H, S element size +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_[BHS]", + "^MUL_ZPZZ_[BHS]", + "^[SU]MULH_(ZPmZ|ZZZ)_[BHS]", + "^[SU]MULH_ZPZZ_[BHS]")>; + +// Multiply, D element size +def : InstRW<[V3AEWrite_5c_2V0], (instregex "^MUL_(ZI|ZPmZ|ZZZI|ZZZ)_D", + "^MUL_ZPZZ_D", + "^[SU]MULH_(ZPmZ|ZZZ)_D", + "^[SU]MULH_ZPZZ_D")>; + +// Multiply long +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^[SU]MULL[BT]_ZZZI_[SD]", + "^[SU]MULL[BT]_ZZZ_[HSD]")>; + +// Multiply accumulate, B, H, S element size +def : InstRW<[V3AEWr_ZMABHS, V3AERd_ZMABHS], + (instregex "^ML[AS]_ZZZI_[HS]", "^ML[AS]_ZPZZZ_[BHS]")>; +def : InstRW<[V3AEWr_ZMABHS, ReadDefault, V3AERd_ZMABHS], + (instregex "^(ML[AS]|MAD|MSB)_ZPmZZ_[BHS]")>; + +// Multiply accumulate, D element size +def : InstRW<[V3AEWr_ZMAD, V3AERd_ZMAD], + (instregex "^ML[AS]_ZZZI_D", "^ML[AS]_ZPZZZ_D")>; +def : InstRW<[V3AEWr_ZMAD, ReadDefault, V3AERd_ZMAD], + (instregex "^(ML[AS]|MAD|MSB)_ZPmZZ_D")>; + +// Multiply accumulate long +def : InstRW<[V3AEWr_ZMAL, V3AERd_ZMAL], (instregex "^[SU]ML[AS]L[BT]_ZZZ_[HSD]", + "^[SU]ML[AS]L[BT]_ZZZI_[SD]")>; + +// Multiply accumulate saturating doubling long regular +def : InstRW<[V3AEWr_ZMASQL, V3AERd_ZMASQ], + (instregex "^SQDML[AS]L(B|T|BT)_ZZZ_[HSD]", + "^SQDML[AS]L[BT]_ZZZI_[SD]")>; + +// Multiply saturating doubling high, B, H, S element size +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SQDMULH_ZZZ_[BHS]", + "^SQDMULH_ZZZI_[HS]")>; + +// Multiply saturating doubling high, D element size +def : InstRW<[V3AEWrite_5c_2V0], (instrs SQDMULH_ZZZ_D, SQDMULH_ZZZI_D)>; + +// Multiply saturating doubling long +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SQDMULL[BT]_ZZZ_[HSD]", + "^SQDMULL[BT]_ZZZI_[SD]")>; + +// Multiply saturating rounding doubling regular/complex accumulate, B, H, S +// element size +def : InstRW<[V3AEWr_ZMASQBHS, V3AERd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZ_[BHS]", + "^SQRDCMLAH_ZZZ_[BHS]", + "^SQRDML[AS]H_ZZZI_[HS]", + "^SQRDCMLAH_ZZZI_[HS]")>; + +// Multiply saturating rounding doubling regular/complex accumulate, D element +// size +def : InstRW<[V3AEWr_ZMASQD, V3AERd_ZMASQ], (instregex "^SQRDML[AS]H_ZZZI?_D", + "^SQRDCMLAH_ZZZ_D")>; + +// Multiply saturating rounding doubling regular/complex, B, H, S element size +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SQRDMULH_ZZZ_[BHS]", + "^SQRDMULH_ZZZI_[HS]")>; + +// Multiply saturating rounding doubling regular/complex, D element size +def : InstRW<[V3AEWrite_5c_2V0], (instregex "^SQRDMULH_ZZZI?_D")>; + +// Multiply/multiply long, (8x8) polynomial +def : InstRW<[V3AEWrite_2c_1V], (instregex "^PMUL_ZZZ_B", + "^PMULL[BT]_ZZZ_[HDQ]")>; + +// Predicate counting vector +def : InstRW<[V3AEWrite_2c_1V], (instregex "^([SU]Q)?(DEC|INC)[HWD]_ZPiI")>; + +// Reciprocal estimate +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^URECPE_ZPmZ_S", "^URSQRTE_ZPmZ_S")>; + +// Reduction, arithmetic, B form +def : InstRW<[V3AEWrite_9c_2V_4V1], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_B")>; + +// Reduction, arithmetic, H form +def : InstRW<[V3AEWrite_8c_2V_2V1], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_H")>; + +// Reduction, arithmetic, S form +def : InstRW<[V3AEWrite_6c_2V_2V1], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_S")>; + +// Reduction, arithmetic, D form +def : InstRW<[V3AEWrite_4c_2V], (instregex "^[SU](ADD|MAX|MIN)V_VPZ_D")>; + +// Reduction, logical +def : InstRW<[V3AEWrite_6c_1V_1V1], (instregex "^(AND|EOR|OR)V_VPZ_[BHSD]")>; + +// Reverse, vector +def : InstRW<[V3AEWrite_2c_1V], (instregex "^REV_ZZ_[BHSD]", + "^REVB_ZPmZ_[HSD]", + "^REVH_ZPmZ_[SD]", + "^REVW_ZPmZ_D")>; + +// Select, vector form +def : InstRW<[V3AEWrite_2c_1V], (instregex "^SEL_ZPZZ_[BHSD]")>; + +// Table lookup +def : InstRW<[V3AEWrite_2c_1V], (instregex "^TBL_ZZZZ?_[BHSD]")>; + +// Table lookup extension +def : InstRW<[V3AEWrite_2c_1V], (instregex "^TBX_ZZZ_[BHSD]")>; + +// Transpose, vector form +def : InstRW<[V3AEWrite_2c_1V], (instregex "^TRN[12]_ZZZ_[BHSDQ]")>; + +// Unpack and extend +def : InstRW<[V3AEWrite_2c_1V], (instregex "^[SU]UNPK(HI|LO)_ZZ_[HSD]")>; + +// Zip/unzip +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(UZP|ZIP)[12]_ZZZ_[BHSDQ]")>; + +// §3.26 SVE floating-point instructions +// ----------------------------------------------------------------------------- + +// Floating point absolute value/difference +def : InstRW<[V3AEWrite_2c_1V], (instregex "^FAB[SD]_ZPmZ_[HSD]", + "^FABD_ZPZZ_[HSD]", + "^FABS_ZPmZ_[HSD]")>; + +// Floating point arithmetic +def : InstRW<[V3AEWrite_2c_1V], (instregex "^F(ADD|SUB)_(ZPm[IZ]|ZZZ)_[HSD]", + "^F(ADD|SUB)_ZPZ[IZ]_[HSD]", + "^FADDP_ZPmZZ_[HSD]", + "^FNEG_ZPmZ_[HSD]", + "^FSUBR_ZPm[IZ]_[HSD]", + "^FSUBR_(ZPZI|ZPZZ)_[HSD]")>; + +// Floating point associative add, F16 +def : InstRW<[V3AEWrite_10c_1V1_9rc], (instrs FADDA_VPZ_H)>; + +// Floating point associative add, F32 +def : InstRW<[V3AEWrite_6c_1V1_5rc], (instrs FADDA_VPZ_S)>; + +// Floating point associative add, F64 +def : InstRW<[V3AEWrite_4c_1V], (instrs FADDA_VPZ_D)>; + +// Floating point compare +def : InstRW<[V3AEWrite_2c_1V0], (instregex "^FACG[ET]_PPzZZ_[HSD]", + "^FCM(EQ|GE|GT|NE)_PPzZ[0Z]_[HSD]", + "^FCM(LE|LT)_PPzZ0_[HSD]", + "^FCMUO_PPzZZ_[HSD]")>; + +// Floating point complex add +def : InstRW<[V3AEWrite_3c_1V], (instregex "^FCADD_ZPmZ_[HSD]")>; + +// Floating point complex multiply add +def : InstRW<[V3AEWr_ZFCMA, ReadDefault, V3AERd_ZFCMA], (instregex "^FCMLA_ZPmZZ_[HSD]")>; +def : InstRW<[V3AEWr_ZFCMA, V3AERd_ZFCMA], (instregex "^FCMLA_ZZZI_[HS]")>; + +// Floating point convert, long or narrow (F16 to F32 or F32 to F16) +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FCVT_ZPmZ_(HtoS|StoH)", + "^FCVTLT_ZPmZ_HtoS", + "^FCVTNT_ZPmZ_StoH")>; + +// Floating point convert, long or narrow (F16 to F64, F32 to F64, F64 to F32 +// or F64 to F16) +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FCVT_ZPmZ_(HtoD|StoD|DtoS|DtoH)", + "^FCVTLT_ZPmZ_StoD", + "^FCVTNT_ZPmZ_DtoS")>; + +// Floating point convert, round to odd +def : InstRW<[V3AEWrite_3c_1V0], (instrs FCVTX_ZPmZ_DtoS, FCVTXNT_ZPmZ_DtoS)>; + +// Floating point base2 log, F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FLOGB_(ZPmZ|ZPZZ)_H")>; + +// Floating point base2 log, F32 +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FLOGB_(ZPmZ|ZPZZ)_S")>; + +// Floating point base2 log, F64 +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FLOGB_(ZPmZ|ZPZZ)_D")>; + +// Floating point convert to integer, F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FCVTZ[SU]_ZPmZ_HtoH")>; + +// Floating point convert to integer, F32 +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FCVTZ[SU]_ZPmZ_(HtoS|StoS)")>; + +// Floating point convert to integer, F64 +def : InstRW<[V3AEWrite_3c_1V0], + (instregex "^FCVTZ[SU]_ZPmZ_(HtoD|StoD|DtoS|DtoD)")>; + +// Floating point copy +def : InstRW<[V3AEWrite_2c_1V], (instregex "^FCPY_ZPmI_[HSD]", + "^FDUP_ZI_[HSD]")>; + +// Floating point divide, F16 +def : InstRW<[V3AEWrite_13c_1V1_8rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_H")>; + +// Floating point divide, F32 +def : InstRW<[V3AEWrite_11c_1V1_4rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_S")>; + +// Floating point divide, F64 +def : InstRW<[V3AEWrite_14c_1V1_2rc], (instregex "^FDIVR?_(ZPmZ|ZPZZ)_D")>; + +// Floating point min/max pairwise +def : InstRW<[V3AEWrite_2c_1V], (instregex "^F(MAX|MIN)(NM)?P_ZPmZZ_[HSD]")>; + +// Floating point min/max +def : InstRW<[V3AEWrite_2c_1V], (instregex "^F(MAX|MIN)(NM)?_ZPm[IZ]_[HSD]", + "^F(MAX|MIN)(NM)?_ZPZ[IZ]_[HSD]")>; + +// Floating point multiply +def : InstRW<[V3AEWrite_3c_1V], (instregex "^(FSCALE|FMULX)_ZPmZ_[HSD]", + "^FMULX_ZPZZ_[HSD]", + "^FMUL_(ZPm[IZ]|ZZZI?)_[HSD]", + "^FMUL_ZPZ[IZ]_[HSD]")>; + +// Floating point multiply accumulate +def : InstRW<[V3AEWr_ZFMA, ReadDefault, V3AERd_ZFMA], + (instregex "^FN?ML[AS]_ZPmZZ_[HSD]", + "^FN?(MAD|MSB)_ZPmZZ_[HSD]")>; +def : InstRW<[V3AEWr_ZFMA, V3AERd_ZFMA], + (instregex "^FML[AS]_ZZZI_[HSD]", + "^FN?ML[AS]_ZPZZZ_[HSD]")>; + +// Floating point multiply add/sub accumulate long +def : InstRW<[V3AEWr_ZFMAL, V3AERd_ZFMAL], (instregex "^FML[AS]L[BT]_ZZZI?_SHH")>; + +// Floating point reciprocal estimate, F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FR(ECP|SQRT)E_ZZ_H", "^FRECPX_ZPmZ_H")>; + +// Floating point reciprocal estimate, F32 +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FR(ECP|SQRT)E_ZZ_S", "^FRECPX_ZPmZ_S")>; + +// Floating point reciprocal estimate, F64 +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FR(ECP|SQRT)E_ZZ_D", "^FRECPX_ZPmZ_D")>; + +// Floating point reciprocal step +def : InstRW<[V3AEWrite_4c_1V], (instregex "^F(RECPS|RSQRTS)_ZZZ_[HSD]")>; + +// Floating point reduction, F16 +def : InstRW<[V3AEWrite_8c_4V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_H")>; + +// Floating point reduction, F32 +def : InstRW<[V3AEWrite_6c_3V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_S")>; + +// Floating point reduction, F64 +def : InstRW<[V3AEWrite_4c_2V], + (instregex "^(FADDV|FMAXNMV|FMAXV|FMINNMV|FMINV)_VPZ_D")>; + +// Floating point round to integral, F16 +def : InstRW<[V3AEWrite_6c_4V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_H")>; + +// Floating point round to integral, F32 +def : InstRW<[V3AEWrite_4c_2V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_S")>; + +// Floating point round to integral, F64 +def : InstRW<[V3AEWrite_3c_1V0], (instregex "^FRINT[AIMNPXZ]_ZPmZ_D")>; + +// Floating point square root, F16 +def : InstRW<[V3AEWrite_13c_1V1_8rc], (instregex "^FSQRT_ZPmZ_H")>; + +// Floating point square root, F32 +def : InstRW<[V3AEWrite_11c_1V1_4rc], (instregex "^FSQRT_ZPmZ_S")>; + +// Floating point square root, F64 +def : InstRW<[V3AEWrite_14c_1V1_2rc], (instregex "^FSQRT_ZPmZ_D")>; + +// Floating point trigonometric exponentiation +def : InstRW<[V3AEWrite_3c_1V1], (instregex "^FEXPA_ZZ_[HSD]")>; + +// Floating point trigonometric multiply add +def : InstRW<[V3AEWrite_4c_1V], (instregex "^FTMAD_ZZI_[HSD]")>; + +// Floating point trigonometric, miscellaneous +def : InstRW<[V3AEWrite_3c_1V], (instregex "^FTS(MUL|SEL)_ZZZ_[HSD]")>; + +// §3.27 SVE BFloat16 (BF16) instructions +// ----------------------------------------------------------------------------- + +// Convert, F32 to BF16 +def : InstRW<[V3AEWrite_4c_1V], (instrs BFCVT_ZPmZ, BFCVTNT_ZPmZ)>; + +// Dot product +def : InstRW<[V3AEWr_ZBFDOT, V3AERd_ZBFDOT], (instrs BFDOT_ZZI, BFDOT_ZZZ)>; + +// Matrix multiply accumulate +def : InstRW<[V3AEWr_ZBFMMA, V3AERd_ZBFMMA], (instrs BFMMLA_ZZZ_HtoS)>; + +// Multiply accumulate long +def : InstRW<[V3AEWr_ZBFMAL, V3AERd_ZBFMAL], (instregex "^BFMLAL[BT]_ZZZI?")>; + +// §3.28 SVE Load instructions +// ----------------------------------------------------------------------------- + +// Load vector +def : InstRW<[V3AEWrite_6c_1L], (instrs LDR_ZXI)>; + +// Load predicate +def : InstRW<[V3AEWrite_6c_1L_1M], (instrs LDR_PXI)>; + +// Contiguous load, scalar + imm +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1[BHWD]_IMM$", + "^LD1S?B_[HSD]_IMM$", + "^LD1S?H_[SD]_IMM$", + "^LD1S?W_D_IMM$" )>; +// Contiguous load, scalar + scalar +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1[BHWD]$", + "^LD1S?B_[HSD]$", + "^LD1S?H_[SD]$", + "^LD1S?W_D$" )>; + +// Contiguous load broadcast, scalar + imm +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1R[BHWD]_IMM$", + "^LD1RS?B_[HSD]_IMM$", + "^LD1RS?H_[SD]_IMM$", + "^LD1RW_D_IMM$", + "^LD1RSW_IMM$", + "^LD1RQ_[BHWD]_IMM$")>; + +// Contiguous load broadcast, scalar + scalar +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LD1RQ_[BHWD]$")>; + +// Non temporal load, scalar + imm +// Non temporal load, scalar + scalar +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LDNT1[BHWD]_ZR[IR]$")>; + +// Non temporal gather load, vector + scalar 32-bit element size +def : InstRW<[V3AEWrite_9c_2L_4V], (instregex "^LDNT1[BHW]_ZZR_S$", + "^LDNT1S[BH]_ZZR_S$")>; + +// Non temporal gather load, vector + scalar 64-bit element size +def : InstRW<[V3AEWrite_9c_2L_2V], (instregex "^LDNT1S?[BHW]_ZZR_D$")>; +def : InstRW<[V3AEWrite_9c_2L_2V], (instrs LDNT1D_ZZR_D)>; + +// Contiguous first faulting load, scalar + scalar +def : InstRW<[V3AEWrite_6c_1L_1I], (instregex "^LDFF1[BHWD]$", + "^LDFF1S?B_[HSD]$", + "^LDFF1S?H_[SD]$", + "^LDFF1S?W_D$")>; + +// Contiguous non faulting load, scalar + imm +def : InstRW<[V3AEWrite_6c_1L], (instregex "^LDNF1[BHWD]_IMM$", + "^LDNF1S?B_[HSD]_IMM$", + "^LDNF1S?H_[SD]_IMM$", + "^LDNF1S?W_D_IMM$")>; + +// Contiguous Load two structures to two vectors, scalar + imm +def : InstRW<[V3AEWrite_8c_2L_2V], (instregex "^LD2[BHWD]_IMM$")>; + +// Contiguous Load two structures to two vectors, scalar + scalar +def : InstRW<[V3AEWrite_9c_2L_2V_2I], (instregex "^LD2[BHWD]$")>; + +// Contiguous Load three structures to three vectors, scalar + imm +def : InstRW<[V3AEWrite_9c_3L_3V], (instregex "^LD3[BHWD]_IMM$")>; + +// Contiguous Load three structures to three vectors, scalar + scalar +def : InstRW<[V3AEWrite_10c_3V_3L_3I], (instregex "^LD3[BHWD]$")>; + +// Contiguous Load four structures to four vectors, scalar + imm +def : InstRW<[V3AEWrite_9c_4L_8V], (instregex "^LD4[BHWD]_IMM$")>; + +// Contiguous Load four structures to four vectors, scalar + scalar +def : InstRW<[V3AEWrite_10c_4L_8V_4I], (instregex "^LD4[BHWD]$")>; + +// Gather load, vector + imm, 32-bit element size +def : InstRW<[V3AEWrite_9c_1L_4V], (instregex "^GLD(FF)?1S?[BH]_S_IMM$", + "^GLD(FF)?1W_IMM$")>; + +// Gather load, vector + imm, 64-bit element size +def : InstRW<[V3AEWrite_9c_1L_4V], (instregex "^GLD(FF)?1S?[BHW]_D_IMM$", + "^GLD(FF)?1D_IMM$")>; + +// Gather load, 32-bit scaled offset +def : InstRW<[V3AEWrite_10c_1L_8V], + (instregex "^GLD(FF)?1S?H_S_[SU]XTW_SCALED$", + "^GLD(FF)?1W_[SU]XTW_SCALED")>; + +// Gather load, 64-bit scaled offset +// NOTE: These instructions are not specified in the SOG. +def : InstRW<[V3AEWrite_10c_1L_4V], + (instregex "^GLD(FF)?1S?[HW]_D_([SU]XTW_)?SCALED$", + "^GLD(FF)?1D_([SU]XTW_)?SCALED$")>; + +// Gather load, 32-bit unpacked unscaled offset +def : InstRW<[V3AEWrite_9c_1L_4V], (instregex "^GLD(FF)?1S?[BH]_S_[SU]XTW$", + "^GLD(FF)?1W_[SU]XTW$")>; + +// Gather load, 64-bit unpacked unscaled offset +// NOTE: These instructions are not specified in the SOG. +def : InstRW<[V3AEWrite_9c_1L_2V], + (instregex "^GLD(FF)?1S?[BHW]_D(_[SU]XTW)?$", + "^GLD(FF)?1D(_[SU]XTW)?$")>; + +// §3.29 SVE Store instructions +// ----------------------------------------------------------------------------- + +// Store from predicate reg +def : InstRW<[V3AEWrite_1c_1SA], (instrs STR_PXI)>; + +// Store from vector reg +def : InstRW<[V3AEWrite_2c_1SA_1V], (instrs STR_ZXI)>; + +// Contiguous store, scalar + imm +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^ST1[BHWD]_IMM$", + "^ST1B_[HSD]_IMM$", + "^ST1H_[SD]_IMM$", + "^ST1W_D_IMM$")>; + +// Contiguous store, scalar + scalar +def : InstRW<[V3AEWrite_2c_1SA_1I_1V], (instregex "^ST1H(_[SD])?$")>; +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^ST1[BWD]$", + "^ST1B_[HSD]$", + "^ST1W_D$")>; + +// Contiguous store two structures from two vectors, scalar + imm +def : InstRW<[V3AEWrite_4c_1SA_1V], (instregex "^ST2[BHWD]_IMM$")>; + +// Contiguous store two structures from two vectors, scalar + scalar +def : InstRW<[V3AEWrite_4c_2SA_2I_2V], (instrs ST2H)>; +def : InstRW<[V3AEWrite_4c_2SA_2V], (instregex "^ST2[BWD]$")>; + +// Contiguous store three structures from three vectors, scalar + imm +def : InstRW<[V3AEWrite_7c_9SA_9V], (instregex "^ST3[BHWD]_IMM$")>; + +// Contiguous store three structures from three vectors, scalar + scalar +def : InstRW<[V3AEWrite_7c_9SA_9I_9V], (instregex "^ST3[BHWD]$")>; + +// Contiguous store four structures from four vectors, scalar + imm +def : InstRW<[V3AEWrite_11c_18SA_18V], (instregex "^ST4[BHWD]_IMM$")>; + +// Contiguous store four structures from four vectors, scalar + scalar +def : InstRW<[V3AEWrite_11c_18SA_18I_18V], (instregex "^ST4[BHWD]$")>; + +// Non temporal store, scalar + imm +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^STNT1[BHWD]_ZRI$")>; + +// Non temporal store, scalar + scalar +def : InstRW<[V3AEWrite_2c_1SA_1I_1V], (instrs STNT1H_ZRR)>; +def : InstRW<[V3AEWrite_2c_1SA_1V], (instregex "^STNT1[BWD]_ZRR$")>; + +// Scatter non temporal store, vector + scalar 32-bit element size +def : InstRW<[V3AEWrite_4c_4SA_4V], (instregex "^STNT1[BHW]_ZZR_S")>; + +// Scatter non temporal store, vector + scalar 64-bit element size +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^STNT1[BHWD]_ZZR_D")>; + +// Scatter store vector + imm 32-bit element size +def : InstRW<[V3AEWrite_4c_4SA_4V], (instregex "^SST1[BH]_S_IMM$", + "^SST1W_IMM$")>; + +// Scatter store vector + imm 64-bit element size +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^SST1[BHW]_D_IMM$", + "^SST1D_IMM$")>; + +// Scatter store, 32-bit scaled offset +def : InstRW<[V3AEWrite_4c_4SA_4V], + (instregex "^SST1(H_S|W)_[SU]XTW_SCALED$")>; + +// Scatter store, 32-bit unpacked unscaled offset +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^SST1[BHW]_D_[SU]XTW$", + "^SST1D_[SU]XTW$")>; + +// Scatter store, 32-bit unpacked scaled offset +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^SST1[HW]_D_[SU]XTW_SCALED$", + "^SST1D_[SU]XTW_SCALED$")>; + +// Scatter store, 32-bit unscaled offset +def : InstRW<[V3AEWrite_4c_4SA_4V], (instregex "^SST1[BH]_S_[SU]XTW$", + "^SST1W_[SU]XTW$")>; + +// Scatter store, 64-bit scaled offset +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^SST1[HW]_D_SCALED$", + "^SST1D_SCALED$")>; + +// Scatter store, 64-bit unscaled offset +def : InstRW<[V3AEWrite_2c_2SA_2V], (instregex "^SST1[BHW]_D$", + "^SST1D$")>; + +// §3.30 SVE Miscellaneous instructions +// ----------------------------------------------------------------------------- + +// Read first fault register, unpredicated +def : InstRW<[V3AEWrite_2c_1M0], (instrs RDFFR_P)>; + +// Read first fault register, predicated +def : InstRW<[V3AEWrite_3or4c_1M0_1M], (instrs RDFFR_PPz)>; + +// Read first fault register and set flags +def : InstRW<[V3AEWrite_3or4c_1M0_1M], (instrs RDFFRS_PPz)>; + +// Set first fault register +// Write to first fault register +def : InstRW<[V3AEWrite_2c_1M0], (instrs SETFFR, WRFFR)>; + +// Prefetch +// NOTE: This is not specified in the SOG. +def : InstRW<[V3AEWrite_4c_1L], (instregex "^PRF[BHWD]")>; + +// §3.31 SVE Cryptographic instructions +// ----------------------------------------------------------------------------- + +// Crypto AES ops +def : InstRW<[V3AEWrite_2c_1V], (instregex "^AES[DE]_ZZZ_B$", + "^AESI?MC_ZZ_B$")>; + +// Crypto SHA3 ops +def : InstRW<[V3AEWrite_2c_1V], (instregex "^(BCAX|EOR3)_ZZZZ$", + "^RAX1_ZZZ_D$", + "^XAR_ZZZI_[BHSD]$")>; + +// Crypto SM4 ops +def : InstRW<[V3AEWrite_4c_1V0], (instregex "^SM4E(KEY)?_ZZZ_S$")>; + +} diff --git a/llvm/lib/Target/AArch64/AArch64SchedPredNeoverse.td b/llvm/lib/Target/AArch64/AArch64SchedPredNeoverse.td index 33b76a4..f841e60 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedPredNeoverse.td +++ b/llvm/lib/Target/AArch64/AArch64SchedPredNeoverse.td @@ -80,5 +80,23 @@ def NeoverseZeroMove : MCSchedPredicate< // MOVI Dd, #0 // MOVI Vd.2D, #0 CheckAll<[CheckOpcode<[MOVID, MOVIv2d_ns]>, - CheckImmOperand<1, 0>]> + CheckImmOperand<1, 0>]>, + // MOV Zd, Zn + CheckAll<[CheckOpcode<[ORR_ZZZ]>, + CheckSameRegOperand<1, 2>]>, + // MOV Vd, Vn + CheckAll<[CheckOpcode<[ORRv16i8, ORRv8i8]>, + CheckSameRegOperand<1, 2>]>, ]>>; + +def NeoverseAllActivePredicate : MCSchedPredicate< + CheckAny<[ + // PTRUE Pd, ALL + // PTRUES Pd, ALL + CheckAll<[ + CheckOpcode<[ + PTRUE_B, PTRUE_H, PTRUE_S, PTRUE_D, + PTRUES_B, PTRUES_H, PTRUES_S, PTRUES_D]>, + CheckIsImmOperand<1>, + CheckImmOperand<1, 31>]>, + ]>>; diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp index d3b1aa6..48e03ad 100644 --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -32,35 +32,21 @@ AArch64SelectionDAGInfo::AArch64SelectionDAGInfo() void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, const SDNode *N) const { + switch (N->getOpcode()) { + case AArch64ISD::WrapperLarge: + // operand #0 must have type i32, but has type i64 + return; + } + SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N); #ifndef NDEBUG // Some additional checks not yet implemented by verifyTargetNode. - constexpr MVT FlagsVT = MVT::i32; switch (N->getOpcode()) { - case AArch64ISD::SUBS: - assert(N->getValueType(1) == FlagsVT); - break; - case AArch64ISD::ADC: - case AArch64ISD::SBC: - assert(N->getOperand(2).getValueType() == FlagsVT); - break; - case AArch64ISD::ADCS: - case AArch64ISD::SBCS: - assert(N->getValueType(1) == FlagsVT); - assert(N->getOperand(2).getValueType() == FlagsVT); - break; - case AArch64ISD::CSEL: - case AArch64ISD::CSINC: - case AArch64ISD::BRCOND: - assert(N->getOperand(3).getValueType() == FlagsVT); - break; case AArch64ISD::SADDWT: case AArch64ISD::SADDWB: case AArch64ISD::UADDWT: case AArch64ISD::UADDWB: { - assert(N->getNumValues() == 1 && "Expected one result!"); - assert(N->getNumOperands() == 2 && "Expected two operands!"); EVT VT = N->getValueType(0); EVT Op0VT = N->getOperand(0).getValueType(); EVT Op1VT = N->getOperand(1).getValueType(); @@ -80,8 +66,6 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, case AArch64ISD::SUNPKHI: case AArch64ISD::UUNPKLO: case AArch64ISD::UUNPKHI: { - assert(N->getNumValues() == 1 && "Expected one result!"); - assert(N->getNumOperands() == 1 && "Expected one operand!"); EVT VT = N->getValueType(0); EVT OpVT = N->getOperand(0).getValueType(); assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() && @@ -98,8 +82,6 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, case AArch64ISD::UZP2: case AArch64ISD::ZIP1: case AArch64ISD::ZIP2: { - assert(N->getNumValues() == 1 && "Expected one result!"); - assert(N->getNumOperands() == 2 && "Expected two operands!"); EVT VT = N->getValueType(0); EVT Op0VT = N->getOperand(0).getValueType(); EVT Op1VT = N->getOperand(1).getValueType(); @@ -109,11 +91,8 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, break; } case AArch64ISD::RSHRNB_I: { - assert(N->getNumValues() == 1 && "Expected one result!"); - assert(N->getNumOperands() == 2 && "Expected two operands!"); EVT VT = N->getValueType(0); EVT Op0VT = N->getOperand(0).getValueType(); - EVT Op1VT = N->getOperand(1).getValueType(); assert(VT.isVector() && VT.isInteger() && "Expected integer vector result type!"); assert(Op0VT.isVector() && Op0VT.isInteger() && @@ -122,8 +101,8 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, "Expected vectors of equal size!"); assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 && "Expected input vector with half the lanes of its result!"); - assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) && - "Expected second operand to be a constant i32!"); + assert(isa<ConstantSDNode>(N->getOperand(1)) && + "Expected second operand to be a constant!"); break; } } diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp index a67bd42..d87bb52 100644 --- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp +++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp @@ -46,7 +46,6 @@ #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include <cassert> #include <memory> -#include <utility> using namespace llvm; diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp index 53b00e8..dae4f6a 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -222,17 +222,8 @@ void AArch64Subtarget::initializeProperties(bool HasMinSize) { PrefetchDistance = 280; MinPrefetchStride = 2048; MaxPrefetchIterationsAhead = 3; - switch (ARMProcFamily) { - case AppleA14: - case AppleA15: - case AppleA16: - case AppleA17: - case AppleM4: + if (isAppleMLike()) MaxInterleaveFactor = 4; - break; - default: - break; - } break; case ExynosM3: MaxInterleaveFactor = 4; diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h index 8974965..8553f16 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -169,6 +169,21 @@ public: return ARMProcFamily; } + /// Returns true if the processor is an Apple M-series or aligned A-series + /// (A14 or newer). + bool isAppleMLike() const { + switch (ARMProcFamily) { + case AppleA14: + case AppleA15: + case AppleA16: + case AppleA17: + case AppleM4: + return true; + default: + return false; + } + } + bool isXRaySupported() const override { return true; } /// Returns true if the function has a streaming body. diff --git a/llvm/lib/Target/AArch64/AArch64SystemOperands.td b/llvm/lib/Target/AArch64/AArch64SystemOperands.td index ae46d71..cb09875 100644 --- a/llvm/lib/Target/AArch64/AArch64SystemOperands.td +++ b/llvm/lib/Target/AArch64/AArch64SystemOperands.td @@ -814,6 +814,7 @@ def lookupBTIByName : SearchIndex { let Key = ["Name"]; } +def : BTI<"r", 0b000>; def : BTI<"c", 0b010>; def : BTI<"j", 0b100>; def : BTI<"jc", 0b110>; @@ -833,6 +834,23 @@ class CMHPriorityHint<string name, bits<1> encoding> : SearchableTable { def : CMHPriorityHint<"ph", 0b1>; + +//===----------------------------------------------------------------------===// +// TIndex instruction options. +//===----------------------------------------------------------------------===// + +class TIndex<string name, bits<1> encoding> : SearchableTable { + let SearchableFields = ["Name", "Encoding"]; + let EnumValueField = "Encoding"; + + string Name = name; + bits<1> Encoding; + let Encoding = encoding; +} + +def : TIndex<"nb", 0b1>; + + //===----------------------------------------------------------------------===// // TLBI (translation lookaside buffer invalidate) instruction options. //===----------------------------------------------------------------------===// @@ -2584,14 +2602,14 @@ foreach n=0-15 in { //===----------------------------------------------------------------------===// // GIC -class GIC<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2> { +class GIC<string name, bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, bit needsreg = 1> { string Name = name; bits<14> Encoding; let Encoding{13-11} = op1; let Encoding{10-7} = crn; let Encoding{6-3} = crm; let Encoding{2-0} = op2; - bit NeedsReg = 1; + bit NeedsReg = needsreg; string RequiresStr = [{ {AArch64::FeatureGCIE} }]; } @@ -2668,12 +2686,12 @@ def : GSB<"ack", 0b000, 0b1100, 0b0000, 0b001>; def : GICR<"cdia", 0b000, 0b1100, 0b0011, 0b000>; def : GICR<"cdnmia", 0b000, 0b1100, 0b0011, 0b001>; -// Op1 CRn CRm Op2 +// Op1 CRn CRm Op2, needsreg def : GIC<"cdaff", 0b000, 0b1100, 0b0001, 0b011>; def : GIC<"cddi", 0b000, 0b1100, 0b0010, 0b000>; def : GIC<"cddis", 0b000, 0b1100, 0b0001, 0b000>; def : GIC<"cden", 0b000, 0b1100, 0b0001, 0b001>; -def : GIC<"cdeoi", 0b000, 0b1100, 0b0001, 0b111>; +def : GIC<"cdeoi", 0b000, 0b1100, 0b0001, 0b111, 0>; def : GIC<"cdhm", 0b000, 0b1100, 0b0010, 0b001>; def : GIC<"cdpend", 0b000, 0b1100, 0b0001, 0b100>; def : GIC<"cdpri", 0b000, 0b1100, 0b0001, 0b010>; @@ -2694,3 +2712,161 @@ def : GIC<"ldhm", 0b110, 0b1100, 0b0010, 0b001>; def : GIC<"ldpend", 0b110, 0b1100, 0b0001, 0b100>; def : GIC<"ldpri", 0b110, 0b1100, 0b0001, 0b010>; def : GIC<"ldrcfg", 0b110, 0b1100, 0b0001, 0b101>; + + +// Stage 1 Permission Overlays Extension 2 (FEAT_S1POE2). +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"DPOTBR0_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b110>; +def : RWSysReg<"DPOTBR0_EL12", 0b11, 0b101, 0b0010, 0b0000, 0b110>; +def : RWSysReg<"DPOTBR1_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b111>; +def : RWSysReg<"DPOTBR1_EL12", 0b11, 0b101, 0b0010, 0b0000, 0b111>; +def : RWSysReg<"DPOTBR0_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b110>; +def : RWSysReg<"DPOTBR1_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b111>; +def : RWSysReg<"DPOTBR0_EL3", 0b11, 0b110, 0b0010, 0b0000, 0b110>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"IRTBRU_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b100>; +def : RWSysReg<"IRTBRU_EL12", 0b11, 0b101, 0b0010, 0b0000, 0b100>; +def : RWSysReg<"IRTBRP_EL1", 0b11, 0b000, 0b0010, 0b0000, 0b101>; +def : RWSysReg<"IRTBRP_EL12", 0b11, 0b101, 0b0010, 0b0000, 0b101>; +def : RWSysReg<"IRTBRU_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b100>; +def : RWSysReg<"IRTBRP_EL2", 0b11, 0b100, 0b0010, 0b0000, 0b101>; +def : RWSysReg<"IRTBRP_EL3", 0b11, 0b110, 0b0010, 0b0000, 0b101>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"TTTBRU_EL1", 0b11, 0b000, 0b1010, 0b0010, 0b110>; +def : RWSysReg<"TTTBRU_EL12", 0b11, 0b101, 0b1010, 0b0010, 0b110>; +def : RWSysReg<"TTTBRP_EL1", 0b11, 0b000, 0b1010, 0b0010, 0b111>; +def : RWSysReg<"TTTBRP_EL12", 0b11, 0b101, 0b1010, 0b0010, 0b111>; +def : RWSysReg<"TTTBRU_EL2", 0b11, 0b100, 0b1010, 0b0010, 0b110>; +def : RWSysReg<"TTTBRP_EL2", 0b11, 0b100, 0b1010, 0b0010, 0b111>; +def : RWSysReg<"TTTBRP_EL3", 0b11, 0b110, 0b1010, 0b0010, 0b111>; + +foreach n = 0-15 in { + defvar nb = !cast<bits<4>>(n); + // Op0 Op1 CRn CRm Op2 + def : RWSysReg<"FGDTP"#n#"_EL1", 0b11, 0b000, 0b0011, {0b001,nb{3}}, nb{2-0}>; + def : RWSysReg<"FGDTP"#n#"_EL2", 0b11, 0b100, 0b0011, {0b001,nb{3}}, nb{2-0}>; + def : RWSysReg<"FGDTP"#n#"_EL12", 0b11, 0b101, 0b0011, {0b001,nb{3}}, nb{2-0}>; + def : RWSysReg<"FGDTP"#n#"_EL3", 0b11, 0b110, 0b0011, {0b001,nb{3}}, nb{2-0}>; + + def : RWSysReg<"FGDTU"#n#"_EL1", 0b11, 0b000, 0b0011, {0b010,nb{3}}, nb{2-0}>; + def : RWSysReg<"FGDTU"#n#"_EL2", 0b11, 0b100, 0b0011, {0b010,nb{3}}, nb{2-0}>; + def : RWSysReg<"FGDTU"#n#"_EL12", 0b11, 0b101, 0b0011, {0b010,nb{3}}, nb{2-0}>; +} + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"LDSTT_EL1", 0b11, 0b000, 0b0010, 0b0001, 0b111>; +def : RWSysReg<"LDSTT_EL12", 0b11, 0b101, 0b0010, 0b0001, 0b111>; +def : RWSysReg<"LDSTT_EL2", 0b11, 0b100, 0b0010, 0b0001, 0b111>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"TINDEX_EL0", 0b11, 0b011, 0b0100, 0b0000, 0b011>; +def : RWSysReg<"TINDEX_EL1", 0b11, 0b000, 0b0100, 0b0000, 0b011>; +def : RWSysReg<"TINDEX_EL2", 0b11, 0b100, 0b0100, 0b0000, 0b011>; +def : RWSysReg<"TINDEX_EL12", 0b11, 0b101, 0b0100, 0b0000, 0b011>; +def : RWSysReg<"TINDEX_EL3", 0b11, 0b110, 0b0100, 0b0000, 0b011>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"STINDEX_EL1", 0b11, 0b000, 0b0100, 0b0000, 0b010>; +def : RWSysReg<"STINDEX_EL2", 0b11, 0b100, 0b0100, 0b0000, 0b010>; +def : RWSysReg<"STINDEX_EL12", 0b11, 0b101, 0b0100, 0b0000, 0b010>; +def : RWSysReg<"STINDEX_EL3", 0b11, 0b110, 0b0100, 0b0000, 0b010>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"TPIDR3_EL0", 0b11, 0b011, 0b1101, 0b0000, 0b000>; +def : RWSysReg<"TPIDR3_EL1", 0b11, 0b000, 0b1101, 0b0000, 0b000>; +def : RWSysReg<"TPIDR3_EL12", 0b11, 0b101, 0b1101, 0b0000, 0b000>; +def : RWSysReg<"TPIDR3_EL2", 0b11, 0b100, 0b1101, 0b0000, 0b000>; +def : RWSysReg<"TPIDR3_EL3", 0b11, 0b110, 0b1101, 0b0000, 0b000>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"VNCCR_EL2", 0b11, 0b100, 0b0010, 0b0010, 0b001>; + +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"DPOCR_EL0", 0b11, 0b011, 0b0100, 0b0101, 0b010>; + +foreach n = 0-15 in { + defvar nb = !cast<bits<4>>(n); + // Op0 Op1 CRn CRm Op2 + def : RWSysReg<"AFGDTP"#n#"_EL1", 0b11, 0b000, 0b0011, {0b011,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTU"#n#"_EL1", 0b11, 0b000, 0b0011, {0b100,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTP"#n#"_EL2", 0b11, 0b100, 0b0011, {0b011,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTU"#n#"_EL2", 0b11, 0b100, 0b0011, {0b100,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTP"#n#"_EL12", 0b11, 0b101, 0b0011, {0b011,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTU"#n#"_EL12", 0b11, 0b101, 0b0011, {0b100,nb{3}}, nb{2-0}>; + def : RWSysReg<"AFGDTP"#n#"_EL3", 0b11, 0b110, 0b0011, {0b011,nb{3}}, nb{2-0}>; +} + +// Extra S1POE2 Hypervisor Configuration Registers +// Op0 Op1 CRn CRm Op2 +def : RWSysReg<"HCRMASK_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b110>; +def : RWSysReg<"HCRXMASK_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b111>; +def : RWSysReg<"NVHCR_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b000>; +def : RWSysReg<"NVHCRX_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b001>; +def : RWSysReg<"NVHCRMASK_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b100>; +def : RWSysReg<"NVHCRXMASK_EL2", 0b11, 0b100, 0b0001, 0b0101, 0b101>; + +// S1POE2 Thread private state extension (FEAT_TPS/TPSP). +foreach n = 0-1 in { + defvar nb = !cast<bits<1>>(n); + // Op0 Op1 CRn CRm Op2 + def : RWSysReg<"TPMIN"#n#"_EL0", 0b11, 0b011, 0b0010, 0b0010, {0b1,nb,0}>; + def : RWSysReg<"TPMAX"#n#"_EL0", 0b11, 0b011, 0b0010, 0b0010, {0b1,nb,1}>; + def : RWSysReg<"TPMIN"#n#"_EL1", 0b11, 0b000, 0b0010, 0b0010, {0b1,nb,0}>; + def : RWSysReg<"TPMAX"#n#"_EL1", 0b11, 0b000, 0b0010, 0b0010, {0b1,nb,1}>; + def : RWSysReg<"TPMIN"#n#"_EL2", 0b11, 0b100, 0b0010, 0b0010, {0b1,nb,0}>; + def : RWSysReg<"TPMAX"#n#"_EL2", 0b11, 0b100, 0b0010, 0b0010, {0b1,nb,1}>; + def : RWSysReg<"TPMIN"#n#"_EL12", 0b11, 0b101, 0b0010, 0b0010, {0b1,nb,0}>; + def : RWSysReg<"TPMAX"#n#"_EL12", 0b11, 0b101, 0b0010, 0b0010, {0b1,nb,1}>; +} + +class PLBIEntry<bits<3> op1, bits<4> crn, bits<4> crm, bits<3> op2, string name, + bit needsreg, bit optionalreg> { + string Name = name; + bits<14> Encoding; + let Encoding{13-11} = op1; + let Encoding{10-7} = crn; + let Encoding{6-3} = crm; + let Encoding{2-0} = op2; + bit NeedsReg = needsreg; + bit OptionalReg = optionalreg; + string RequiresStr = [{ {AArch64::FeatureS1POE2} }]; +} + +def PLBITable : GenericTable { + let FilterClass = "PLBIEntry"; + let CppTypeName = "PLBI"; + let Fields = ["Name", "Encoding", "NeedsReg", "OptionalReg", "RequiresStr"]; + + let PrimaryKey = ["Encoding"]; + let PrimaryKeyName = "lookupPLBIByEncoding"; +} + +def lookupPLBIByName : SearchIndex { + let Table = PLBITable; + let Key = ["Name"]; +} + +multiclass PLBI<string name, bits<3> op1, bits<4> crn, bits<3> op2, + bit needsreg, bit optreg> { + // Entries containing "IS" or "OS" allow optional regs when +tlbid enabled + def : PLBIEntry<op1, crn, 0b0111, op2, name, needsreg, 0>; + def : PLBIEntry<op1, crn, 0b0011, op2, name#"IS", needsreg, optreg>; + def : PLBIEntry<op1, crn, 0b0001, op2, name#"OS", needsreg, optreg>; + def : PLBIEntry<op1, crn, 0b1111, op2, name#"NXS", needsreg, 0>; + def : PLBIEntry<op1, crn, 0b1011, op2, name#"ISNXS", needsreg, optreg>; + def : PLBIEntry<op1, crn, 0b1001, op2, name#"OSNXS", needsreg, optreg>; +} + +// CRm defines above six variants of each instruction. It is omitted here. +// Op1 CRn Op2 nr optreg +defm : PLBI<"ALLE3", 0b110, 0b1010, 0b000, 0, 0>; +defm : PLBI<"ALLE2", 0b100, 0b1010, 0b000, 0, 1>; +defm : PLBI<"ALLE1", 0b100, 0b1010, 0b100, 0, 1>; +defm : PLBI<"VMALLE1", 0b000, 0b1010, 0b000, 0, 1>; +defm : PLBI<"ASIDE1", 0b000, 0b1010, 0b010, 1, 0>; +defm : PLBI<"PERME3", 0b110, 0b1010, 0b001, 1, 0>; +defm : PLBI<"PERME2", 0b100, 0b1010, 0b001, 1, 0>; +defm : PLBI<"PERME1", 0b000, 0b1010, 0b001, 1, 0>; +defm : PLBI<"PERMAE1", 0b000, 0b1010, 0b011, 1, 0>; diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index 5b80b08..346e18e 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -53,8 +53,6 @@ #include "llvm/Transforms/Utils/LowerIFunc.h" #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" #include <memory> -#include <optional> -#include <string> using namespace llvm; @@ -262,6 +260,7 @@ LLVMInitializeAArch64Target() { initializeAArch64PostSelectOptimizePass(PR); initializeAArch64PromoteConstantPass(PR); initializeAArch64RedundantCopyEliminationPass(PR); + initializeAArch64RedundantCondBranchPass(PR); initializeAArch64StorePairSuppressPass(PR); initializeFalkorHWPFFixPass(PR); initializeFalkorMarkStridedAccessesLegacyPass(PR); @@ -764,8 +763,8 @@ bool AArch64PassConfig::addGlobalInstructionSelect() { } void AArch64PassConfig::addMachineSSAOptimization() { - if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None) - addPass(createMachineSMEABIPass()); + if (TM->getOptLevel() != CodeGenOptLevel::None && EnableNewSMEABILowering) + addPass(createMachineSMEABIPass(TM->getOptLevel())); if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt) addPass(createSMEPeepholeOptPass()); @@ -798,7 +797,7 @@ bool AArch64PassConfig::addILPOpts() { void AArch64PassConfig::addPreRegAlloc() { if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering) - addPass(createMachineSMEABIPass()); + addPass(createMachineSMEABIPass(CodeGenOptLevel::None)); // Change dead register definitions to refer to the zero register. if (TM->getOptLevel() != CodeGenOptLevel::None && @@ -864,6 +863,8 @@ void AArch64PassConfig::addPreEmitPass() { if (TM->getOptLevel() >= CodeGenOptLevel::Aggressive && EnableAArch64CopyPropagation) addPass(createMachineCopyPropagationPass(true)); + if (TM->getOptLevel() != CodeGenOptLevel::None) + addPass(createAArch64RedundantCondBranchPass()); addPass(createAArch64A53Fix835769()); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index fede586..043be55 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -9,8 +9,8 @@ #include "AArch64TargetTransformInfo.h" #include "AArch64ExpandImm.h" #include "AArch64PerfectShuffle.h" +#include "AArch64SMEAttributes.h" #include "MCTargetDesc/AArch64AddressingModes.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -77,6 +77,10 @@ static cl::opt<unsigned> DMBLookaheadThreshold( "dmb-lookahead-threshold", cl::init(10), cl::Hidden, cl::desc("The number of instructions to search for a redundant dmb")); +static cl::opt<int> Aarch64ForceUnrollThreshold( + "aarch64-force-unroll-threshold", cl::init(0), cl::Hidden, + cl::desc("Threshold for forced unrolling of small loops in AArch64")); + namespace { class TailFoldingOption { // These bitfields will only ever be set to something non-zero in operator=, @@ -248,12 +252,23 @@ static bool hasPossibleIncompatibleOps(const Function *F, return false; } -APInt AArch64TTIImpl::getFeatureMask(const Function &F) const { +static void extractAttrFeatures(const Function &F, const AArch64TTIImpl *TTI, + SmallVectorImpl<StringRef> &Features) { StringRef AttributeStr = - isMultiversionedFunction(F) ? "fmv-features" : "target-features"; + TTI->isMultiversionedFunction(F) ? "fmv-features" : "target-features"; StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString(); - SmallVector<StringRef, 8> Features; FeatureStr.split(Features, ","); +} + +APInt AArch64TTIImpl::getFeatureMask(const Function &F) const { + SmallVector<StringRef, 8> Features; + extractAttrFeatures(F, this, Features); + return AArch64::getCpuSupportsMask(Features); +} + +APInt AArch64TTIImpl::getPriorityMask(const Function &F) const { + SmallVector<StringRef, 8> Features; + extractAttrFeatures(F, this, Features); return AArch64::getFMVPriority(Features); } @@ -308,9 +323,9 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, return (EffectiveCallerBits & EffectiveCalleeBits) == EffectiveCalleeBits; } -bool AArch64TTIImpl::areTypesABICompatible( - const Function *Caller, const Function *Callee, - const ArrayRef<Type *> &Types) const { +bool AArch64TTIImpl::areTypesABICompatible(const Function *Caller, + const Function *Callee, + ArrayRef<Type *> Types) const { if (!BaseT::areTypesABICompatible(Caller, Callee, Types)) return false; @@ -371,8 +386,13 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, bool AArch64TTIImpl::shouldMaximizeVectorBandwidth( TargetTransformInfo::RegisterKind K) const { assert(K != TargetTransformInfo::RGK_Scalar); - return (K == TargetTransformInfo::RGK_FixedWidthVector && - ST->isNeonAvailable()); + + if (K == TargetTransformInfo::RGK_FixedWidthVector && ST->isNeonAvailable()) + return true; + + return K == TargetTransformInfo::RGK_ScalableVector && + ST->isSVEorStreamingSVEAvailable() && + !ST->disableMaximizeScalableBandwidth(); } /// Calculate the cost of materializing a 64-bit value. This helper @@ -917,8 +937,20 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, if (ICA.getArgs().empty()) break; - // TODO: Add handling for fshl where third argument is not a constant. const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]); + + // ROTR / ROTL is a funnel shift with equal first and second operand. For + // ROTR on integer registers (i32/i64) this can be done in a single ror + // instruction. A fshl with a non-constant shift uses a neg + ror. + if (RetTy->isIntegerTy() && ICA.getArgs()[0] == ICA.getArgs()[1] && + (RetTy->getPrimitiveSizeInBits() == 32 || + RetTy->getPrimitiveSizeInBits() == 64)) { + InstructionCost NegCost = + (ICA.getID() == Intrinsic::fshl && !OpInfoZ.isConstant()) ? 1 : 0; + return 1 + NegCost; + } + + // TODO: Add handling for fshl where third argument is not a constant. if (!OpInfoZ.isConstant()) break; @@ -1032,6 +1064,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, } break; } + case Intrinsic::experimental_vector_extract_last_active: + if (ST->isSVEorStreamingSVEAvailable()) { + auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]); + // This should turn into chained clastb instructions. + return LegalCost; + } + break; default: break; } @@ -1418,10 +1457,22 @@ static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) { case Intrinsic::aarch64_sve_orr: return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_orr_u) .setMatchingIROpcode(Instruction::Or); + case Intrinsic::aarch64_sve_sqrshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sqrshl_u); + case Intrinsic::aarch64_sve_sqshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sqshl_u); case Intrinsic::aarch64_sve_sqsub: return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sqsub_u); + case Intrinsic::aarch64_sve_srshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_srshl_u); + case Intrinsic::aarch64_sve_uqrshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uqrshl_u); + case Intrinsic::aarch64_sve_uqshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uqshl_u); case Intrinsic::aarch64_sve_uqsub: return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uqsub_u); + case Intrinsic::aarch64_sve_urshl: + return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_urshl_u); case Intrinsic::aarch64_sve_add_u: return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode( @@ -1863,25 +1914,23 @@ static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC, static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC, IntrinsicInst &II) { - IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); - if (!Pg) - return std::nullopt; + Value *Pg = II.getOperand(1); - if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) - return std::nullopt; + // sve.dup(V, all_active, X) ==> splat(X) + if (isAllActivePredicate(Pg)) { + auto *RetTy = cast<ScalableVectorType>(II.getType()); + Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(), + II.getArgOperand(2)); + return IC.replaceInstUsesWith(II, Splat); + } - const auto PTruePattern = - cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); - if (PTruePattern != AArch64SVEPredPattern::vl1) + if (!match(Pg, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( + m_SpecificInt(AArch64SVEPredPattern::vl1)))) return std::nullopt; - // The intrinsic is inserting into lane zero so use an insert instead. - auto *IdxTy = Type::getInt64Ty(II.getContext()); - auto *Insert = InsertElementInst::Create( - II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0)); - Insert->insertBefore(II.getIterator()); - Insert->takeName(&II); - + // sve.dup(V, sve.ptrue(vl1), X) ==> insertelement V, X, 0 + Value *Insert = IC.Builder.CreateInsertElement( + II.getArgOperand(0), II.getArgOperand(2), uint64_t(0)); return IC.replaceInstUsesWith(II, Insert); } @@ -2220,7 +2269,7 @@ static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, return std::nullopt; } -template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc> +template <Intrinsic::ID MulOpc, Intrinsic::ID FuseOpc> static std::optional<Instruction *> instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II, bool MergeIntoAddendOp) { @@ -3000,9 +3049,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const { llvm_unreachable("Unsupported register kind"); } -bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, - ArrayRef<const Value *> Args, - Type *SrcOverrideTy) const { +bool AArch64TTIImpl::isSingleExtWideningInstruction( + unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args, + Type *SrcOverrideTy) const { // A helper that returns a vector type from the given type. The number of // elements in type Ty determines the vector width. auto toVectorTy = [&](Type *ArgTy) { @@ -3020,48 +3069,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) return false; - // Determine if the operation has a widening variant. We consider both the - // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the - // instructions. - // - // TODO: Add additional widening operations (e.g., shl, etc.) once we - // verify that their extending operands are eliminated during code - // generation. Type *SrcTy = SrcOverrideTy; switch (Opcode) { - case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). - case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). + case Instruction::Add: // UADDW(2), SADDW(2). + case Instruction::Sub: { // USUBW(2), SSUBW(2). // The second operand needs to be an extend if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) { if (!SrcTy) SrcTy = toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType()); - } else + break; + } + + if (Opcode == Instruction::Sub) return false; - break; - case Instruction::Mul: { // SMULL(2), UMULL(2) - // Both operands need to be extends of the same type. - if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || - (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { + + // UADDW(2), SADDW(2) can be commutted. + if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) { if (!SrcTy) SrcTy = toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType()); - } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) { - // If one of the operands is a Zext and the other has enough zero bits to - // be treated as unsigned, we can still general a umull, meaning the zext - // is free. - KnownBits Known = - computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); - if (Args[0]->getType()->getScalarSizeInBits() - - Known.Zero.countLeadingOnes() > - DstTy->getScalarSizeInBits() / 2) - return false; - if (!SrcTy) - SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(), - DstTy->getScalarSizeInBits() / 2)); - } else - return false; - break; + break; + } + return false; } default: return false; @@ -3092,6 +3122,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize; } +Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy, + ArrayRef<const Value *> Args, + Type *SrcOverrideTy) const { + if (Opcode != Instruction::Add && Opcode != Instruction::Sub && + Opcode != Instruction::Mul) + return nullptr; + + // Exit early if DstTy is not a vector type whose elements are one of [i16, + // i32, i64]. SVE doesn't generally have the same set of instructions to + // perform an extend with the add/sub/mul. There are SMULLB style + // instructions, but they operate on top/bottom, requiring some sort of lane + // interleaving to be used with zext/sext. + unsigned DstEltSize = DstTy->getScalarSizeInBits(); + if (!useNeonVector(DstTy) || Args.size() != 2 || + (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) + return nullptr; + + auto getScalarSizeWithOverride = [&](const Value *V) { + if (SrcOverrideTy) + return SrcOverrideTy->getScalarSizeInBits(); + return cast<Instruction>(V) + ->getOperand(0) + ->getType() + ->getScalarSizeInBits(); + }; + + unsigned MaxEltSize = 0; + if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || + (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { + unsigned EltSize0 = getScalarSizeWithOverride(Args[0]); + unsigned EltSize1 = getScalarSizeWithOverride(Args[1]); + MaxEltSize = std::max(EltSize0, EltSize1); + } else if (isa<SExtInst, ZExtInst>(Args[0]) && + isa<SExtInst, ZExtInst>(Args[1])) { + unsigned EltSize0 = getScalarSizeWithOverride(Args[0]); + unsigned EltSize1 = getScalarSizeWithOverride(Args[1]); + // mul(sext, zext) will become smull(sext, zext) if the extends are large + // enough. + if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2) + return nullptr; + MaxEltSize = DstEltSize / 2; + } else if (Opcode == Instruction::Mul && + (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) { + // If one of the operands is a Zext and the other has enough zero bits + // to be treated as unsigned, we can still generate a umull, meaning the + // zext is free. + KnownBits Known = + computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); + if (Args[0]->getType()->getScalarSizeInBits() - + Known.Zero.countLeadingOnes() > + DstTy->getScalarSizeInBits() / 2) + return nullptr; + + MaxEltSize = + getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]); + } else + return nullptr; + + if (MaxEltSize * 2 > DstEltSize) + return nullptr; + + Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2); + if (ExtTy->getPrimitiveSizeInBits() <= 64) + return nullptr; + return ExtTy; +} + // s/urhadd instructions implement the following pattern, making the // extends free: // %x = add ((zext i8 -> i16), 1) @@ -3152,7 +3249,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, if (I && I->hasOneUser()) { auto *SingleUser = cast<Instruction>(*I->user_begin()); SmallVector<const Value *, 4> Operands(SingleUser->operand_values()); - if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) { + if (Type *ExtTy = isBinExtWideningInstruction( + SingleUser->getOpcode(), Dst, Operands, + Src != I->getOperand(0)->getType() ? Src : nullptr)) { + // The cost from Src->Src*2 needs to be added if required, the cost from + // Src*2->ExtTy is free. + if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) { + Type *DoubleSrcTy = + Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2); + return getCastInstrCost(Opcode, DoubleSrcTy, Src, + TTI::CastContextHint::None, CostKind); + } + + return 0; + } + + if (isSingleExtWideningInstruction( + SingleUser->getOpcode(), Dst, Operands, + Src != I->getOperand(0)->getType() ? Src : nullptr)) { // For adds only count the second operand as free if both operands are // extends but not the same operation. (i.e both operands are not free in // add(sext, zext)). @@ -3161,8 +3275,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, (isa<CastInst>(SingleUser->getOperand(1)) && cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode)) return 0; - } else // Others are free so long as isWideningInstruction returned true. + } else { + // Others are free so long as isSingleExtWideningInstruction + // returned true. return 0; + } } // The cast will be free for the s/urhadd instructions @@ -4088,12 +4205,15 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead( std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost( Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info, - TTI::OperandValueInfo Op2Info, bool IncludeTrunc, + TTI::OperandValueInfo Op2Info, bool IncludeTrunc, bool CanUseSVE, std::function<InstructionCost(Type *)> InstCost) const { if (!Ty->getScalarType()->isHalfTy() && !Ty->getScalarType()->isBFloatTy()) return std::nullopt; if (Ty->getScalarType()->isHalfTy() && ST->hasFullFP16()) return std::nullopt; + if (CanUseSVE && Ty->isScalableTy() && ST->hasSVEB16B16() && + ST->isNonStreamingSVEorSME2Available()) + return std::nullopt; Type *PromotedTy = Ty->getWithNewType(Type::getFloatTy(Ty->getContext())); InstructionCost Cost = getCastInstrCost(Instruction::FPExt, PromotedTy, Ty, @@ -4135,12 +4255,26 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( ISD == ISD::FDIV || ISD == ISD::FREM) if (auto PromotedCost = getFP16BF16PromoteCost( Ty, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/true, + // There is not native support for fdiv/frem even with +sve-b16b16. + /*CanUseSVE=*/ISD != ISD::FDIV && ISD != ISD::FREM, [&](Type *PromotedTy) { return getArithmeticInstrCost(Opcode, PromotedTy, CostKind, Op1Info, Op2Info); })) return *PromotedCost; + // If the operation is a widening instruction (smull or umull) and both + // operands are extends the cost can be cheaper by considering that the + // operation will operate on the narrowest type size possible (double the + // largest input size) and a further extend. + if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) { + if (ExtTy != Ty) + return getArithmeticInstrCost(Opcode, ExtTy, CostKind) + + getCastInstrCost(Instruction::ZExt, Ty, ExtTy, + TTI::CastContextHint::None, CostKind); + return LT.first; + } + switch (ISD) { default: return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, @@ -4374,10 +4508,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( // - two 2-cost i64 inserts, and // - two 1-cost muls. // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with - // LT.first = 2 the cost is 28. If both operands are extensions it will not - // need to scalarize so the cost can be cheaper (smull or umull). - // so the cost can be cheaper (smull or umull). - if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args)) + // LT.first = 2 the cost is 28. + if (LT.second != MVT::v2i64) return LT.first; return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() * (getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) + @@ -4539,7 +4671,8 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost( if (Opcode == Instruction::FCmp) { if (auto PromotedCost = getFP16BF16PromoteCost( ValTy, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/false, - [&](Type *PromotedTy) { + // TODO: Consider costing SVE FCMPs. + /*CanUseSVE=*/false, [&](Type *PromotedTy) { InstructionCost Cost = getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind, Op1Info, Op2Info); @@ -4635,12 +4768,26 @@ bool AArch64TTIImpl::prefersVectorizedAddressing() const { } InstructionCost -AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, - Align Alignment, unsigned AddressSpace, +AArch64TTIImpl::getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + switch (MICA.getID()) { + case Intrinsic::masked_scatter: + case Intrinsic::masked_gather: + return getGatherScatterOpCost(MICA, CostKind); + case Intrinsic::masked_load: + case Intrinsic::masked_store: + return getMaskedMemoryOpCost(MICA, CostKind); + } + return BaseT::getMemIntrinsicInstrCost(MICA, CostKind); +} + +InstructionCost +AArch64TTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const { + Type *Src = MICA.getDataType(); + if (useNeonVector(Src)) - return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, - CostKind); + return BaseT::getMemIntrinsicInstrCost(MICA, CostKind); auto LT = getTypeLegalizationCost(Src); if (!LT.first.isValid()) return InstructionCost::getInvalid(); @@ -4682,12 +4829,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode, } } -InstructionCost AArch64TTIImpl::getGatherScatterOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { +InstructionCost +AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + + unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather || + MICA.getID() == Intrinsic::vp_gather) + ? Instruction::Load + : Instruction::Store; + + Type *DataTy = MICA.getDataType(); + Align Alignment = MICA.getAlignment(); + const Instruction *I = MICA.getInst(); + if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy)) - return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); + return BaseT::getMemIntrinsicInstrCost(MICA, CostKind); auto *VT = cast<VectorType>(DataTy); auto LT = getTypeLegalizationCost(DataTy); if (!LT.first.isValid()) @@ -5165,6 +5321,7 @@ void AArch64TTIImpl::getUnrollingPreferences( // inlining. Don't unroll auto-vectorized loops either, though do allow // unrolling of the scalar remainder. bool IsVectorized = getBooleanLoopAttribute(L, "llvm.loop.isvectorized"); + InstructionCost Cost = 0; for (auto *BB : L->getBlocks()) { for (auto &I : *BB) { // Both auto-vectorized loops and the scalar remainder have the @@ -5179,24 +5336,19 @@ void AArch64TTIImpl::getUnrollingPreferences( continue; return; } + + SmallVector<const Value *, 4> Operands(I.operand_values()); + Cost += getInstructionCost(&I, Operands, + TargetTransformInfo::TCK_SizeAndLatency); } } // Apply subtarget-specific unrolling preferences. - switch (ST->getProcFamily()) { - case AArch64Subtarget::AppleA14: - case AArch64Subtarget::AppleA15: - case AArch64Subtarget::AppleA16: - case AArch64Subtarget::AppleM4: + if (ST->isAppleMLike()) getAppleRuntimeUnrollPreferences(L, SE, UP, *this); - break; - case AArch64Subtarget::Falkor: - if (EnableFalkorHWPFUnrollFix) - getFalkorUnrollingPreferences(L, SE, UP); - break; - default: - break; - } + else if (ST->getProcFamily() == AArch64Subtarget::Falkor && + EnableFalkorHWPFUnrollFix) + getFalkorUnrollingPreferences(L, SE, UP); // If this is a small, multi-exit loop similar to something like std::find, // then there is typically a performance improvement achieved by unrolling. @@ -5225,6 +5377,11 @@ void AArch64TTIImpl::getUnrollingPreferences( UP.UnrollAndJam = true; UP.UnrollAndJamInnerLoopThreshold = 60; } + + // Force unrolling small loops can be very useful because of the branch + // taken cost of the backedge. + if (Cost < Aarch64ForceUnrollThreshold) + UP.Force = true; } void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE, @@ -5895,6 +6052,15 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, SrcTy = DstTy; } + // Check for identity masks, which we can treat as free for both fixed and + // scalable vector paths. + if (!Mask.empty() && LT.second.isFixedLengthVector() && + (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) && + all_of(enumerate(Mask), [](const auto &M) { + return M.value() < 0 || M.value() == (int)M.index(); + })) + return 0; + // Segmented shuffle matching. if (Kind == TTI::SK_PermuteSingleSrc && isa<FixedVectorType>(SrcTy) && !Mask.empty() && SrcTy->getPrimitiveSizeInBits().isNonZero() && @@ -5942,21 +6108,13 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, all_of(Mask, [](int E) { return E < 8; })) return getPerfectShuffleCost(Mask); - // Check for identity masks, which we can treat as free. - if (!Mask.empty() && LT.second.isFixedLengthVector() && - (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) && - all_of(enumerate(Mask), [](const auto &M) { - return M.value() < 0 || M.value() == (int)M.index(); - })) - return 0; - // Check for other shuffles that are not SK_ kinds but we have native // instructions for, for example ZIP and UZP. unsigned Unused; if (LT.second.isFixedLengthVector() && LT.second.getVectorNumElements() == Mask.size() && (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) && - (isZIPMask(Mask, LT.second.getVectorNumElements(), Unused) || + (isZIPMask(Mask, LT.second.getVectorNumElements(), Unused, Unused) || isUZPMask(Mask, LT.second.getVectorNumElements(), Unused) || isREVMask(Mask, LT.second.getScalarSizeInBits(), LT.second.getVectorNumElements(), 16) || @@ -6122,7 +6280,8 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, } static bool containsDecreasingPointers(Loop *TheLoop, - PredicatedScalarEvolution *PSE) { + PredicatedScalarEvolution *PSE, + const DominatorTree &DT) { const auto &Strides = DenseMap<Value *, const SCEV *>(); for (BasicBlock *BB : TheLoop->blocks()) { // Scan the instructions in the block and look for addresses that are @@ -6131,8 +6290,8 @@ static bool containsDecreasingPointers(Loop *TheLoop, if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) { Value *Ptr = getLoadStorePointerOperand(&I); Type *AccessTy = getLoadStoreType(&I); - if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true, - /*ShouldCheckWrap=*/false) + if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, DT, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false) .value_or(0) < 0) return true; } @@ -6177,7 +6336,8 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const { // negative strides. This will require extra work to reverse the loop // predicate, which may be expensive. if (containsDecreasingPointers(TFI->LVL->getLoop(), - TFI->LVL->getPredicatedScalarEvolution())) + TFI->LVL->getPredicatedScalarEvolution(), + *TFI->LVL->getDominatorTree())) Required |= TailFoldingOpts::Reverse; if (Required == TailFoldingOpts::Disabled) Required |= TailFoldingOpts::Simple; @@ -6650,10 +6810,15 @@ bool AArch64TTIImpl::isProfitableToSinkOperands( Ops.push_back(&Ext->getOperandUse(0)); Ops.push_back(&Op); - if (isa<SExtInst>(Ext)) + if (isa<SExtInst>(Ext)) { NumSExts++; - else + } else { NumZExts++; + // A zext(a) is also a sext(zext(a)), if we take more than 2 steps. + if (Ext->getOperand(0)->getType()->getScalarSizeInBits() * 2 < + I->getType()->getScalarSizeInBits()) + NumSExts++; + } continue; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index fe2e849..ecefe2a 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -59,9 +59,17 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> { VECTOR_LDST_FOUR_ELEMENTS }; - bool isWideningInstruction(Type *DstTy, unsigned Opcode, - ArrayRef<const Value *> Args, - Type *SrcOverrideTy = nullptr) const; + /// Given a add/sub/mul operation, detect a widening addl/subl/mull pattern + /// where both operands can be treated like extends. Returns the minimal type + /// needed to compute the operation. + Type *isBinExtWideningInstruction(unsigned Opcode, Type *DstTy, + ArrayRef<const Value *> Args, + Type *SrcOverrideTy = nullptr) const; + /// Given a add/sub operation with a single extend operand, detect a + /// widening addw/subw pattern. + bool isSingleExtWideningInstruction(unsigned Opcode, Type *DstTy, + ArrayRef<const Value *> Args, + Type *SrcOverrideTy = nullptr) const; // A helper function called by 'getVectorInstrCost'. // @@ -84,12 +92,13 @@ public: const Function *Callee) const override; bool areTypesABICompatible(const Function *Caller, const Function *Callee, - const ArrayRef<Type *> &Types) const override; + ArrayRef<Type *> Types) const override; unsigned getInlineCallPenalty(const Function *F, const CallBase &Call, unsigned DefaultCallPenalty) const override; APInt getFeatureMask(const Function &F) const override; + APInt getPriorityMask(const Function &F) const override; bool isMultiversionedFunction(const Function &F) const override; @@ -180,15 +189,14 @@ public: unsigned Opcode2) const; InstructionCost - getMaskedMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, - unsigned AddressSpace, - TTI::TargetCostKind CostKind) const override; + getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; - InstructionCost - getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const override; + InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const; + + InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const; bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst, Type *Src) const; @@ -304,7 +312,7 @@ public: } bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) const { - if (!ST->hasSVE()) + if (!ST->isSVEorStreamingSVEAvailable()) return false; // For fixed vectors, avoid scalarization if using SVE for them. @@ -316,15 +324,34 @@ public: } bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } + bool isElementTypeLegalForCompressStore(Type *Ty) const { + return Ty->isFloatTy() || Ty->isDoubleTy() || Ty->isIntegerTy(32) || + Ty->isIntegerTy(64); + } + + bool isLegalMaskedCompressStore(Type *DataType, + Align Alignment) const override { + if (!ST->isSVEAvailable()) + return false; + + if (isa<FixedVectorType>(DataType) && + DataType->getPrimitiveSizeInBits() < 128) + return false; + + return isElementTypeLegalForCompressStore(DataType->getScalarType()); + } + bool isLegalMaskedGatherScatter(Type *DataType) const { if (!ST->isSVEAvailable()) return false; @@ -448,11 +475,10 @@ public: /// FP16 and BF16 operations are lowered to fptrunc(op(fpext, fpext) if the /// architecture features are not present. - std::optional<InstructionCost> - getFP16BF16PromoteCost(Type *Ty, TTI::TargetCostKind CostKind, - TTI::OperandValueInfo Op1Info, - TTI::OperandValueInfo Op2Info, bool IncludeTrunc, - std::function<InstructionCost(Type *)> InstCost) const; + std::optional<InstructionCost> getFP16BF16PromoteCost( + Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info, + TTI::OperandValueInfo Op2Info, bool IncludeTrunc, bool CanUseSVE, + std::function<InstructionCost(Type *)> InstCost) const; InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp index 6273cfc..433cb03 100644 --- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -88,7 +88,7 @@ private: StringRef Mnemonic; ///< Instruction mnemonic. // Map of register aliases registers via the .req directive. - StringMap<std::pair<RegKind, unsigned>> RegisterReqs; + StringMap<std::pair<RegKind, MCRegister>> RegisterReqs; class PrefixInfo { public: @@ -165,7 +165,7 @@ private: AArch64CC::CondCode parseCondCodeString(StringRef Cond, std::string &Suggestion); bool parseCondCode(OperandVector &Operands, bool invertCondCode); - unsigned matchRegisterNameAlias(StringRef Name, RegKind Kind); + MCRegister matchRegisterNameAlias(StringRef Name, RegKind Kind); bool parseRegister(OperandVector &Operands); bool parseSymbolicImmVal(const MCExpr *&ImmVal); bool parseNeonVectorList(OperandVector &Operands); @@ -268,6 +268,7 @@ private: ParseStatus tryParsePSBHint(OperandVector &Operands); ParseStatus tryParseBTIHint(OperandVector &Operands); ParseStatus tryParseCMHPriorityHint(OperandVector &Operands); + ParseStatus tryParseTIndexHint(OperandVector &Operands); ParseStatus tryParseAdrpLabel(OperandVector &Operands); ParseStatus tryParseAdrLabel(OperandVector &Operands); template <bool AddFPZeroAsLiteral> @@ -373,6 +374,7 @@ private: k_PHint, k_BTIHint, k_CMHPriorityHint, + k_TIndexHint, } Kind; SMLoc StartLoc, EndLoc; @@ -391,7 +393,7 @@ private: }; struct RegOp { - unsigned RegNum; + MCRegister Reg; RegKind Kind; int ElementWidth; @@ -417,7 +419,7 @@ private: }; struct MatrixRegOp { - unsigned RegNum; + MCRegister Reg; unsigned ElementWidth; MatrixKind Kind; }; @@ -427,7 +429,7 @@ private: }; struct VectorListOp { - unsigned RegNum; + MCRegister Reg; unsigned Count; unsigned Stride; unsigned NumElements; @@ -507,6 +509,11 @@ private: unsigned Length; unsigned Val; }; + struct TIndexHintOp { + const char *Data; + unsigned Length; + unsigned Val; + }; struct SVCROp { const char *Data; @@ -534,6 +541,7 @@ private: struct PHintOp PHint; struct BTIHintOp BTIHint; struct CMHPriorityHintOp CMHPriorityHint; + struct TIndexHintOp TIndexHint; struct ShiftExtendOp ShiftExtend; struct SVCROp SVCR; }; @@ -607,6 +615,9 @@ public: case k_CMHPriorityHint: CMHPriorityHint = o.CMHPriorityHint; break; + case k_TIndexHint: + TIndexHint = o.TIndexHint; + break; case k_ShiftExtend: ShiftExtend = o.ShiftExtend; break; @@ -688,12 +699,12 @@ public: MCRegister getReg() const override { assert(Kind == k_Register && "Invalid access!"); - return Reg.RegNum; + return Reg.Reg; } - unsigned getMatrixReg() const { + MCRegister getMatrixReg() const { assert(Kind == k_MatrixRegister && "Invalid access!"); - return MatrixReg.RegNum; + return MatrixReg.Reg; } unsigned getMatrixElementWidth() const { @@ -716,9 +727,9 @@ public: return Reg.EqualityTy; } - unsigned getVectorListStart() const { + MCRegister getVectorListStart() const { assert(Kind == k_VectorList && "Invalid access!"); - return VectorList.RegNum; + return VectorList.Reg; } unsigned getVectorListCount() const { @@ -791,6 +802,16 @@ public: return StringRef(CMHPriorityHint.Data, CMHPriorityHint.Length); } + unsigned getTIndexHint() const { + assert(Kind == k_TIndexHint && "Invalid access!"); + return TIndexHint.Val; + } + + StringRef getTIndexHintName() const { + assert(Kind == k_TIndexHint && "Invalid access!"); + return StringRef(TIndexHint.Data, TIndexHint.Length); + } + StringRef getSVCR() const { assert(Kind == k_SVCR && "Invalid access!"); return StringRef(SVCR.Data, SVCR.Length); @@ -1264,15 +1285,15 @@ public: bool isNeonVectorRegLo() const { return Kind == k_Register && Reg.Kind == RegKind::NeonVector && (AArch64MCRegisterClasses[AArch64::FPR128_loRegClassID].contains( - Reg.RegNum) || + Reg.Reg) || AArch64MCRegisterClasses[AArch64::FPR64_loRegClassID].contains( - Reg.RegNum)); + Reg.Reg)); } bool isNeonVectorReg0to7() const { return Kind == k_Register && Reg.Kind == RegKind::NeonVector && (AArch64MCRegisterClasses[AArch64::FPR128_0to7RegClassID].contains( - Reg.RegNum)); + Reg.Reg)); } bool isMatrix() const { return Kind == k_MatrixRegister; } @@ -1401,34 +1422,34 @@ public: bool isGPR32as64() const { return Kind == k_Register && Reg.Kind == RegKind::Scalar && - AArch64MCRegisterClasses[AArch64::GPR64RegClassID].contains(Reg.RegNum); + AArch64MCRegisterClasses[AArch64::GPR64RegClassID].contains(Reg.Reg); } bool isGPR64as32() const { return Kind == k_Register && Reg.Kind == RegKind::Scalar && - AArch64MCRegisterClasses[AArch64::GPR32RegClassID].contains(Reg.RegNum); + AArch64MCRegisterClasses[AArch64::GPR32RegClassID].contains(Reg.Reg); } bool isGPR64x8() const { return Kind == k_Register && Reg.Kind == RegKind::Scalar && AArch64MCRegisterClasses[AArch64::GPR64x8ClassRegClassID].contains( - Reg.RegNum); + Reg.Reg); } bool isWSeqPair() const { return Kind == k_Register && Reg.Kind == RegKind::Scalar && AArch64MCRegisterClasses[AArch64::WSeqPairsClassRegClassID].contains( - Reg.RegNum); + Reg.Reg); } bool isXSeqPair() const { return Kind == k_Register && Reg.Kind == RegKind::Scalar && AArch64MCRegisterClasses[AArch64::XSeqPairsClassRegClassID].contains( - Reg.RegNum); + Reg.Reg); } bool isSyspXzrPair() const { - return isGPR64<AArch64::GPR64RegClassID>() && Reg.RegNum == AArch64::XZR; + return isGPR64<AArch64::GPR64RegClassID>() && Reg.Reg == AArch64::XZR; } template<int64_t Angle, int64_t Remainder> @@ -1495,7 +1516,7 @@ public: isTypedVectorList<VectorKind, NumRegs, NumElements, ElementWidth>(); if (!Res) return DiagnosticPredicate::NoMatch; - if (!AArch64MCRegisterClasses[RegClass].contains(VectorList.RegNum)) + if (!AArch64MCRegisterClasses[RegClass].contains(VectorList.Reg)) return DiagnosticPredicate::NearMatch; return DiagnosticPredicate::Match; } @@ -1507,9 +1528,9 @@ public: ElementWidth, Stride>(); if (!Res) return DiagnosticPredicate::NoMatch; - if ((VectorList.RegNum < (AArch64::Z0 + Stride)) || - ((VectorList.RegNum >= AArch64::Z16) && - (VectorList.RegNum < (AArch64::Z16 + Stride)))) + if ((VectorList.Reg < (AArch64::Z0 + Stride)) || + ((VectorList.Reg >= AArch64::Z16) && + (VectorList.Reg < (AArch64::Z16 + Stride)))) return DiagnosticPredicate::Match; return DiagnosticPredicate::NoMatch; } @@ -1534,6 +1555,7 @@ public: bool isPHint() const { return Kind == k_PHint; } bool isBTIHint() const { return Kind == k_BTIHint; } bool isCMHPriorityHint() const { return Kind == k_CMHPriorityHint; } + bool isTIndexHint() const { return Kind == k_TIndexHint; } bool isShiftExtend() const { return Kind == k_ShiftExtend; } bool isShifter() const { if (!isShiftExtend()) @@ -1841,7 +1863,7 @@ public: void addPPRorPNRRegOperands(MCInst &Inst, unsigned N) const { assert(N == 1 && "Invalid number of operands!"); - unsigned Reg = getReg(); + MCRegister Reg = getReg(); // Normalise to PPR if (Reg >= AArch64::PN0 && Reg <= AArch64::PN15) Reg = Reg - AArch64::PN0 + AArch64::P0; @@ -2224,6 +2246,11 @@ public: Inst.addOperand(MCOperand::createImm(getCMHPriorityHint())); } + void addTIndexHintOperands(MCInst &Inst, unsigned N) const { + assert(N == 1 && "Invalid number of operands!"); + Inst.addOperand(MCOperand::createImm(getTIndexHint())); + } + void addShifterOperands(MCInst &Inst, unsigned N) const { assert(N == 1 && "Invalid number of operands!"); unsigned Imm = @@ -2336,13 +2363,12 @@ public: } static std::unique_ptr<AArch64Operand> - CreateReg(unsigned RegNum, RegKind Kind, SMLoc S, SMLoc E, MCContext &Ctx, + CreateReg(MCRegister Reg, RegKind Kind, SMLoc S, SMLoc E, MCContext &Ctx, RegConstraintEqualityTy EqTy = RegConstraintEqualityTy::EqualsReg, AArch64_AM::ShiftExtendType ExtTy = AArch64_AM::LSL, - unsigned ShiftAmount = 0, - unsigned HasExplicitAmount = false) { + unsigned ShiftAmount = 0, unsigned HasExplicitAmount = false) { auto Op = std::make_unique<AArch64Operand>(k_Register, Ctx); - Op->Reg.RegNum = RegNum; + Op->Reg.Reg = Reg; Op->Reg.Kind = Kind; Op->Reg.ElementWidth = 0; Op->Reg.EqualityTy = EqTy; @@ -2354,28 +2380,26 @@ public: return Op; } - static std::unique_ptr<AArch64Operand> - CreateVectorReg(unsigned RegNum, RegKind Kind, unsigned ElementWidth, - SMLoc S, SMLoc E, MCContext &Ctx, - AArch64_AM::ShiftExtendType ExtTy = AArch64_AM::LSL, - unsigned ShiftAmount = 0, - unsigned HasExplicitAmount = false) { + static std::unique_ptr<AArch64Operand> CreateVectorReg( + MCRegister Reg, RegKind Kind, unsigned ElementWidth, SMLoc S, SMLoc E, + MCContext &Ctx, AArch64_AM::ShiftExtendType ExtTy = AArch64_AM::LSL, + unsigned ShiftAmount = 0, unsigned HasExplicitAmount = false) { assert((Kind == RegKind::NeonVector || Kind == RegKind::SVEDataVector || Kind == RegKind::SVEPredicateVector || Kind == RegKind::SVEPredicateAsCounter) && "Invalid vector kind"); - auto Op = CreateReg(RegNum, Kind, S, E, Ctx, EqualsReg, ExtTy, ShiftAmount, + auto Op = CreateReg(Reg, Kind, S, E, Ctx, EqualsReg, ExtTy, ShiftAmount, HasExplicitAmount); Op->Reg.ElementWidth = ElementWidth; return Op; } static std::unique_ptr<AArch64Operand> - CreateVectorList(unsigned RegNum, unsigned Count, unsigned Stride, + CreateVectorList(MCRegister Reg, unsigned Count, unsigned Stride, unsigned NumElements, unsigned ElementWidth, RegKind RegisterKind, SMLoc S, SMLoc E, MCContext &Ctx) { auto Op = std::make_unique<AArch64Operand>(k_VectorList, Ctx); - Op->VectorList.RegNum = RegNum; + Op->VectorList.Reg = Reg; Op->VectorList.Count = Count; Op->VectorList.Stride = Stride; Op->VectorList.NumElements = NumElements; @@ -2586,10 +2610,21 @@ public: } static std::unique_ptr<AArch64Operand> - CreateMatrixRegister(unsigned RegNum, unsigned ElementWidth, MatrixKind Kind, + CreateTIndexHint(unsigned Val, StringRef Str, SMLoc S, MCContext &Ctx) { + auto Op = std::make_unique<AArch64Operand>(k_TIndexHint, Ctx); + Op->TIndexHint.Val = Val; + Op->TIndexHint.Data = Str.data(); + Op->TIndexHint.Length = Str.size(); + Op->StartLoc = S; + Op->EndLoc = S; + return Op; + } + + static std::unique_ptr<AArch64Operand> + CreateMatrixRegister(MCRegister Reg, unsigned ElementWidth, MatrixKind Kind, SMLoc S, SMLoc E, MCContext &Ctx) { auto Op = std::make_unique<AArch64Operand>(k_MatrixRegister, Ctx); - Op->MatrixReg.RegNum = RegNum; + Op->MatrixReg.Reg = Reg; Op->MatrixReg.ElementWidth = ElementWidth; Op->MatrixReg.Kind = Kind; Op->StartLoc = S; @@ -2660,9 +2695,9 @@ void AArch64Operand::print(raw_ostream &OS, const MCAsmInfo &MAI) const { break; case k_VectorList: { OS << "<vectorlist "; - unsigned Reg = getVectorListStart(); + MCRegister Reg = getVectorListStart(); for (unsigned i = 0, e = getVectorListCount(); i != e; ++i) - OS << Reg + i * getVectorListStride() << " "; + OS << Reg.id() + i * getVectorListStride() << " "; OS << ">"; break; } @@ -2698,8 +2733,11 @@ void AArch64Operand::print(raw_ostream &OS, const MCAsmInfo &MAI) const { case k_CMHPriorityHint: OS << getCMHPriorityHintName(); break; + case k_TIndexHint: + OS << getTIndexHintName(); + break; case k_MatrixRegister: - OS << "<matrix " << getMatrixReg() << ">"; + OS << "<matrix " << getMatrixReg().id() << ">"; break; case k_MatrixTileList: { OS << "<matrixlist "; @@ -2715,7 +2753,7 @@ void AArch64Operand::print(raw_ostream &OS, const MCAsmInfo &MAI) const { break; } case k_Register: - OS << "<register " << getReg() << ">"; + OS << "<register " << getReg().id() << ">"; if (!getShiftExtendAmount() && !hasShiftExtendAmount()) break; [[fallthrough]]; @@ -3048,53 +3086,53 @@ ParseStatus AArch64AsmParser::tryParseRegister(MCRegister &Reg, SMLoc &StartLoc, } // Matches a register name or register alias previously defined by '.req' -unsigned AArch64AsmParser::matchRegisterNameAlias(StringRef Name, - RegKind Kind) { - unsigned RegNum = 0; - if ((RegNum = matchSVEDataVectorRegName(Name))) - return Kind == RegKind::SVEDataVector ? RegNum : 0; +MCRegister AArch64AsmParser::matchRegisterNameAlias(StringRef Name, + RegKind Kind) { + MCRegister Reg = MCRegister(); + if ((Reg = matchSVEDataVectorRegName(Name))) + return Kind == RegKind::SVEDataVector ? Reg : MCRegister(); - if ((RegNum = matchSVEPredicateVectorRegName(Name))) - return Kind == RegKind::SVEPredicateVector ? RegNum : 0; + if ((Reg = matchSVEPredicateVectorRegName(Name))) + return Kind == RegKind::SVEPredicateVector ? Reg : MCRegister(); - if ((RegNum = matchSVEPredicateAsCounterRegName(Name))) - return Kind == RegKind::SVEPredicateAsCounter ? RegNum : 0; + if ((Reg = matchSVEPredicateAsCounterRegName(Name))) + return Kind == RegKind::SVEPredicateAsCounter ? Reg : MCRegister(); - if ((RegNum = MatchNeonVectorRegName(Name))) - return Kind == RegKind::NeonVector ? RegNum : 0; + if ((Reg = MatchNeonVectorRegName(Name))) + return Kind == RegKind::NeonVector ? Reg : MCRegister(); - if ((RegNum = matchMatrixRegName(Name))) - return Kind == RegKind::Matrix ? RegNum : 0; + if ((Reg = matchMatrixRegName(Name))) + return Kind == RegKind::Matrix ? Reg : MCRegister(); - if (Name.equals_insensitive("zt0")) + if (Name.equals_insensitive("zt0")) return Kind == RegKind::LookupTable ? unsigned(AArch64::ZT0) : 0; // The parsed register must be of RegKind Scalar - if ((RegNum = MatchRegisterName(Name))) - return (Kind == RegKind::Scalar) ? RegNum : 0; + if ((Reg = MatchRegisterName(Name))) + return (Kind == RegKind::Scalar) ? Reg : MCRegister(); - if (!RegNum) { + if (!Reg) { // Handle a few common aliases of registers. - if (auto RegNum = StringSwitch<unsigned>(Name.lower()) - .Case("fp", AArch64::FP) - .Case("lr", AArch64::LR) - .Case("x31", AArch64::XZR) - .Case("w31", AArch64::WZR) - .Default(0)) - return Kind == RegKind::Scalar ? RegNum : 0; + if (MCRegister Reg = StringSwitch<unsigned>(Name.lower()) + .Case("fp", AArch64::FP) + .Case("lr", AArch64::LR) + .Case("x31", AArch64::XZR) + .Case("w31", AArch64::WZR) + .Default(0)) + return Kind == RegKind::Scalar ? Reg : MCRegister(); // Check for aliases registered via .req. Canonicalize to lower case. // That's more consistent since register names are case insensitive, and // it's how the original entry was passed in from MC/MCParser/AsmParser. auto Entry = RegisterReqs.find(Name.lower()); if (Entry == RegisterReqs.end()) - return 0; + return MCRegister(); - // set RegNum if the match is the right kind of register + // set Reg if the match is the right kind of register if (Kind == Entry->getValue().first) - RegNum = Entry->getValue().second; + Reg = Entry->getValue().second; } - return RegNum; + return Reg; } unsigned AArch64AsmParser::getNumRegsForRegKind(RegKind K) { @@ -3122,8 +3160,8 @@ ParseStatus AArch64AsmParser::tryParseScalarRegister(MCRegister &RegNum) { return ParseStatus::NoMatch; std::string lowerCase = Tok.getString().lower(); - unsigned Reg = matchRegisterNameAlias(lowerCase, RegKind::Scalar); - if (Reg == 0) + MCRegister Reg = matchRegisterNameAlias(lowerCase, RegKind::Scalar); + if (!Reg) return ParseStatus::NoMatch; RegNum = Reg; @@ -3339,6 +3377,23 @@ ParseStatus AArch64AsmParser::tryParseCMHPriorityHint(OperandVector &Operands) { return ParseStatus::Success; } +/// tryParseTIndexHint - Try to parse a TIndex operand +ParseStatus AArch64AsmParser::tryParseTIndexHint(OperandVector &Operands) { + SMLoc S = getLoc(); + const AsmToken &Tok = getTok(); + if (Tok.isNot(AsmToken::Identifier)) + return TokError("invalid operand for instruction"); + + auto TIndex = AArch64TIndexHint::lookupTIndexByName(Tok.getString()); + if (!TIndex) + return TokError("invalid operand for instruction"); + + Operands.push_back(AArch64Operand::CreateTIndexHint( + TIndex->Encoding, Tok.getString(), S, getContext())); + Lex(); // Eat identifier token. + return ParseStatus::Success; +} + /// tryParseAdrpLabel - Parse and validate a source label for the ADRP /// instruction. ParseStatus AArch64AsmParser::tryParseAdrpLabel(OperandVector &Operands) { @@ -3667,7 +3722,7 @@ ParseStatus AArch64AsmParser::tryParseMatrixRegister(OperandVector &Operands) { } // Try to parse matrix register. - unsigned Reg = matchRegisterNameAlias(Name, RegKind::Matrix); + MCRegister Reg = matchRegisterNameAlias(Name, RegKind::Matrix); if (!Reg) return ParseStatus::NoMatch; @@ -3850,7 +3905,6 @@ static const struct Extension { {"rdma", {AArch64::FeatureRDM}}, {"sb", {AArch64::FeatureSB}}, {"ssbs", {AArch64::FeatureSSBS}}, - {"tme", {AArch64::FeatureTME}}, {"fp8", {AArch64::FeatureFP8}}, {"faminmax", {AArch64::FeatureFAMINMAX}}, {"fp8fma", {AArch64::FeatureFP8FMA}}, @@ -3896,6 +3950,10 @@ static const struct Extension { {"f16mm", {AArch64::FeatureF16MM}}, {"f16f32dot", {AArch64::FeatureF16F32DOT}}, {"f16f32mm", {AArch64::FeatureF16F32MM}}, + {"mops-go", {AArch64::FeatureMOPS_GO}}, + {"poe2", {AArch64::FeatureS1POE2}}, + {"tev", {AArch64::FeatureTEV}}, + {"btie", {AArch64::FeatureBTIE}}, }; static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) { @@ -3985,6 +4043,7 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, bool ExpectRegister = true; bool OptionalRegister = false; bool hasAll = getSTI().hasFeature(AArch64::FeatureAll); + bool hasTLBID = getSTI().hasFeature(AArch64::FeatureTLBID); if (Mnemonic == "ic") { const AArch64IC::IC *IC = AArch64IC::lookupICByName(Op); @@ -4052,7 +4111,7 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, setRequiredFeatureString(GIC->getRequiredFeatures(), Str); return TokError(Str); } - ExpectRegister = true; + ExpectRegister = GIC->NeedsReg; createSysAlias(GIC->Encoding, Operands, S); } else if (Mnemonic == "gsb") { const AArch64GSB::GSB *GSB = AArch64GSB::lookupGSBByName(Op); @@ -4065,6 +4124,20 @@ bool AArch64AsmParser::parseSysAlias(StringRef Name, SMLoc NameLoc, } ExpectRegister = false; createSysAlias(GSB->Encoding, Operands, S); + } else if (Mnemonic == "plbi") { + const AArch64PLBI::PLBI *PLBI = AArch64PLBI::lookupPLBIByName(Op); + if (!PLBI) + return TokError("invalid operand for PLBI instruction"); + else if (!PLBI->haveFeatures(getSTI().getFeatureBits())) { + std::string Str("PLBI " + std::string(PLBI->Name) + " requires: "); + setRequiredFeatureString(PLBI->getRequiredFeatures(), Str); + return TokError(Str); + } + ExpectRegister = PLBI->NeedsReg; + if (hasAll || hasTLBID) { + OptionalRegister = PLBI->OptionalReg; + } + createSysAlias(PLBI->Encoding, Operands, S); } else if (Mnemonic == "cfp" || Mnemonic == "dvp" || Mnemonic == "cpp" || Mnemonic == "cosp") { @@ -4130,12 +4203,12 @@ bool AArch64AsmParser::parseSyslAlias(StringRef Name, SMLoc NameLoc, SMLoc startLoc = getLoc(); const AsmToken ®Tok = getTok(); StringRef reg = regTok.getString(); - unsigned RegNum = matchRegisterNameAlias(reg.lower(), RegKind::Scalar); - if (!RegNum) + MCRegister Reg = matchRegisterNameAlias(reg.lower(), RegKind::Scalar); + if (!Reg) return TokError("expected register operand"); Operands.push_back(AArch64Operand::CreateReg( - RegNum, RegKind::Scalar, startLoc, getLoc(), getContext(), EqualsReg)); + Reg, RegKind::Scalar, startLoc, getLoc(), getContext(), EqualsReg)); Lex(); // Eat token if (parseToken(AsmToken::Comma)) @@ -4453,7 +4526,7 @@ ParseStatus AArch64AsmParser::tryParseVectorRegister(MCRegister &Reg, // a '.'. size_t Start = 0, Next = Name.find('.'); StringRef Head = Name.slice(Start, Next); - unsigned RegNum = matchRegisterNameAlias(Head, MatchKind); + MCRegister RegNum = matchRegisterNameAlias(Head, MatchKind); if (RegNum) { if (Next != StringRef::npos) { @@ -4937,13 +5010,13 @@ ParseStatus AArch64AsmParser::tryParseZTOperand(OperandVector &Operands) { const AsmToken &Tok = getTok(); std::string Name = Tok.getString().lower(); - unsigned RegNum = matchRegisterNameAlias(Name, RegKind::LookupTable); + MCRegister Reg = matchRegisterNameAlias(Name, RegKind::LookupTable); - if (RegNum == 0) + if (!Reg) return ParseStatus::NoMatch; Operands.push_back(AArch64Operand::CreateReg( - RegNum, RegKind::LookupTable, StartLoc, getLoc(), getContext())); + Reg, RegKind::LookupTable, StartLoc, getLoc(), getContext())); Lex(); // Eat register. // Check if register is followed by an index @@ -5439,11 +5512,11 @@ bool AArch64AsmParser::parseInstruction(ParseInstructionInfo &Info, size_t Start = 0, Next = Name.find('.'); StringRef Head = Name.slice(Start, Next); - // IC, DC, AT, TLBI, MLBI, GIC{R}, GSB and Prediction invalidation + // IC, DC, AT, TLBI, MLBI, PLBI, GIC{R}, GSB and Prediction invalidation // instructions are aliases for the SYS instruction. if (Head == "ic" || Head == "dc" || Head == "at" || Head == "tlbi" || Head == "cfp" || Head == "dvp" || Head == "cpp" || Head == "cosp" || - Head == "mlbi" || Head == "gic" || Head == "gsb") + Head == "mlbi" || Head == "plbi" || Head == "gic" || Head == "gsb") return parseSysAlias(Head, NameLoc, Operands); // GICR instructions are aliases for the SYSL instruction. @@ -5925,21 +5998,15 @@ bool AArch64AsmParser::validateInstruction(MCInst &Inst, SMLoc &IDLoc, case AArch64::CPYETWN: case AArch64::CPYETRN: case AArch64::CPYETN: { - MCRegister Xd_wb = Inst.getOperand(0).getReg(); - MCRegister Xs_wb = Inst.getOperand(1).getReg(); - MCRegister Xn_wb = Inst.getOperand(2).getReg(); + // Xd_wb == op0, Xs_wb == op1, Xn_wb == op2 MCRegister Xd = Inst.getOperand(3).getReg(); MCRegister Xs = Inst.getOperand(4).getReg(); MCRegister Xn = Inst.getOperand(5).getReg(); - if (Xd_wb != Xd) - return Error(Loc[0], - "invalid CPY instruction, Xd_wb and Xd do not match"); - if (Xs_wb != Xs) - return Error(Loc[0], - "invalid CPY instruction, Xs_wb and Xs do not match"); - if (Xn_wb != Xn) - return Error(Loc[0], - "invalid CPY instruction, Xn_wb and Xn do not match"); + + assert(Xd == Inst.getOperand(0).getReg() && "Xd_wb and Xd do not match"); + assert(Xs == Inst.getOperand(1).getReg() && "Xs_wb and Xs do not match"); + assert(Xn == Inst.getOperand(2).getReg() && "Xn_wb and Xn do not match"); + if (Xd == Xs) return Error(Loc[0], "invalid CPY instruction, destination and source" " registers are the same"); @@ -5975,17 +6042,14 @@ bool AArch64AsmParser::validateInstruction(MCInst &Inst, SMLoc &IDLoc, case AArch64::MOPSSETGET: case AArch64::MOPSSETGEN: case AArch64::MOPSSETGETN: { - MCRegister Xd_wb = Inst.getOperand(0).getReg(); - MCRegister Xn_wb = Inst.getOperand(1).getReg(); + // Xd_wb == op0, Xn_wb == op1 MCRegister Xd = Inst.getOperand(2).getReg(); MCRegister Xn = Inst.getOperand(3).getReg(); MCRegister Xm = Inst.getOperand(4).getReg(); - if (Xd_wb != Xd) - return Error(Loc[0], - "invalid SET instruction, Xd_wb and Xd do not match"); - if (Xn_wb != Xn) - return Error(Loc[0], - "invalid SET instruction, Xn_wb and Xn do not match"); + + assert(Xd == Inst.getOperand(0).getReg() && "Xd_wb and Xd do not match"); + assert(Xn == Inst.getOperand(1).getReg() && "Xn_wb and Xn do not match"); + if (Xd == Xn) return Error(Loc[0], "invalid SET instruction, destination and size" " registers are the same"); @@ -5997,6 +6061,30 @@ bool AArch64AsmParser::validateInstruction(MCInst &Inst, SMLoc &IDLoc, " registers are the same"); break; } + case AArch64::SETGOP: + case AArch64::SETGOPT: + case AArch64::SETGOPN: + case AArch64::SETGOPTN: + case AArch64::SETGOM: + case AArch64::SETGOMT: + case AArch64::SETGOMN: + case AArch64::SETGOMTN: + case AArch64::SETGOE: + case AArch64::SETGOET: + case AArch64::SETGOEN: + case AArch64::SETGOETN: { + // Xd_wb == op0, Xn_wb == op1 + MCRegister Xd = Inst.getOperand(2).getReg(); + MCRegister Xn = Inst.getOperand(3).getReg(); + + assert(Xd == Inst.getOperand(0).getReg() && "Xd_wb and Xd do not match"); + assert(Xn == Inst.getOperand(1).getReg() && "Xn_wb and Xn do not match"); + + if (Xd == Xn) + return Error(Loc[0], "invalid SET instruction, destination and size" + " registers are the same"); + break; + } } // Now check immediate ranges. Separate from the above as there is overlap @@ -7651,7 +7739,7 @@ bool AArch64AsmParser::parseDirectiveReq(StringRef Name, SMLoc L) { if (parseEOL()) return true; - auto pair = std::make_pair(RegisterKind, (unsigned) RegNum); + auto pair = std::make_pair(RegisterKind, RegNum); if (RegisterReqs.insert(std::make_pair(Name, pair)).first->second != pair) Warning(L, "ignoring redefinition of register alias '" + Name + "'"); diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt index 2226ac5..3334b36 100644 --- a/llvm/lib/Target/AArch64/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/CMakeLists.txt @@ -61,6 +61,7 @@ add_llvm_target(AArch64CodeGen AArch64CompressJumpTables.cpp AArch64ConditionOptimizer.cpp AArch64RedundantCopyElimination.cpp + AArch64RedundantCondBranchPass.cpp AArch64ISelDAGToDAG.cpp AArch64ISelLowering.cpp AArch64InstrInfo.cpp @@ -76,6 +77,7 @@ add_llvm_target(AArch64CodeGen AArch64PromoteConstant.cpp AArch64PBQPRegAlloc.cpp AArch64RegisterInfo.cpp + AArch64SMEAttributes.cpp AArch64SLSHardening.cpp AArch64SelectionDAGInfo.cpp AArch64SpeculationHardening.cpp diff --git a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp index dc2feba4..4eb762a 100644 --- a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp +++ b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp @@ -1532,6 +1532,32 @@ static DecodeStatus DecodeSETMemOpInstruction(MCInst &Inst, uint32_t insn, return MCDisassembler::Success; } +static DecodeStatus DecodeSETMemGoOpInstruction(MCInst &Inst, uint32_t insn, + uint64_t Addr, + const MCDisassembler *Decoder) { + unsigned Rd = fieldFromInstruction(insn, 0, 5); + unsigned Rn = fieldFromInstruction(insn, 5, 5); + + // None of the registers may alias: if they do, then the instruction is not + // merely unpredictable but actually entirely unallocated. + if (Rd == Rn) + return MCDisassembler::Fail; + + // Rd and Rn register operands are written back, so they appear + // twice in the operand list, once as outputs and once as inputs. + if (!DecodeSimpleRegisterClass<AArch64::GPR64commonRegClassID, 0, 31>( + Inst, Rd, Addr, Decoder) || + !DecodeSimpleRegisterClass<AArch64::GPR64RegClassID, 0, 32>( + Inst, Rn, Addr, Decoder) || + !DecodeSimpleRegisterClass<AArch64::GPR64commonRegClassID, 0, 31>( + Inst, Rd, Addr, Decoder) || + !DecodeSimpleRegisterClass<AArch64::GPR64RegClassID, 0, 32>( + Inst, Rn, Addr, Decoder)) + return MCDisassembler::Fail; + + return MCDisassembler::Success; +} + static DecodeStatus DecodePRFMRegInstruction(MCInst &Inst, uint32_t insn, uint64_t Addr, const MCDisassembler *Decoder) { diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 79bef76..7907a3c 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -17,8 +17,8 @@ #include "AArch64ISelLowering.h" #include "AArch64MachineFunctionInfo.h" #include "AArch64RegisterInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ObjCARCUtil.h" @@ -1421,6 +1421,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } else if (Info.CFIType) { MIB->setCFIType(MF, Info.CFIType->getZExtValue()); } + MIB->setDeactivationSymbol(MF, Info.DeactivationSymbol); MIB.add(Info.Callee); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 14b0f9a..f9db39e 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -310,6 +310,8 @@ private: MachineIRBuilder &MIRBuilder) const; MachineInstr *emitSBCS(Register Dst, MachineOperand &LHS, MachineOperand &RHS, MachineIRBuilder &MIRBuilder) const; + MachineInstr *emitCMP(MachineOperand &LHS, MachineOperand &RHS, + MachineIRBuilder &MIRBuilder) const; MachineInstr *emitCMN(MachineOperand &LHS, MachineOperand &RHS, MachineIRBuilder &MIRBuilder) const; MachineInstr *emitTST(MachineOperand &LHS, MachineOperand &RHS, @@ -415,10 +417,10 @@ private: } std::optional<bool> - isWorthFoldingIntoAddrMode(MachineInstr &MI, + isWorthFoldingIntoAddrMode(const MachineInstr &MI, const MachineRegisterInfo &MRI) const; - bool isWorthFoldingIntoExtendedReg(MachineInstr &MI, + bool isWorthFoldingIntoExtendedReg(const MachineInstr &MI, const MachineRegisterInfo &MRI, bool IsAddrOperand) const; ComplexRendererFns @@ -4413,6 +4415,15 @@ AArch64InstructionSelector::emitSBCS(Register Dst, MachineOperand &LHS, } MachineInstr * +AArch64InstructionSelector::emitCMP(MachineOperand &LHS, MachineOperand &RHS, + MachineIRBuilder &MIRBuilder) const { + MachineRegisterInfo &MRI = MIRBuilder.getMF().getRegInfo(); + bool Is32Bit = MRI.getType(LHS.getReg()).getSizeInBits() == 32; + auto RC = Is32Bit ? &AArch64::GPR32RegClass : &AArch64::GPR64RegClass; + return emitSUBS(MRI.createVirtualRegister(RC), LHS, RHS, MIRBuilder); +} + +MachineInstr * AArch64InstructionSelector::emitCMN(MachineOperand &LHS, MachineOperand &RHS, MachineIRBuilder &MIRBuilder) const { MachineRegisterInfo &MRI = MIRBuilder.getMF().getRegInfo(); @@ -4464,8 +4475,7 @@ MachineInstr *AArch64InstructionSelector::emitIntegerCompare( // Fold the compare into a cmn or tst if possible. if (auto FoldCmp = tryFoldIntegerCompare(LHS, RHS, Predicate, MIRBuilder)) return FoldCmp; - auto Dst = MRI.cloneVirtualRegister(LHS.getReg()); - return emitSUBS(Dst, LHS, RHS, MIRBuilder); + return emitCMP(LHS, RHS, MIRBuilder); } MachineInstr *AArch64InstructionSelector::emitCSetForFCmp( @@ -4870,9 +4880,8 @@ MachineInstr *AArch64InstructionSelector::emitConjunctionRec( // Produce a normal comparison if we are first in the chain if (!CCOp) { - auto Dst = MRI.cloneVirtualRegister(LHS); if (isa<GICmp>(Cmp)) - return emitSUBS(Dst, Cmp->getOperand(2), Cmp->getOperand(3), MIB); + return emitCMP(Cmp->getOperand(2), Cmp->getOperand(3), MIB); return emitFPCompare(Cmp->getOperand(2).getReg(), Cmp->getOperand(3).getReg(), MIB); } @@ -5666,6 +5675,9 @@ AArch64InstructionSelector::emitConstantVector(Register Dst, Constant *CV, MachineRegisterInfo &MRI) { LLT DstTy = MRI.getType(Dst); unsigned DstSize = DstTy.getSizeInBits(); + assert((DstSize == 64 || DstSize == 128) && + "Unexpected vector constant size"); + if (CV->isNullValue()) { if (DstSize == 128) { auto Mov = @@ -5735,17 +5747,24 @@ AArch64InstructionSelector::emitConstantVector(Register Dst, Constant *CV, // Try to create the new constants with MOVI, and if so generate a fneg // for it. if (auto *NewOp = TryMOVIWithBits(NegBits)) { - Register NewDst = MRI.createVirtualRegister(&AArch64::FPR128RegClass); + Register NewDst = MRI.createVirtualRegister( + DstSize == 64 ? &AArch64::FPR64RegClass : &AArch64::FPR128RegClass); NewOp->getOperand(0).setReg(NewDst); return MIRBuilder.buildInstr(NegOpc, {Dst}, {NewDst}); } return nullptr; }; MachineInstr *R; - if ((R = TryWithFNeg(DefBits, 32, AArch64::FNEGv4f32)) || - (R = TryWithFNeg(DefBits, 64, AArch64::FNEGv2f64)) || + if ((R = TryWithFNeg(DefBits, 32, + DstSize == 64 ? AArch64::FNEGv2f32 + : AArch64::FNEGv4f32)) || + (R = TryWithFNeg(DefBits, 64, + DstSize == 64 ? AArch64::FNEGDr + : AArch64::FNEGv2f64)) || (STI.hasFullFP16() && - (R = TryWithFNeg(DefBits, 16, AArch64::FNEGv8f16)))) + (R = TryWithFNeg(DefBits, 16, + DstSize == 64 ? AArch64::FNEGv4f16 + : AArch64::FNEGv8f16)))) return R; } @@ -7049,7 +7068,7 @@ AArch64InstructionSelector::selectNegArithImmed(MachineOperand &Root) const { /// %9:gpr(p0) = G_PTR_ADD %0, %8(s64) /// %12:gpr(s32) = G_LOAD %9(p0) :: (load (s16)) std::optional<bool> AArch64InstructionSelector::isWorthFoldingIntoAddrMode( - MachineInstr &MI, const MachineRegisterInfo &MRI) const { + const MachineInstr &MI, const MachineRegisterInfo &MRI) const { if (MI.getOpcode() == AArch64::G_SHL) { // Address operands with shifts are free, except for running on subtargets // with AddrLSLSlow14. @@ -7070,7 +7089,7 @@ std::optional<bool> AArch64InstructionSelector::isWorthFoldingIntoAddrMode( /// \p IsAddrOperand whether the def of MI is used as an address operand /// (e.g. feeding into an LDR/STR). bool AArch64InstructionSelector::isWorthFoldingIntoExtendedReg( - MachineInstr &MI, const MachineRegisterInfo &MRI, + const MachineInstr &MI, const MachineRegisterInfo &MRI, bool IsAddrOperand) const { // Always fold if there is one use, or if we're optimizing for size. diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 5f93847..44a1489 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -21,6 +21,7 @@ #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/GlobalISel/Utils.h" #include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/IR/DerivedTypes.h" @@ -289,7 +290,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .moreElementsToNextPow2(0) .lower(); - getActionDefinitionsBuilder({G_ABDS, G_ABDU}) + getActionDefinitionsBuilder( + {G_ABDS, G_ABDU, G_UAVGFLOOR, G_UAVGCEIL, G_SAVGFLOOR, G_SAVGCEIL}) .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32}) .lower(); @@ -430,11 +432,6 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .minScalar(0, s32) .scalarize(0); - getActionDefinitionsBuilder({G_INTRINSIC_LRINT, G_INTRINSIC_LLRINT}) - .legalFor({{s64, MinFPScalar}, {s64, s32}, {s64, s64}}) - .libcallFor({{s64, s128}}) - .minScalarOrElt(1, MinFPScalar); - getActionDefinitionsBuilder({G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2, G_FLOG10, G_FTAN, G_FEXP, G_FEXP2, G_FEXP10, G_FACOS, G_FASIN, G_FATAN, G_FATAN2, G_FCOSH, @@ -449,10 +446,17 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .minScalar(0, s32) .libcallFor({{s32, s32}, {s64, s32}, {s128, s32}}); - // TODO: Libcall support for s128. - // TODO: s16 should be legal with full FP16 support. - getActionDefinitionsBuilder({G_LROUND, G_LLROUND}) - .legalFor({{s64, s32}, {s64, s64}}); + getActionDefinitionsBuilder({G_LROUND, G_INTRINSIC_LRINT}) + .legalFor({{s32, s32}, {s32, s64}, {s64, s32}, {s64, s64}}) + .legalFor(HasFP16, {{s32, s16}, {s64, s16}}) + .minScalar(1, s32) + .libcallFor({{s64, s128}}); + getActionDefinitionsBuilder({G_LLROUND, G_INTRINSIC_LLRINT}) + .legalFor({{s64, s32}, {s64, s64}}) + .legalFor(HasFP16, {{s64, s16}}) + .minScalar(0, s64) + .minScalar(1, s32) + .libcallFor({{s64, s128}}); // TODO: Custom legalization for mismatched types. getActionDefinitionsBuilder(G_FCOPYSIGN) @@ -817,14 +821,33 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .legalFor( {{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}}) .libcallFor({{s16, s128}, {s32, s128}, {s64, s128}}) - .clampNumElements(0, v4s16, v4s16) - .clampNumElements(0, v2s32, v2s32) + .moreElementsToNextPow2(1) + .customIf([](const LegalityQuery &Q) { + LLT DstTy = Q.Types[0]; + LLT SrcTy = Q.Types[1]; + return SrcTy.isFixedVector() && DstTy.isFixedVector() && + SrcTy.getScalarSizeInBits() == 64 && + DstTy.getScalarSizeInBits() == 16; + }) + // Clamp based on input + .clampNumElements(1, v4s32, v4s32) + .clampNumElements(1, v2s64, v2s64) .scalarize(0); getActionDefinitionsBuilder(G_FPEXT) .legalFor( {{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}}) .libcallFor({{s128, s64}, {s128, s32}, {s128, s16}}) + .moreElementsToNextPow2(0) + .widenScalarIf( + [](const LegalityQuery &Q) { + LLT DstTy = Q.Types[0]; + LLT SrcTy = Q.Types[1]; + return SrcTy.isVector() && DstTy.isVector() && + SrcTy.getScalarSizeInBits() == 16 && + DstTy.getScalarSizeInBits() == 64; + }, + changeElementTo(1, s32)) .clampNumElements(0, v4s32, v4s32) .clampNumElements(0, v2s64, v2s64) .scalarize(0); @@ -1230,7 +1253,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .legalFor({{v16s8, v8s8}, {v8s16, v4s16}, {v4s32, v2s32}}) .bitcastIf( [=](const LegalityQuery &Query) { - return Query.Types[0].getSizeInBits() <= 128 && + return Query.Types[0].isFixedVector() && + Query.Types[1].isFixedVector() && + Query.Types[0].getSizeInBits() <= 128 && Query.Types[1].getSizeInBits() <= 64; }, [=](const LegalityQuery &Query) { @@ -1464,6 +1489,10 @@ bool AArch64LegalizerInfo::legalizeCustom( return legalizeICMP(MI, MRI, MIRBuilder); case TargetOpcode::G_BITCAST: return legalizeBitcast(MI, Helper); + case TargetOpcode::G_FPTRUNC: + // In order to lower f16 to f64 properly, we need to use f32 as an + // intermediary + return legalizeFptrunc(MI, MIRBuilder, MRI); } llvm_unreachable("expected switch to return"); @@ -1809,6 +1838,9 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerBinOp(TargetOpcode::G_FMAXNUM); case Intrinsic::aarch64_neon_fminnm: return LowerBinOp(TargetOpcode::G_FMINNUM); + case Intrinsic::aarch64_neon_pmull: + case Intrinsic::aarch64_neon_pmull64: + return LowerBinOp(AArch64::G_PMULL); case Intrinsic::aarch64_neon_smull: return LowerBinOp(AArch64::G_SMULL); case Intrinsic::aarch64_neon_umull: @@ -1817,6 +1849,14 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerBinOp(TargetOpcode::G_ABDS); case Intrinsic::aarch64_neon_uabd: return LowerBinOp(TargetOpcode::G_ABDU); + case Intrinsic::aarch64_neon_uhadd: + return LowerBinOp(TargetOpcode::G_UAVGFLOOR); + case Intrinsic::aarch64_neon_urhadd: + return LowerBinOp(TargetOpcode::G_UAVGCEIL); + case Intrinsic::aarch64_neon_shadd: + return LowerBinOp(TargetOpcode::G_SAVGFLOOR); + case Intrinsic::aarch64_neon_srhadd: + return LowerBinOp(TargetOpcode::G_SAVGCEIL); case Intrinsic::aarch64_neon_abs: { // Lower the intrinsic to G_ABS. MIB.buildInstr(TargetOpcode::G_ABS, {MI.getOperand(0)}, {MI.getOperand(2)}); @@ -2390,3 +2430,80 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI, MI.eraseFromParent(); return true; } + +bool AArch64LegalizerInfo::legalizeFptrunc(MachineInstr &MI, + MachineIRBuilder &MIRBuilder, + MachineRegisterInfo &MRI) const { + auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs(); + assert(SrcTy.isFixedVector() && isPowerOf2_32(SrcTy.getNumElements()) && + "Expected a power of 2 elements"); + + LLT s16 = LLT::scalar(16); + LLT s32 = LLT::scalar(32); + LLT s64 = LLT::scalar(64); + LLT v2s16 = LLT::fixed_vector(2, s16); + LLT v4s16 = LLT::fixed_vector(4, s16); + LLT v2s32 = LLT::fixed_vector(2, s32); + LLT v4s32 = LLT::fixed_vector(4, s32); + LLT v2s64 = LLT::fixed_vector(2, s64); + + SmallVector<Register> RegsToUnmergeTo; + SmallVector<Register> TruncOddDstRegs; + SmallVector<Register> RegsToMerge; + + unsigned ElemCount = SrcTy.getNumElements(); + + // Find the biggest size chunks we can work with + int StepSize = ElemCount % 4 ? 2 : 4; + + // If we have a power of 2 greater than 2, we need to first unmerge into + // enough pieces + if (ElemCount <= 2) + RegsToUnmergeTo.push_back(Src); + else { + for (unsigned i = 0; i < ElemCount / 2; ++i) + RegsToUnmergeTo.push_back(MRI.createGenericVirtualRegister(v2s64)); + + MIRBuilder.buildUnmerge(RegsToUnmergeTo, Src); + } + + // Create all of the round-to-odd instructions and store them + for (auto SrcReg : RegsToUnmergeTo) { + Register Mid = + MIRBuilder.buildInstr(AArch64::G_FPTRUNC_ODD, {v2s32}, {SrcReg}) + .getReg(0); + TruncOddDstRegs.push_back(Mid); + } + + // Truncate 4s32 to 4s16 if we can to reduce instruction count, otherwise + // truncate 2s32 to 2s16. + unsigned Index = 0; + for (unsigned LoopIter = 0; LoopIter < ElemCount / StepSize; ++LoopIter) { + if (StepSize == 4) { + Register ConcatDst = + MIRBuilder + .buildMergeLikeInstr( + {v4s32}, {TruncOddDstRegs[Index++], TruncOddDstRegs[Index++]}) + .getReg(0); + + RegsToMerge.push_back( + MIRBuilder.buildFPTrunc(v4s16, ConcatDst).getReg(0)); + } else { + RegsToMerge.push_back( + MIRBuilder.buildFPTrunc(v2s16, TruncOddDstRegs[Index++]).getReg(0)); + } + } + + // If there is only one register, replace the destination + if (RegsToMerge.size() == 1) { + MRI.replaceRegWith(Dst, RegsToMerge.pop_back_val()); + MI.eraseFromParent(); + return true; + } + + // Merge the rest of the instructions & replace the register + Register Fin = MIRBuilder.buildMergeLikeInstr(DstTy, RegsToMerge).getReg(0); + MRI.replaceRegWith(Dst, Fin); + MI.eraseFromParent(); + return true; +}
\ No newline at end of file diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h index bcb29432..12b6a6f 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h @@ -67,6 +67,8 @@ private: bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const; bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const; bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const; + bool legalizeFptrunc(MachineInstr &MI, MachineIRBuilder &MIRBuilder, + MachineRegisterInfo &MRI) const; const AArch64Subtarget *ST; }; } // End llvm namespace. diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp index 23dcaea..221a7bc 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp @@ -215,14 +215,15 @@ bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI, ShuffleVectorPseudo &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); unsigned WhichResult; + unsigned OperandOrder; ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask(); Register Dst = MI.getOperand(0).getReg(); unsigned NumElts = MRI.getType(Dst).getNumElements(); - if (!isTRNMask(ShuffleMask, NumElts, WhichResult)) + if (!isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) return false; unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2; - Register V1 = MI.getOperand(1).getReg(); - Register V2 = MI.getOperand(2).getReg(); + Register V1 = MI.getOperand(OperandOrder == 0 ? 1 : 2).getReg(); + Register V2 = MI.getOperand(OperandOrder == 0 ? 2 : 1).getReg(); MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2}); return true; } @@ -252,14 +253,15 @@ bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI, ShuffleVectorPseudo &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); unsigned WhichResult; + unsigned OperandOrder; ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask(); Register Dst = MI.getOperand(0).getReg(); unsigned NumElts = MRI.getType(Dst).getNumElements(); - if (!isZIPMask(ShuffleMask, NumElts, WhichResult)) + if (!isZIPMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) return false; unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2; - Register V1 = MI.getOperand(1).getReg(); - Register V2 = MI.getOperand(2).getReg(); + Register V1 = MI.getOperand(OperandOrder == 0 ? 1 : 2).getReg(); + Register V2 = MI.getOperand(OperandOrder == 0 ? 2 : 1).getReg(); MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2}); return true; } diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp index 896eab5..29538d0 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp @@ -435,6 +435,8 @@ bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, Register ExtSrcReg = ExtMI->getOperand(1).getReg(); LLT ExtSrcTy = MRI.getType(ExtSrcReg); LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); + if (ExtSrcTy.getScalarSizeInBits() * 2 > DstTy.getScalarSizeInBits()) + return false; if ((DstTy.getScalarSizeInBits() == 16 && ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) || (DstTy.getScalarSizeInBits() == 32 && @@ -492,7 +494,7 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2; LLT MidScalarLLT = LLT::scalar(MidScalarSize); - Register zeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0); + Register ZeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0); for (unsigned I = 0; I < WorkingRegisters.size(); I++) { // If the number of elements is too small to build an instruction, extend // its size before applying addlv @@ -508,10 +510,10 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Generate the {U/S}ADDLV instruction, whose output is always double of the // Src's Scalar size - LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32) + LLT AddlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32) : LLT::fixed_vector(2, 64); - Register addlvReg = - B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]}).getReg(0); + Register AddlvReg = + B.buildInstr(Opc, {AddlvTy}, {WorkingRegisters[I]}).getReg(0); // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or // v2i64 register. @@ -520,26 +522,26 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Therefore we have to extract/truncate the the value to the right type if (MidScalarSize == 32 || MidScalarSize == 64) { WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, - {MidScalarLLT}, {addlvReg, zeroReg}) + {MidScalarLLT}, {AddlvReg, ZeroReg}) .getReg(0); } else { - Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, - {LLT::scalar(32)}, {addlvReg, zeroReg}) + Register ExtractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT, + {LLT::scalar(32)}, {AddlvReg, ZeroReg}) .getReg(0); WorkingRegisters[I] = - B.buildTrunc({MidScalarLLT}, {extractReg}).getReg(0); + B.buildTrunc({MidScalarLLT}, {ExtractReg}).getReg(0); } } - Register outReg; + Register OutReg; if (WorkingRegisters.size() > 1) { - outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1]) + OutReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1]) .getReg(0); for (unsigned I = 2; I < WorkingRegisters.size(); I++) { - outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I]).getReg(0); + OutReg = B.buildAdd(MidScalarLLT, OutReg, WorkingRegisters[I]).getReg(0); } } else { - outReg = WorkingRegisters[0]; + OutReg = WorkingRegisters[0]; } if (DstTy.getScalarSizeInBits() > MidScalarSize) { @@ -547,9 +549,9 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, // Src's ScalarType B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT, - {DstReg}, {outReg}); + {DstReg}, {OutReg}); } else { - B.buildCopy(DstReg, outReg); + B.buildCopy(DstReg, OutReg); } MI.eraseFromParent(); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp index 6d2d705..4d3d081 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -483,6 +483,14 @@ static bool isFPIntrinsic(const MachineRegisterInfo &MRI, case Intrinsic::aarch64_neon_sqadd: case Intrinsic::aarch64_neon_sqsub: case Intrinsic::aarch64_crypto_sha1h: + case Intrinsic::aarch64_neon_srshl: + case Intrinsic::aarch64_neon_urshl: + case Intrinsic::aarch64_neon_sqshl: + case Intrinsic::aarch64_neon_uqshl: + case Intrinsic::aarch64_neon_sqrshl: + case Intrinsic::aarch64_neon_uqrshl: + case Intrinsic::aarch64_neon_ushl: + case Intrinsic::aarch64_neon_sshl: case Intrinsic::aarch64_crypto_sha1c: case Intrinsic::aarch64_crypto_sha1p: case Intrinsic::aarch64_crypto_sha1m: @@ -560,6 +568,7 @@ bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI, case TargetOpcode::G_FCMP: case TargetOpcode::G_LROUND: case TargetOpcode::G_LLROUND: + case AArch64::G_PMULL: return true; case TargetOpcode::G_INTRINSIC: switch (cast<GIntrinsic>(MI).getIntrinsicID()) { diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp index 7a2b679..1f9694c 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp @@ -586,6 +586,11 @@ public: /// Generate the compact unwind encoding from the CFI directives. uint64_t generateCompactUnwindEncoding(const MCDwarfFrameInfo *FI, const MCContext *Ctxt) const override { + // MTE-tagged frames must use DWARF unwinding because compact unwind + // doesn't handle MTE tags + if (FI->IsMTETaggedFrame) + return CU::UNWIND_ARM64_MODE_DWARF; + ArrayRef<MCCFIInstruction> Instrs = FI->Instructions; if (Instrs.empty()) return CU::UNWIND_ARM64_MODE_FRAMELESS; diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp index 5c3e26e..3e4c110 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp @@ -1034,7 +1034,7 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, if (!GIC || !GIC->haveFeatures(STI.getFeatureBits())) return false; - NeedsReg = true; + NeedsReg = GIC->NeedsReg; Ins = "gic\t"; Name = std::string(GIC->Name); } else { @@ -1047,6 +1047,18 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI, Ins = "gsb\t"; Name = std::string(GSB->Name); } + } else if (CnVal == 10) { + // PLBI aliases + const AArch64PLBI::PLBI *PLBI = AArch64PLBI::lookupPLBIByEncoding(Encoding); + if (!PLBI || !PLBI->haveFeatures(STI.getFeatureBits())) + return false; + + NeedsReg = PLBI->NeedsReg; + if (STI.hasFeature(AArch64::FeatureAll) || + STI.hasFeature(AArch64::FeatureTLBID)) + OptionalReg = PLBI->OptionalReg; + Ins = "plbi\t"; + Name = std::string(PLBI->Name); } else return false; @@ -1114,7 +1126,6 @@ bool AArch64InstPrinter::printSyslAlias(const MCInst *MI, } else return false; - std::string Str; llvm::transform(Name, Name.begin(), ::tolower); O << '\t' << Ins << '\t' << Reg.str() << ", " << Name; @@ -1609,6 +1620,19 @@ void AArch64InstPrinter::printCMHPriorityHintOp(const MCInst *MI, AArch64CMHPriorityHint::lookupCMHPriorityHintByEncoding(priorityhint_op); if (PHint) O << PHint->Name; + else + markup(O, Markup::Immediate) << '#' << formatImm(priorityhint_op); +} + +void AArch64InstPrinter::printTIndexHintOp(const MCInst *MI, unsigned OpNum, + const MCSubtargetInfo &STI, + raw_ostream &O) { + unsigned tindexhintop = MI->getOperand(OpNum).getImm(); + auto TIndex = AArch64TIndexHint::lookupTIndexByEncoding(tindexhintop); + if (TIndex) + O << TIndex->Name; + else + markup(O, Markup::Immediate) << '#' << formatImm(tindexhintop); } void AArch64InstPrinter::printFPImmOperand(const MCInst *MI, unsigned OpNum, diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h index 307402d..3f7a3b4 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.h @@ -156,6 +156,9 @@ protected: void printCMHPriorityHintOp(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O); + void printTIndexHintOp(const MCInst *MI, unsigned OpNum, + const MCSubtargetInfo &STI, raw_ostream &O); + void printFPImmOperand(const MCInst *MI, unsigned OpNum, const MCSubtargetInfo &STI, raw_ostream &O); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFStreamer.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFStreamer.cpp index 64f96c5..942e1bd 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFStreamer.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFStreamer.cpp @@ -53,7 +53,7 @@ void AArch64WinCOFFStreamer::emitWindowsUnwindTables() { } void AArch64WinCOFFStreamer::finishImpl() { - emitFrames(nullptr); + emitFrames(); emitWindowsUnwindTables(); MCWinCOFFStreamer::finishImpl(); diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index 434ea67..b3e1ddb 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -72,20 +72,34 @@ using namespace llvm; namespace { -enum ZAState { +// Note: For agnostic ZA, we assume the function is always entered/exited in the +// "ACTIVE" state -- this _may_ not be the case (since OFF is also a +// possibility, but for the purpose of placing ZA saves/restores, that does not +// matter). +enum ZAState : uint8_t { // Any/unknown state (not valid) ANY = 0, // ZA is in use and active (i.e. within the accumulator) ACTIVE, + // ZA is active, but ZT0 has been saved. + // This handles the edge case of sharedZA && !sharesZT0. + ACTIVE_ZT0_SAVED, + // A ZA save has been set up or committed (i.e. ZA is dormant or off) + // If the function uses ZT0 it must also be saved. LOCAL_SAVED, - // ZA is off or a lazy save has been set up by the caller - CALLER_DORMANT, + // ZA has been committed to the lazy save buffer of the current function. + // If the function uses ZT0 it must also be saved. + // ZA is off. + LOCAL_COMMITTED, + + // The ZA/ZT0 state on entry to the function. + ENTRY, - // ZA is off + // ZA is off. OFF, // The number of ZA states (not a valid state) @@ -121,8 +135,10 @@ struct InstInfo { /// Contains the needed ZA state for each instruction in a block. Instructions /// that do not require a ZA state are not recorded. struct BlockInfo { - ZAState FixedEntryState{ZAState::ANY}; SmallVector<InstInfo> Insts; + ZAState FixedEntryState{ZAState::ANY}; + ZAState DesiredIncomingState{ZAState::ANY}; + ZAState DesiredOutgoingState{ZAState::ANY}; LiveRegs PhysLiveRegsAtEntry = LiveRegs::None; LiveRegs PhysLiveRegsAtExit = LiveRegs::None; }; @@ -162,6 +178,14 @@ public: return AgnosticZABufferPtr; } + int getZT0SaveSlot(MachineFunction &MF) { + if (ZT0SaveFI) + return *ZT0SaveFI; + MachineFrameInfo &MFI = MF.getFrameInfo(); + ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16)); + return *ZT0SaveFI; + } + /// Returns true if the function must allocate a ZA save buffer on entry. This /// will be the case if, at any point in the function, a ZA save was emitted. bool needsSaveBuffer() const { @@ -171,14 +195,22 @@ public: } private: + std::optional<int> ZT0SaveFI; std::optional<int> TPIDR2BlockFI; Register AgnosticZABufferPtr = AArch64::NoRegister; }; +/// Checks if \p State is a legal edge bundle state. For a state to be a legal +/// bundle state, it must be possible to transition from it to any other bundle +/// state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED, +/// as you can transition between those states by saving/restoring ZA. The OFF +/// state would not be legal, as transitioning to it drops the content of ZA. static bool isLegalEdgeBundleZAState(ZAState State) { switch (State) { - case ZAState::ACTIVE: - case ZAState::LOCAL_SAVED: + case ZAState::ACTIVE: // ZA state within the accumulator/ZT0. + case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active). + case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack. + case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack. return true; default: return false; @@ -192,8 +224,10 @@ StringRef getZAStateString(ZAState State) { switch (State) { MAKE_CASE(ZAState::ANY) MAKE_CASE(ZAState::ACTIVE) + MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED) MAKE_CASE(ZAState::LOCAL_SAVED) - MAKE_CASE(ZAState::CALLER_DORMANT) + MAKE_CASE(ZAState::LOCAL_COMMITTED) + MAKE_CASE(ZAState::ENTRY) MAKE_CASE(ZAState::OFF) default: llvm_unreachable("Unexpected ZAState"); @@ -214,18 +248,39 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI, /// Returns the required ZA state needed before \p MI and an iterator pointing /// to where any code required to change the ZA state should be inserted. static std::pair<ZAState, MachineBasicBlock::iterator> -getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI, - bool ZAOffAtReturn) { +getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI, + SMEAttrs SMEFnAttrs) { MachineBasicBlock::iterator InsertPt(MI); + // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are + // intended to mark the position immediately before a call. Due to + // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN, + // so we use std::prev(InsertPt) to get the position before the call. + if (MI.getOpcode() == AArch64::InOutZAUsePseudo) return {ZAState::ACTIVE, std::prev(InsertPt)}; + // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo. if (MI.getOpcode() == AArch64::RequiresZASavePseudo) return {ZAState::LOCAL_SAVED, std::prev(InsertPt)}; - if (MI.isReturn()) + // If we only need to save ZT0 there's two cases to consider: + // 1. The function has ZA state (that we don't need to save). + // - In this case we switch to the "ACTIVE_ZT0_SAVED" state. + // This only saves ZT0. + // 2. The function does not have ZA state + // - In this case we switch to "LOCAL_COMMITTED" state. + // This saves ZT0 and turns ZA off. + if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) { + return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED + : ZAState::LOCAL_COMMITTED, + std::prev(InsertPt)}; + } + + if (MI.isReturn()) { + bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface(); return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt}; + } for (auto &MO : MI.operands()) { if (isZAorZTRegOp(TRI, MO)) @@ -238,7 +293,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI, struct MachineSMEABI : public MachineFunctionPass { inline static char ID = 0; - MachineSMEABI() : MachineFunctionPass(ID) {} + MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default) + : MachineFunctionPass(ID), OptLevel(OptLevel) {} bool runOnMachineFunction(MachineFunction &MF) override; @@ -267,9 +323,17 @@ struct MachineSMEABI : public MachineFunctionPass { const EdgeBundles &Bundles, ArrayRef<ZAState> BundleStates); + /// Propagates desired states forwards (from predecessors -> successors) if + /// \p Forwards, otherwise, propagates backwards (from successors -> + /// predecessors). + void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true); + + void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, bool IsSave); + // Emission routines for private and shared ZA functions (using lazy saves). - void emitNewZAPrologue(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI); + void emitSMEPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs); @@ -277,8 +341,8 @@ struct MachineSMEABI : public MachineFunctionPass { MachineBasicBlock::iterator MBBI); void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); - void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, - bool ClearTPIDR2); + void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, + bool ClearTPIDR2, bool On); // Emission routines for agnostic ZA functions. void emitSetupFullZASave(MachineBasicBlock &MBB, @@ -335,12 +399,15 @@ struct MachineSMEABI : public MachineFunctionPass { MachineBasicBlock::iterator MBBI, DebugLoc DL); private: + CodeGenOptLevel OptLevel = CodeGenOptLevel::Default; + MachineFunction *MF = nullptr; const AArch64Subtarget *Subtarget = nullptr; const AArch64RegisterInfo *TRI = nullptr; const AArch64FunctionInfo *AFI = nullptr; const TargetInstrInfo *TII = nullptr; MachineRegisterInfo *MRI = nullptr; + MachineLoopInfo *MLI = nullptr; }; static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) { @@ -365,6 +432,17 @@ static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) { LiveUnits.addReg(AArch64::W0_HI); } +[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) { + switch (Opc) { + case AArch64::TLSDESC_CALLSEQ: + case AArch64::TLSDESC_AUTH_CALLSEQ: + case AArch64::ADJCALLSTACKDOWN: + return true; + default: + return false; + } +} + FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) && @@ -379,12 +457,10 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { if (MBB.isEntryBlock()) { // Entry block: - Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface() - ? ZAState::CALLER_DORMANT - : ZAState::ACTIVE; + Block.FixedEntryState = ZAState::ENTRY; } else if (MBB.isEHPad()) { // EH entry block: - Block.FixedEntryState = ZAState::LOCAL_SAVED; + Block.FixedEntryState = ZAState::LOCAL_COMMITTED; } LiveRegUnits LiveUnits(*TRI); @@ -406,10 +482,8 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { PhysLiveRegsAfterSMEPrologue = PhysLiveRegs; } // Note: We treat Agnostic ZA as inout_za with an alternate save/restore. - auto [NeededState, InsertPt] = getZAStateBeforeInst( - *TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface()); - assert((InsertPt == MBBI || - InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) && + auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs); + assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) && "Unexpected state change insertion point!"); // TODO: Do something to avoid state changes where NZCV is live. if (MBBI == FirstTerminatorInsertPt) @@ -422,12 +496,69 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { // Reverse vector (as we had to iterate backwards for liveness). std::reverse(Block.Insts.begin(), Block.Insts.end()); + + // Record the desired states on entry/exit of this block. These are the + // states that would not incur a state transition. + if (!Block.Insts.empty()) { + Block.DesiredIncomingState = Block.Insts.front().NeededState; + Block.DesiredOutgoingState = Block.Insts.back().NeededState; + } } return FunctionInfo{std::move(Blocks), AfterSMEProloguePt, PhysLiveRegsAfterSMEPrologue}; } +void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo, + bool Forwards) { + // If `Forwards`, this propagates desired states from predecessors to + // successors, otherwise, this propagates states from successors to + // predecessors. + auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & { + return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState; + }; + + SmallVector<MachineBasicBlock *> Worklist; + for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) { + if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards))) + Worklist.push_back(MF->getBlockNumbered(BlockID)); + } + + while (!Worklist.empty()) { + MachineBasicBlock *MBB = Worklist.pop_back_val(); + BlockInfo &Block = FnInfo.Blocks[MBB->getNumber()]; + + // Pick a legal edge bundle state that matches the majority of + // predecessors/successors. + int StateCounts[ZAState::NUM_ZA_STATE] = {0}; + for (MachineBasicBlock *PredOrSucc : + Forwards ? predecessors(MBB) : successors(MBB)) { + BlockInfo &PredOrSuccBlock = FnInfo.Blocks[PredOrSucc->getNumber()]; + ZAState ZAState = GetBlockState(PredOrSuccBlock, !Forwards); + if (isLegalEdgeBundleZAState(ZAState)) + StateCounts[ZAState]++; + } + + ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts); + ZAState &CurrentState = GetBlockState(Block, Forwards); + if (PropagatedState != CurrentState) { + CurrentState = PropagatedState; + ZAState &OtherState = GetBlockState(Block, !Forwards); + // Propagate to the incoming/outgoing state if that is also "ANY". + if (OtherState == ZAState::ANY) + OtherState = PropagatedState; + // Push any successors/predecessors that may need updating to the + // worklist. + for (MachineBasicBlock *SuccOrPred : + Forwards ? successors(MBB) : predecessors(MBB)) { + BlockInfo &SuccOrPredBlock = FnInfo.Blocks[SuccOrPred->getNumber()]; + if (!isLegalEdgeBundleZAState(GetBlockState(SuccOrPredBlock, Forwards))) + Worklist.push_back(SuccOrPred); + } + } + } +} + /// Assigns each edge bundle a ZA state based on the needed states of blocks /// that have incoming or outgoing edges in that bundle. SmallVector<ZAState> @@ -440,40 +571,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles, // Attempt to assign a ZA state for this bundle that minimizes state // transitions. Edges within loops are given a higher weight as we assume // they will be executed more than once. - // TODO: We should propagate desired incoming/outgoing states through blocks - // that have the "ANY" state first to make better global decisions. int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0}; for (unsigned BlockID : Bundles.getBlocks(I)) { LLVM_DEBUG(dbgs() << "- bb." << BlockID); const BlockInfo &Block = FnInfo.Blocks[BlockID]; - if (Block.Insts.empty()) { - LLVM_DEBUG(dbgs() << " (no state preference)\n"); - continue; - } bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I; bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I; - ZAState DesiredIncomingState = Block.Insts.front().NeededState; - if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) { - EdgeStateCounts[DesiredIncomingState]++; + bool LegalInEdge = + InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState); + bool LegalOutEgde = + OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState); + if (LegalInEdge) { LLVM_DEBUG(dbgs() << " DesiredIncomingState: " - << getZAStateString(DesiredIncomingState)); + << getZAStateString(Block.DesiredIncomingState)); + EdgeStateCounts[Block.DesiredIncomingState]++; } - ZAState DesiredOutgoingState = Block.Insts.back().NeededState; - if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) { - EdgeStateCounts[DesiredOutgoingState]++; + if (LegalOutEgde) { LLVM_DEBUG(dbgs() << " DesiredOutgoingState: " - << getZAStateString(DesiredOutgoingState)); + << getZAStateString(Block.DesiredOutgoingState)); + EdgeStateCounts[Block.DesiredOutgoingState]++; } + if (!LegalInEdge && !LegalOutEgde) + LLVM_DEBUG(dbgs() << " (no state preference)"); LLVM_DEBUG(dbgs() << '\n'); } ZAState BundleState = ZAState(max_element(EdgeStateCounts) - EdgeStateCounts); - // Force ZA to be active in bundles that don't have a preferred state. - // TODO: Something better here (to avoid extra mode switches). if (BundleState == ZAState::ANY) BundleState = ZAState::ACTIVE; @@ -505,8 +632,8 @@ MachineSMEABI::findStateChangeInsertionPoint( PhysLiveRegs = Block.PhysLiveRegsAtExit; } - if (!(PhysLiveRegs & LiveRegs::NZCV)) - return {InsertPt, PhysLiveRegs}; // Nothing to do (no live flags). + if (PhysLiveRegs == LiveRegs::None) + return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs). // Find the previous state change. We can not move before this point. MachineBasicBlock::iterator PrevStateChangeI; @@ -523,15 +650,21 @@ MachineSMEABI::findStateChangeInsertionPoint( // Note: LiveUnits will only accurately track X0 and NZCV. LiveRegUnits LiveUnits(*TRI); setPhysLiveRegs(LiveUnits, PhysLiveRegs); + auto BestCandidate = std::make_pair(InsertPt, PhysLiveRegs); for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) { // Don't move before/into a call (which may have a state change before it). if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall()) break; LiveUnits.stepBackward(*I); - if (LiveUnits.available(AArch64::NZCV)) - return {I, getPhysLiveRegs(LiveUnits)}; + LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits); + // Find places where NZCV is available, but keep looking for locations where + // both NZCV and X0 are available, which can avoid some copies. + if (!(CurrentPhysLiveRegs & LiveRegs::NZCV)) + BestCandidate = {I, CurrentPhysLiveRegs}; + if (CurrentPhysLiveRegs == LiveRegs::None) + break; } - return {InsertPt, PhysLiveRegs}; + return BestCandidate; } void MachineSMEABI::insertStateChanges(EmitContext &Context, @@ -675,9 +808,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context, restorePhyRegSave(RegSave, MBB, MBBI, DL); } -void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI, - bool ClearTPIDR2) { +void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + bool ClearTPIDR2, bool On) { DebugLoc DL = getDebugLoc(MBB, MBBI); if (ClearTPIDR2) @@ -688,7 +821,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB, // Disable ZA. BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) .addImm(AArch64SVCR::SVCRZA) - .addImm(0); + .addImm(On ? 1 : 0); } void MachineSMEABI::emitAllocateLazySaveBuffer( @@ -746,31 +879,46 @@ void MachineSMEABI::emitAllocateLazySaveBuffer( } } -void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI) { +static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111; + +void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { auto *TLI = Subtarget->getTargetLowering(); DebugLoc DL = getDebugLoc(MBB, MBBI); - // Get current TPIDR2_EL0. - Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); - BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS)) - .addReg(TPIDR2EL0, RegState::Define) - .addImm(AArch64SysReg::TPIDR2_EL0); - // If TPIDR2_EL0 is non-zero, commit the lazy save. - // NOTE: Functions that only use ZT0 don't need to zero ZA. - bool ZeroZA = AFI->getSMEFnAttrs().hasZAState(); - auto CommitZASave = - BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo)) - .addReg(TPIDR2EL0) - .addImm(ZeroZA ? 1 : 0) - .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE)) - .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); - if (ZeroZA) - CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine); - // Enable ZA (as ZA could have previously been in the OFF state). - BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) - .addImm(AArch64SVCR::SVCRZA) - .addImm(1); + bool ZeroZA = AFI->getSMEFnAttrs().isNewZA(); + bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0(); + if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) { + // Get current TPIDR2_EL0. + Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS)) + .addReg(TPIDR2EL0, RegState::Define) + .addImm(AArch64SysReg::TPIDR2_EL0); + // If TPIDR2_EL0 is non-zero, commit the lazy save. + // NOTE: Functions that only use ZT0 don't need to zero ZA. + auto CommitZASave = + BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo)) + .addReg(TPIDR2EL0) + .addImm(ZeroZA) + .addImm(ZeroZT0) + .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE)) + .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + if (ZeroZA) + CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine); + if (ZeroZT0) + CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine); + // Enable ZA (as ZA could have previously been in the OFF state). + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) + .addImm(AArch64SVCR::SVCRZA) + .addImm(1); + } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) { + if (ZeroZA) + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M)) + .addImm(ZERO_ALL_ZA_MASK) + .addDef(AArch64::ZAB0, RegState::ImplicitDefine); + if (ZeroZT0) + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0); + } } void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context, @@ -799,6 +947,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context, restorePhyRegSave(RegSave, MBB, MBBI, DL); } +void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + bool IsSave) { + DebugLoc DL = getDebugLoc(MBB, MBBI); + Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass); + + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save) + .addFrameIndex(Context.getZT0SaveSlot(*MF)) + .addImm(0) + .addImm(0); + + if (IsSave) { + BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX)) + .addReg(AArch64::ZT0) + .addReg(ZT0Save); + } else { + BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0) + .addReg(ZT0Save); + } +} + void MachineSMEABI::emitAllocateFullZASaveBuffer( EmitContext &Context, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { @@ -843,6 +1013,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer( restorePhyRegSave(RegSave, MBB, MBBI, DL); } +struct FromState { + ZAState From; + + constexpr uint8_t to(ZAState To) const { + static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits"); + return uint8_t(From) << 4 | uint8_t(To); + } +}; + +constexpr FromState transitionFrom(ZAState From) { return FromState{From}; } + void MachineSMEABI::emitStateChange(EmitContext &Context, MachineBasicBlock &MBB, MachineBasicBlock::iterator InsertPt, @@ -852,19 +1033,17 @@ void MachineSMEABI::emitStateChange(EmitContext &Context, if (From == ZAState::ANY || To == ZAState::ANY) return; - // If we're exiting from the CALLER_DORMANT state that means this new ZA - // function did not touch ZA (so ZA was never turned on). - if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF) + // If we're exiting from the ENTRY state that means that the function has not + // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up. + if (From == ZAState::ENTRY && To == ZAState::OFF) return; // TODO: Avoid setting up the save buffer if there's no transition to // LOCAL_SAVED. - if (From == ZAState::CALLER_DORMANT) { - assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() && - "CALLER_DORMANT state requires private ZA interface"); + if (From == ZAState::ENTRY) { assert(&MBB == &MBB.getParent()->front() && - "CALLER_DORMANT state only valid in entry block"); - emitNewZAPrologue(MBB, MBB.getFirstNonPHI()); + "ENTRY state only valid in entry block"); + emitSMEPrologue(MBB, MBB.getFirstNonPHI()); if (To == ZAState::ACTIVE) return; // Nothing more to do (ZA is active after the prologue). @@ -874,17 +1053,67 @@ void MachineSMEABI::emitStateChange(EmitContext &Context, From = ZAState::ACTIVE; } - if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED) - emitZASave(Context, MBB, InsertPt, PhysLiveRegs); - else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE) - emitZARestore(Context, MBB, InsertPt, PhysLiveRegs); - else if (To == ZAState::OFF) { - assert(From != ZAState::CALLER_DORMANT && - "CALLER_DORMANT to OFF should have already been handled"); - assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() && - "Should not turn ZA off in agnostic ZA function"); - emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED); - } else { + SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs(); + bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface(); + bool HasZT0State = SMEFnAttrs.hasZT0State(); + bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState(); + + switch (transitionFrom(From).to(To)) { + // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED + case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED): + emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true); + break; + case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE): + emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false); + break; + + // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED + case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED): + case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED): + if (HasZT0State && From == ZAState::ACTIVE) + emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true); + if (HasZAState) + emitZASave(Context, MBB, InsertPt, PhysLiveRegs); + break; + + // This section handles: ACTIVE -> LOCAL_COMMITTED + case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED): + // TODO: We could support ZA state here, but this transition is currently + // only possible when we _don't_ have ZA state. + assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state."); + emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true); + emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false); + break; + + // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED) + case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF): + case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED): + // These transistions are a no-op. + break; + + // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED] + case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE): + case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED): + case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE): + if (HasZAState) + emitZARestore(Context, MBB, InsertPt, PhysLiveRegs); + else + emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true); + if (HasZT0State && To == ZAState::ACTIVE) + emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false); + break; + + // This section handles transistions to OFF (not previously covered) + case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF): + case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF): + case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF): + assert(SMEFnAttrs.hasPrivateZAInterface() && + "Did not expect to turn ZA off in shared/agnostic ZA function"); + emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED, + /*On=*/false); + break; + + default: dbgs() << "Error: Transition from " << getZAStateString(From) << " to " << getZAStateString(To) << '\n'; llvm_unreachable("Unimplemented state transition"); @@ -918,6 +1147,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) { getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles(); FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs); + + if (OptLevel != CodeGenOptLevel::None) { + // Propagate desired states forward, then backwards. Most of the propagation + // should be done in the forward step, and backwards propagation is then + // used to fill in the gaps. Note: Doing both in one step can give poor + // results. For example, consider this subgraph: + // + // ┌─────┐ + // ┌─┤ BB0 ◄───┐ + // │ └─┬───┘ │ + // │ ┌─▼───◄──┐│ + // │ │ BB1 │ ││ + // │ └─┬┬──┘ ││ + // │ │└─────┘│ + // │ ┌─▼───┐ │ + // │ │ BB2 ├───┘ + // │ └─┬───┘ + // │ ┌─▼───┐ + // └─► BB3 │ + // └─────┘ + // + // If: + // - "BB0" and "BB2" (outer loop) has no state preference + // - "BB1" (inner loop) desires the ACTIVE state on entry/exit + // - "BB3" desires the LOCAL_SAVED state on entry + // + // If we propagate forwards first, ACTIVE is propagated from BB1 to BB2, + // then from BB2 to BB0. Which results in the inner and outer loops having + // the "ACTIVE" state. This avoids any state changes in the loops. + // + // If we propagate backwards first, we _could_ propagate LOCAL_SAVED from + // BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED + // in the outer loop. + for (bool Forwards : {true, false}) + propagateDesiredStates(FnInfo, Forwards); + } + SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo); EmitContext Context; @@ -941,4 +1207,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) { return true; } -FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); } +FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) { + return new MachineSMEABI(OptLevel); +} diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index 79ceb2a..6bdad03 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #include "AArch64.h" -#include "Utils/AArch64SMEAttributes.h" +#include "AArch64SMEAttributes.h" #include "llvm/ADT/StringRef.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 1664f4a..1f031f9 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -347,6 +347,11 @@ def SVELogicalImm16Pat : ComplexPattern<i32, 1, "SelectSVELogicalImm<MVT::i16>", def SVELogicalImm32Pat : ComplexPattern<i32, 1, "SelectSVELogicalImm<MVT::i32>", []>; def SVELogicalImm64Pat : ComplexPattern<i64, 1, "SelectSVELogicalImm<MVT::i64>", []>; +def SVELogicalFPImm16Pat : ComplexPattern<f16, 1, "SelectSVELogicalImm<MVT::i16>", []>; +def SVELogicalFPImm32Pat : ComplexPattern<f32, 1, "SelectSVELogicalImm<MVT::i32>", []>; +def SVELogicalFPImm64Pat : ComplexPattern<f64, 1, "SelectSVELogicalImm<MVT::i64>", []>; +def SVELogicalBFPImmPat : ComplexPattern<bf16, 1, "SelectSVELogicalImm<MVT::i16>", []>; + def SVELogicalImm8NotPat : ComplexPattern<i32, 1, "SelectSVELogicalImm<MVT::i8, true>", []>; def SVELogicalImm16NotPat : ComplexPattern<i32, 1, "SelectSVELogicalImm<MVT::i16, true>", []>; def SVELogicalImm32NotPat : ComplexPattern<i32, 1, "SelectSVELogicalImm<MVT::i32, true>", []>; @@ -2160,6 +2165,26 @@ multiclass sve_int_dup_mask_imm<string asm> { (!cast<Instruction>(NAME) i64:$imm)>; def : Pat<(nxv2i64 (splat_vector (i64 (SVELogicalImm64Pat i64:$imm)))), (!cast<Instruction>(NAME) i64:$imm)>; + + def : Pat<(nxv8f16 (splat_vector (f16 (SVELogicalFPImm16Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv4f16 (splat_vector (f16 (SVELogicalFPImm16Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv2f16 (splat_vector (f16 (SVELogicalFPImm16Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv4f32 (splat_vector (f32 (SVELogicalFPImm32Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv2f32 (splat_vector (f32 (SVELogicalFPImm32Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv2f64 (splat_vector (f64 (SVELogicalFPImm64Pat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + + def : Pat<(nxv8bf16 (splat_vector (bf16 (SVELogicalBFPImmPat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv4bf16 (splat_vector (bf16 (SVELogicalBFPImmPat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; + def : Pat<(nxv2bf16 (splat_vector (bf16 (SVELogicalBFPImmPat i64:$imm)))), + (!cast<Instruction>(NAME) i64:$imm)>; } //===----------------------------------------------------------------------===// @@ -2439,8 +2464,6 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op> def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>; def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>; - def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>; - def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>; } multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> { @@ -11143,6 +11166,12 @@ class sve2_fp8_mmla<bit opc, ZPRRegOp dst_ty, string mnemonic> let Uses = [FPMR, FPCR]; } +multiclass sve2_fp8_fmmla<bits<1> opc, ZPRRegOp zprty, string mnemonic, ValueType ResVT> { + def NAME : sve2_fp8_mmla<opc, zprty, mnemonic>; + def : Pat<(ResVT (int_aarch64_sve_fp8_fmmla ResVT:$acc, nxv16i8:$zn, nxv16i8:$zm)), + (!cast<Instruction>(NAME) $acc, $zn, $zm)>; +} + class sve_fp8_dot_indexed<bits<4> opc, ZPRRegOp dst_ty, Operand iop_ty, string mnemonic> : I<(outs dst_ty:$Zda), (ins dst_ty:$_Zda, ZPR8:$Zn, ZPR3b8:$Zm, iop_ty:$iop), mnemonic, "\t$Zda, $Zn, $Zm$iop", "", []>, Sched<[]> { diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp index 268a229..556d2c3 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.cpp @@ -146,6 +146,13 @@ namespace AArch64CMHPriorityHint { } // namespace llvm namespace llvm { +namespace AArch64TIndexHint { +#define GET_TINDEX_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64TIndexHint +} // namespace llvm + +namespace llvm { namespace AArch64SysReg { #define GET_SysRegsList_IMPL #include "AArch64GenSystemOperands.inc" @@ -186,11 +193,18 @@ std::string AArch64SysReg::genericRegisterString(uint32_t Bits) { } namespace llvm { - namespace AArch64TLBI { +namespace AArch64TLBI { #define GET_TLBITable_IMPL #include "AArch64GenSystemOperands.inc" - } -} +} // namespace AArch64TLBI +} // namespace llvm + +namespace llvm { +namespace AArch64PLBI { +#define GET_PLBITable_IMPL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64PLBI +} // namespace llvm namespace llvm { namespace AArch64TLBIP { diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h index 27812e9..0c98fdc7 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -695,6 +695,14 @@ struct CMHPriorityHint : SysAlias { #include "AArch64GenSystemOperands.inc" } // namespace AArch64CMHPriorityHint +namespace AArch64TIndexHint { +struct TIndex : SysAlias { + using SysAlias::SysAlias; +}; +#define GET_TINDEX_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64TIndexHint + namespace AArch64SME { enum ToggleCondition : unsigned { Always, @@ -853,6 +861,14 @@ struct GSB : SysAlias { #include "AArch64GenSystemOperands.inc" } // namespace AArch64GSB +namespace AArch64PLBI { +struct PLBI : SysAliasOptionalReg { + using SysAliasOptionalReg::SysAliasOptionalReg; +}; +#define GET_PLBITable_DECL +#include "AArch64GenSystemOperands.inc" +} // namespace AArch64PLBI + namespace AArch64II { /// Target Operand Flag enum. enum TOF { @@ -987,6 +1003,16 @@ AArch64StringToPACKeyID(StringRef Name) { return std::nullopt; } +inline static unsigned getBTIHintNum(bool CallTarget, bool JumpTarget) { + unsigned HintNum = 32; + if (CallTarget) + HintNum |= 2; + if (JumpTarget) + HintNum |= 4; + assert(HintNum != 32 && "No target kinds!"); + return HintNum; +} + namespace AArch64 { // The number of bits in a SVE register is architecturally defined // to be a multiple of this value. If <M x t> has this number of bits, diff --git a/llvm/lib/Target/AArch64/Utils/CMakeLists.txt b/llvm/lib/Target/AArch64/Utils/CMakeLists.txt index 6ff462c..33b15cc 100644 --- a/llvm/lib/Target/AArch64/Utils/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/Utils/CMakeLists.txt @@ -1,10 +1,8 @@ add_llvm_component_library(LLVMAArch64Utils AArch64BaseInfo.cpp - AArch64SMEAttributes.cpp LINK_COMPONENTS Support - Core ADD_TO_COMPONENT AArch64 |
