aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64')
-rw-r--r--llvm/lib/Target/AArch64/AArch64.h2
-rw-r--r--llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp48
-rw-r--r--llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp34
-rw-r--r--llvm/lib/Target/AArch64/AArch64BranchTargets.cpp46
-rw-r--r--llvm/lib/Target/AArch64/AArch64CallingConvention.cpp4
-rw-r--r--llvm/lib/Target/AArch64/AArch64CallingConvention.h45
-rw-r--r--llvm/lib/Target/AArch64/AArch64Combine.td10
-rw-r--r--llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp79
-rw-r--r--llvm/lib/Target/AArch64/AArch64FastISel.cpp26
-rw-r--r--llvm/lib/Target/AArch64/AArch64Features.td28
-rw-r--r--llvm/lib/Target/AArch64/AArch64FrameLowering.cpp241
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp57
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp768
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.h27
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrFormats.td229
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.cpp550
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.h13
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.td180
-rw-r--r--llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp3
-rw-r--r--llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp64
-rw-r--r--llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h48
-rw-r--r--llvm/lib/Target/AArch64/AArch64MacroFusion.cpp42
-rw-r--r--llvm/lib/Target/AArch64/AArch64Processors.td120
-rw-r--r--llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td55
-rw-r--r--llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td38
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedA320.td2
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedA510.td4
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td2
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td2
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td2
-rw-r--r--llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td2
-rw-r--r--llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp27
-rw-r--r--llvm/lib/Target/AArch64/AArch64StackTagging.cpp3
-rw-r--r--llvm/lib/Target/AArch64/AArch64Subtarget.h17
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetMachine.cpp25
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetMachine.h4
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp363
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h20
-rw-r--r--llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp38
-rw-r--r--llvm/lib/Target/AArch64/CMakeLists.txt1
-rw-r--r--llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp413
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp8
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp125
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp28
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp10
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp48
-rw-r--r--llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h18
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp19
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp10
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp16
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp14
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h2
-rw-r--r--llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp14
-rw-r--r--llvm/lib/Target/AArch64/MachineSMEABIPass.cpp696
-rw-r--r--llvm/lib/Target/AArch64/SMEABIPass.cpp31
-rw-r--r--llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp33
-rw-r--r--llvm/lib/Target/AArch64/SVEInstrFormats.td54
-rw-r--r--llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp48
-rw-r--r--llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h32
59 files changed, 3250 insertions, 1638 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 5496ebd..8d0ff41 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -60,6 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
FunctionPass *createAArch64CollectLOHPass();
FunctionPass *createSMEABIPass();
FunctionPass *createSMEPeepholeOptPass();
+FunctionPass *createMachineSMEABIPass();
ModulePass *createSVEIntrinsicOptsPass();
InstructionSelector *
createAArch64InstructionSelector(const AArch64TargetMachine &,
@@ -111,6 +112,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
void initializeLDTLSCleanupPass(PassRegistry&);
void initializeSMEABIPass(PassRegistry &);
void initializeSMEPeepholeOptPass(PassRegistry &);
+void initializeMachineSMEABIPass(PassRegistry &);
void initializeSVEIntrinsicOptsPass(PassRegistry &);
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
} // end namespace llvm
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
index e8d3161..1169f26 100644
--- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
@@ -316,6 +316,12 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
ThunkArgTranslation::PointerIndirection};
};
+ if (T->isHalfTy()) {
+ // Prefix with `llvm` since MSVC doesn't specify `_Float16`
+ Out << "__llvm_h__";
+ return direct(T);
+ }
+
if (T->isFloatTy()) {
Out << "f";
return direct(T);
@@ -327,8 +333,8 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
}
if (T->isFloatingPointTy()) {
- report_fatal_error(
- "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
+ report_fatal_error("Only 16, 32, and 64 bit floating points are supported "
+ "for ARM64EC thunks");
}
auto &DL = M->getDataLayout();
@@ -342,8 +348,16 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
uint64_t ElementCnt = T->getArrayNumElements();
uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
- if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
- Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
+ if (ElementTy->isHalfTy() || ElementTy->isFloatTy() ||
+ ElementTy->isDoubleTy()) {
+ if (ElementTy->isHalfTy())
+ // Prefix with `llvm` since MSVC doesn't specify `_Float16`
+ Out << "__llvm_H__";
+ else if (ElementTy->isFloatTy())
+ Out << "F";
+ else if (ElementTy->isDoubleTy())
+ Out << "D";
+ Out << TotalSizeBytes;
if (Alignment.value() >= 16 && !Ret)
Out << "a" << Alignment.value();
if (TotalSizeBytes <= 8) {
@@ -355,8 +369,9 @@ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
return pointerIndirection(T);
}
} else if (T->isFloatingPointTy()) {
- report_fatal_error("Only 32 and 64 bit floating points are supported for "
- "ARM64EC thunks");
+ report_fatal_error(
+ "Only 16, 32, and 64 bit floating points are supported "
+ "for ARM64EC thunks");
}
}
@@ -597,6 +612,14 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
return Thunk;
}
+std::optional<std::string> getArm64ECMangledFunctionName(GlobalValue &GV) {
+ if (!GV.hasName()) {
+ GV.setName("__unnamed");
+ }
+
+ return llvm::getArm64ECMangledFunctionName(GV.getName());
+}
+
// Builds the "guest exit thunk", a helper to call a function which may or may
// not be an exit thunk. (We optimistically assume non-dllimport function
// declarations refer to functions defined in AArch64 code; if the linker
@@ -608,7 +631,7 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
getThunkType(F->getFunctionType(), F->getAttributes(),
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
ArgTranslations);
- auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
+ auto MangledName = getArm64ECMangledFunctionName(*F);
assert(MangledName && "Can't guest exit to function that's already native");
std::string ThunkName = *MangledName;
if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
@@ -727,9 +750,6 @@ AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
// Lower an indirect call with inline code.
void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
- assert(CB->getModule()->getTargetTriple().isOSWindows() &&
- "Only applicable for Windows targets");
-
IRBuilder<> B(CB);
Value *CalledOperand = CB->getCalledOperand();
@@ -790,7 +810,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
if (!F)
continue;
if (std::optional<std::string> MangledName =
- getArm64ECMangledFunctionName(A.getName().str())) {
+ getArm64ECMangledFunctionName(A)) {
F->addMetadata("arm64ec_unmangled_name",
*MDNode::get(M->getContext(),
MDString::get(M->getContext(), A.getName())));
@@ -807,7 +827,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
cast<GlobalValue>(F.getPersonalityFn()->stripPointerCasts());
if (PersFn->getValueType() && PersFn->getValueType()->isFunctionTy()) {
if (std::optional<std::string> MangledName =
- getArm64ECMangledFunctionName(PersFn->getName().str())) {
+ getArm64ECMangledFunctionName(*PersFn)) {
PersFn->setName(MangledName.value());
}
}
@@ -821,7 +841,7 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
// Rename hybrid patchable functions and change callers to use a global
// alias instead.
if (std::optional<std::string> MangledName =
- getArm64ECMangledFunctionName(F.getName().str())) {
+ getArm64ECMangledFunctionName(F)) {
std::string OrigName(F.getName());
F.setName(MangledName.value() + HybridPatchableTargetSuffix);
@@ -927,7 +947,7 @@ bool AArch64Arm64ECCallLowering::processFunction(
// FIXME: Handle functions with weak linkage?
if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
if (std::optional<std::string> MangledName =
- getArm64ECMangledFunctionName(F.getName().str())) {
+ getArm64ECMangledFunctionName(F)) {
F.addMetadata("arm64ec_unmangled_name",
*MDNode::get(M->getContext(),
MDString::get(M->getContext(), F.getName())));
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index c52487a..da34430 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -1829,8 +1829,8 @@ void AArch64AsmPrinter::emitMOVK(Register Dest, uint64_t Imm, unsigned Shift) {
void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
Register DestReg = MI.getOperand(0).getReg();
- if (STI->hasZeroCycleZeroingFP() && !STI->hasZeroCycleZeroingFPWorkaround() &&
- STI->isNeonAvailable()) {
+ if (STI->hasZeroCycleZeroingFPR64() &&
+ !STI->hasZeroCycleZeroingFPWorkaround() && STI->isNeonAvailable()) {
// Convert H/S register to corresponding D register
if (AArch64::H0 <= DestReg && DestReg <= AArch64::H31)
DestReg = AArch64::D0 + (DestReg - AArch64::H0);
@@ -2229,13 +2229,24 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
if (BrTarget == AddrDisc)
report_fatal_error("Branch target is signed with its own value");
- // If we are printing BLRA pseudo instruction, then x16 and x17 are
- // implicit-def'ed by the MI and AddrDisc is not used as any other input, so
- // try to save one MOV by setting MayUseAddrAsScratch.
+ // If we are printing BLRA pseudo, try to save one MOV by making use of the
+ // fact that x16 and x17 are described as clobbered by the MI instruction and
+ // AddrDisc is not used as any other input.
+ //
+ // Back in the day, emitPtrauthDiscriminator was restricted to only returning
+ // either x16 or x17, meaning the returned register is always among the
+ // implicit-def'ed registers of BLRA pseudo. Now this property can be violated
+ // if isX16X17Safer predicate is false, thus manually check if AddrDisc is
+ // among x16 and x17 to prevent clobbering unexpected registers.
+ //
// Unlike BLRA, BRA pseudo is used to perform computed goto, and thus not
// declared as clobbering x16/x17.
+ //
+ // FIXME: Make use of `killed` flags and register masks instead.
+ bool AddrDiscIsImplicitDef =
+ IsCall && (AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17);
Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17,
- /*MayUseAddrAsScratch=*/IsCall);
+ AddrDiscIsImplicitDef);
bool IsZeroDisc = DiscReg == AArch64::XZR;
unsigned Opc;
@@ -2862,7 +2873,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
MCInst TmpInst;
TmpInst.setOpcode(AArch64::MOVIv16b_ns);
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
- TmpInst.addOperand(MCOperand::createImm(MI->getOperand(1).getImm()));
+ TmpInst.addOperand(MCOperand::createImm(0));
EmitToStreamer(*OutStreamer, TmpInst);
return;
}
@@ -2968,8 +2979,15 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
// See the comments in emitPtrauthBranch.
if (Callee == AddrDisc)
report_fatal_error("Call target is signed with its own value");
+
+ // After isX16X17Safer predicate was introduced, emitPtrauthDiscriminator is
+ // no longer restricted to only reusing AddrDisc when it is X16 or X17
+ // (which are implicit-def'ed by AUTH_TCRETURN pseudos), thus impose this
+ // restriction manually not to clobber an unexpected register.
+ bool AddrDiscIsImplicitDef =
+ AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17;
Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, ScratchReg,
- /*MayUseAddrAsScratch=*/true);
+ AddrDiscIsImplicitDef);
const bool IsZero = DiscReg == AArch64::XZR;
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
diff --git a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp
index 3436dc9..137ff89 100644
--- a/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp
+++ b/llvm/lib/Target/AArch64/AArch64BranchTargets.cpp
@@ -30,6 +30,14 @@ using namespace llvm;
#define AARCH64_BRANCH_TARGETS_NAME "AArch64 Branch Targets"
namespace {
+// BTI HINT encoding: base (32) plus 'c' (2) and/or 'j' (4).
+enum : unsigned {
+ BTIBase = 32, // Base immediate for BTI HINT
+ BTIC = 1u << 1, // 2
+ BTIJ = 1u << 2, // 4
+ BTIMask = BTIC | BTIJ,
+};
+
class AArch64BranchTargets : public MachineFunctionPass {
public:
static char ID;
@@ -42,6 +50,7 @@ private:
void addBTI(MachineBasicBlock &MBB, bool CouldCall, bool CouldJump,
bool NeedsWinCFI);
};
+
} // end anonymous namespace
char AArch64BranchTargets::ID = 0;
@@ -62,9 +71,8 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) {
if (!MF.getInfo<AArch64FunctionInfo>()->branchTargetEnforcement())
return false;
- LLVM_DEBUG(
- dbgs() << "********** AArch64 Branch Targets **********\n"
- << "********** Function: " << MF.getName() << '\n');
+ LLVM_DEBUG(dbgs() << "********** AArch64 Branch Targets **********\n"
+ << "********** Function: " << MF.getName() << '\n');
const Function &F = MF.getFunction();
// LLVM does not consider basic blocks which are the targets of jump tables
@@ -103,6 +111,12 @@ bool AArch64BranchTargets::runOnMachineFunction(MachineFunction &MF) {
JumpTableTargets.count(&MBB))
CouldJump = true;
+ if (MBB.isEHPad()) {
+ if (HasWinCFI && (MBB.isEHFuncletEntry() || MBB.isCleanupFuncletEntry()))
+ CouldCall = true;
+ else
+ CouldJump = true;
+ }
if (CouldCall || CouldJump) {
addBTI(MBB, CouldCall, CouldJump, HasWinCFI);
MadeChange = true;
@@ -130,7 +144,12 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall,
auto MBBI = MBB.begin();
- // Skip the meta instructions, those will be removed anyway.
+ // If the block starts with EH_LABEL(s), skip them first.
+ while (MBBI != MBB.end() && MBBI->isEHLabel()) {
+ ++MBBI;
+ }
+
+ // Skip meta/CFI/etc. (and EMITBKEY) to reach the first executable insn.
for (; MBBI != MBB.end() &&
(MBBI->isMetaInstruction() || MBBI->getOpcode() == AArch64::EMITBKEY);
++MBBI)
@@ -138,16 +157,21 @@ void AArch64BranchTargets::addBTI(MachineBasicBlock &MBB, bool CouldCall,
// SCTLR_EL1.BT[01] is set to 0 by default which means
// PACI[AB]SP are implicitly BTI C so no BTI C instruction is needed there.
- if (MBBI != MBB.end() && HintNum == 34 &&
+ if (MBBI != MBB.end() && ((HintNum & BTIMask) == BTIC) &&
(MBBI->getOpcode() == AArch64::PACIASP ||
MBBI->getOpcode() == AArch64::PACIBSP))
return;
- if (HasWinCFI && MBBI->getFlag(MachineInstr::FrameSetup)) {
- BuildMI(MBB, MBB.begin(), MBB.findDebugLoc(MBB.begin()),
- TII->get(AArch64::SEH_Nop));
+ // Insert BTI exactly at the first executable instruction.
+ const DebugLoc DL = MBB.findDebugLoc(MBBI);
+ MachineInstr *BTI = BuildMI(MBB, MBBI, DL, TII->get(AArch64::HINT))
+ .addImm(HintNum)
+ .getInstr();
+
+ // WinEH: put .seh_nop after BTI when the first real insn is FrameSetup.
+ if (HasWinCFI && MBBI != MBB.end() &&
+ MBBI->getFlag(MachineInstr::FrameSetup)) {
+ auto AfterBTI = std::next(MachineBasicBlock::iterator(BTI));
+ BuildMI(MBB, AfterBTI, DL, TII->get(AArch64::SEH_Nop));
}
- BuildMI(MBB, MBB.begin(), MBB.findDebugLoc(MBB.begin()),
- TII->get(AArch64::HINT))
- .addImm(HintNum);
}
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
index 787a1a8..cc46159 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
@@ -75,8 +75,10 @@ static bool finishStackBlock(SmallVectorImpl<CCValAssign> &PendingMembers,
auto &It = PendingMembers[0];
CCAssignFn *AssignFn =
TLI->CCAssignFnForCall(State.getCallingConv(), /*IsVarArg=*/false);
+ // FIXME: Get the correct original type.
+ Type *OrigTy = EVT(It.getValVT()).getTypeForEVT(State.getContext());
if (AssignFn(It.getValNo(), It.getValVT(), It.getValVT(), CCValAssign::Full,
- ArgFlags, State))
+ ArgFlags, OrigTy, State))
llvm_unreachable("Call operand has unhandled type");
// Return the flags to how they were before.
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.h b/llvm/lib/Target/AArch64/AArch64CallingConvention.h
index 63185a9..7105fa6 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.h
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.h
@@ -18,52 +18,63 @@
namespace llvm {
bool CC_AArch64_AAPCS(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
- CCState &State);
+ Type *OrigTy, CCState &State);
bool CC_AArch64_Arm64EC_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_Arm64EC_Thunk_Native(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_DarwinPCS_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_DarwinPCS(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_DarwinPCS_ILP32_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT,
- CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ CCValAssign::LocInfo LocInfo,
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_Win64PCS(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
- CCState &State);
+ Type *OrigTy, CCState &State);
bool CC_AArch64_Win64_VarArg(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_Win64_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool CC_AArch64_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
- CCState &State);
+ Type *OrigTy, CCState &State);
bool CC_AArch64_Preserve_None(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool RetCC_AArch64_AAPCS(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
- CCState &State);
+ Type *OrigTy, CCState &State);
bool RetCC_AArch64_Arm64EC_Thunk(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State);
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
+ CCState &State);
bool RetCC_AArch64_Arm64EC_CFGuard_Check(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags,
+ ISD::ArgFlagsTy ArgFlags, Type *OrigTy,
CCState &State);
} // namespace llvm
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index 99f0af5..5f499e5 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -351,9 +351,10 @@ def AArch64PostLegalizerLowering
// Post-legalization combines which are primarily optimizations.
def AArch64PostLegalizerCombiner
: GICombiner<"AArch64PostLegalizerCombinerImpl",
- [copy_prop, cast_of_cast_combines, buildvector_of_truncate,
- integer_of_truncate, mutate_anyext_to_zext,
- combines_for_extload, combine_indexed_load_store, sext_trunc_sextload,
+ [copy_prop, cast_of_cast_combines,
+ buildvector_of_truncate, integer_of_truncate,
+ mutate_anyext_to_zext, combines_for_extload,
+ combine_indexed_load_store, sext_trunc_sextload,
hoist_logic_op_with_same_opcode_hands,
redundant_and, xor_of_and_with_same_reg,
extractvecelt_pairwise_add, redundant_or,
@@ -367,5 +368,6 @@ def AArch64PostLegalizerCombiner
select_to_minmax, or_to_bsp, combine_concat_vector,
commute_constant_to_rhs, extract_vec_elt_combines,
push_freeze_to_prevent_poison_from_propagating,
- combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> {
+ combine_mul_cmlt, combine_use_vector_truncate,
+ extmultomull, truncsat_combines]> {
}
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 201bfe0..57dcd68 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -92,8 +92,9 @@ private:
bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI);
bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
- MachineBasicBlock *expandRestoreZA(MachineBasicBlock &MBB,
- MachineBasicBlock::iterator MBBI);
+ MachineBasicBlock *
+ expandCommitOrRestoreZASave(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI);
MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
};
@@ -528,6 +529,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp(
UseRev = true;
}
break;
+ case AArch64::Destructive2xRegImmUnpred:
+ // EXT_ZZI_CONSTRUCTIVE Zd, Zs, Imm
+ // ==> MOVPRFX Zd Zs; EXT_ZZI Zd, Zd, Zs, Imm
+ std::tie(DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 1, 2);
+ break;
default:
llvm_unreachable("Unsupported Destructive Operand type");
}
@@ -548,6 +554,7 @@ bool AArch64ExpandPseudo::expand_DestructiveOp(
break;
case AArch64::DestructiveUnaryPassthru:
case AArch64::DestructiveBinaryImm:
+ case AArch64::Destructive2xRegImmUnpred:
DOPRegIsUnique = true;
break;
case AArch64::DestructiveTernaryCommWithRev:
@@ -674,6 +681,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp(
.add(MI.getOperand(SrcIdx))
.add(MI.getOperand(Src2Idx));
break;
+ case AArch64::Destructive2xRegImmUnpred:
+ DOP.addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState)
+ .add(MI.getOperand(SrcIdx))
+ .add(MI.getOperand(Src2Idx));
+ break;
}
if (PRFX) {
@@ -979,10 +991,15 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext(
return true;
}
-MachineBasicBlock *
-AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB,
- MachineBasicBlock::iterator MBBI) {
+static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
+
+MachineBasicBlock *AArch64ExpandPseudo::expandCommitOrRestoreZASave(
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
MachineInstr &MI = *MBBI;
+ bool IsRestoreZA = MI.getOpcode() == AArch64::RestoreZAPseudo;
+ assert((MI.getOpcode() == AArch64::RestoreZAPseudo ||
+ MI.getOpcode() == AArch64::CommitZASavePseudo) &&
+ "Expected ZA commit or restore");
assert((std::next(MBBI) != MBB.end() ||
MI.getParent()->successors().begin() !=
MI.getParent()->successors().end()) &&
@@ -990,21 +1007,23 @@ AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB,
// Compare TPIDR2_EL0 value against 0.
DebugLoc DL = MI.getDebugLoc();
- MachineInstrBuilder Cbz = BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX))
- .add(MI.getOperand(0));
+ MachineInstrBuilder Branch =
+ BuildMI(MBB, MBBI, DL,
+ TII->get(IsRestoreZA ? AArch64::CBZX : AArch64::CBNZX))
+ .add(MI.getOperand(0));
// Split MBB and create two new blocks:
// - MBB now contains all instructions before RestoreZAPseudo.
- // - SMBB contains the RestoreZAPseudo instruction only.
- // - EndBB contains all instructions after RestoreZAPseudo.
+ // - SMBB contains the [Commit|RestoreZA]Pseudo instruction only.
+ // - EndBB contains all instructions after [Commit|RestoreZA]Pseudo.
MachineInstr &PrevMI = *std::prev(MBBI);
MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
? *SMBB->successors().begin()
: SMBB->splitAt(MI, /*UpdateLiveIns*/ true);
- // Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB.
- Cbz.addMBB(SMBB);
+ // Add the SMBB label to the CB[N]Z instruction & create a branch to EndBB.
+ Branch.addMBB(SMBB);
BuildMI(&MBB, DL, TII->get(AArch64::B))
.addMBB(EndBB);
MBB.addSuccessor(EndBB);
@@ -1012,11 +1031,30 @@ AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB,
// Replace the pseudo with a call (BL).
MachineInstrBuilder MIB =
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL));
- MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
+ // Copy operands (mainly the regmask) from the pseudo.
for (unsigned I = 2; I < MI.getNumOperands(); ++I)
MIB.add(MI.getOperand(I));
- BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
+ if (IsRestoreZA) {
+ // Mark the TPIDR2 block pointer (X0) as an implicit use.
+ MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
+ } else /*CommitZA*/ {
+ [[maybe_unused]] auto *TRI =
+ MBB.getParent()->getSubtarget().getRegisterInfo();
+ // Clear TPIDR2_EL0.
+ BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::MSR))
+ .addImm(AArch64SysReg::TPIDR2_EL0)
+ .addReg(AArch64::XZR);
+ bool ZeroZA = MI.getOperand(1).getImm() != 0;
+ if (ZeroZA) {
+ assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
+ BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::ZERO_M))
+ .addImm(ZERO_ALL_ZA_MASK)
+ .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+ }
+ }
+
+ BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
MI.eraseFromParent();
return EndBB;
}
@@ -1236,14 +1274,20 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
.add(MI.getOperand(3));
transferImpOps(MI, I, I);
} else {
+ unsigned RegState =
+ getRenamableRegState(MI.getOperand(1).isRenamable()) |
+ getKillRegState(
+ MI.getOperand(1).isKill() &&
+ MI.getOperand(1).getReg() != MI.getOperand(2).getReg() &&
+ MI.getOperand(1).getReg() != MI.getOperand(3).getReg());
BuildMI(MBB, MBBI, MI.getDebugLoc(),
TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::ORRv8i8
: AArch64::ORRv16i8))
.addReg(DstReg,
RegState::Define |
getRenamableRegState(MI.getOperand(0).isRenamable()))
- .add(MI.getOperand(1))
- .add(MI.getOperand(1));
+ .addReg(MI.getOperand(1).getReg(), RegState)
+ .addReg(MI.getOperand(1).getReg(), RegState);
auto I2 =
BuildMI(MBB, MBBI, MI.getDebugLoc(),
TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8
@@ -1629,8 +1673,9 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandCALL_BTI(MBB, MBBI);
case AArch64::StoreSwiftAsyncContext:
return expandStoreSwiftAsyncContext(MBB, MBBI);
+ case AArch64::CommitZASavePseudo:
case AArch64::RestoreZAPseudo: {
- auto *NewMBB = expandRestoreZA(MBB, MBBI);
+ auto *NewMBB = expandCommitOrRestoreZASave(MBB, MBBI);
if (NewMBB != &MBB)
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
@@ -1641,6 +1686,8 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
}
+ case AArch64::InOutZAUsePseudo:
+ case AArch64::RequiresZASavePseudo:
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
case AArch64::COALESCER_BARRIER_FPR64:
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9d74bb5..cf34498 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -267,7 +267,7 @@ private:
private:
CCAssignFn *CCAssignFnForCall(CallingConv::ID CC) const;
bool processCallArgs(CallLoweringInfo &CLI, SmallVectorImpl<MVT> &ArgVTs,
- unsigned &NumBytes);
+ SmallVectorImpl<Type *> &OrigTys, unsigned &NumBytes);
bool finishCall(CallLoweringInfo &CLI, unsigned NumBytes);
public:
@@ -3011,11 +3011,13 @@ bool AArch64FastISel::fastLowerArguments() {
bool AArch64FastISel::processCallArgs(CallLoweringInfo &CLI,
SmallVectorImpl<MVT> &OutVTs,
+ SmallVectorImpl<Type *> &OrigTys,
unsigned &NumBytes) {
CallingConv::ID CC = CLI.CallConv;
SmallVector<CCValAssign, 16> ArgLocs;
CCState CCInfo(CC, false, *FuncInfo.MF, ArgLocs, *Context);
- CCInfo.AnalyzeCallOperands(OutVTs, CLI.OutFlags, CCAssignFnForCall(CC));
+ CCInfo.AnalyzeCallOperands(OutVTs, CLI.OutFlags, OrigTys,
+ CCAssignFnForCall(CC));
// Get a count of how many bytes are to be pushed on the stack.
NumBytes = CCInfo.getStackSize();
@@ -3194,6 +3196,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) {
// Set up the argument vectors.
SmallVector<MVT, 16> OutVTs;
+ SmallVector<Type *, 16> OrigTys;
OutVTs.reserve(CLI.OutVals.size());
for (auto *Val : CLI.OutVals) {
@@ -3207,6 +3210,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) {
return false;
OutVTs.push_back(VT);
+ OrigTys.push_back(Val->getType());
}
Address Addr;
@@ -3222,7 +3226,7 @@ bool AArch64FastISel::fastLowerCall(CallLoweringInfo &CLI) {
// Handle the arguments now that we've gotten them.
unsigned NumBytes;
- if (!processCallArgs(CLI, OutVTs, NumBytes))
+ if (!processCallArgs(CLI, OutVTs, OrigTys, NumBytes))
return false;
const AArch64RegisterInfo *RegInfo = Subtarget->getRegisterInfo();
@@ -3574,12 +3578,8 @@ bool AArch64FastISel::fastLowerIntrinsicCall(const IntrinsicInst *II) {
Args.reserve(II->arg_size());
// Populate the argument list.
- for (auto &Arg : II->args()) {
- ArgListEntry Entry;
- Entry.Val = Arg;
- Entry.Ty = Arg->getType();
- Args.push_back(Entry);
- }
+ for (auto &Arg : II->args())
+ Args.emplace_back(Arg);
CallLoweringInfo CLI;
MCContext &Ctx = MF->getContext();
@@ -4870,12 +4870,8 @@ bool AArch64FastISel::selectFRem(const Instruction *I) {
Args.reserve(I->getNumOperands());
// Populate the argument list.
- for (auto &Arg : I->operands()) {
- ArgListEntry Entry;
- Entry.Val = Arg;
- Entry.Ty = Arg->getType();
- Args.push_back(Entry);
- }
+ for (auto &Arg : I->operands())
+ Args.emplace_back(Arg);
CallLoweringInfo CLI;
MCContext &Ctx = MF->getContext();
diff --git a/llvm/lib/Target/AArch64/AArch64Features.td b/llvm/lib/Target/AArch64/AArch64Features.td
index c1c1f0a..6904e09 100644
--- a/llvm/lib/Target/AArch64/AArch64Features.td
+++ b/llvm/lib/Target/AArch64/AArch64Features.td
@@ -621,25 +621,27 @@ def FeatureZCRegMoveGPR64 : SubtargetFeature<"zcm-gpr64", "HasZeroCycleRegMoveGP
def FeatureZCRegMoveGPR32 : SubtargetFeature<"zcm-gpr32", "HasZeroCycleRegMoveGPR32", "true",
"Has zero-cycle register moves for GPR32 registers">;
+def FeatureZCRegMoveFPR128 : SubtargetFeature<"zcm-fpr128", "HasZeroCycleRegMoveFPR128", "true",
+ "Has zero-cycle register moves for FPR128 registers">;
+
def FeatureZCRegMoveFPR64 : SubtargetFeature<"zcm-fpr64", "HasZeroCycleRegMoveFPR64", "true",
"Has zero-cycle register moves for FPR64 registers">;
def FeatureZCRegMoveFPR32 : SubtargetFeature<"zcm-fpr32", "HasZeroCycleRegMoveFPR32", "true",
"Has zero-cycle register moves for FPR32 registers">;
-def FeatureZCZeroingGP : SubtargetFeature<"zcz-gp", "HasZeroCycleZeroingGP", "true",
- "Has zero-cycle zeroing instructions for generic registers">;
+def FeatureZCZeroingGPR64 : SubtargetFeature<"zcz-gpr64", "HasZeroCycleZeroingGPR64", "true",
+ "Has zero-cycle zeroing instructions for GPR64 registers">;
+
+def FeatureZCZeroingGPR32 : SubtargetFeature<"zcz-gpr32", "HasZeroCycleZeroingGPR32", "true",
+ "Has zero-cycle zeroing instructions for GPR32 registers">;
// It is generally beneficial to rewrite "fmov s0, wzr" to "movi d0, #0".
// as movi is more efficient across all cores. Newer cores can eliminate
// fmovs early and there is no difference with movi, but this not true for
// all implementations.
-def FeatureNoZCZeroingFP : SubtargetFeature<"no-zcz-fp", "HasZeroCycleZeroingFP", "false",
- "Has no zero-cycle zeroing instructions for FP registers">;
-
-def FeatureZCZeroing : SubtargetFeature<"zcz", "HasZeroCycleZeroing", "true",
- "Has zero-cycle zeroing instructions",
- [FeatureZCZeroingGP]>;
+def FeatureNoZCZeroingFPR64 : SubtargetFeature<"no-zcz-fpr64", "HasZeroCycleZeroingFPR64", "false",
+ "Has no zero-cycle zeroing instructions for FPR64 registers">;
/// ... but the floating-point version doesn't quite work in rare cases on older
/// CPUs.
@@ -730,9 +732,13 @@ def FeatureFuseArithmeticLogic : SubtargetFeature<
"fuse-arith-logic", "HasFuseArithmeticLogic", "true",
"CPU fuses arithmetic and logic operations">;
-def FeatureFuseCCSelect : SubtargetFeature<
- "fuse-csel", "HasFuseCCSelect", "true",
- "CPU fuses conditional select operations">;
+def FeatureFuseCmpCSel : SubtargetFeature<
+ "fuse-csel", "HasFuseCmpCSel", "true",
+ "CPU can fuse CMP and CSEL operations">;
+
+def FeatureFuseCmpCSet : SubtargetFeature<
+ "fuse-cset", "HasFuseCmpCSet", "true",
+ "CPU can fuse CMP and CSET operations">;
def FeatureFuseCryptoEOR : SubtargetFeature<
"fuse-crypto-eor", "HasFuseCryptoEOR", "true",
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 885f2a9..7725fa4 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -338,7 +338,7 @@ static bool requiresSaveVG(const MachineFunction &MF);
// Conservatively, returns true if the function is likely to have an SVE vectors
// on the stack. This function is safe to be called before callee-saves or
// object offsets have been determined.
-static bool isLikelyToHaveSVEStack(MachineFunction &MF) {
+static bool isLikelyToHaveSVEStack(const MachineFunction &MF) {
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (AFI->isSVECC())
return true;
@@ -532,6 +532,7 @@ bool AArch64FrameLowering::canUseRedZone(const MachineFunction &MF) const {
bool AArch64FrameLowering::hasFPImpl(const MachineFunction &MF) const {
const MachineFrameInfo &MFI = MF.getFrameInfo();
const TargetRegisterInfo *RegInfo = MF.getSubtarget().getRegisterInfo();
+ const AArch64FunctionInfo &AFI = *MF.getInfo<AArch64FunctionInfo>();
// Win64 EH requires a frame pointer if funclets are present, as the locals
// are accessed off the frame pointer in both the parent function and the
@@ -545,6 +546,29 @@ bool AArch64FrameLowering::hasFPImpl(const MachineFunction &MF) const {
MFI.hasStackMap() || MFI.hasPatchPoint() ||
RegInfo->hasStackRealignment(MF))
return true;
+
+ // If we:
+ //
+ // 1. Have streaming mode changes
+ // OR:
+ // 2. Have a streaming body with SVE stack objects
+ //
+ // Then the value of VG restored when unwinding to this function may not match
+ // the value of VG used to set up the stack.
+ //
+ // This is a problem as the CFA can be described with an expression of the
+ // form: CFA = SP + NumBytes + VG * NumScalableBytes.
+ //
+ // If the value of VG used in that expression does not match the value used to
+ // set up the stack, an incorrect address for the CFA will be computed, and
+ // unwinding will fail.
+ //
+ // We work around this issue by ensuring the frame-pointer can describe the
+ // CFA in either of these cases.
+ if (AFI.needsDwarfUnwindInfo(MF) &&
+ ((requiresSaveVG(MF) || AFI.getSMEFnAttrs().hasStreamingBody()) &&
+ (!AFI.hasCalculatedStackSizeSVE() || AFI.getStackSizeSVE() > 0)))
+ return true;
// With large callframes around we may need to use FP to access the scavenging
// emergency spillslot.
//
@@ -663,10 +687,6 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
- AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
- SMEAttrs Attrs = AFI->getSMEFnAttrs();
- bool LocallyStreaming =
- Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();
const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo();
if (CSI.empty())
@@ -680,14 +700,6 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
assert(!Info.isSpilledToReg() && "Spilling to registers not implemented");
int64_t Offset = MFI.getObjectOffset(FrameIdx) - getOffsetOfLocalArea();
-
- // The location of VG will be emitted before each streaming-mode change in
- // the function. Only locally-streaming functions require emitting the
- // non-streaming VG location here.
- if ((LocallyStreaming && FrameIdx == AFI->getStreamingVGIdx()) ||
- (!LocallyStreaming && Info.getReg() == AArch64::VG))
- continue;
-
CFIBuilder.buildOffset(Info.getReg(), Offset);
}
}
@@ -707,8 +719,16 @@ void AArch64FrameLowering::emitCalleeSavedSVELocations(
AArch64FunctionInfo &AFI = *MF.getInfo<AArch64FunctionInfo>();
CFIInstBuilder CFIBuilder(MBB, MBBI, MachineInstr::FrameSetup);
+ std::optional<int64_t> IncomingVGOffsetFromDefCFA;
+ if (requiresSaveVG(MF)) {
+ auto IncomingVG = *find_if(
+ reverse(CSI), [](auto &Info) { return Info.getReg() == AArch64::VG; });
+ IncomingVGOffsetFromDefCFA =
+ MFI.getObjectOffset(IncomingVG.getFrameIdx()) - getOffsetOfLocalArea();
+ }
+
for (const auto &Info : CSI) {
- if (!(MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector))
+ if (MFI.getStackID(Info.getFrameIdx()) != TargetStackID::ScalableVector)
continue;
// Not all unwinders may know about SVE registers, so assume the lowest
@@ -722,7 +742,8 @@ void AArch64FrameLowering::emitCalleeSavedSVELocations(
StackOffset::getScalable(MFI.getObjectOffset(Info.getFrameIdx())) -
StackOffset::getFixed(AFI.getCalleeSavedStackSize(MFI));
- CFIBuilder.insertCFIInst(createCFAOffset(TRI, Reg, Offset));
+ CFIBuilder.insertCFIInst(
+ createCFAOffset(TRI, Reg, Offset, IncomingVGOffsetFromDefCFA));
}
}
@@ -783,9 +804,6 @@ static void emitCalleeSavedRestores(MachineBasicBlock &MBB,
!static_cast<const AArch64RegisterInfo &>(TRI).regNeedsCFI(Reg, Reg))
continue;
- if (!Info.isRestored())
- continue;
-
CFIBuilder.buildRestore(Info.getReg());
}
}
@@ -1465,34 +1483,35 @@ bool requiresGetVGCall(MachineFunction &MF) {
static bool requiresSaveVG(const MachineFunction &MF) {
const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+ if (!AFI->needsDwarfUnwindInfo(MF) || !AFI->hasStreamingModeChanges())
+ return false;
// For Darwin platforms we don't save VG for non-SVE functions, even if SME
// is enabled with streaming mode changes.
- if (!AFI->hasStreamingModeChanges())
- return false;
auto &ST = MF.getSubtarget<AArch64Subtarget>();
if (ST.isTargetDarwin())
return ST.hasSVE();
return true;
}
-bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
+static bool matchLibcall(const TargetLowering &TLI, const MachineOperand &MO,
+ RTLIB::Libcall LC) {
+ return MO.isSymbol() &&
+ StringRef(TLI.getLibcallName(LC)) == MO.getSymbolName();
+}
+
+bool isVGInstruction(MachineBasicBlock::iterator MBBI,
+ const TargetLowering &TLI) {
unsigned Opc = MBBI->getOpcode();
- if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI ||
- Opc == AArch64::UBFMXri)
+ if (Opc == AArch64::CNTD_XPiI)
return true;
- if (requiresGetVGCall(*MBBI->getMF())) {
- if (Opc == AArch64::ORRXrr)
- return true;
+ if (!requiresGetVGCall(*MBBI->getMF()))
+ return false;
- if (Opc == AArch64::BL) {
- auto Op1 = MBBI->getOperand(0);
- return Op1.isSymbol() &&
- (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
- }
- }
+ if (Opc == AArch64::BL)
+ return matchLibcall(TLI, MBBI->getOperand(0), RTLIB::SMEABI_GET_CURRENT_VG);
- return false;
+ return Opc == TargetOpcode::COPY;
}
// Convert callee-save register save/restore instruction to do stack pointer
@@ -1507,13 +1526,14 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
unsigned NewOpc;
// If the function contains streaming mode changes, we expect instructions
- // to calculate the value of VG before spilling. For locally-streaming
- // functions, we need to do this for both the streaming and non-streaming
- // vector length. Move past these instructions if necessary.
+ // to calculate the value of VG before spilling. Move past these instructions
+ // if necessary.
MachineFunction &MF = *MBB.getParent();
- if (requiresSaveVG(MF))
- while (isVGInstruction(MBBI))
+ if (requiresSaveVG(MF)) {
+ auto &TLI = *MF.getSubtarget().getTargetLowering();
+ while (isVGInstruction(MBBI, TLI))
++MBBI;
+ }
switch (MBBI->getOpcode()) {
default:
@@ -2097,11 +2117,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
// Move past the saves of the callee-saved registers, fixing up the offsets
// and pre-inc if we decided to combine the callee-save and local stack
// pointer bump above.
+ auto &TLI = *MF.getSubtarget().getTargetLowering();
while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) &&
!IsSVECalleeSave(MBBI)) {
if (CombineSPBump &&
// Only fix-up frame-setup load/store instructions.
- (!requiresSaveVG(MF) || !isVGInstruction(MBBI)))
+ (!requiresSaveVG(MF) || !isVGInstruction(MBBI, TLI)))
fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(),
NeedsWinCFI, &HasWinCFI);
++MBBI;
@@ -3468,8 +3489,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
MachineFunction &MF = *MBB.getParent();
+ auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering();
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
- AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
bool NeedsWinCFI = needsWinCFI(MF);
DebugLoc DL;
SmallVector<RegPairInfo, 8> RegPairs;
@@ -3538,59 +3559,44 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
}
unsigned X0Scratch = AArch64::NoRegister;
+ auto RestoreX0 = make_scope_exit([&] {
+ if (X0Scratch != AArch64::NoRegister)
+ BuildMI(MBB, MI, DL, TII.get(TargetOpcode::COPY), AArch64::X0)
+ .addReg(X0Scratch)
+ .setMIFlag(MachineInstr::FrameSetup);
+ });
+
if (Reg1 == AArch64::VG) {
// Find an available register to store value of VG to.
Reg1 = findScratchNonCalleeSaveRegister(&MBB, true);
assert(Reg1 != AArch64::NoRegister);
- SMEAttrs Attrs = AFI->getSMEFnAttrs();
-
- if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
- AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
- // For locally-streaming functions, we need to store both the streaming
- // & non-streaming VG. Spill the streaming value first.
- BuildMI(MBB, MI, DL, TII.get(AArch64::RDSVLI_XI), Reg1)
- .addImm(1)
- .setMIFlag(MachineInstr::FrameSetup);
- BuildMI(MBB, MI, DL, TII.get(AArch64::UBFMXri), Reg1)
- .addReg(Reg1)
- .addImm(3)
- .addImm(63)
- .setMIFlag(MachineInstr::FrameSetup);
-
- AFI->setStreamingVGIdx(RPI.FrameIdx);
- } else if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) {
+ if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) {
BuildMI(MBB, MI, DL, TII.get(AArch64::CNTD_XPiI), Reg1)
.addImm(31)
.addImm(1)
.setMIFlag(MachineInstr::FrameSetup);
- AFI->setVGIdx(RPI.FrameIdx);
} else {
const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
- if (llvm::any_of(
- MBB.liveins(),
- [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) {
- return STI.getRegisterInfo()->isSuperOrSubRegisterEq(
- AArch64::X0, LiveIn.PhysReg);
- }))
+ if (any_of(MBB.liveins(),
+ [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) {
+ return STI.getRegisterInfo()->isSuperOrSubRegisterEq(
+ AArch64::X0, LiveIn.PhysReg);
+ })) {
X0Scratch = Reg1;
-
- if (X0Scratch != AArch64::NoRegister)
- BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), Reg1)
- .addReg(AArch64::XZR)
- .addReg(AArch64::X0, RegState::Undef)
- .addReg(AArch64::X0, RegState::Implicit)
+ BuildMI(MBB, MI, DL, TII.get(TargetOpcode::COPY), X0Scratch)
+ .addReg(AArch64::X0)
.setMIFlag(MachineInstr::FrameSetup);
+ }
- const uint32_t *RegMask = TRI->getCallPreservedMask(
- MF,
- CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
+ RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG;
+ const uint32_t *RegMask =
+ TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC));
BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
- .addExternalSymbol("__arm_get_current_vg")
+ .addExternalSymbol(TLI.getLibcallName(LC))
.addRegMask(RegMask)
.addReg(AArch64::X0, RegState::ImplicitDefine)
.setMIFlag(MachineInstr::FrameSetup);
Reg1 = AArch64::X0;
- AFI->setVGIdx(RPI.FrameIdx);
}
}
@@ -3685,13 +3691,6 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}
-
- if (X0Scratch != AArch64::NoRegister)
- BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), AArch64::X0)
- .addReg(AArch64::XZR)
- .addReg(X0Scratch, RegState::Undef)
- .addReg(X0Scratch, RegState::Implicit)
- .setMIFlag(MachineInstr::FrameSetup);
}
return true;
}
@@ -4070,15 +4069,8 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
// Increase the callee-saved stack size if the function has streaming mode
// changes, as we will need to spill the value of the VG register.
- // For locally streaming functions, we spill both the streaming and
- // non-streaming VG value.
- SMEAttrs Attrs = AFI->getSMEFnAttrs();
- if (requiresSaveVG(MF)) {
- if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
- CSStackSize += 16;
- else
- CSStackSize += 8;
- }
+ if (requiresSaveVG(MF))
+ CSStackSize += 8;
// Determine if a Hazard slot should be used, and increase the CSStackSize by
// StackHazardSize if so.
@@ -4229,29 +4221,13 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
// Insert VG into the list of CSRs, immediately before LR if saved.
if (requiresSaveVG(MF)) {
- std::vector<CalleeSavedInfo> VGSaves;
- SMEAttrs Attrs = AFI->getSMEFnAttrs();
-
- auto VGInfo = CalleeSavedInfo(AArch64::VG);
- VGInfo.setRestored(false);
- VGSaves.push_back(VGInfo);
-
- // Add VG again if the function is locally-streaming, as we will spill two
- // values.
- if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
- VGSaves.push_back(VGInfo);
-
- bool InsertBeforeLR = false;
-
- for (unsigned I = 0; I < CSI.size(); I++)
- if (CSI[I].getReg() == AArch64::LR) {
- InsertBeforeLR = true;
- CSI.insert(CSI.begin() + I, VGSaves.begin(), VGSaves.end());
- break;
- }
-
- if (!InsertBeforeLR)
- llvm::append_range(CSI, VGSaves);
+ CalleeSavedInfo VGInfo(AArch64::VG);
+ auto It =
+ find_if(CSI, [](auto &Info) { return Info.getReg() == AArch64::LR; });
+ if (It != CSI.end())
+ CSI.insert(It, VGInfo);
+ else
+ CSI.push_back(VGInfo);
}
Register LastReg = 0;
@@ -5254,46 +5230,11 @@ MachineBasicBlock::iterator tryMergeAdjacentSTG(MachineBasicBlock::iterator II,
}
} // namespace
-static void emitVGSaveRestore(MachineBasicBlock::iterator II,
- const AArch64FrameLowering *TFI) {
- MachineInstr &MI = *II;
- MachineBasicBlock *MBB = MI.getParent();
- MachineFunction *MF = MBB->getParent();
-
- if (MI.getOpcode() != AArch64::VGSavePseudo &&
- MI.getOpcode() != AArch64::VGRestorePseudo)
- return;
-
- auto *AFI = MF->getInfo<AArch64FunctionInfo>();
- SMEAttrs FuncAttrs = AFI->getSMEFnAttrs();
- bool LocallyStreaming =
- FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
-
- int64_t VGFrameIdx =
- LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
- assert(VGFrameIdx != std::numeric_limits<int>::max() &&
- "Expected FrameIdx for VG");
-
- CFIInstBuilder CFIBuilder(*MBB, II, MachineInstr::NoFlags);
- if (MI.getOpcode() == AArch64::VGSavePseudo) {
- const MachineFrameInfo &MFI = MF->getFrameInfo();
- int64_t Offset =
- MFI.getObjectOffset(VGFrameIdx) - TFI->getOffsetOfLocalArea();
- CFIBuilder.buildOffset(AArch64::VG, Offset);
- } else {
- CFIBuilder.buildRestore(AArch64::VG);
- }
-
- MI.eraseFromParent();
-}
-
void AArch64FrameLowering::processFunctionBeforeFrameIndicesReplaced(
MachineFunction &MF, RegScavenger *RS = nullptr) const {
for (auto &BB : MF)
for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();) {
- if (requiresSaveVG(MF))
- emitVGSaveRestore(II++, this);
- else if (StackTaggingMergeSetTag)
+ if (StackTaggingMergeSetTag)
II = tryMergeAdjacentSTG(II, this, RS);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index ad42f4b..6fdc981 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -246,9 +246,9 @@ public:
return false;
}
- template<MVT::SimpleValueType VT>
+ template <MVT::SimpleValueType VT, bool Negate>
bool SelectSVEAddSubImm(SDValue N, SDValue &Imm, SDValue &Shift) {
- return SelectSVEAddSubImm(N, VT, Imm, Shift);
+ return SelectSVEAddSubImm(N, VT, Imm, Shift, Negate);
}
template <MVT::SimpleValueType VT, bool Negate>
@@ -489,7 +489,8 @@ private:
bool SelectCMP_SWAP(SDNode *N);
- bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
+ bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift,
+ bool Negate);
bool SelectSVEAddSubSSatImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift,
bool Negate);
bool SelectSVECpyDupImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -4227,35 +4228,36 @@ bool AArch64DAGToDAGISel::SelectCMP_SWAP(SDNode *N) {
}
bool AArch64DAGToDAGISel::SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm,
- SDValue &Shift) {
+ SDValue &Shift, bool Negate) {
if (!isa<ConstantSDNode>(N))
return false;
SDLoc DL(N);
- uint64_t Val = cast<ConstantSDNode>(N)
- ->getAPIntValue()
- .trunc(VT.getFixedSizeInBits())
- .getZExtValue();
+ APInt Val =
+ cast<ConstantSDNode>(N)->getAPIntValue().trunc(VT.getFixedSizeInBits());
+
+ if (Negate)
+ Val = -Val;
switch (VT.SimpleTy) {
case MVT::i8:
// All immediates are supported.
Shift = CurDAG->getTargetConstant(0, DL, MVT::i32);
- Imm = CurDAG->getTargetConstant(Val, DL, MVT::i32);
+ Imm = CurDAG->getTargetConstant(Val.getZExtValue(), DL, MVT::i32);
return true;
case MVT::i16:
case MVT::i32:
case MVT::i64:
// Support 8bit unsigned immediates.
- if (Val <= 255) {
+ if ((Val & ~0xff) == 0) {
Shift = CurDAG->getTargetConstant(0, DL, MVT::i32);
- Imm = CurDAG->getTargetConstant(Val, DL, MVT::i32);
+ Imm = CurDAG->getTargetConstant(Val.getZExtValue(), DL, MVT::i32);
return true;
}
// Support 16bit unsigned immediates that are a multiple of 256.
- if (Val <= 65280 && Val % 256 == 0) {
+ if ((Val & ~0xff00) == 0) {
Shift = CurDAG->getTargetConstant(8, DL, MVT::i32);
- Imm = CurDAG->getTargetConstant(Val >> 8, DL, MVT::i32);
+ Imm = CurDAG->getTargetConstant(Val.lshr(8).getZExtValue(), DL, MVT::i32);
return true;
}
break;
@@ -7617,16 +7619,29 @@ bool AArch64DAGToDAGISel::SelectAnyPredicate(SDValue N) {
bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
SDValue &Base, SDValue &Offset,
unsigned Scale) {
- // Try to untangle an ADD node into a 'reg + offset'
- if (CurDAG->isBaseWithConstantOffset(N))
- if (auto C = dyn_cast<ConstantSDNode>(N.getOperand(1))) {
+ auto MatchConstantOffset = [&](SDValue CN) -> SDValue {
+ if (auto *C = dyn_cast<ConstantSDNode>(CN)) {
int64_t ImmOff = C->getSExtValue();
- if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) {
- Base = N.getOperand(0);
- Offset = CurDAG->getTargetConstant(ImmOff / Scale, SDLoc(N), MVT::i64);
- return true;
- }
+ if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0)))
+ return CurDAG->getTargetConstant(ImmOff / Scale, SDLoc(N), MVT::i64);
+ }
+ return SDValue();
+ };
+
+ if (SDValue C = MatchConstantOffset(N)) {
+ Base = CurDAG->getConstant(0, SDLoc(N), MVT::i32);
+ Offset = C;
+ return true;
+ }
+
+ // Try to untangle an ADD node into a 'reg + offset'
+ if (CurDAG->isBaseWithConstantOffset(N)) {
+ if (SDValue C = MatchConstantOffset(N.getOperand(1))) {
+ Base = N.getOperand(0);
+ Offset = C;
+ return true;
}
+ }
// By default, just match reg + 0.
Base = N;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2b6ea86..b7011e0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17,6 +17,7 @@
#include "AArch64PerfectShuffle.h"
#include "AArch64RegisterInfo.h"
#include "AArch64Subtarget.h"
+#include "AArch64TargetMachine.h"
#include "MCTargetDesc/AArch64AddressingModes.h"
#include "Utils/AArch64BaseInfo.h"
#include "Utils/AArch64SMEAttributes.h"
@@ -1120,7 +1121,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
- // We combine OR nodes for bitfield operations.
+ // We combine OR nodes for ccmp operations.
setTargetDAGCombine(ISD::OR);
// Try to create BICs for vector ANDs.
setTargetDAGCombine(ISD::AND);
@@ -1769,7 +1770,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
- if (Subtarget->hasSVEB16B16()) {
+ if (Subtarget->hasSVEB16B16() &&
+ Subtarget->isNonStreamingSVEorSME2Available()) {
setOperationAction(ISD::FADD, VT, Legal);
setOperationAction(ISD::FMA, VT, Custom);
setOperationAction(ISD::FMAXIMUM, VT, Custom);
@@ -1791,7 +1793,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
}
- if (!Subtarget->hasSVEB16B16()) {
+ if (!Subtarget->hasSVEB16B16() ||
+ !Subtarget->isNonStreamingSVEorSME2Available()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
@@ -1998,6 +2001,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::f16, Promote);
}
+const AArch64TargetMachine &AArch64TargetLowering::getTM() const {
+ return static_cast<const AArch64TargetMachine &>(getTargetMachine());
+}
+
void AArch64TargetLowering::addTypeForNEON(MVT VT) {
assert(VT.isVector() && "VT should be a vector type");
@@ -2578,6 +2585,30 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode(
Known = Known.intersectWith(Known2);
break;
}
+ case AArch64ISD::CSNEG:
+ case AArch64ISD::CSINC:
+ case AArch64ISD::CSINV: {
+ KnownBits KnownOp0 = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
+ KnownBits KnownOp1 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
+
+ // The result is either:
+ // CSINC: KnownOp0 or KnownOp1 + 1
+ // CSINV: KnownOp0 or ~KnownOp1
+ // CSNEG: KnownOp0 or KnownOp1 * -1
+ if (Op.getOpcode() == AArch64ISD::CSINC)
+ KnownOp1 = KnownBits::add(
+ KnownOp1,
+ KnownBits::makeConstant(APInt(Op.getScalarValueSizeInBits(), 1)));
+ else if (Op.getOpcode() == AArch64ISD::CSINV)
+ std::swap(KnownOp1.Zero, KnownOp1.One);
+ else if (Op.getOpcode() == AArch64ISD::CSNEG)
+ KnownOp1 =
+ KnownBits::mul(KnownOp1, KnownBits::makeConstant(APInt::getAllOnes(
+ Op.getScalarValueSizeInBits())));
+
+ Known = KnownOp0.intersectWith(KnownOp1);
+ break;
+ }
case AArch64ISD::BICi: {
// Compute the bit cleared value.
APInt Mask =
@@ -2977,21 +3008,20 @@ AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
if (TPIDR2.Uses > 0) {
+ // Note: This case just needs to do `SVL << 48`. It is not implemented as we
+ // generally don't support big-endian SVE/SME.
+ if (!Subtarget->isLittleEndian())
+ reportFatalInternalError(
+ "TPIDR2 block initialization is not supported on big-endian targets");
+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
- // Store the buffer pointer to the TPIDR2 stack object.
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui))
+ // Store buffer pointer and num_za_save_slices.
+ // Bytes 10-15 are implicitly zeroed.
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STPXi))
.addReg(MI.getOperand(0).getReg())
+ .addReg(MI.getOperand(1).getReg())
.addFrameIndex(TPIDR2.FrameIndex)
.addImm(0);
- // Set the reserved bytes (10-15) to zero
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui))
- .addReg(AArch64::WZR)
- .addFrameIndex(TPIDR2.FrameIndex)
- .addImm(5);
- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui))
- .addReg(AArch64::WZR)
- .addFrameIndex(TPIDR2.FrameIndex)
- .addImm(3);
} else
MFI.RemoveStackObject(TPIDR2.FrameIndex);
@@ -3083,13 +3113,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
+ RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
- .addExternalSymbol("__arm_sme_state_size")
+ .addExternalSymbol(getLibcallName(LC))
.addReg(AArch64::X0, RegState::ImplicitDefine)
- .addRegMask(TRI->getCallPreservedMask(
- *MF, CallingConv::
- AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+ .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::X0);
@@ -3101,6 +3130,30 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
return BB;
}
+MachineBasicBlock *
+AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
+ MachineBasicBlock *BB) const {
+ MachineFunction *MF = BB->getParent();
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ Register ResultReg = MI.getOperand(0).getReg();
+ if (FuncInfo->isPStateSMRegUsed()) {
+ RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+ .addExternalSymbol(getLibcallName(LC))
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
+ .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
+ .addReg(AArch64::X0);
+ } else {
+ assert(MI.getMF()->getRegInfo().use_empty(ResultReg) &&
+ "Expected no users of the entry pstate.sm!");
+ }
+ MI.eraseFromParent();
+ return BB;
+}
+
// Helper function to find the instruction that defined a virtual register.
// If unable to find such instruction, returns nullptr.
static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
@@ -3216,6 +3269,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitAllocateSMESaveBuffer(MI, BB);
case AArch64::GetSMESaveSize:
return EmitGetSMESaveSize(MI, BB);
+ case AArch64::EntryPStateSM:
+ return EmitEntryPStateSM(MI, BB);
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
@@ -3320,7 +3375,8 @@ static bool isZerosVector(const SDNode *N) {
/// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
/// CC
-static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
+static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC,
+ SDValue RHS = {}) {
switch (CC) {
default:
llvm_unreachable("Unknown condition code!");
@@ -3331,9 +3387,9 @@ static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
case ISD::SETGT:
return AArch64CC::GT;
case ISD::SETGE:
- return AArch64CC::GE;
+ return (RHS && isNullConstant(RHS)) ? AArch64CC::PL : AArch64CC::GE;
case ISD::SETLT:
- return AArch64CC::LT;
+ return (RHS && isNullConstant(RHS)) ? AArch64CC::MI : AArch64CC::LT;
case ISD::SETLE:
return AArch64CC::LE;
case ISD::SETUGT:
@@ -3492,6 +3548,13 @@ bool isLegalCmpImmed(APInt C) {
return isLegalArithImmed(C.abs().getZExtValue());
}
+unsigned numberOfInstrToLoadImm(APInt C) {
+ uint64_t Imm = C.getZExtValue();
+ SmallVector<AArch64_IMM::ImmInsnModel> Insn;
+ AArch64_IMM::expandMOVImm(Imm, 32, Insn);
+ return Insn.size();
+}
+
static bool isSafeSignedCMN(SDValue Op, SelectionDAG &DAG) {
// 0 - INT_MIN sign wraps, so no signed wrap means cmn is safe.
if (Op->getFlags().hasNoSignedWrap())
@@ -3782,7 +3845,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
SDLoc DL(Val);
// Determine OutCC and handle FP special case.
if (isInteger) {
- OutCC = changeIntCCToAArch64CC(CC);
+ OutCC = changeIntCCToAArch64CC(CC, RHS);
} else {
assert(LHS.getValueType().isFloatingPoint());
AArch64CC::CondCode ExtraCC;
@@ -3961,6 +4024,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
// CC has already been adjusted.
RHS = DAG.getConstant(0, DL, VT);
} else if (!isLegalCmpImmed(C)) {
+ unsigned NumImmForC = numberOfInstrToLoadImm(C);
// Constant does not fit, try adjusting it by one?
switch (CC) {
default:
@@ -3969,43 +4033,49 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
case ISD::SETGE:
if (!C.isMinSignedValue()) {
APInt CMinusOne = C - 1;
- if (isLegalCmpImmed(CMinusOne)) {
+ if (isLegalCmpImmed(CMinusOne) ||
+ (NumImmForC > numberOfInstrToLoadImm(CMinusOne))) {
CC = (CC == ISD::SETLT) ? ISD::SETLE : ISD::SETGT;
RHS = DAG.getConstant(CMinusOne, DL, VT);
}
}
break;
case ISD::SETULT:
- case ISD::SETUGE:
- if (!C.isZero()) {
- APInt CMinusOne = C - 1;
- if (isLegalCmpImmed(CMinusOne)) {
- CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT;
- RHS = DAG.getConstant(CMinusOne, DL, VT);
- }
+ case ISD::SETUGE: {
+ // C is not 0 because it is a legal immediate.
+ assert(!C.isZero() && "C should not be zero here");
+ APInt CMinusOne = C - 1;
+ if (isLegalCmpImmed(CMinusOne) ||
+ (NumImmForC > numberOfInstrToLoadImm(CMinusOne))) {
+ CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT;
+ RHS = DAG.getConstant(CMinusOne, DL, VT);
}
break;
+ }
case ISD::SETLE:
case ISD::SETGT:
if (!C.isMaxSignedValue()) {
APInt CPlusOne = C + 1;
- if (isLegalCmpImmed(CPlusOne)) {
+ if (isLegalCmpImmed(CPlusOne) ||
+ (NumImmForC > numberOfInstrToLoadImm(CPlusOne))) {
CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE;
RHS = DAG.getConstant(CPlusOne, DL, VT);
}
}
break;
case ISD::SETULE:
- case ISD::SETUGT:
+ case ISD::SETUGT: {
if (!C.isAllOnes()) {
APInt CPlusOne = C + 1;
- if (isLegalCmpImmed(CPlusOne)) {
+ if (isLegalCmpImmed(CPlusOne) ||
+ (NumImmForC > numberOfInstrToLoadImm(CPlusOne))) {
CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
RHS = DAG.getConstant(CPlusOne, DL, VT);
}
}
break;
}
+ }
}
}
@@ -4079,7 +4149,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
if (!Cmp) {
Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
- AArch64CC = changeIntCCToAArch64CC(CC);
+ AArch64CC = changeIntCCToAArch64CC(CC, RHS);
}
AArch64cc = getCondCode(DAG, AArch64CC);
return Cmp;
@@ -4865,6 +4935,18 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
if (DstWidth < SatWidth)
return SDValue();
+ if (SrcVT == MVT::f16 && SatVT == MVT::i16 && DstVT == MVT::i32) {
+ if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
+ SDValue CVTf32 =
+ DAG.getNode(AArch64ISD::FCVTZS_HALF, DL, MVT::f32, SrcVal);
+ SDValue Bitcast = DAG.getBitcast(DstVT, CVTf32);
+ return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, DstVT, Bitcast,
+ DAG.getValueType(SatVT));
+ }
+ SDValue CVTf32 = DAG.getNode(AArch64ISD::FCVTZU_HALF, DL, MVT::f32, SrcVal);
+ return DAG.getBitcast(DstVT, CVTf32);
+ }
+
SDValue NativeCvt =
DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal, DAG.getValueType(DstVT));
SDValue Sat;
@@ -5174,13 +5256,7 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
ArgListTy Args;
- ArgListEntry Entry;
-
- Entry.Node = Arg;
- Entry.Ty = ArgTy;
- Entry.IsSExt = false;
- Entry.IsZExt = false;
- Args.push_back(Entry);
+ Args.emplace_back(Arg, ArgTy);
RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64
: RTLIB::SINCOS_STRET_F32;
@@ -5711,15 +5787,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
SDValue Chain, SDLoc DL,
EVT VT) const {
- SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
+ RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
+ SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
getPointerTy(DAG.getDataLayout()));
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
Type *RetTy = StructType::get(Int64Ty, Int64Ty);
TargetLowering::CallLoweringInfo CLI(DAG);
ArgListTy Args;
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
- CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
- RetTy, Callee, std::move(Args));
+ getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
@@ -7886,8 +7962,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
else if (ActualMVT == MVT::i16)
ValVT = MVT::i16;
}
- bool Res =
- AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo);
+ bool Res = AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags,
+ Ins[i].OrigTy, CCInfo);
assert(!Res && "Call operand has unhandled type");
(void)Res;
}
@@ -8132,19 +8208,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
+ if (Attrs.hasStreamingCompatibleInterface()) {
+ SDValue EntryPStateSM =
+ DAG.getNode(AArch64ISD::ENTRY_PSTATE_SM, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain});
+
+ // Copy the value to a virtual register, and save that in FuncInfo.
+ Register EntryPStateSMReg =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ Chain = DAG.getCopyToReg(EntryPStateSM.getValue(1), DL, EntryPStateSMReg,
+ EntryPStateSM);
+ FuncInfo->setPStateSMReg(EntryPStateSMReg);
+ }
+
// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
- SDValue PStateSM;
- if (Attrs.hasStreamingCompatibleInterface()) {
- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
- Register Reg = MF.getRegInfo().createVirtualRegister(
- getRegClassFor(PStateSM.getValueType().getSimpleVT()));
- FuncInfo->setPStateSMReg(Reg);
- Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+ if (Attrs.hasStreamingCompatibleInterface())
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
- AArch64SME::IfCallerIsNonStreaming, PStateSM);
- } else
+ AArch64SME::IfCallerIsNonStreaming);
+ else
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
AArch64SME::Always);
@@ -8244,53 +8327,57 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (Subtarget->hasCustomCallingConv())
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
- // Create a 16 Byte TPIDR2 object. The dynamic buffer
- // will be expanded and stored in the static object later using a pseudonode.
- if (Attrs.hasZAState()) {
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
- DAG.getConstant(1, DL, MVT::i32));
-
- SDValue Buffer;
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
- } else {
- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
- DAG.getVTList(MVT::i64, MVT::Other),
- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
- MFI.CreateVariableSizedObject(Align(16), nullptr);
- }
- Chain = DAG.getNode(
- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
- } else if (Attrs.hasAgnosticZAInterface()) {
- // Call __arm_sme_state_size().
- SDValue BufferSize =
- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
- DAG.getVTList(MVT::i64, MVT::Other), Chain);
- Chain = BufferSize.getValue(1);
-
- SDValue Buffer;
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
- Buffer =
- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
- } else {
- // Allocate space dynamically.
- Buffer = DAG.getNode(
- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
- MFI.CreateVariableSizedObject(Align(16), nullptr);
+ if (!getTM().useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
+ // Old SME ABI lowering (deprecated):
+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
+ // will be expanded and stored in the static object later using a
+ // pseudonode.
+ if (Attrs.hasZAState()) {
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+ DAG.getConstant(1, DL, MVT::i32));
+ SDValue Buffer;
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
+ } else {
+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
+ DAG.getVTList(MVT::i64, MVT::Other),
+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ }
+ SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+ DAG.getConstant(1, DL, MVT::i32));
+ Chain = DAG.getNode(
+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0),
+ /*Num save slices*/ NumZaSaveSlices});
+ } else if (Attrs.hasAgnosticZAInterface()) {
+ // Call __arm_sme_state_size().
+ SDValue BufferSize =
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
+ Chain = BufferSize.getValue(1);
+ SDValue Buffer;
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+ DAG.getVTList(MVT::i64, MVT::Other),
+ {Chain, BufferSize});
+ } else {
+ // Allocate space dynamically.
+ Buffer = DAG.getNode(
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ }
+ // Copy the value to a virtual register, and save that in FuncInfo.
+ Register BufferPtr =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}
-
- // Copy the value to a virtual register, and save that in FuncInfo.
- Register BufferPtr =
- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
- FuncInfo->setSMESaveBufferAddr(BufferPtr);
- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}
if (CallConv == CallingConv::PreserveNone) {
@@ -8307,6 +8394,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
}
+ if (getTM().useNewSMEABILowering()) {
+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
+ if (Attrs.isNewZT0())
+ Chain = DAG.getNode(
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
+ DAG.getTargetConstant(0, DL, MVT::i32));
+ }
+
return Chain;
}
@@ -8537,7 +8633,7 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
if (IsCalleeWin64) {
UseVarArgCC = true;
} else {
- UseVarArgCC = !Outs[i].IsFixed;
+ UseVarArgCC = ArgFlags.isVarArg();
}
}
@@ -8557,19 +8653,20 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
// FIXME: CCAssignFnForCall should be called once, for the call and not per
// argument. This logic should exactly mirror LowerFormalArguments.
CCAssignFn *AssignFn = TLI.CCAssignFnForCall(CalleeCC, UseVarArgCC);
- bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
+ bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags,
+ Outs[i].OrigTy, CCInfo);
assert(!Res && "Call operand has unhandled type");
(void)Res;
}
}
static SMECallAttrs
-getSMECallAttrs(const Function &Caller,
+getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI,
const TargetLowering::CallLoweringInfo &CLI) {
if (CLI.CB)
- return SMECallAttrs(*CLI.CB);
+ return SMECallAttrs(*CLI.CB, &TLI);
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
- return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
+ return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
}
@@ -8591,7 +8688,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
- SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState() ||
CallAttrs.caller().hasStreamingBody())
@@ -8834,8 +8931,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
bool Enable, SDValue Chain,
SDValue InGlue,
- unsigned Condition,
- SDValue PStateSM) const {
+ unsigned Condition) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
@@ -8847,9 +8943,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
SmallVector<SDValue> Ops = {Chain, MSROp};
unsigned Opcode;
if (Condition != AArch64SME::Always) {
+ FuncInfo->setPStateSMRegUsed(true);
+ Register PStateReg = FuncInfo->getPStateSMReg();
+ assert(PStateReg.isValid() && "PStateSM Register is invalid");
+ SDValue PStateSM =
+ DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue);
+ // Use chain and glue from the CopyFromReg.
+ Ops[0] = PStateSM.getValue(1);
+ InGlue = PStateSM.getValue(2);
SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
- assert(PStateSM && "PStateSM should be defined");
Ops.push_back(ConditionOp);
Ops.push_back(PStateSM);
} else {
@@ -8871,22 +8974,19 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();
-
TargetLowering::ArgListTy Args;
- TargetLowering::ArgListEntry Entry;
- Entry.Ty = PointerType::getUnqual(*DAG.getContext());
- Entry.Node =
- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
- Args.push_back(Entry);
-
- SDValue Callee =
- DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
- TLI.getPointerTy(DAG.getDataLayout()));
+ Args.emplace_back(
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
+ PointerType::getUnqual(*DAG.getContext()));
+
+ RTLIB::Libcall LC =
+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
+ SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
+ TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
- CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
- Callee, std::move(Args));
+ TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}
@@ -8982,7 +9082,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
unsigned NumArgs = Outs.size();
for (unsigned i = 0; i != NumArgs; ++i) {
- if (!Outs[i].IsFixed && Outs[i].VT.isScalableVector())
+ if (Outs[i].Flags.isVarArg() && Outs[i].VT.isScalableVector())
report_fatal_error("Passing SVE types to variadic functions is "
"currently not supported");
}
@@ -9014,14 +9114,28 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
CallConv = CallingConv::AArch64_SVE_VectorCall;
}
+ // Determine whether we need any streaming mode changes.
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
+ bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
+ // TODO: Handle agnostic ZA functions.
+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
+ return std::nullopt;
+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
+ return std::nullopt;
+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
+ : AArch64ISD::INOUT_ZA_USE;
+ }();
+
if (IsTailCall) {
// Check if it's really possible to do a tail call.
IsTailCall = isEligibleForTailCallOptimization(CLI);
// A sibling call is one where we're under the usual C ABI and not planning
// to change that but can still do a tail call:
- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
- CallConv != CallingConv::SwiftTail)
+ if (!ZAMarkerNode && !TailCallOpt && IsTailCall &&
+ CallConv != CallingConv::Tail && CallConv != CallingConv::SwiftTail)
IsSibCall = true;
if (IsTailCall)
@@ -9073,9 +9187,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
}
- // Determine whether we need any streaming mode changes.
- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
-
auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9089,22 +9200,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return R;
};
- bool RequiresLazySave = CallAttrs.requiresLazySave();
+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
- const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
- MachinePointerInfo MPI =
- MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- SDValue NumZaSaveSlicesAddr =
- DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
- DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
- SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
- DAG.getConstant(1, DL, MVT::i32));
- Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
- MPI, MVT::i16);
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
@@ -9124,15 +9226,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
/*IsSave=*/true);
}
- SDValue PStateSM;
bool RequiresSMChange = CallAttrs.requiresSMChange();
if (RequiresSMChange) {
- if (CallAttrs.caller().hasStreamingInterfaceOrBody())
- PStateSM = DAG.getConstant(1, DL, MVT::i64);
- else if (CallAttrs.caller().hasNonStreamingInterface())
- PStateSM = DAG.getConstant(0, DL, MVT::i64);
- else
- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
OptimizationRemarkEmitter ORE(&MF.getFunction());
ORE.emit([&]() {
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -9171,10 +9266,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
- // Adjust the stack pointer for the new arguments...
+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
// These operations are automatically eliminated by the prolog/epilog pass
- if (!IsSibCall)
+ assert((!IsSibCall || !ZAMarkerNode) && "ZA markers require CALLSEQ_START");
+ if (!IsSibCall) {
Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
+ if (ZAMarkerNode) {
+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
+ // using a chain can result in incorrect scheduling. The markers refer to
+ // the position just before the CALLSEQ_START (though occur after as
+ // CALLSEQ_START lacks in-glue).
+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
+ {Chain, Chain.getValue(1)});
+ }
+ }
SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
getPointerTy(DAG.getDataLayout()));
@@ -9441,17 +9546,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue InGlue;
if (RequiresSMChange) {
- if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
- Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
- DAG.getVTList(MVT::Other, MVT::Glue), Chain);
- InGlue = Chain.getValue(1);
- }
-
- SDValue NewChain = changeStreamingMode(
- DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
- getSMToggleCondition(CallAttrs), PStateSM);
- Chain = NewChain.getValue(0);
- InGlue = NewChain.getValue(1);
+ Chain =
+ changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
+ Chain, InGlue, getSMToggleCondition(CallAttrs));
+ InGlue = Chain.getValue(1);
}
// Build a sequence of copy-to-reg nodes chained together with token chain
@@ -9633,20 +9731,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
InGlue = Result.getValue(Result->getNumValues() - 1);
if (RequiresSMChange) {
- assert(PStateSM && "Expected a PStateSM to be set");
Result = changeStreamingMode(
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
- getSMToggleCondition(CallAttrs), PStateSM);
-
- if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
- InGlue = Result.getValue(1);
- Result =
- DAG.getNode(AArch64ISD::VG_RESTORE, DL,
- DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
- }
+ getSMToggleCondition(CallAttrs));
}
- if (CallAttrs.requiresEnablingZAAfterCall())
+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9659,15 +9749,15 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresLazySave) {
// Conditionally restore the lazy save using a pseudo node.
+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
- TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+ TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
- "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+ getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
-
// Copy the address of the TPIDR2 block into X0 before 'calling' the
// RESTORE_ZA pseudo.
SDValue Glue;
@@ -9679,7 +9769,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
RestoreRoutine, RegMask, Result.getValue(1)});
-
// Finally reset the TPIDR2_EL0 register to 0.
Result = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
@@ -9802,14 +9891,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
- if (FuncAttrs.hasStreamingCompatibleInterface()) {
- Register Reg = FuncInfo->getPStateSMReg();
- assert(Reg.isValid() && "PStateSM Register is invalid");
- SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+ if (FuncAttrs.hasStreamingCompatibleInterface())
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(),
- AArch64SME::IfCallerIsNonStreaming, PStateSM);
- } else
+ AArch64SME::IfCallerIsNonStreaming);
+ else
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(), AArch64SME::Always);
Glue = Chain.getValue(1);
@@ -11390,13 +11476,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(
// select_cc lhs, rhs, sub(rhs, lhs), sub(lhs, rhs), cc ->
// select_cc lhs, rhs, neg(sub(lhs, rhs)), sub(lhs, rhs), cc
// The second forms can be matched into subs+cneg.
+ // NOTE: Drop poison generating flags from the negated operand to avoid
+ // inadvertently propagating poison after the canonicalisation.
if (TVal.getOpcode() == ISD::SUB && FVal.getOpcode() == ISD::SUB) {
if (TVal.getOperand(0) == LHS && TVal.getOperand(1) == RHS &&
- FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS)
+ FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS) {
+ TVal->dropFlags(SDNodeFlags::PoisonGeneratingFlags);
FVal = DAG.getNegative(TVal, DL, TVal.getValueType());
- else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS &&
- FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS)
+ } else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS &&
+ FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS) {
+ FVal->dropFlags(SDNodeFlags::PoisonGeneratingFlags);
TVal = DAG.getNegative(FVal, DL, FVal.getValueType());
+ }
}
unsigned Opcode = AArch64ISD::CSEL;
@@ -13477,7 +13568,7 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
// Look for the first non-undef element.
const int *FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
- // Benefit form APInt to handle overflow when calculating expected element.
+ // Benefit from APInt to handle overflow when calculating expected element.
unsigned NumElts = VT.getVectorNumElements();
unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1, /*isSigned=*/false,
@@ -13485,7 +13576,7 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
// The following shuffle indices must be the successive elements after the
// first real element.
bool FoundWrongElt = std::any_of(FirstRealElt + 1, M.end(), [&](int Elt) {
- return Elt != ExpectedElt++ && Elt != -1;
+ return Elt != ExpectedElt++ && Elt >= 0;
});
if (FoundWrongElt)
return false;
@@ -14737,12 +14828,107 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
return ResultSLI;
}
+static SDValue tryLowerToBSL(SDValue N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ assert(VT.isVector() && "Expected vector type in tryLowerToBSL\n");
+ SDLoc DL(N);
+ const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+
+ if (VT.isScalableVector() && !Subtarget.hasSVE2())
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ if (N0.getOpcode() != ISD::AND)
+ return SDValue();
+
+ SDValue N1 = N->getOperand(1);
+ if (N1.getOpcode() != ISD::AND)
+ return SDValue();
+
+ // InstCombine does (not (neg a)) => (add a -1).
+ // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
+ // Loop over all combinations of AND operands.
+ for (int i = 1; i >= 0; --i) {
+ for (int j = 1; j >= 0; --j) {
+ SDValue O0 = N0->getOperand(i);
+ SDValue O1 = N1->getOperand(j);
+ SDValue Sub, Add, SubSibling, AddSibling;
+
+ // Find a SUB and an ADD operand, one from each AND.
+ if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
+ Sub = O0;
+ Add = O1;
+ SubSibling = N0->getOperand(1 - i);
+ AddSibling = N1->getOperand(1 - j);
+ } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
+ Add = O0;
+ Sub = O1;
+ AddSibling = N0->getOperand(1 - i);
+ SubSibling = N1->getOperand(1 - j);
+ } else
+ continue;
+
+ if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
+ continue;
+
+ // Constant ones is always righthand operand of the Add.
+ if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
+ continue;
+
+ if (Sub.getOperand(1) != Add.getOperand(0))
+ continue;
+
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
+ }
+ }
+
+ // (or (and a b) (and (not a) c)) => (bsl a b c)
+ // We only have to look for constant vectors here since the general, variable
+ // case can be handled in TableGen.
+ unsigned Bits = VT.getScalarSizeInBits();
+ for (int i = 1; i >= 0; --i)
+ for (int j = 1; j >= 0; --j) {
+ APInt Val1, Val2;
+
+ if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
+ ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
+ ~Val1.trunc(Bits) == Val2.trunc(Bits)) {
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
+ N0->getOperand(1 - i), N1->getOperand(1 - j));
+ }
+ BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
+ BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
+ if (!BVN0 || !BVN1)
+ continue;
+
+ bool FoundMatch = true;
+ for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
+ ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
+ ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
+ if (!CN0 || !CN1 ||
+ CN0->getAPIntValue().trunc(Bits) !=
+ ~CN1->getAsAPIntVal().trunc(Bits)) {
+ FoundMatch = false;
+ break;
+ }
+ }
+ if (FoundMatch)
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
+ N0->getOperand(1 - i), N1->getOperand(1 - j));
+ }
+
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
SelectionDAG &DAG) const {
if (useSVEForFixedLengthVectorVT(Op.getValueType(),
!Subtarget->isNeonAvailable()))
return LowerToScalableOp(Op, DAG);
+ if (SDValue Res = tryLowerToBSL(Op, DAG))
+ return Res;
+
// Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2))
if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG))
return Res;
@@ -15772,6 +15958,7 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
isREVMask(M, EltSize, NumElts, 32) ||
isREVMask(M, EltSize, NumElts, 16) ||
isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
+ isSingletonEXTMask(M, VT, DummyUnsigned) ||
isTRNMask(M, NumElts, DummyUnsigned) ||
isUZPMask(M, NumElts, DummyUnsigned) ||
isZIPMask(M, NumElts, DummyUnsigned) ||
@@ -16284,9 +16471,8 @@ AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,
Chain = SP.getValue(1);
SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size);
if (Align)
- SP =
- DAG.getNode(ISD::AND, DL, VT, SP.getValue(0),
- DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT));
+ SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0),
+ DAG.getSignedConstant(-Align->value(), DL, VT));
Chain = DAG.getCopyToReg(Chain, DL, AArch64::SP, SP);
SDValue Ops[2] = {SP, Chain};
return DAG.getMergeValues(Ops, DL);
@@ -16323,7 +16509,7 @@ AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,
SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size);
if (Align)
SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0),
- DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT));
+ DAG.getSignedConstant(-Align->value(), DL, VT));
Chain = DAG.getCopyToReg(Chain, DL, AArch64::SP, SP);
Chain = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), DL);
@@ -16351,7 +16537,7 @@ AArch64TargetLowering::LowerInlineDYNAMIC_STACKALLOC(SDValue Op,
SP = DAG.getNode(ISD::SUB, DL, MVT::i64, SP, Size);
if (Align)
SP = DAG.getNode(ISD::AND, DL, VT, SP.getValue(0),
- DAG.getSignedConstant(-(uint64_t)Align->value(), DL, VT));
+ DAG.getSignedConstant(-Align->value(), DL, VT));
// Set the real SP to the new value with a probing loop.
Chain = DAG.getNode(AArch64ISD::PROBED_ALLOCA, DL, MVT::Other, Chain, SP);
@@ -17254,7 +17440,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor,
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
bool AArch64TargetLowering::lowerInterleavedLoad(
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
- ArrayRef<unsigned> Indices, unsigned Factor) const {
+ ArrayRef<unsigned> Indices, unsigned Factor, const APInt &GapMask) const {
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
"Invalid interleave factor");
assert(!Shuffles.empty() && "Empty shufflevector input");
@@ -17264,7 +17450,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
auto *LI = dyn_cast<LoadInst>(Load);
if (!LI)
return false;
- assert(!Mask && "Unexpected mask on a load");
+ assert(!Mask && GapMask.popcount() == Factor && "Unexpected mask on a load");
const DataLayout &DL = LI->getDataLayout();
@@ -17442,14 +17628,16 @@ bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) {
bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store,
Value *LaneMask,
ShuffleVectorInst *SVI,
- unsigned Factor) const {
+ unsigned Factor,
+ const APInt &GapMask) const {
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
"Invalid interleave factor");
auto *SI = dyn_cast<StoreInst>(Store);
if (!SI)
return false;
- assert(!LaneMask && "Unexpected mask on store");
+ assert(!LaneMask && GapMask.popcount() == Factor &&
+ "Unexpected mask on store");
auto *VecTy = cast<FixedVectorType>(SVI->getType());
assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store");
@@ -17987,7 +18175,8 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
case MVT::f64:
return true;
case MVT::bf16:
- return VT.isScalableVector() && Subtarget->hasSVEB16B16();
+ return VT.isScalableVector() && Subtarget->hasSVEB16B16() &&
+ Subtarget->isNonStreamingSVEorSME2Available();
default:
break;
}
@@ -18151,7 +18340,7 @@ bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, BitSize))
return true;
- if ((int64_t)Val < 0)
+ if (Val < 0)
Val = ~Val;
if (BitSize == 32)
Val &= (1LL << 32) - 1;
@@ -19414,106 +19603,6 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
-static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
- const AArch64TargetLowering &TLI) {
- EVT VT = N->getValueType(0);
- SelectionDAG &DAG = DCI.DAG;
- SDLoc DL(N);
- const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
-
- if (!VT.isVector())
- return SDValue();
-
- if (VT.isScalableVector() && !Subtarget.hasSVE2())
- return SDValue();
-
- if (VT.isFixedLengthVector() &&
- (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT)))
- return SDValue();
-
- SDValue N0 = N->getOperand(0);
- if (N0.getOpcode() != ISD::AND)
- return SDValue();
-
- SDValue N1 = N->getOperand(1);
- if (N1.getOpcode() != ISD::AND)
- return SDValue();
-
- // InstCombine does (not (neg a)) => (add a -1).
- // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
- // Loop over all combinations of AND operands.
- for (int i = 1; i >= 0; --i) {
- for (int j = 1; j >= 0; --j) {
- SDValue O0 = N0->getOperand(i);
- SDValue O1 = N1->getOperand(j);
- SDValue Sub, Add, SubSibling, AddSibling;
-
- // Find a SUB and an ADD operand, one from each AND.
- if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
- Sub = O0;
- Add = O1;
- SubSibling = N0->getOperand(1 - i);
- AddSibling = N1->getOperand(1 - j);
- } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
- Add = O0;
- Sub = O1;
- AddSibling = N0->getOperand(1 - i);
- SubSibling = N1->getOperand(1 - j);
- } else
- continue;
-
- if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
- continue;
-
- // Constant ones is always righthand operand of the Add.
- if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
- continue;
-
- if (Sub.getOperand(1) != Add.getOperand(0))
- continue;
-
- return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
- }
- }
-
- // (or (and a b) (and (not a) c)) => (bsl a b c)
- // We only have to look for constant vectors here since the general, variable
- // case can be handled in TableGen.
- unsigned Bits = VT.getScalarSizeInBits();
- uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
- for (int i = 1; i >= 0; --i)
- for (int j = 1; j >= 0; --j) {
- APInt Val1, Val2;
-
- if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
- ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
- (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) {
- return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
- N0->getOperand(1 - i), N1->getOperand(1 - j));
- }
- BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
- BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
- if (!BVN0 || !BVN1)
- continue;
-
- bool FoundMatch = true;
- for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
- ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
- ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
- if (!CN0 || !CN1 ||
- CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
- FoundMatch = false;
- break;
- }
- }
- if (FoundMatch)
- return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
- N0->getOperand(1 - i), N1->getOperand(1 - j));
- }
-
- return SDValue();
-}
-
// Given a tree of and/or(csel(0, 1, cc0), csel(0, 1, cc1)), we may be able to
// convert to csel(ccmp(.., cc0)), depending on cc1:
@@ -19595,17 +19684,10 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget,
const AArch64TargetLowering &TLI) {
SelectionDAG &DAG = DCI.DAG;
- EVT VT = N->getValueType(0);
if (SDValue R = performANDORCSELCombine(N, DAG))
return R;
- if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
- return SDValue();
-
- if (SDValue Res = tryCombineToBSL(N, DCI, TLI))
- return Res;
-
return SDValue();
}
@@ -22107,6 +22189,17 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
Zero);
}
+static SDValue tryCombineNeonFcvtFP16ToI16(SDNode *N, unsigned Opcode,
+ SelectionDAG &DAG) {
+ if (N->getValueType(0) != MVT::i16)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue CVT = DAG.getNode(Opcode, DL, MVT::f32, N->getOperand(1));
+ SDValue Bitcast = DAG.getBitcast(MVT::i32, CVT);
+ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Bitcast);
+}
+
// If a merged operation has no inactive lanes we can relax it to a predicated
// or unpredicated operation, which potentially allows better isel (perhaps
// using immediate forms) or relaxing register reuse requirements.
@@ -22360,6 +22453,26 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_neon_uabd:
return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_neon_fcvtzs:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTZS_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtzu:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTZU_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtas:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTAS_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtau:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTAU_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtms:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTMS_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtmu:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTMU_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtns:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTNS_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtnu:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTNU_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtps:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTPS_HALF, DAG);
+ case Intrinsic::aarch64_neon_fcvtpu:
+ return tryCombineNeonFcvtFP16ToI16(N, AArch64ISD::FCVTPU_HALF, DAG);
case Intrinsic::aarch64_crc32b:
case Intrinsic::aarch64_crc32cb:
return tryCombineCRC32(0xff, N, DAG);
@@ -25450,6 +25563,29 @@ static SDValue performCSELCombine(SDNode *N,
}
}
+ // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't
+ // use overflow flags, to avoid the comparison with zero. In case of success,
+ // this also replaces the original SUB(x,y) with the newly created SUBS(x,y).
+ // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB
+ // nodes with their SUBS equivalent as is already done for other flag-setting
+ // operators, in which case doing the replacement here becomes redundant.
+ if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) &&
+ isNullConstant(Cond.getOperand(1))) {
+ SDValue Sub = Cond.getOperand(0);
+ AArch64CC::CondCode CC =
+ static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
+ if (Sub.getOpcode() == ISD::SUB &&
+ (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI ||
+ CC == AArch64CC::PL)) {
+ SDLoc DL(N);
+ SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(),
+ Sub.getOperand(0), Sub.getOperand(1));
+ DCI.CombineTo(Sub.getNode(), Subs);
+ DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1));
+ return SDValue(N, 0);
+ }
+ }
+
// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;
@@ -28166,6 +28302,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
case Intrinsic::aarch64_sme_in_streaming_mode: {
SDLoc DL(N);
SDValue Chain = DAG.getEntryNode();
+
SDValue RuntimePStateSM =
getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
Results.push_back(
@@ -28609,14 +28746,20 @@ Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
void AArch64TargetLowering::insertSSPDeclarations(Module &M) const {
// MSVC CRT provides functionalities for stack protection.
- if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) {
+ RTLIB::LibcallImpl SecurityCheckCookieLibcall =
+ getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE);
+
+ RTLIB::LibcallImpl SecurityCookieVar =
+ getLibcallImpl(RTLIB::STACK_CHECK_GUARD);
+ if (SecurityCheckCookieLibcall != RTLIB::Unsupported &&
+ SecurityCookieVar != RTLIB::Unsupported) {
// MSVC CRT has a global variable holding security cookie.
- M.getOrInsertGlobal("__security_cookie",
+ M.getOrInsertGlobal(getLibcallImplName(SecurityCookieVar),
PointerType::getUnqual(M.getContext()));
// MSVC CRT has a function to validate security cookie.
FunctionCallee SecurityCheckCookie =
- M.getOrInsertFunction(Subtarget->getSecurityCheckCookieName(),
+ M.getOrInsertFunction(getLibcallImplName(SecurityCheckCookieLibcall),
Type::getVoidTy(M.getContext()),
PointerType::getUnqual(M.getContext()));
if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) {
@@ -28628,17 +28771,12 @@ void AArch64TargetLowering::insertSSPDeclarations(Module &M) const {
TargetLowering::insertSSPDeclarations(M);
}
-Value *AArch64TargetLowering::getSDagStackGuard(const Module &M) const {
- // MSVC CRT has a global variable holding security cookie.
- if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
- return M.getGlobalVariable("__security_cookie");
- return TargetLowering::getSDagStackGuard(M);
-}
-
Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const {
// MSVC CRT has a function to validate security cookie.
- if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
- return M.getFunction(Subtarget->getSecurityCheckCookieName());
+ RTLIB::LibcallImpl SecurityCheckCookieLibcall =
+ getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE);
+ if (SecurityCheckCookieLibcall != RTLIB::Unsupported)
+ return M.getFunction(getLibcallImplName(SecurityCheckCookieLibcall));
return TargetLowering::getSSPStackGuardCheck(M);
}
@@ -28972,7 +29110,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
- auto CallAttrs = SMECallAttrs(*Base);
+ auto CallAttrs = SMECallAttrs(*Base, this);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState())
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ea63edd8..4673836 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -23,6 +23,8 @@
namespace llvm {
+class AArch64TargetMachine;
+
namespace AArch64 {
/// Possible values of current rounding mode, which is specified in bits
/// 23:22 of FPCR.
@@ -64,6 +66,8 @@ public:
explicit AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI);
+ const AArch64TargetMachine &getTM() const;
+
/// Control the following reassociation of operands: (op (op x, c1), y) -> (op
/// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
@@ -173,6 +177,10 @@ public:
MachineBasicBlock *EmitZTInstr(MachineInstr &MI, MachineBasicBlock *BB,
unsigned Opcode, bool Op0IsDef) const;
MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const;
+
+ // Note: The following group of functions are only used as part of the old SME
+ // ABI lowering. They will be removed once -aarch64-new-sme-abi=true is the
+ // default.
MachineBasicBlock *EmitInitTPIDR2Object(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
@@ -181,6 +189,8 @@ public:
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const;
+ MachineBasicBlock *EmitEntryPStateSM(MachineInstr &MI,
+ MachineBasicBlock *BB) const;
/// Replace (0, vreg) discriminator components with the operands of blend
/// or with (immediate, NoRegister) when possible.
@@ -220,11 +230,11 @@ public:
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
ArrayRef<ShuffleVectorInst *> Shuffles,
- ArrayRef<unsigned> Indices,
- unsigned Factor) const override;
+ ArrayRef<unsigned> Indices, unsigned Factor,
+ const APInt &GapMask) const override;
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
- ShuffleVectorInst *SVI,
- unsigned Factor) const override;
+ ShuffleVectorInst *SVI, unsigned Factor,
+ const APInt &GapMask) const override;
bool lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask,
IntrinsicInst *DI) const override;
@@ -344,7 +354,6 @@ public:
Value *getIRStackGuard(IRBuilderBase &IRB) const override;
void insertSSPDeclarations(Module &M) const override;
- Value *getSDagStackGuard(const Module &M) const override;
Function *getSSPStackGuardCheck(const Module &M) const override;
/// If the target has a standard location for the unsafe stack pointer,
@@ -523,8 +532,8 @@ public:
/// node. \p Condition should be one of the enum values from
/// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
- SDValue Chain, SDValue InGlue, unsigned Condition,
- SDValue PStateSM = SDValue()) const;
+ SDValue Chain, SDValue InGlue,
+ unsigned Condition) const;
bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
@@ -887,6 +896,10 @@ private:
bool shouldScalarizeBinop(SDValue VecOp) const override {
return VecOp.getOpcode() == ISD::SETCC;
}
+
+ bool hasMultipleConditionRegisters(EVT VT) const override {
+ return VT.isScalableVector();
+ }
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index ba7cbcc..feff590 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -36,7 +36,12 @@ def DestructiveBinary : DestructiveInstTypeEnum<5>;
def DestructiveBinaryComm : DestructiveInstTypeEnum<6>;
def DestructiveBinaryCommWithRev : DestructiveInstTypeEnum<7>;
def DestructiveTernaryCommWithRev : DestructiveInstTypeEnum<8>;
-def DestructiveUnaryPassthru : DestructiveInstTypeEnum<9>;
+
+// 3 inputs unpredicated (reg1, reg2, imm).
+// Can be MOVPRFX'd iff reg1 == reg2.
+def Destructive2xRegImmUnpred : DestructiveInstTypeEnum<9>;
+
+def DestructiveUnaryPassthru : DestructiveInstTypeEnum<10>;
class FalseLanesEnum<bits<2> val> {
bits<2> Value = val;
@@ -3027,8 +3032,12 @@ class BaseAddSubEReg64<bit isSub, bit setFlags, RegisterClass dstRegtype,
// Aliases for register+register add/subtract.
class AddSubRegAlias<string asm, Instruction inst, RegisterClass dstRegtype,
- RegisterClass src1Regtype, RegisterClass src2Regtype,
- int shiftExt>
+ RegisterClass src1Regtype, dag src2>
+ : InstAlias<asm#"\t$dst, $src1, $src2",
+ (inst dstRegtype:$dst, src1Regtype:$src1, src2)>;
+class AddSubRegAlias64<string asm, Instruction inst, RegisterClass dstRegtype,
+ RegisterClass src1Regtype, RegisterClass src2Regtype,
+ int shiftExt>
: InstAlias<asm#"\t$dst, $src1, $src2",
(inst dstRegtype:$dst, src1Regtype:$src1, src2Regtype:$src2,
shiftExt)>;
@@ -3096,22 +3105,22 @@ multiclass AddSub<bit isSub, string mnemonic, string alias,
// Register/register aliases with no shift when SP is not used.
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"),
- GPR32, GPR32, GPR32, 0>;
+ GPR32, GPR32, (arith_shifted_reg32 GPR32:$src2, 0)>;
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"),
- GPR64, GPR64, GPR64, 0>;
+ GPR64, GPR64, (arith_shifted_reg64 GPR64:$src2, 0)>;
// Register/register aliases with no shift when either the destination or
// first source register is SP.
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"),
- GPR32sponly, GPR32sp, GPR32, 16>; // UXTW #0
+ GPR32sponly, GPR32sp,
+ (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"),
- GPR32sp, GPR32sponly, GPR32, 16>; // UXTW #0
- def : AddSubRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Xrx64"),
- GPR64sponly, GPR64sp, GPR64, 24>; // UXTX #0
- def : AddSubRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Xrx64"),
- GPR64sp, GPR64sponly, GPR64, 24>; // UXTX #0
+ GPR32sp, GPR32sponly,
+ (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0
+ def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"),
+ GPR64sponly, GPR64sp, GPR64, 24>; // UXTX #0
+ def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"),
+ GPR64sp, GPR64sponly, GPR64, 24>; // UXTX #0
}
multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp,
@@ -3175,15 +3184,19 @@ multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp,
def : InstAlias<cmp#"\t$src, $imm", (!cast<Instruction>(NAME#"Xri")
XZR, GPR64sp:$src, addsub_shifted_imm64:$imm), 5>;
def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Wrx")
- WZR, GPR32sp:$src1, GPR32:$src2, arith_extend:$sh), 4>;
+ WZR, GPR32sp:$src1,
+ (arith_extended_reg32_i32 GPR32:$src2, arith_extend:$sh)), 4>;
def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrx")
- XZR, GPR64sp:$src1, GPR32:$src2, arith_extend:$sh), 4>;
+ XZR, GPR64sp:$src1,
+ (arith_extended_reg32_i64 GPR32:$src2, arith_extend:$sh)), 4>;
def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrx64")
XZR, GPR64sp:$src1, GPR64:$src2, arith_extendlsl64:$sh), 4>;
def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Wrs")
- WZR, GPR32:$src1, GPR32:$src2, arith_shift32:$sh), 4>;
+ WZR, GPR32:$src1,
+ (arith_shifted_reg32 GPR32:$src2, arith_shift32:$sh)), 4>;
def : InstAlias<cmp#"\t$src1, $src2$sh", (!cast<Instruction>(NAME#"Xrs")
- XZR, GPR64:$src1, GPR64:$src2, arith_shift64:$sh), 4>;
+ XZR, GPR64:$src1,
+ (arith_shifted_reg64 GPR64:$src2, arith_shift64:$sh)), 4>;
// Support negative immediates, e.g. cmp Rn, -imm -> cmn Rn, imm
def : InstSubst<cmpAlias#"\t$src, $imm", (!cast<Instruction>(NAME#"Wri")
@@ -3193,27 +3206,28 @@ multiclass AddSubS<bit isSub, string mnemonic, SDNode OpNode, string cmp,
// Compare shorthands
def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Wrs")
- WZR, GPR32:$src1, GPR32:$src2, 0), 5>;
+ WZR, GPR32:$src1, (arith_shifted_reg32 GPR32:$src2, 0)), 5>;
def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Xrs")
- XZR, GPR64:$src1, GPR64:$src2, 0), 5>;
+ XZR, GPR64:$src1, (arith_shifted_reg64 GPR64:$src2, 0)), 5>;
def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Wrx")
- WZR, GPR32sponly:$src1, GPR32:$src2, 16), 5>;
+ WZR, GPR32sponly:$src1,
+ (arith_extended_reg32_i32 GPR32:$src2, 16)), 5>;
def : InstAlias<cmp#"\t$src1, $src2", (!cast<Instruction>(NAME#"Xrx64")
XZR, GPR64sponly:$src1, GPR64:$src2, 24), 5>;
// Register/register aliases with no shift when SP is not used.
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"),
- GPR32, GPR32, GPR32, 0>;
+ GPR32, GPR32, (arith_shifted_reg32 GPR32:$src2, 0)>;
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"),
- GPR64, GPR64, GPR64, 0>;
+ GPR64, GPR64, (arith_shifted_reg64 GPR64:$src2, 0)>;
// Register/register aliases with no shift when the first source register
// is SP.
def : AddSubRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrx"),
- GPR32, GPR32sponly, GPR32, 16>; // UXTW #0
- def : AddSubRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Xrx64"),
- GPR64, GPR64sponly, GPR64, 24>; // UXTX #0
+ GPR32, GPR32sponly,
+ (arith_extended_reg32_i32 GPR32:$src2, 16)>; // UXTW #0
+ def : AddSubRegAlias64<mnemonic, !cast<Instruction>(NAME#"Xrx64"),
+ GPR64, GPR64sponly, GPR64, 24>; // UXTX #0
}
class AddSubG<bit isSub, string asm_inst, SDPatternOperator OpNode>
@@ -3398,9 +3412,10 @@ class BaseLogicalSReg<bits<2> opc, bit N, RegisterClass regtype,
}
// Aliases for register+register logical instructions.
-class LogicalRegAlias<string asm, Instruction inst, RegisterClass regtype>
+class LogicalRegAlias<string asm, Instruction inst, RegisterClass regtype,
+ dag op2>
: InstAlias<asm#"\t$dst, $src1, $src2",
- (inst regtype:$dst, regtype:$src1, regtype:$src2, 0)>;
+ (inst regtype:$dst, regtype:$src1, op2)>;
multiclass LogicalImm<bits<2> opc, string mnemonic, SDNode OpNode,
string Alias> {
@@ -3472,10 +3487,10 @@ multiclass LogicalReg<bits<2> opc, bit N, string mnemonic,
let Inst{31} = 1;
}
- def : LogicalRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Wrs"), GPR32>;
- def : LogicalRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Xrs"), GPR64>;
+ def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"),
+ GPR32, (logical_shifted_reg32 GPR32:$src2, 0)>;
+ def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"),
+ GPR64, (logical_shifted_reg64 GPR64:$src2, 0)>;
}
// Split from LogicalReg to allow setting NZCV Defs
@@ -3495,10 +3510,10 @@ multiclass LogicalRegS<bits<2> opc, bit N, string mnemonic,
}
} // Defs = [NZCV]
- def : LogicalRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Wrs"), GPR32>;
- def : LogicalRegAlias<mnemonic,
- !cast<Instruction>(NAME#"Xrs"), GPR64>;
+ def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Wrs"),
+ GPR32, (logical_shifted_reg32 GPR32:$src2, 0)>;
+ def : LogicalRegAlias<mnemonic, !cast<Instruction>(NAME#"Xrs"),
+ GPR64, (logical_shifted_reg64 GPR64:$src2, 0)>;
}
//---
@@ -3986,9 +4001,10 @@ class LoadStore8RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins,
let Inst{4-0} = Rt;
}
-class ROInstAlias<string asm, DAGOperand regtype, Instruction INST>
+class ROInstAlias<string asm, DAGOperand regtype, Instruction INST,
+ ro_extend ext>
: InstAlias<asm # "\t$Rt, [$Rn, $Rm]",
- (INST regtype:$Rt, GPR64sp:$Rn, GPR64:$Rm, 0, 0)>;
+ (INST regtype:$Rt, GPR64sp:$Rn, GPR64:$Rm, (ext 0, 0))>;
multiclass Load8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
string asm, ValueType Ty, SDPatternOperator loadop> {
@@ -4014,7 +4030,7 @@ multiclass Load8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend8>;
}
multiclass Store8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
@@ -4039,7 +4055,7 @@ multiclass Store8RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend8>;
}
class LoadStore16RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins,
@@ -4086,7 +4102,7 @@ multiclass Load16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend16>;
}
multiclass Store16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
@@ -4111,7 +4127,7 @@ multiclass Store16RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend16>;
}
class LoadStore32RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins,
@@ -4158,7 +4174,7 @@ multiclass Load32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend32>;
}
multiclass Store32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
@@ -4183,7 +4199,7 @@ multiclass Store32RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend32>;
}
class LoadStore64RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins,
@@ -4230,7 +4246,7 @@ multiclass Load64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend64>;
}
multiclass Store64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
@@ -4255,7 +4271,7 @@ multiclass Store64RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend64>;
}
class LoadStore128RO<bits<2> sz, bit V, bits<2> opc, string asm, dag ins,
@@ -4302,7 +4318,7 @@ multiclass Load128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend128>;
}
multiclass Store128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
@@ -4323,7 +4339,7 @@ multiclass Store128RO<bits<2> sz, bit V, bits<2> opc, DAGOperand regtype,
let Inst{13} = 0b1;
}
- def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX")>;
+ def : ROInstAlias<asm, regtype, !cast<Instruction>(NAME # "roX"), ro_Xextend128>;
}
let mayLoad = 0, mayStore = 0, hasSideEffects = 1 in
@@ -4372,9 +4388,7 @@ multiclass PrefetchRO<bits<2> sz, bit V, bits<2> opc, string asm> {
let Inst{13} = 0b1;
}
- def : InstAlias<"prfm $Rt, [$Rn, $Rm]",
- (!cast<Instruction>(NAME # "roX") prfop:$Rt,
- GPR64sp:$Rn, GPR64:$Rm, 0, 0)>;
+ def : ROInstAlias<"prfm", prfop, !cast<Instruction>(NAME # "roX"), ro_Xextend64>;
}
//---
@@ -6484,7 +6498,9 @@ class BaseSIMDThreeSameVectorDot<bit Q, bit U, bits<2> sz, bits<4> opc, string a
(OpNode (AccumType RegType:$Rd),
(InputType RegType:$Rn),
(InputType RegType:$Rm)))]> {
- let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}");
+
+ let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 #
+ "|" # kind1 # "\t$Rd, $Rn, $Rm}");
}
multiclass SIMDThreeSameVectorDot<bit U, bit Mixed, string asm, SDPatternOperator OpNode> {
@@ -6507,7 +6523,8 @@ class BaseSIMDThreeSameVectorFML<bit Q, bit U, bit b13, bits<3> size, string asm
(OpNode (AccumType RegType:$Rd),
(InputType RegType:$Rn),
(InputType RegType:$Rm)))]> {
- let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 # "}");
+ let AsmString = !strconcat(asm, "{\t$Rd" # kind1 # ", $Rn" # kind2 # ", $Rm" # kind2 #
+ "|" # kind1 # "\t$Rd, $Rn, $Rm}");
let Inst{13} = b13;
}
@@ -7359,7 +7376,9 @@ multiclass SIMDDifferentThreeVectorBD<bit U, bits<4> opc, string asm,
[(set (v8i16 V128:$Rd), (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm)))]>;
def v16i8 : BaseSIMDDifferentThreeVector<U, 0b001, opc,
V128, V128, V128,
- asm#"2", ".8h", ".16b", ".16b", []>;
+ asm#"2", ".8h", ".16b", ".16b",
+ [(set (v8i16 V128:$Rd), (OpNode (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))),
+ (v8i8 (extract_high_v16i8 (v16i8 V128:$Rm)))))]>;
let Predicates = [HasAES] in {
def v1i64 : BaseSIMDDifferentThreeVector<U, 0b110, opc,
V128, V64, V64,
@@ -7371,10 +7390,6 @@ multiclass SIMDDifferentThreeVectorBD<bit U, bits<4> opc, string asm,
[(set (v16i8 V128:$Rd), (OpNode (extract_high_v2i64 (v2i64 V128:$Rn)),
(extract_high_v2i64 (v2i64 V128:$Rm))))]>;
}
-
- def : Pat<(v8i16 (OpNode (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))),
- (v8i8 (extract_high_v16i8 (v16i8 V128:$Rm))))),
- (!cast<Instruction>(NAME#"v16i8") V128:$Rn, V128:$Rm)>;
}
multiclass SIMDLongThreeVectorHS<bit U, bits<4> opc, string asm,
@@ -7399,87 +7414,7 @@ multiclass SIMDLongThreeVectorHS<bit U, bits<4> opc, string asm,
(extract_high_v4i32 (v4i32 V128:$Rm))))]>;
}
-multiclass SIMDLongThreeVectorBHSabdl<bit U, bits<4> opc, string asm,
- SDPatternOperator OpNode = null_frag> {
- def v8i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b000, opc,
- V128, V64, V64,
- asm, ".8h", ".8b", ".8b",
- [(set (v8i16 V128:$Rd),
- (zext (v8i8 (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm)))))]>;
- def v16i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b001, opc,
- V128, V128, V128,
- asm#"2", ".8h", ".16b", ".16b",
- [(set (v8i16 V128:$Rd),
- (zext (v8i8 (OpNode (extract_high_v16i8 (v16i8 V128:$Rn)),
- (extract_high_v16i8 (v16i8 V128:$Rm))))))]>;
- def v4i16_v4i32 : BaseSIMDDifferentThreeVector<U, 0b010, opc,
- V128, V64, V64,
- asm, ".4s", ".4h", ".4h",
- [(set (v4i32 V128:$Rd),
- (zext (v4i16 (OpNode (v4i16 V64:$Rn), (v4i16 V64:$Rm)))))]>;
- def v8i16_v4i32 : BaseSIMDDifferentThreeVector<U, 0b011, opc,
- V128, V128, V128,
- asm#"2", ".4s", ".8h", ".8h",
- [(set (v4i32 V128:$Rd),
- (zext (v4i16 (OpNode (extract_high_v8i16 (v8i16 V128:$Rn)),
- (extract_high_v8i16 (v8i16 V128:$Rm))))))]>;
- def v2i32_v2i64 : BaseSIMDDifferentThreeVector<U, 0b100, opc,
- V128, V64, V64,
- asm, ".2d", ".2s", ".2s",
- [(set (v2i64 V128:$Rd),
- (zext (v2i32 (OpNode (v2i32 V64:$Rn), (v2i32 V64:$Rm)))))]>;
- def v4i32_v2i64 : BaseSIMDDifferentThreeVector<U, 0b101, opc,
- V128, V128, V128,
- asm#"2", ".2d", ".4s", ".4s",
- [(set (v2i64 V128:$Rd),
- (zext (v2i32 (OpNode (extract_high_v4i32 (v4i32 V128:$Rn)),
- (extract_high_v4i32 (v4i32 V128:$Rm))))))]>;
-}
-
-multiclass SIMDLongThreeVectorTiedBHSabal<bit U, bits<4> opc,
- string asm,
- SDPatternOperator OpNode> {
- def v8i8_v8i16 : BaseSIMDDifferentThreeVectorTied<U, 0b000, opc,
- V128, V64, V64,
- asm, ".8h", ".8b", ".8b",
- [(set (v8i16 V128:$dst),
- (add (v8i16 V128:$Rd),
- (zext (v8i8 (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm))))))]>;
- def v16i8_v8i16 : BaseSIMDDifferentThreeVectorTied<U, 0b001, opc,
- V128, V128, V128,
- asm#"2", ".8h", ".16b", ".16b",
- [(set (v8i16 V128:$dst),
- (add (v8i16 V128:$Rd),
- (zext (v8i8 (OpNode (extract_high_v16i8 (v16i8 V128:$Rn)),
- (extract_high_v16i8 (v16i8 V128:$Rm)))))))]>;
- def v4i16_v4i32 : BaseSIMDDifferentThreeVectorTied<U, 0b010, opc,
- V128, V64, V64,
- asm, ".4s", ".4h", ".4h",
- [(set (v4i32 V128:$dst),
- (add (v4i32 V128:$Rd),
- (zext (v4i16 (OpNode (v4i16 V64:$Rn), (v4i16 V64:$Rm))))))]>;
- def v8i16_v4i32 : BaseSIMDDifferentThreeVectorTied<U, 0b011, opc,
- V128, V128, V128,
- asm#"2", ".4s", ".8h", ".8h",
- [(set (v4i32 V128:$dst),
- (add (v4i32 V128:$Rd),
- (zext (v4i16 (OpNode (extract_high_v8i16 (v8i16 V128:$Rn)),
- (extract_high_v8i16 (v8i16 V128:$Rm)))))))]>;
- def v2i32_v2i64 : BaseSIMDDifferentThreeVectorTied<U, 0b100, opc,
- V128, V64, V64,
- asm, ".2d", ".2s", ".2s",
- [(set (v2i64 V128:$dst),
- (add (v2i64 V128:$Rd),
- (zext (v2i32 (OpNode (v2i32 V64:$Rn), (v2i32 V64:$Rm))))))]>;
- def v4i32_v2i64 : BaseSIMDDifferentThreeVectorTied<U, 0b101, opc,
- V128, V128, V128,
- asm#"2", ".2d", ".4s", ".4s",
- [(set (v2i64 V128:$dst),
- (add (v2i64 V128:$Rd),
- (zext (v2i32 (OpNode (extract_high_v4i32 (v4i32 V128:$Rn)),
- (extract_high_v4i32 (v4i32 V128:$Rm)))))))]>;
-}
-
+let isCommutable = 1 in
multiclass SIMDLongThreeVectorBHS<bit U, bits<4> opc, string asm,
SDPatternOperator OpNode = null_frag> {
def v8i8_v8i16 : BaseSIMDDifferentThreeVector<U, 0b000, opc,
@@ -8986,7 +8921,8 @@ class BaseSIMDThreeSameVectorBFDot<bit Q, bit U, string asm, string kind1,
(InputType RegType:$Rm)))]> {
let AsmString = !strconcat(asm,
"{\t$Rd" # kind1 # ", $Rn" # kind2 #
- ", $Rm" # kind2 # "}");
+ ", $Rm" # kind2 #
+ "|" # kind1 # "\t$Rd, $Rn, $Rm}");
}
multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
@@ -9032,7 +8968,7 @@ class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
[(set (v4f32 V128:$dst), (OpNode (v4f32 V128:$Rd),
(v8bf16 V128:$Rn),
(v8bf16 V128:$Rm)))]> {
- let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h}");
+ let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h|.4s\t$Rd, $Rn, $Rm}");
}
let mayRaiseFPException = 1, Uses = [FPCR] in
@@ -9071,8 +9007,7 @@ class SIMDThreeSameVectorBF16MatrixMul<string asm>
(int_aarch64_neon_bfmmla (v4f32 V128:$Rd),
(v8bf16 V128:$Rn),
(v8bf16 V128:$Rm)))]> {
- let AsmString = !strconcat(asm, "{\t$Rd", ".4s", ", $Rn", ".8h",
- ", $Rm", ".8h", "}");
+ let AsmString = !strconcat(asm, "{\t$Rd.4s, $Rn.8h, $Rm.8h|.4s\t$Rd, $Rn, $Rm}");
}
let mayRaiseFPException = 1, Uses = [FPCR] in
@@ -9143,7 +9078,7 @@ class SIMDThreeSameVectorMatMul<bit B, bit U, string asm, SDPatternOperator OpNo
[(set (v4i32 V128:$dst), (OpNode (v4i32 V128:$Rd),
(v16i8 V128:$Rn),
(v16i8 V128:$Rm)))]> {
- let AsmString = asm # "{\t$Rd.4s, $Rn.16b, $Rm.16b}";
+ let AsmString = asm # "{\t$Rd.4s, $Rn.16b, $Rm.16b|.4s\t$Rd, $Rn, $Rm}";
}
//----------------------------------------------------------------------------
@@ -12561,7 +12496,7 @@ multiclass STOPregister<string asm, string instr> {
let Predicates = [HasLSUI] in
class BaseSTOPregisterLSUI<string asm, RegisterClass OP, Register Reg,
Instruction inst> :
- InstAlias<asm # "\t$Rs, [$Rn]", (inst Reg, OP:$Rs, GPR64sp:$Rn), 0>;
+ InstAlias<asm # "\t$Rs, [$Rn]", (inst Reg, OP:$Rs, GPR64sp:$Rn)>;
multiclass STOPregisterLSUI<string asm, string instr> {
def : BaseSTOPregisterLSUI<asm # "l", GPR32, WZR,
@@ -13344,8 +13279,8 @@ multiclass AtomicFPStore<bit R, bits<3> op0, string asm> {
class BaseSIMDThreeSameVectorFP8MatrixMul<string asm, bits<2> size, string kind>
: BaseSIMDThreeSameVectorTied<1, 1, {size, 0}, 0b11101,
V128, asm, ".16b", []> {
- let AsmString = !strconcat(asm, "{\t$Rd", kind, ", $Rn", ".16b",
- ", $Rm", ".16b", "}");
+ let AsmString = !strconcat(asm, "{\t$Rd", kind, ", $Rn.16b, $Rm.16b",
+ "|", kind, "\t$Rd, $Rn, $Rm}");
}
multiclass SIMDThreeSameVectorFP8MatrixMul<string asm>{
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 59d4fd2..3ce7829 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -20,7 +20,9 @@
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/CodeGen/CFIInstBuilder.h"
#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -83,6 +85,11 @@ static cl::opt<unsigned>
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
cl::desc("Restrict range of B instructions (DEBUG)"));
+static cl::opt<unsigned> GatherOptSearchLimit(
+ "aarch64-search-limit", cl::Hidden, cl::init(2048),
+ cl::desc("Restrict range of instructions to search for the "
+ "machine-combiner gather pattern optimization"));
+
AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
AArch64::CATCHRET),
@@ -5068,7 +5075,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
}
- } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGP()) {
+ } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) {
BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
@@ -5078,8 +5085,13 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
// Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
MCRegister DestRegX = TRI->getMatchingSuperReg(
DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
- MCRegister SrcRegX = TRI->getMatchingSuperReg(
- SrcReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
+ assert(DestRegX.isValid() && "Destination super-reg not valid");
+ MCRegister SrcRegX =
+ SrcReg == AArch64::WZR
+ ? AArch64::XZR
+ : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(SrcRegX.isValid() && "Source super-reg not valid");
// This instruction is reading and writing X registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
// that we are reading an undefined value from SrcRegX, but a proper
@@ -5190,7 +5202,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addReg(SrcReg, getKillRegState(KillSrc))
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGP()) {
+ } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) {
BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
@@ -5306,15 +5318,49 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (AArch64::FPR64RegClass.contains(DestReg) &&
AArch64::FPR64RegClass.contains(SrcReg)) {
- BuildMI(MBB, I, DL, get(AArch64::FMOVDr), DestReg)
- .addReg(SrcReg, getKillRegState(KillSrc));
+ if (Subtarget.hasZeroCycleRegMoveFPR128() &&
+ !Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::dsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::dsub,
+ &AArch64::FPR128RegClass);
+ // This instruction is reading and writing Q registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegQ, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
+ } else {
+ BuildMI(MBB, I, DL, get(AArch64::FMOVDr), DestReg)
+ .addReg(SrcReg, getKillRegState(KillSrc));
+ }
return;
}
if (AArch64::FPR32RegClass.contains(DestReg) &&
AArch64::FPR32RegClass.contains(SrcReg)) {
- if (Subtarget.hasZeroCycleRegMoveFPR64() &&
- !Subtarget.hasZeroCycleRegMoveFPR32()) {
+ if (Subtarget.hasZeroCycleRegMoveFPR128() &&
+ !Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::ssub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::ssub,
+ &AArch64::FPR128RegClass);
+ // This instruction is reading and writing Q registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegQ, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
+ } else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32()) {
const TargetRegisterInfo *TRI = &getRegisterInfo();
MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::ssub,
&AArch64::FPR64RegClass);
@@ -5336,8 +5382,24 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (AArch64::FPR16RegClass.contains(DestReg) &&
AArch64::FPR16RegClass.contains(SrcReg)) {
- if (Subtarget.hasZeroCycleRegMoveFPR64() &&
- !Subtarget.hasZeroCycleRegMoveFPR32()) {
+ if (Subtarget.hasZeroCycleRegMoveFPR128() &&
+ !Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32() && Subtarget.isNeonAvailable()) {
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::hsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::hsub,
+ &AArch64::FPR128RegClass);
+ // This instruction is reading and writing Q registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegQ, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
+ } else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32()) {
const TargetRegisterInfo *TRI = &getRegisterInfo();
MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::hsub,
&AArch64::FPR64RegClass);
@@ -5363,8 +5425,24 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
if (AArch64::FPR8RegClass.contains(DestReg) &&
AArch64::FPR8RegClass.contains(SrcReg)) {
- if (Subtarget.hasZeroCycleRegMoveFPR64() &&
- !Subtarget.hasZeroCycleRegMoveFPR32()) {
+ if (Subtarget.hasZeroCycleRegMoveFPR128() &&
+ !Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR64() && Subtarget.isNeonAvailable()) {
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ MCRegister DestRegQ = TRI->getMatchingSuperReg(DestReg, AArch64::bsub,
+ &AArch64::FPR128RegClass);
+ MCRegister SrcRegQ = TRI->getMatchingSuperReg(SrcReg, AArch64::bsub,
+ &AArch64::FPR128RegClass);
+ // This instruction is reading and writing Q registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegQ, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRv16i8), DestRegQ)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcRegQ, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
+ } else if (Subtarget.hasZeroCycleRegMoveFPR64() &&
+ !Subtarget.hasZeroCycleRegMoveFPR32()) {
const TargetRegisterInfo *TRI = &getRegisterInfo();
MCRegister DestRegD = TRI->getMatchingSuperReg(DestReg, AArch64::bsub,
&AArch64::FPR64RegClass);
@@ -5861,33 +5939,53 @@ void AArch64InstrInfo::decomposeStackOffsetForFrameOffsets(
}
}
-// Convenience function to create a DWARF expression for
-// Expr + NumBytes + NumVGScaledBytes * AArch64::VG
-static void appendVGScaledOffsetExpr(SmallVectorImpl<char> &Expr, int NumBytes,
- int NumVGScaledBytes, unsigned VG,
- llvm::raw_string_ostream &Comment) {
- uint8_t buffer[16];
-
- if (NumBytes) {
+// Convenience function to create a DWARF expression for: Constant `Operation`.
+// This helper emits compact sequences for common cases. For example, for`-15
+// DW_OP_plus`, this helper would create DW_OP_lit15 DW_OP_minus.
+static void appendConstantExpr(SmallVectorImpl<char> &Expr, int64_t Constant,
+ dwarf::LocationAtom Operation) {
+ if (Operation == dwarf::DW_OP_plus && Constant < 0 && -Constant <= 31) {
+ // -Constant (1 to 31)
+ Expr.push_back(dwarf::DW_OP_lit0 - Constant);
+ Operation = dwarf::DW_OP_minus;
+ } else if (Constant >= 0 && Constant <= 31) {
+ // Literal value 0 to 31
+ Expr.push_back(dwarf::DW_OP_lit0 + Constant);
+ } else {
+ // Signed constant
Expr.push_back(dwarf::DW_OP_consts);
- Expr.append(buffer, buffer + encodeSLEB128(NumBytes, buffer));
- Expr.push_back((uint8_t)dwarf::DW_OP_plus);
- Comment << (NumBytes < 0 ? " - " : " + ") << std::abs(NumBytes);
+ appendLEB128<LEB128Sign::Signed>(Expr, Constant);
}
+ return Expr.push_back(Operation);
+}
- if (NumVGScaledBytes) {
- Expr.push_back((uint8_t)dwarf::DW_OP_consts);
- Expr.append(buffer, buffer + encodeSLEB128(NumVGScaledBytes, buffer));
-
- Expr.push_back((uint8_t)dwarf::DW_OP_bregx);
- Expr.append(buffer, buffer + encodeULEB128(VG, buffer));
- Expr.push_back(0);
+// Convenience function to create a DWARF expression for a register.
+static void appendReadRegExpr(SmallVectorImpl<char> &Expr, unsigned RegNum) {
+ Expr.push_back((char)dwarf::DW_OP_bregx);
+ appendLEB128<LEB128Sign::Unsigned>(Expr, RegNum);
+ Expr.push_back(0);
+}
- Expr.push_back((uint8_t)dwarf::DW_OP_mul);
- Expr.push_back((uint8_t)dwarf::DW_OP_plus);
+// Convenience function to create a DWARF expression for loading a register from
+// a CFA offset.
+static void appendLoadRegExpr(SmallVectorImpl<char> &Expr,
+ int64_t OffsetFromDefCFA) {
+ // This assumes the top of the DWARF stack contains the CFA.
+ Expr.push_back(dwarf::DW_OP_dup);
+ // Add the offset to the register.
+ appendConstantExpr(Expr, OffsetFromDefCFA, dwarf::DW_OP_plus);
+ // Dereference the address (loads a 64 bit value)..
+ Expr.push_back(dwarf::DW_OP_deref);
+}
- Comment << (NumVGScaledBytes < 0 ? " - " : " + ")
- << std::abs(NumVGScaledBytes) << " * VG";
+// Convenience function to create a comment for
+// (+/-) NumBytes (* RegScale)?
+static void appendOffsetComment(int NumBytes, llvm::raw_string_ostream &Comment,
+ StringRef RegScale = {}) {
+ if (NumBytes) {
+ Comment << (NumBytes < 0 ? " - " : " + ") << std::abs(NumBytes);
+ if (!RegScale.empty())
+ Comment << ' ' << RegScale;
}
}
@@ -5909,19 +6007,26 @@ static MCCFIInstruction createDefCFAExpression(const TargetRegisterInfo &TRI,
else
Comment << printReg(Reg, &TRI);
- // Build up the expression (Reg + NumBytes + NumVGScaledBytes * AArch64::VG)
+ // Build up the expression (Reg + NumBytes + VG * NumVGScaledBytes)
SmallString<64> Expr;
unsigned DwarfReg = TRI.getDwarfRegNum(Reg, true);
- Expr.push_back((uint8_t)(dwarf::DW_OP_breg0 + DwarfReg));
- Expr.push_back(0);
- appendVGScaledOffsetExpr(Expr, NumBytes, NumVGScaledBytes,
- TRI.getDwarfRegNum(AArch64::VG, true), Comment);
+ assert(DwarfReg <= 31 && "DwarfReg out of bounds (0..31)");
+ // Reg + NumBytes
+ Expr.push_back(dwarf::DW_OP_breg0 + DwarfReg);
+ appendLEB128<LEB128Sign::Signed>(Expr, NumBytes);
+ appendOffsetComment(NumBytes, Comment);
+ if (NumVGScaledBytes) {
+ // + VG * NumVGScaledBytes
+ appendOffsetComment(NumVGScaledBytes, Comment, "* VG");
+ appendReadRegExpr(Expr, TRI.getDwarfRegNum(AArch64::VG, true));
+ appendConstantExpr(Expr, NumVGScaledBytes, dwarf::DW_OP_mul);
+ Expr.push_back(dwarf::DW_OP_plus);
+ }
// Wrap this into DW_CFA_def_cfa.
SmallString<64> DefCfaExpr;
DefCfaExpr.push_back(dwarf::DW_CFA_def_cfa_expression);
- uint8_t buffer[16];
- DefCfaExpr.append(buffer, buffer + encodeULEB128(Expr.size(), buffer));
+ appendLEB128<LEB128Sign::Unsigned>(DefCfaExpr, Expr.size());
DefCfaExpr.append(Expr.str());
return MCCFIInstruction::createEscape(nullptr, DefCfaExpr.str(), SMLoc(),
Comment.str());
@@ -5941,9 +6046,10 @@ MCCFIInstruction llvm::createDefCFA(const TargetRegisterInfo &TRI,
return MCCFIInstruction::cfiDefCfa(nullptr, DwarfReg, (int)Offset.getFixed());
}
-MCCFIInstruction llvm::createCFAOffset(const TargetRegisterInfo &TRI,
- unsigned Reg,
- const StackOffset &OffsetFromDefCFA) {
+MCCFIInstruction
+llvm::createCFAOffset(const TargetRegisterInfo &TRI, unsigned Reg,
+ const StackOffset &OffsetFromDefCFA,
+ std::optional<int64_t> IncomingVGOffsetFromDefCFA) {
int64_t NumBytes, NumVGScaledBytes;
AArch64InstrInfo::decomposeStackOffsetForDwarfOffsets(
OffsetFromDefCFA, NumBytes, NumVGScaledBytes);
@@ -5958,17 +6064,32 @@ MCCFIInstruction llvm::createCFAOffset(const TargetRegisterInfo &TRI,
llvm::raw_string_ostream Comment(CommentBuffer);
Comment << printReg(Reg, &TRI) << " @ cfa";
- // Build up expression (NumBytes + NumVGScaledBytes * AArch64::VG)
+ // Build up expression (CFA + VG * NumVGScaledBytes + NumBytes)
+ assert(NumVGScaledBytes && "Expected scalable offset");
SmallString<64> OffsetExpr;
- appendVGScaledOffsetExpr(OffsetExpr, NumBytes, NumVGScaledBytes,
- TRI.getDwarfRegNum(AArch64::VG, true), Comment);
+ // + VG * NumVGScaledBytes
+ StringRef VGRegScale;
+ if (IncomingVGOffsetFromDefCFA) {
+ appendLoadRegExpr(OffsetExpr, *IncomingVGOffsetFromDefCFA);
+ VGRegScale = "* IncomingVG";
+ } else {
+ appendReadRegExpr(OffsetExpr, TRI.getDwarfRegNum(AArch64::VG, true));
+ VGRegScale = "* VG";
+ }
+ appendConstantExpr(OffsetExpr, NumVGScaledBytes, dwarf::DW_OP_mul);
+ appendOffsetComment(NumVGScaledBytes, Comment, VGRegScale);
+ OffsetExpr.push_back(dwarf::DW_OP_plus);
+ if (NumBytes) {
+ // + NumBytes
+ appendOffsetComment(NumBytes, Comment);
+ appendConstantExpr(OffsetExpr, NumBytes, dwarf::DW_OP_plus);
+ }
// Wrap this into DW_CFA_expression
SmallString<64> CfaExpr;
CfaExpr.push_back(dwarf::DW_CFA_expression);
- uint8_t buffer[16];
- CfaExpr.append(buffer, buffer + encodeULEB128(DwarfReg, buffer));
- CfaExpr.append(buffer, buffer + encodeULEB128(OffsetExpr.size(), buffer));
+ appendLEB128<LEB128Sign::Unsigned>(CfaExpr, DwarfReg);
+ appendLEB128<LEB128Sign::Unsigned>(CfaExpr, OffsetExpr.size());
CfaExpr.append(OffsetExpr.str());
return MCCFIInstruction::createEscape(nullptr, CfaExpr.str(), SMLoc(),
@@ -6597,7 +6718,7 @@ static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
if (MO.isReg() && MO.getReg().isVirtual())
MI = MRI.getUniqueVRegDef(MO.getReg());
// And it needs to be in the trace (otherwise, it won't have a depth).
- if (!MI || MI->getParent() != &MBB || (unsigned)MI->getOpcode() != CombineOpc)
+ if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
return false;
// Must only used by the user we combine with.
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
@@ -7389,11 +7510,319 @@ static bool getMiscPatterns(MachineInstr &Root,
return false;
}
+/// Check if the given instruction forms a gather load pattern that can be
+/// optimized for better Memory-Level Parallelism (MLP). This function
+/// identifies chains of NEON lane load instructions that load data from
+/// different memory addresses into individual lanes of a 128-bit vector
+/// register, then attempts to split the pattern into parallel loads to break
+/// the serial dependency between instructions.
+///
+/// Pattern Matched:
+/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) ->
+/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root)
+///
+/// Transformed Into:
+/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64
+/// to combine the results, enabling better memory-level parallelism.
+///
+/// Supported Element Types:
+/// - 32-bit elements (LD1i32, 4 lanes total)
+/// - 16-bit elements (LD1i16, 8 lanes total)
+/// - 8-bit elements (LD1i8, 16 lanes total)
+static bool getGatherLanePattern(MachineInstr &Root,
+ SmallVectorImpl<unsigned> &Patterns,
+ unsigned LoadLaneOpCode, unsigned NumLanes) {
+ const MachineFunction *MF = Root.getMF();
+
+ // Early exit if optimizing for size.
+ if (MF->getFunction().hasMinSize())
+ return false;
+
+ const MachineRegisterInfo &MRI = MF->getRegInfo();
+ const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
+
+ // The root of the pattern must load into the last lane of the vector.
+ if (Root.getOperand(2).getImm() != NumLanes - 1)
+ return false;
+
+ // Check that we have load into all lanes except lane 0.
+ // For each load we also want to check that:
+ // 1. It has a single non-debug use (since we will be replacing the virtual
+ // register)
+ // 2. That the addressing mode only uses a single pointer operand
+ auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
+ auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
+ SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end());
+ SmallVector<const MachineInstr *, 16> LoadInstrs;
+ while (!RemainingLanes.empty() && CurrInstr &&
+ CurrInstr->getOpcode() == LoadLaneOpCode &&
+ MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
+ CurrInstr->getNumOperands() == 4) {
+ RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
+ LoadInstrs.push_back(CurrInstr);
+ CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
+ }
+
+ // Check that we have found a match for lanes N-1.. 1.
+ if (!RemainingLanes.empty())
+ return false;
+
+ // Match the SUBREG_TO_REG sequence.
+ if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
+ return false;
+
+ // Verify that the subreg to reg loads an integer into the first lane.
+ auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
+ unsigned SingleLaneSizeInBits = 128 / NumLanes;
+ if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
+ return false;
+
+ // Verify that it also has a single non debug use.
+ if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
+ return false;
+
+ LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg));
+
+ // If there is any chance of aliasing, do not apply the pattern.
+ // Walk backward through the MBB starting from Root.
+ // Exit early if we've encountered all load instructions or hit the search
+ // limit.
+ auto MBBItr = Root.getIterator();
+ unsigned RemainingSteps = GatherOptSearchLimit;
+ SmallPtrSet<const MachineInstr *, 16> RemainingLoadInstrs;
+ RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end());
+ const MachineBasicBlock *MBB = Root.getParent();
+
+ for (; MBBItr != MBB->begin() && RemainingSteps > 0 &&
+ !RemainingLoadInstrs.empty();
+ --MBBItr, --RemainingSteps) {
+ const MachineInstr &CurrInstr = *MBBItr;
+
+ // Remove this instruction from remaining loads if it's one we're tracking.
+ RemainingLoadInstrs.erase(&CurrInstr);
+
+ // Check for potential aliasing with any of the load instructions to
+ // optimize.
+ if (CurrInstr.isLoadFoldBarrier())
+ return false;
+ }
+
+ // If we hit the search limit without finding all load instructions,
+ // don't match the pattern.
+ if (RemainingSteps == 0 && !RemainingLoadInstrs.empty())
+ return false;
+
+ switch (NumLanes) {
+ case 4:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
+ break;
+ case 8:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
+ break;
+ case 16:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
+ break;
+ default:
+ llvm_unreachable("Got bad number of lanes for gather pattern.");
+ }
+
+ return true;
+}
+
+/// Search for patterns of LD instructions we can optimize.
+static bool getLoadPatterns(MachineInstr &Root,
+ SmallVectorImpl<unsigned> &Patterns) {
+
+ // The pattern searches for loads into single lanes.
+ switch (Root.getOpcode()) {
+ case AArch64::LD1i32:
+ return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4);
+ case AArch64::LD1i16:
+ return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8);
+ case AArch64::LD1i8:
+ return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16);
+ default:
+ return false;
+ }
+}
+
+/// Generate optimized instruction sequence for gather load patterns to improve
+/// Memory-Level Parallelism (MLP). This function transforms a chain of
+/// sequential NEON lane loads into parallel vector loads that can execute
+/// concurrently.
+static void
+generateGatherLanePattern(MachineInstr &Root,
+ SmallVectorImpl<MachineInstr *> &InsInstrs,
+ SmallVectorImpl<MachineInstr *> &DelInstrs,
+ DenseMap<Register, unsigned> &InstrIdxForVirtReg,
+ unsigned Pattern, unsigned NumLanes) {
+ MachineFunction &MF = *Root.getParent()->getParent();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+
+ // Gather the initial load instructions to build the pattern.
+ SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
+ MachineInstr *CurrInstr = &Root;
+ for (unsigned i = 0; i < NumLanes - 1; ++i) {
+ LoadToLaneInstrs.push_back(CurrInstr);
+ CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
+ }
+
+ // Sort the load instructions according to the lane.
+ llvm::sort(LoadToLaneInstrs,
+ [](const MachineInstr *A, const MachineInstr *B) {
+ return A->getOperand(2).getImm() > B->getOperand(2).getImm();
+ });
+
+ MachineInstr *SubregToReg = CurrInstr;
+ LoadToLaneInstrs.push_back(
+ MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
+ auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
+
+ const TargetRegisterClass *FPR128RegClass =
+ MRI.getRegClass(Root.getOperand(0).getReg());
+
+ // Helper lambda to create a LD1 instruction.
+ auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr,
+ Register SrcRegister, unsigned Lane,
+ Register OffsetRegister,
+ bool OffsetRegisterKillState) {
+ auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
+ MachineInstrBuilder LoadIndexIntoRegister =
+ BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
+ NewRegister)
+ .addReg(SrcRegister)
+ .addImm(Lane)
+ .addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState));
+ InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
+ InsInstrs.push_back(LoadIndexIntoRegister);
+ return NewRegister;
+ };
+
+ // Helper to create load instruction based on the NumLanes in the NEON
+ // register we are rewriting.
+ auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg,
+ Register OffsetReg,
+ bool KillState) -> MachineInstrBuilder {
+ unsigned Opcode;
+ switch (NumLanes) {
+ case 4:
+ Opcode = AArch64::LDRSui;
+ break;
+ case 8:
+ Opcode = AArch64::LDRHui;
+ break;
+ case 16:
+ Opcode = AArch64::LDRBui;
+ break;
+ default:
+ llvm_unreachable(
+ "Got unsupported number of lanes in machine-combiner gather pattern");
+ }
+ // Immediate offset load
+ return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
+ .addReg(OffsetReg)
+ .addImm(0);
+ };
+
+ // Load the remaining lanes into register 0.
+ auto LanesToLoadToReg0 =
+ llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
+ LoadToLaneInstrsAscending.begin() + NumLanes / 2);
+ Register PrevReg = SubregToReg->getOperand(0).getReg();
+ for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
+ const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
+ PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
+ OffsetRegOperand.getReg(),
+ OffsetRegOperand.isKill());
+ DelInstrs.push_back(LoadInstr);
+ }
+ Register LastLoadReg0 = PrevReg;
+
+ // First load into register 1. Perform an integer load to zero out the upper
+ // lanes in a single instruction.
+ MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin();
+ MachineInstr *OriginalSplitLoad =
+ *std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
+ Register DestRegForMiddleIndex = MRI.createVirtualRegister(
+ MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
+
+ const MachineOperand &OriginalSplitToLoadOffsetOperand =
+ OriginalSplitLoad->getOperand(3);
+ MachineInstrBuilder MiddleIndexLoadInstr =
+ CreateLDRInstruction(NumLanes, DestRegForMiddleIndex,
+ OriginalSplitToLoadOffsetOperand.getReg(),
+ OriginalSplitToLoadOffsetOperand.isKill());
+
+ InstrIdxForVirtReg.insert(
+ std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
+ InsInstrs.push_back(MiddleIndexLoadInstr);
+ DelInstrs.push_back(OriginalSplitLoad);
+
+ // Subreg To Reg instruction for register 1.
+ Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
+ unsigned SubregType;
+ switch (NumLanes) {
+ case 4:
+ SubregType = AArch64::ssub;
+ break;
+ case 8:
+ SubregType = AArch64::hsub;
+ break;
+ case 16:
+ SubregType = AArch64::bsub;
+ break;
+ default:
+ llvm_unreachable(
+ "Got invalid NumLanes for machine-combiner gather pattern");
+ }
+
+ auto SubRegToRegInstr =
+ BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
+ DestRegForSubregToReg)
+ .addImm(0)
+ .addReg(DestRegForMiddleIndex, getKillRegState(true))
+ .addImm(SubregType);
+ InstrIdxForVirtReg.insert(
+ std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
+ InsInstrs.push_back(SubRegToRegInstr);
+
+ // Load remaining lanes into register 1.
+ auto LanesToLoadToReg1 =
+ llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
+ LoadToLaneInstrsAscending.end());
+ PrevReg = SubRegToRegInstr->getOperand(0).getReg();
+ for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
+ const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
+ PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
+ OffsetRegOperand.getReg(),
+ OffsetRegOperand.isKill());
+
+ // Do not add the last reg to DelInstrs - it will be removed later.
+ if (Index == NumLanes / 2 - 2) {
+ break;
+ }
+ DelInstrs.push_back(LoadInstr);
+ }
+ Register LastLoadReg1 = PrevReg;
+
+ // Create the final zip instruction to combine the results.
+ MachineInstrBuilder ZipInstr =
+ BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
+ Root.getOperand(0).getReg())
+ .addReg(LastLoadReg0)
+ .addReg(LastLoadReg1);
+ InsInstrs.push_back(ZipInstr);
+}
+
CombinerObjective
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
switch (Pattern) {
case AArch64MachineCombinerPattern::SUBADD_OP1:
case AArch64MachineCombinerPattern::SUBADD_OP2:
+ case AArch64MachineCombinerPattern::GATHER_LANE_i32:
+ case AArch64MachineCombinerPattern::GATHER_LANE_i16:
+ case AArch64MachineCombinerPattern::GATHER_LANE_i8:
return CombinerObjective::MustReduceDepth;
default:
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7423,6 +7852,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
if (getMiscPatterns(Root, Patterns))
return true;
+ // Load patterns
+ if (getLoadPatterns(Root, Patterns))
+ return true;
+
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}
@@ -8678,6 +9111,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
break;
}
+ case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
+ generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 4);
+ break;
+ }
+ case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
+ generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 8);
+ break;
+ }
+ case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
+ generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 16);
+ break;
+ }
} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 7c255da..179574a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
FMULv8i16_indexed_OP2,
FNMADD,
+
+ GATHER_LANE_i32,
+ GATHER_LANE_i16,
+ GATHER_LANE_i8
};
class AArch64InstrInfo final : public AArch64GenInstrInfo {
const AArch64RegisterInfo RI;
@@ -642,8 +646,10 @@ bool isNZCVTouchedInInstructionRange(const MachineInstr &DefMI,
MCCFIInstruction createDefCFA(const TargetRegisterInfo &TRI, unsigned FrameReg,
unsigned Reg, const StackOffset &Offset,
bool LastAdjustmentWasScalable = true);
-MCCFIInstruction createCFAOffset(const TargetRegisterInfo &MRI, unsigned Reg,
- const StackOffset &OffsetFromDefCFA);
+MCCFIInstruction
+createCFAOffset(const TargetRegisterInfo &MRI, unsigned Reg,
+ const StackOffset &OffsetFromDefCFA,
+ std::optional<int64_t> IncomingVGOffsetFromDefCFA);
/// emitFrameOffset - Emit instructions as needed to set DestReg to SrcReg
/// plus Offset. This is intended to be used from within the prolog/epilog
@@ -820,7 +826,8 @@ enum DestructiveInstType {
DestructiveBinaryComm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x6),
DestructiveBinaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x7),
DestructiveTernaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x8),
- DestructiveUnaryPassthru = TSFLAG_DESTRUCTIVE_INST_TYPE(0x9),
+ Destructive2xRegImmUnpred = TSFLAG_DESTRUCTIVE_INST_TYPE(0x9),
+ DestructiveUnaryPassthru = TSFLAG_DESTRUCTIVE_INST_TYPE(0xa),
};
enum FalseLaneType {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ac31236..ce40e20 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -143,7 +143,7 @@ def HasFuseAES : Predicate<"Subtarget->hasFuseAES()">,
"fuse-aes">;
def HasSVE : Predicate<"Subtarget->isSVEAvailable()">,
AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
-def HasSVEB16B16 : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEB16B16()">,
+def HasSVEB16B16 : Predicate<"Subtarget->hasSVEB16B16()">,
AssemblerPredicateWithAll<(all_of FeatureSVEB16B16), "sve-b16b16">;
def HasSVE2 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2), "sve2">;
@@ -248,6 +248,10 @@ def HasSVE_or_SME
: Predicate<"Subtarget->isSVEorStreamingSVEAvailable()">,
AssemblerPredicateWithAll<(any_of FeatureSVE, FeatureSME),
"sve or sme">;
+def HasNonStreamingSVE_or_SME2
+ : Predicate<"Subtarget->isNonStreamingSVEorSME2Available()">,
+ AssemblerPredicateWithAll<(any_of FeatureSVE, FeatureSME2),
+ "sve or sme2">;
def HasNonStreamingSVE_or_SME2p1
: Predicate<"Subtarget->isSVEAvailable() ||"
"(Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSME2p1())">,
@@ -985,6 +989,17 @@ def AArch64fcvtxnv: PatFrags<(ops node:$Rn),
[(int_aarch64_neon_fcvtxn node:$Rn),
(AArch64fcvtxn_n node:$Rn)]>;
+def AArch64fcvtzs_half : SDNode<"AArch64ISD::FCVTZS_HALF", SDTFPExtendOp>;
+def AArch64fcvtzu_half : SDNode<"AArch64ISD::FCVTZU_HALF", SDTFPExtendOp>;
+def AArch64fcvtas_half : SDNode<"AArch64ISD::FCVTAS_HALF", SDTFPExtendOp>;
+def AArch64fcvtau_half : SDNode<"AArch64ISD::FCVTAU_HALF", SDTFPExtendOp>;
+def AArch64fcvtms_half : SDNode<"AArch64ISD::FCVTMS_HALF", SDTFPExtendOp>;
+def AArch64fcvtmu_half : SDNode<"AArch64ISD::FCVTMU_HALF", SDTFPExtendOp>;
+def AArch64fcvtns_half : SDNode<"AArch64ISD::FCVTNS_HALF", SDTFPExtendOp>;
+def AArch64fcvtnu_half : SDNode<"AArch64ISD::FCVTNU_HALF", SDTFPExtendOp>;
+def AArch64fcvtps_half : SDNode<"AArch64ISD::FCVTPS_HALF", SDTFPExtendOp>;
+def AArch64fcvtpu_half : SDNode<"AArch64ISD::FCVTPU_HALF", SDTFPExtendOp>;
+
//def Aarch64softf32tobf16v8: SDNode<"AArch64ISD::", SDTFPRoundOp>;
// Vector immediate ops
@@ -2151,7 +2166,7 @@ let Predicates = [HasPAuth] in {
i64imm:$Disc, GPR64:$AddrDisc),
[], "$AuthVal = $Val">, Sched<[WriteI, ReadI]> {
let isCodeGenOnly = 1;
- let hasSideEffects = 0;
+ let hasSideEffects = 1;
let mayStore = 0;
let mayLoad = 0;
let Size = 32;
@@ -2656,13 +2671,17 @@ defm ADD : AddSub<0, "add", "sub", add>;
defm SUB : AddSub<1, "sub", "add">;
def : InstAlias<"mov $dst, $src",
- (ADDWri GPR32sponly:$dst, GPR32sp:$src, 0, 0)>;
+ (ADDWri GPR32sponly:$dst, GPR32sp:$src,
+ (addsub_shifted_imm32 0, 0))>;
def : InstAlias<"mov $dst, $src",
- (ADDWri GPR32sp:$dst, GPR32sponly:$src, 0, 0)>;
+ (ADDWri GPR32sp:$dst, GPR32sponly:$src,
+ (addsub_shifted_imm32 0, 0))>;
def : InstAlias<"mov $dst, $src",
- (ADDXri GPR64sponly:$dst, GPR64sp:$src, 0, 0)>;
+ (ADDXri GPR64sponly:$dst, GPR64sp:$src,
+ (addsub_shifted_imm64 0, 0))>;
def : InstAlias<"mov $dst, $src",
- (ADDXri GPR64sp:$dst, GPR64sponly:$src, 0, 0)>;
+ (ADDXri GPR64sp:$dst, GPR64sponly:$src,
+ (addsub_shifted_imm64 0, 0))>;
defm ADDS : AddSubS<0, "adds", AArch64add_flag, "cmn", "subs", "cmp">;
defm SUBS : AddSubS<1, "subs", AArch64sub_flag, "cmp", "adds", "cmn">;
@@ -2722,19 +2741,31 @@ def : Pat<(AArch64sub_flag GPR64:$Rn, neg_addsub_shifted_imm64:$imm),
(ADDSXri GPR64:$Rn, neg_addsub_shifted_imm64:$imm)>;
}
-def : InstAlias<"neg $dst, $src", (SUBWrs GPR32:$dst, WZR, GPR32:$src, 0), 3>;
-def : InstAlias<"neg $dst, $src", (SUBXrs GPR64:$dst, XZR, GPR64:$src, 0), 3>;
+def : InstAlias<"neg $dst, $src",
+ (SUBWrs GPR32:$dst, WZR,
+ (arith_shifted_reg32 GPR32:$src, 0)), 3>;
+def : InstAlias<"neg $dst, $src",
+ (SUBXrs GPR64:$dst, XZR,
+ (arith_shifted_reg64 GPR64:$src, 0)), 3>;
def : InstAlias<"neg $dst, $src$shift",
- (SUBWrs GPR32:$dst, WZR, GPR32:$src, arith_shift32:$shift), 2>;
+ (SUBWrs GPR32:$dst, WZR,
+ (arith_shifted_reg32 GPR32:$src, arith_shift32:$shift)), 2>;
def : InstAlias<"neg $dst, $src$shift",
- (SUBXrs GPR64:$dst, XZR, GPR64:$src, arith_shift64:$shift), 2>;
-
-def : InstAlias<"negs $dst, $src", (SUBSWrs GPR32:$dst, WZR, GPR32:$src, 0), 3>;
-def : InstAlias<"negs $dst, $src", (SUBSXrs GPR64:$dst, XZR, GPR64:$src, 0), 3>;
+ (SUBXrs GPR64:$dst, XZR,
+ (arith_shifted_reg64 GPR64:$src, arith_shift64:$shift)), 2>;
+
+def : InstAlias<"negs $dst, $src",
+ (SUBSWrs GPR32:$dst, WZR,
+ (arith_shifted_reg32 GPR32:$src, 0)), 3>;
+def : InstAlias<"negs $dst, $src",
+ (SUBSXrs GPR64:$dst, XZR,
+ (arith_shifted_reg64 GPR64:$src, 0)), 3>;
def : InstAlias<"negs $dst, $src$shift",
- (SUBSWrs GPR32:$dst, WZR, GPR32:$src, arith_shift32:$shift), 2>;
+ (SUBSWrs GPR32:$dst, WZR,
+ (arith_shifted_reg32 GPR32:$src, arith_shift32:$shift)), 2>;
def : InstAlias<"negs $dst, $src$shift",
- (SUBSXrs GPR64:$dst, XZR, GPR64:$src, arith_shift64:$shift), 2>;
+ (SUBSXrs GPR64:$dst, XZR,
+ (arith_shifted_reg64 GPR64:$src, arith_shift64:$shift)), 2>;
// Unsigned/Signed divide
@@ -3161,16 +3192,26 @@ defm ORN : LogicalReg<0b01, 1, "orn",
BinOpFrag<(or node:$LHS, (not node:$RHS))>>;
defm ORR : LogicalReg<0b01, 0, "orr", or>;
-def : InstAlias<"mov $dst, $src", (ORRWrs GPR32:$dst, WZR, GPR32:$src, 0), 2>;
-def : InstAlias<"mov $dst, $src", (ORRXrs GPR64:$dst, XZR, GPR64:$src, 0), 2>;
-
-def : InstAlias<"mvn $Wd, $Wm", (ORNWrs GPR32:$Wd, WZR, GPR32:$Wm, 0), 3>;
-def : InstAlias<"mvn $Xd, $Xm", (ORNXrs GPR64:$Xd, XZR, GPR64:$Xm, 0), 3>;
+def : InstAlias<"mov $dst, $src",
+ (ORRWrs GPR32:$dst, WZR,
+ (logical_shifted_reg32 GPR32:$src, 0)), 2>;
+def : InstAlias<"mov $dst, $src",
+ (ORRXrs GPR64:$dst, XZR,
+ (logical_shifted_reg64 GPR64:$src, 0)), 2>;
+
+def : InstAlias<"mvn $Wd, $Wm",
+ (ORNWrs GPR32:$Wd, WZR,
+ (logical_shifted_reg32 GPR32:$Wm, 0)), 3>;
+def : InstAlias<"mvn $Xd, $Xm",
+ (ORNXrs GPR64:$Xd, XZR,
+ (logical_shifted_reg64 GPR64:$Xm, 0)), 3>;
def : InstAlias<"mvn $Wd, $Wm$sh",
- (ORNWrs GPR32:$Wd, WZR, GPR32:$Wm, logical_shift32:$sh), 2>;
+ (ORNWrs GPR32:$Wd, WZR,
+ (logical_shifted_reg32 GPR32:$Wm, logical_shift32:$sh)), 2>;
def : InstAlias<"mvn $Xd, $Xm$sh",
- (ORNXrs GPR64:$Xd, XZR, GPR64:$Xm, logical_shift64:$sh), 2>;
+ (ORNXrs GPR64:$Xd, XZR,
+ (logical_shifted_reg64 GPR64:$Xm, logical_shift64:$sh)), 2>;
def : InstAlias<"tst $src1, $src2",
(ANDSWri WZR, GPR32:$src1, logical_imm32:$src2), 2>;
@@ -3178,14 +3219,18 @@ def : InstAlias<"tst $src1, $src2",
(ANDSXri XZR, GPR64:$src1, logical_imm64:$src2), 2>;
def : InstAlias<"tst $src1, $src2",
- (ANDSWrs WZR, GPR32:$src1, GPR32:$src2, 0), 3>;
+ (ANDSWrs WZR, GPR32:$src1,
+ (logical_shifted_reg32 GPR32:$src2, 0)), 3>;
def : InstAlias<"tst $src1, $src2",
- (ANDSXrs XZR, GPR64:$src1, GPR64:$src2, 0), 3>;
+ (ANDSXrs XZR, GPR64:$src1,
+ (logical_shifted_reg64 GPR64:$src2, 0)), 3>;
def : InstAlias<"tst $src1, $src2$sh",
- (ANDSWrs WZR, GPR32:$src1, GPR32:$src2, logical_shift32:$sh), 2>;
+ (ANDSWrs WZR, GPR32:$src1,
+ (logical_shifted_reg32 GPR32:$src2, logical_shift32:$sh)), 2>;
def : InstAlias<"tst $src1, $src2$sh",
- (ANDSXrs XZR, GPR64:$src1, GPR64:$src2, logical_shift64:$sh), 2>;
+ (ANDSXrs XZR, GPR64:$src1,
+ (logical_shifted_reg64 GPR64:$src2, logical_shift64:$sh)), 2>;
def : Pat<(not GPR32:$Wm), (ORNWrr WZR, GPR32:$Wm)>;
@@ -5707,27 +5752,6 @@ let Predicates = [HasFullFP16] in {
// Advanced SIMD two vector instructions.
//===----------------------------------------------------------------------===//
-defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", abdu>;
-// Match UABDL in log2-shuffle patterns.
-def : Pat<(abs (v8i16 (sub (zext (v8i8 V64:$opA)),
- (zext (v8i8 V64:$opB))))),
- (UABDLv8i8_v8i16 V64:$opA, V64:$opB)>;
-def : Pat<(abs (v8i16 (sub (zext (extract_high_v16i8 (v16i8 V128:$opA))),
- (zext (extract_high_v16i8 (v16i8 V128:$opB)))))),
- (UABDLv16i8_v8i16 V128:$opA, V128:$opB)>;
-def : Pat<(abs (v4i32 (sub (zext (v4i16 V64:$opA)),
- (zext (v4i16 V64:$opB))))),
- (UABDLv4i16_v4i32 V64:$opA, V64:$opB)>;
-def : Pat<(abs (v4i32 (sub (zext (extract_high_v8i16 (v8i16 V128:$opA))),
- (zext (extract_high_v8i16 (v8i16 V128:$opB)))))),
- (UABDLv8i16_v4i32 V128:$opA, V128:$opB)>;
-def : Pat<(abs (v2i64 (sub (zext (v2i32 V64:$opA)),
- (zext (v2i32 V64:$opB))))),
- (UABDLv2i32_v2i64 V64:$opA, V64:$opB)>;
-def : Pat<(abs (v2i64 (sub (zext (extract_high_v4i32 (v4i32 V128:$opA))),
- (zext (extract_high_v4i32 (v4i32 V128:$opB)))))),
- (UABDLv4i32_v2i64 V128:$opA, V128:$opB)>;
-
defm ABS : SIMDTwoVectorBHSD<0, 0b01011, "abs", abs>;
defm CLS : SIMDTwoVectorBHS<0, 0b00100, "cls", int_aarch64_neon_cls>;
defm CLZ : SIMDTwoVectorBHS<1, 0b00100, "clz", ctlz>;
@@ -6055,6 +6079,7 @@ defm MLA : SIMDThreeSameVectorBHSTied<0, 0b10010, "mla", null_frag>;
defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", null_frag>;
defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>;
+let isCommutable = 1 in
defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>;
defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba",
TriOpFrag<(add node:$LHS, (abds node:$MHS, node:$RHS))> >;
@@ -6552,9 +6577,33 @@ defm UQXTN : SIMDTwoScalarMixedBHS<1, 0b10100, "uqxtn", int_aarch64_neon_scalar
defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
int_aarch64_neon_usqadd>;
+// f16 -> s16 conversions
+let Predicates = [HasFullFP16] in {
+ def : Pat<(i16(fp_to_sint_sat_gi f16:$Rn)), (FCVTZSv1f16 f16:$Rn)>;
+ def : Pat<(i16(fp_to_uint_sat_gi f16:$Rn)), (FCVTZUv1f16 f16:$Rn)>;
+}
+
def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
(CMLTv1i64rz V64:$Rn)>;
+// f16 -> i16 conversions leave the bit pattern in a f32
+class F16ToI16ScalarPat<SDNode cvt_isd, BaseSIMDTwoScalar instr>
+ : Pat<(f32 (cvt_isd (f16 FPR16:$Rn))),
+ (f32 (SUBREG_TO_REG (i64 0), (instr FPR16:$Rn), hsub))>;
+
+let Predicates = [HasFullFP16] in {
+def : F16ToI16ScalarPat<AArch64fcvtzs_half, FCVTZSv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtzu_half, FCVTZUv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtas_half, FCVTASv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtau_half, FCVTAUv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtms_half, FCVTMSv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtmu_half, FCVTMUv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtns_half, FCVTNSv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtnu_half, FCVTNUv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtps_half, FCVTPSv1f16>;
+def : F16ToI16ScalarPat<AArch64fcvtpu_half, FCVTPUv1f16>;
+}
+
// Round FP64 to BF16.
let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in
def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
@@ -6802,40 +6851,47 @@ def : Pat <(f64 (uint_to_fp (i32
// Advanced SIMD three different-sized vector instructions.
//===----------------------------------------------------------------------===//
-defm ADDHN : SIMDNarrowThreeVectorBHS<0,0b0100,"addhn", int_aarch64_neon_addhn>;
-defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn>;
-defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>;
-defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>;
-defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>;
-defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", abds>;
-defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", abds>;
+defm ADDHN : SIMDNarrowThreeVectorBHS<0,0b0100,"addhn", int_aarch64_neon_addhn>;
+defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn>;
+defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>;
+defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>;
+let isCommutable = 1 in
+defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>;
+defm SABAL : SIMDLongThreeVectorTiedBHS<0,0b0101,"sabal",
+ TriOpFrag<(add node:$LHS, (zext (abds node:$MHS, node:$RHS)))>>;
+defm SABDL : SIMDLongThreeVectorBHS<0, 0b0111, "sabdl",
+ BinOpFrag<(zext (abds node:$LHS, node:$RHS))>>;
defm SADDL : SIMDLongThreeVectorBHS< 0, 0b0000, "saddl",
- BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>;
+ BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>;
defm SADDW : SIMDWideThreeVectorBHS< 0, 0b0001, "saddw",
BinOpFrag<(add node:$LHS, (sext node:$RHS))>>;
defm SMLAL : SIMDLongThreeVectorTiedBHS<0, 0b1000, "smlal",
- TriOpFrag<(add node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>;
+ TriOpFrag<(add node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>;
defm SMLSL : SIMDLongThreeVectorTiedBHS<0, 0b1010, "smlsl",
- TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>;
+ TriOpFrag<(sub node:$LHS, (AArch64smull node:$MHS, node:$RHS))>>;
defm SMULL : SIMDLongThreeVectorBHS<0, 0b1100, "smull", AArch64smull>;
defm SQDMLAL : SIMDLongThreeVectorSQDMLXTiedHS<0, 0b1001, "sqdmlal", saddsat>;
defm SQDMLSL : SIMDLongThreeVectorSQDMLXTiedHS<0, 0b1011, "sqdmlsl", ssubsat>;
-defm SQDMULL : SIMDLongThreeVectorHS<0, 0b1101, "sqdmull",
- int_aarch64_neon_sqdmull>;
+defm SQDMULL : SIMDLongThreeVectorHS<0, 0b1101, "sqdmull", int_aarch64_neon_sqdmull>;
+let isCommutable = 0 in
defm SSUBL : SIMDLongThreeVectorBHS<0, 0b0010, "ssubl",
BinOpFrag<(sub (sext node:$LHS), (sext node:$RHS))>>;
defm SSUBW : SIMDWideThreeVectorBHS<0, 0b0011, "ssubw",
BinOpFrag<(sub node:$LHS, (sext node:$RHS))>>;
-defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", abdu>;
+defm UABAL : SIMDLongThreeVectorTiedBHS<1, 0b0101, "uabal",
+ TriOpFrag<(add node:$LHS, (zext (abdu node:$MHS, node:$RHS)))>>;
+defm UABDL : SIMDLongThreeVectorBHS<1, 0b0111, "uabdl",
+ BinOpFrag<(zext (abdu node:$LHS, node:$RHS))>>;
defm UADDL : SIMDLongThreeVectorBHS<1, 0b0000, "uaddl",
BinOpFrag<(add (zanyext node:$LHS), (zanyext node:$RHS))>>;
defm UADDW : SIMDWideThreeVectorBHS<1, 0b0001, "uaddw",
BinOpFrag<(add node:$LHS, (zanyext node:$RHS))>>;
defm UMLAL : SIMDLongThreeVectorTiedBHS<1, 0b1000, "umlal",
- TriOpFrag<(add node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>;
+ TriOpFrag<(add node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>;
defm UMLSL : SIMDLongThreeVectorTiedBHS<1, 0b1010, "umlsl",
- TriOpFrag<(sub node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>;
+ TriOpFrag<(sub node:$LHS, (AArch64umull node:$MHS, node:$RHS))>>;
defm UMULL : SIMDLongThreeVectorBHS<1, 0b1100, "umull", AArch64umull>;
+let isCommutable = 0 in
defm USUBL : SIMDLongThreeVectorBHS<1, 0b0010, "usubl",
BinOpFrag<(sub (zanyext node:$LHS), (zanyext node:$RHS))>>;
defm USUBW : SIMDWideThreeVectorBHS< 1, 0b0011, "usubw",
diff --git a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
index 782d62a7..e69fa32 100644
--- a/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
+++ b/llvm/lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp
@@ -1193,7 +1193,8 @@ AArch64LoadStoreOpt::mergePairedInsns(MachineBasicBlock::iterator I,
// USE kill %w1 ; need to clear kill flag when moving STRWui downwards
// STRW %w0
Register Reg = getLdStRegOp(*I).getReg();
- for (MachineInstr &MI : make_range(std::next(I), Paired))
+ for (MachineInstr &MI :
+ make_range(std::next(I->getIterator()), Paired->getIterator()))
MI.clearRegisterKills(Reg, TRI);
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
index b97d622..fd4ef2a 100644
--- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
@@ -8,8 +8,8 @@
//
// This pass performs below peephole optimizations on MIR level.
//
-// 1. MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri
-// MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri
+// 1. MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri
+// MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri
//
// 2. MOVi32imm + ADDWrr ==> ADDWRi + ADDWRi
// MOVi64imm + ADDXrr ==> ADDXri + ADDXri
@@ -128,6 +128,7 @@ struct AArch64MIPeepholeOpt : public MachineFunctionPass {
// Strategy used to split logical immediate bitmasks.
enum class SplitStrategy {
Intersect,
+ Disjoint,
};
template <typename T>
bool trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
@@ -163,6 +164,7 @@ INITIALIZE_PASS(AArch64MIPeepholeOpt, "aarch64-mi-peephole-opt",
template <typename T>
static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) {
T UImm = static_cast<T>(Imm);
+ assert(UImm && (UImm != ~static_cast<T>(0)) && "Invalid immediate!");
// The bitmask immediate consists of consecutive ones. Let's say there is
// constant 0b00000000001000000000010000000000 which does not consist of
@@ -191,18 +193,47 @@ static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) {
}
template <typename T>
+static bool splitDisjointBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc,
+ T &Imm2Enc) {
+ assert(Imm && (Imm != ~static_cast<T>(0)) && "Invalid immediate!");
+
+ // Try to split a bitmask of the form 0b00000000011000000000011110000000 into
+ // two disjoint masks such as 0b00000000011000000000000000000000 and
+ // 0b00000000000000000000011110000000 where the inclusive/exclusive OR of the
+ // new masks match the original mask.
+ unsigned LowestBitSet = llvm::countr_zero(Imm);
+ unsigned LowestGapBitUnset =
+ LowestBitSet + llvm::countr_one(Imm >> LowestBitSet);
+
+ // Create a mask for the least significant group of consecutive ones.
+ assert(LowestGapBitUnset < sizeof(T) * CHAR_BIT && "Undefined behaviour!");
+ T NewImm1 = (static_cast<T>(1) << LowestGapBitUnset) -
+ (static_cast<T>(1) << LowestBitSet);
+ // Create a disjoint mask for the remaining ones.
+ T NewImm2 = Imm & ~NewImm1;
+
+ // Do not split if NewImm2 is not a valid bitmask immediate.
+ if (!AArch64_AM::isLogicalImmediate(NewImm2, RegSize))
+ return false;
+
+ Imm1Enc = AArch64_AM::encodeLogicalImmediate(NewImm1, RegSize);
+ Imm2Enc = AArch64_AM::encodeLogicalImmediate(NewImm2, RegSize);
+ return true;
+}
+
+template <typename T>
bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
SplitStrategy Strategy,
unsigned OtherOpc) {
- // Try below transformation.
+ // Try below transformations.
//
- // MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri
- // MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri
+ // MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri
+ // MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri
//
// The mov pseudo instruction could be expanded to multiple mov instructions
// later. Let's try to split the constant operand of mov instruction into two
- // bitmask immediates. It makes only two AND instructions instead of multiple
- // mov + and instructions.
+ // bitmask immediates based on the given split strategy. It makes only two
+ // logical instructions instead of multiple mov + logic instructions.
return splitTwoPartImm<T>(
MI,
@@ -224,6 +255,9 @@ bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
case SplitStrategy::Intersect:
SplitSucc = splitBitmaskImm(Imm, RegSize, Imm0, Imm1);
break;
+ case SplitStrategy::Disjoint:
+ SplitSucc = splitDisjointBitmaskImm(Imm, RegSize, Imm0, Imm1);
+ break;
}
if (SplitSucc)
return std::make_pair(Opc, !OtherOpc ? Opc : OtherOpc);
@@ -889,6 +923,22 @@ bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
Changed |= trySplitLogicalImm<uint64_t>(
AArch64::ANDXri, MI, SplitStrategy::Intersect, AArch64::ANDSXri);
break;
+ case AArch64::EORWrr:
+ Changed |= trySplitLogicalImm<uint32_t>(AArch64::EORWri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::EORXrr:
+ Changed |= trySplitLogicalImm<uint64_t>(AArch64::EORXri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::ORRWrr:
+ Changed |= trySplitLogicalImm<uint32_t>(AArch64::ORRWri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::ORRXrr:
+ Changed |= trySplitLogicalImm<uint64_t>(AArch64::ORRXri, MI,
+ SplitStrategy::Disjoint);
+ break;
case AArch64::ORRWrs:
Changed |= visitORR(MI);
break;
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 800787c..1fde87e 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -213,9 +213,6 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
/// or return type
bool IsSVECC = false;
- /// The frame-index for the TPIDR2 object used for lazy saves.
- TPIDR2Object TPIDR2;
-
/// Whether this function changes streaming mode within the function.
bool HasStreamingModeChanges = false;
@@ -231,25 +228,26 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;
- // Holds a pointer to a buffer that is large enough to represent
- // all SME ZA state and any additional state required by the
- // __arm_sme_save/restore support routines.
- Register SMESaveBufferAddr = MCRegister::NoRegister;
-
- // true if SMESaveBufferAddr is used.
- bool SMESaveBufferUsed = false;
+ // true if PStateSMReg is used.
+ bool PStateSMRegUsed = false;
// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
- // The stack slots where VG values are stored to.
- int64_t VGIdx = std::numeric_limits<int>::max();
- int64_t StreamingVGIdx = std::numeric_limits<int>::max();
-
// Holds the SME function attributes (streaming mode, ZA/ZT0 state).
SMEAttrs SMEFnAttrs;
+ // Note: The following properties are only used for the old SME ABI lowering:
+ /// The frame-index for the TPIDR2 object used for lazy saves.
+ TPIDR2Object TPIDR2;
+ // Holds a pointer to a buffer that is large enough to represent
+ // all SME ZA state and any additional state required by the
+ // __arm_sme_save/restore support routines.
+ Register SMESaveBufferAddr = MCRegister::NoRegister;
+ // true if SMESaveBufferAddr is used.
+ bool SMESaveBufferUsed = false;
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -258,6 +256,13 @@ public:
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;
+ // Old SME ABI lowering state getters/setters:
+ Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+ void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+ unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+ void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+ TPIDR2Object &getTPIDR2Obj() { return TPIDR2; }
+
void setPredicateRegForFillSpill(unsigned Reg) {
PredicateRegForFillSpill = Reg;
}
@@ -265,26 +270,15 @@ public:
return PredicateRegForFillSpill;
}
- Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
- void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
-
- unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; };
- void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
-
Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
- int64_t getVGIdx() const { return VGIdx; };
- void setVGIdx(unsigned Idx) { VGIdx = Idx; };
-
- int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
- void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; };
+ unsigned isPStateSMRegUsed() const { return PStateSMRegUsed; };
+ void setPStateSMRegUsed(bool Used = true) { PStateSMRegUsed = Used; };
bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };
- TPIDR2Object &getTPIDR2Obj() { return TPIDR2; }
-
void initializeBaseYamlFields(const yaml::AArch64FunctionInfo &YamlMFI);
unsigned getBytesInStackArgArea() const { return BytesInStackArgArea; }
diff --git a/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp b/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp
index ff7a0d1..f4a7f77 100644
--- a/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MacroFusion.cpp
@@ -237,8 +237,8 @@ static bool isAddressLdStPair(const MachineInstr *FirstMI,
}
/// Compare and conditional select.
-static bool isCCSelectPair(const MachineInstr *FirstMI,
- const MachineInstr &SecondMI) {
+static bool isCmpCSelPair(const MachineInstr *FirstMI,
+ const MachineInstr &SecondMI) {
// 32 bits
if (SecondMI.getOpcode() == AArch64::CSELWr) {
// Assume the 1st instr to be a wildcard if it is unspecified.
@@ -279,6 +279,40 @@ static bool isCCSelectPair(const MachineInstr *FirstMI,
return false;
}
+/// Compare and cset.
+static bool isCmpCSetPair(const MachineInstr *FirstMI,
+ const MachineInstr &SecondMI) {
+ if ((SecondMI.getOpcode() == AArch64::CSINCWr &&
+ SecondMI.getOperand(1).getReg() == AArch64::WZR &&
+ SecondMI.getOperand(2).getReg() == AArch64::WZR) ||
+ (SecondMI.getOpcode() == AArch64::CSINCXr &&
+ SecondMI.getOperand(1).getReg() == AArch64::XZR &&
+ SecondMI.getOperand(2).getReg() == AArch64::XZR)) {
+ // Assume the 1st instr to be a wildcard if it is unspecified.
+ if (FirstMI == nullptr)
+ return true;
+
+ if (FirstMI->definesRegister(AArch64::WZR, /*TRI=*/nullptr) ||
+ FirstMI->definesRegister(AArch64::XZR, /*TRI=*/nullptr))
+ switch (FirstMI->getOpcode()) {
+ case AArch64::SUBSWrs:
+ case AArch64::SUBSXrs:
+ return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
+ case AArch64::SUBSWrx:
+ case AArch64::SUBSXrx:
+ case AArch64::SUBSXrx64:
+ return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
+ case AArch64::SUBSWri:
+ case AArch64::SUBSWrr:
+ case AArch64::SUBSXri:
+ case AArch64::SUBSXrr:
+ return true;
+ }
+ }
+
+ return false;
+}
+
// Arithmetic and logic.
static bool isArithmeticLogicPair(const MachineInstr *FirstMI,
const MachineInstr &SecondMI) {
@@ -465,7 +499,9 @@ static bool shouldScheduleAdjacent(const TargetInstrInfo &TII,
return true;
if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI))
return true;
- if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI))
+ if (ST.hasFuseCmpCSel() && isCmpCSelPair(FirstMI, SecondMI))
+ return true;
+ if (ST.hasFuseCmpCSet() && isCmpCSetPair(FirstMI, SecondMI))
return true;
if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI))
return true;
diff --git a/llvm/lib/Target/AArch64/AArch64Processors.td b/llvm/lib/Target/AArch64/AArch64Processors.td
index adc984a..d5f4e91 100644
--- a/llvm/lib/Target/AArch64/AArch64Processors.td
+++ b/llvm/lib/Target/AArch64/AArch64Processors.td
@@ -22,7 +22,8 @@ def TuneA320 : SubtargetFeature<"a320", "ARMProcFamily", "CortexA320",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove]>;
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost]>;
def TuneA53 : SubtargetFeature<"a53", "ARMProcFamily", "CortexA53",
"Cortex-A53 ARM processors", [
@@ -45,7 +46,8 @@ def TuneA510 : SubtargetFeature<"a510", "ARMProcFamily", "CortexA510",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost
]>;
def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520",
@@ -53,7 +55,8 @@ def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove]>;
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost]>;
def TuneA520AE : SubtargetFeature<"a520ae", "ARMProcFamily", "CortexA520",
"Cortex-A520AE ARM processors", [
@@ -131,6 +134,8 @@ def TuneA78 : SubtargetFeature<"a78", "ARMProcFamily", "CortexA78",
FeatureCmpBccFusion,
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureAddrLSLSlow14,
FeatureALULSLFast,
FeaturePostRAScheduler,
@@ -143,6 +148,8 @@ def TuneA78AE : SubtargetFeature<"a78ae", "ARMProcFamily",
FeatureCmpBccFusion,
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureAddrLSLSlow14,
FeatureALULSLFast,
FeaturePostRAScheduler,
@@ -155,6 +162,8 @@ def TuneA78C : SubtargetFeature<"a78c", "ARMProcFamily",
FeatureCmpBccFusion,
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureAddrLSLSlow14,
FeatureALULSLFast,
FeaturePostRAScheduler,
@@ -166,6 +175,8 @@ def TuneA710 : SubtargetFeature<"a710", "ARMProcFamily", "CortexA710",
FeatureCmpBccFusion,
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureALULSLFast,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
@@ -178,6 +189,8 @@ def TuneA715 : SubtargetFeature<"a715", "ARMProcFamily", "CortexA715",
FeatureCmpBccFusion,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureEnableSelectOptimize,
FeaturePredictableSelectIsExpensive]>;
@@ -188,6 +201,8 @@ def TuneA720 : SubtargetFeature<"a720", "ARMProcFamily", "CortexA720",
FeatureCmpBccFusion,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureEnableSelectOptimize,
FeaturePredictableSelectIsExpensive]>;
@@ -198,6 +213,8 @@ def TuneA720AE : SubtargetFeature<"a720ae", "ARMProcFamily", "CortexA720",
FeatureCmpBccFusion,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureEnableSelectOptimize,
FeaturePredictableSelectIsExpensive]>;
@@ -209,6 +226,8 @@ def TuneA725 : SubtargetFeature<"cortex-a725", "ARMProcFamily",
FeatureCmpBccFusion,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureEnableSelectOptimize,
FeaturePredictableSelectIsExpensive]>;
@@ -259,6 +278,8 @@ def TuneX4 : SubtargetFeature<"cortex-x4", "ARMProcFamily", "CortexX4",
"Cortex-X4 ARM processors", [
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureFuseAES,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
@@ -270,6 +291,8 @@ def TuneX925 : SubtargetFeature<"cortex-x925", "ARMProcFamily",
"CortexX925", "Cortex-X925 ARM processors",[
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureFuseAES,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
@@ -318,8 +341,9 @@ def TuneAppleA7 : SubtargetFeature<"apple-a7", "ARMProcFamily", "AppleA7",
FeatureFuseAES, FeatureFuseCryptoEOR,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing,
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64,
FeatureZCZeroingFPWorkaround]>;
def TuneAppleA10 : SubtargetFeature<"apple-a10", "ARMProcFamily", "AppleA10",
@@ -332,8 +356,9 @@ def TuneAppleA10 : SubtargetFeature<"apple-a10", "ARMProcFamily", "AppleA10",
FeatureFuseCryptoEOR,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA11 : SubtargetFeature<"apple-a11", "ARMProcFamily", "AppleA11",
"Apple A11", [
@@ -345,8 +370,9 @@ def TuneAppleA11 : SubtargetFeature<"apple-a11", "ARMProcFamily", "AppleA11",
FeatureFuseCryptoEOR,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA12 : SubtargetFeature<"apple-a12", "ARMProcFamily", "AppleA12",
"Apple A12", [
@@ -358,8 +384,9 @@ def TuneAppleA12 : SubtargetFeature<"apple-a12", "ARMProcFamily", "AppleA12",
FeatureFuseCryptoEOR,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA13 : SubtargetFeature<"apple-a13", "ARMProcFamily", "AppleA13",
"Apple A13", [
@@ -371,8 +398,9 @@ def TuneAppleA13 : SubtargetFeature<"apple-a13", "ARMProcFamily", "AppleA13",
FeatureFuseCryptoEOR,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA14 : SubtargetFeature<"apple-a14", "ARMProcFamily", "AppleA14",
"Apple A14", [
@@ -384,13 +412,14 @@ def TuneAppleA14 : SubtargetFeature<"apple-a14", "ARMProcFamily", "AppleA14",
FeatureFuseAddress,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseCryptoEOR,
FeatureFuseLiterals,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA15 : SubtargetFeature<"apple-a15", "ARMProcFamily", "AppleA15",
"Apple A15", [
@@ -402,13 +431,14 @@ def TuneAppleA15 : SubtargetFeature<"apple-a15", "ARMProcFamily", "AppleA15",
FeatureFuseAdrpAdd,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseCryptoEOR,
FeatureFuseLiterals,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA16 : SubtargetFeature<"apple-a16", "ARMProcFamily", "AppleA16",
"Apple A16", [
@@ -420,13 +450,14 @@ def TuneAppleA16 : SubtargetFeature<"apple-a16", "ARMProcFamily", "AppleA16",
FeatureFuseAdrpAdd,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseCryptoEOR,
FeatureFuseLiterals,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleA17 : SubtargetFeature<"apple-a17", "ARMProcFamily", "AppleA17",
"Apple A17", [
@@ -438,13 +469,14 @@ def TuneAppleA17 : SubtargetFeature<"apple-a17", "ARMProcFamily", "AppleA17",
FeatureFuseAdrpAdd,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseCryptoEOR,
FeatureFuseLiterals,
FeatureStorePairSuppress,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneAppleM4 : SubtargetFeature<"apple-m4", "ARMProcFamily", "AppleM4",
"Apple M4", [
@@ -456,13 +488,13 @@ def TuneAppleM4 : SubtargetFeature<"apple-m4", "ARMProcFamily", "AppleM4",
FeatureFuseAdrpAdd,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseCryptoEOR,
FeatureFuseLiterals,
FeatureZCRegMoveGPR64,
- FeatureZCRegMoveFPR64,
- FeatureZCZeroing
- ]>;
+ FeatureZCRegMoveFPR128,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneExynosM3 : SubtargetFeature<"exynosm3", "ARMProcFamily", "ExynosM3",
"Samsung Exynos-M3 processors",
@@ -470,7 +502,7 @@ def TuneExynosM3 : SubtargetFeature<"exynosm3", "ARMProcFamily", "ExynosM3",
FeatureForce32BitJumpTables,
FeatureFuseAddress,
FeatureFuseAES,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseAdrpAdd,
FeatureFuseLiterals,
FeatureStorePairSuppress,
@@ -488,19 +520,21 @@ def TuneExynosM4 : SubtargetFeature<"exynosm4", "ARMProcFamily", "ExynosM3",
FeatureFuseAddress,
FeatureFuseAES,
FeatureFuseArithmeticLogic,
- FeatureFuseCCSelect,
+ FeatureFuseCmpCSel,
FeatureFuseAdrpAdd,
FeatureFuseLiterals,
FeatureStorePairSuppress,
FeatureALULSLFast,
FeaturePostRAScheduler,
- FeatureZCZeroing]>;
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64]>;
def TuneKryo : SubtargetFeature<"kryo", "ARMProcFamily", "Kryo",
"Qualcomm Kryo processors", [
FeaturePostRAScheduler,
FeaturePredictableSelectIsExpensive,
- FeatureZCZeroing,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64,
FeatureALULSLFast,
FeatureStorePairSuppress]>;
@@ -508,7 +542,8 @@ def TuneFalkor : SubtargetFeature<"falkor", "ARMProcFamily", "Falkor",
"Qualcomm Falkor processors", [
FeaturePostRAScheduler,
FeaturePredictableSelectIsExpensive,
- FeatureZCZeroing,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64,
FeatureStorePairSuppress,
FeatureALULSLFast,
FeatureSlowSTRQro]>;
@@ -533,6 +568,8 @@ def TuneNeoverseN2 : SubtargetFeature<"neoversen2", "ARMProcFamily", "NeoverseN2
"Neoverse N2 ARM processors", [
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureALULSLFast,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
@@ -544,6 +581,8 @@ def TuneNeoverseN3 : SubtargetFeature<"neoversen3", "ARMProcFamily", "NeoverseN3
FeaturePostRAScheduler,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureEnableSelectOptimize,
FeaturePredictableSelectIsExpensive]>;
@@ -560,6 +599,8 @@ def TuneNeoverseV1 : SubtargetFeature<"neoversev1", "ARMProcFamily", "NeoverseV1
"Neoverse V1 ARM processors", [
FeatureFuseAES,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureAddrLSLSlow14,
FeatureALULSLFast,
FeaturePostRAScheduler,
@@ -572,6 +613,8 @@ def TuneNeoverseV2 : SubtargetFeature<"neoversev2", "ARMProcFamily", "NeoverseV2
FeatureFuseAES,
FeatureCmpBccFusion,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeatureALULSLFast,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
@@ -585,6 +628,8 @@ def TuneNeoverseV3 : SubtargetFeature<"neoversev3", "ARMProcFamily", "NeoverseV3
FeatureFuseAES,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
FeatureAvoidLDAPUR,
@@ -595,6 +640,8 @@ def TuneNeoverseV3AE : SubtargetFeature<"neoversev3AE", "ARMProcFamily", "Neover
FeatureFuseAES,
FeatureALULSLFast,
FeatureFuseAdrpAdd,
+ FeatureFuseCmpCSel,
+ FeatureFuseCmpCSet,
FeaturePostRAScheduler,
FeatureEnableSelectOptimize,
FeatureAvoidLDAPUR,
@@ -604,7 +651,8 @@ def TuneSaphira : SubtargetFeature<"saphira", "ARMProcFamily", "Saphira",
"Qualcomm Saphira processors", [
FeaturePostRAScheduler,
FeaturePredictableSelectIsExpensive,
- FeatureZCZeroing,
+ FeatureZCZeroingGPR32,
+ FeatureZCZeroingGPR64,
FeatureStorePairSuppress,
FeatureALULSLFast]>;
@@ -756,7 +804,6 @@ def ProcessorFeatures {
FeatureSB, FeaturePAuth, FeatureSSBS, FeatureSVE, FeatureSVE2,
FeatureComplxNum, FeatureCRC, FeatureDotProd,
FeatureFPARMv8,FeatureFullFP16, FeatureJS, FeatureLSE,
- FeatureUseFixedOverScalableIfEqualCost,
FeatureRAS, FeatureRCPC, FeatureRDM, FeatureFPAC];
list<SubtargetFeature> A520 = [HasV9_2aOps, FeaturePerfMon, FeatureAM,
FeatureMTE, FeatureETE, FeatureSVEBitPerm,
@@ -766,7 +813,6 @@ def ProcessorFeatures {
FeatureSVE, FeatureSVE2, FeatureBF16, FeatureComplxNum, FeatureCRC,
FeatureFPARMv8, FeatureFullFP16, FeatureMatMulInt8, FeatureJS,
FeatureNEON, FeatureLSE, FeatureRAS, FeatureRCPC, FeatureRDM,
- FeatureUseFixedOverScalableIfEqualCost,
FeatureDotProd, FeatureFPAC];
list<SubtargetFeature> A520AE = [HasV9_2aOps, FeaturePerfMon, FeatureAM,
FeatureMTE, FeatureETE, FeatureSVEBitPerm,
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index db27ca9..0d8cb3a 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -39,11 +39,18 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
def AArch64CoalescerBarrier
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64VGSave : SDNode<"AArch64ISD::VG_SAVE", SDTypeProfile<0, 0, []>,
- [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64EntryPStateSM
+ : SDNode<"AArch64ISD::ENTRY_PSTATE_SM", SDTypeProfile<1, 0,
+ [SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect]>;
-def AArch64VGRestore : SDNode<"AArch64ISD::VG_RESTORE", SDTypeProfile<0, 0, []>,
- [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
+let usesCustomInserter = 1 in {
+ def EntryPStateSM : Pseudo<(outs GPR64:$is_streaming), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>;
+
+//===----------------------------------------------------------------------===//
+// Old SME ABI lowering ISD nodes/pseudos (deprecated)
+//===----------------------------------------------------------------------===//
def AArch64AllocateZABuffer : SDNode<"AArch64ISD::ALLOCATE_ZA_BUFFER", SDTypeProfile<1, 1,
[SDTCisInt<0>, SDTCisInt<1>]>,
@@ -54,10 +61,10 @@ let usesCustomInserter = 1, Defs = [SP], Uses = [SP] in {
def : Pat<(i64 (AArch64AllocateZABuffer GPR64:$size)),
(AllocateZABuffer $size)>;
-def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 1,
- [SDTCisInt<0>]>, [SDNPHasChain, SDNPMayStore]>;
+def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain, SDNPMayStore]>;
let usesCustomInserter = 1 in {
- def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
+ def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer, GPR64:$save_slices), [(AArch64InitTPIDR2Obj GPR64:$buffer, GPR64:$save_slices)]>, Sched<[WriteI]> {}
}
// Nodes to allocate a save buffer for SME.
@@ -78,6 +85,30 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
(AllocateSMESaveBuffer $size)>;
//===----------------------------------------------------------------------===//
+// New SME ABI lowering ISD nodes/pseudos (-aarch64-new-sme-abi)
+//===----------------------------------------------------------------------===//
+
+let hasSideEffects = 1, isMeta = 1 in {
+ def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
+ def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
+}
+
+def CommitZASavePseudo
+ : Pseudo<(outs),
+ (ins GPR64:$tpidr2_el0, i1imm:$zero_za, i64imm:$commit_routine, variable_ops), []>,
+ Sched<[]>;
+
+def AArch64_inout_za_use
+ : SDNode<"AArch64ISD::INOUT_ZA_USE", SDTypeProfile<0, 0,[]>,
+ [SDNPHasChain, SDNPInGlue]>;
+def : Pat<(AArch64_inout_za_use), (InOutZAUsePseudo)>;
+
+def AArch64_requires_za_save
+ : SDNode<"AArch64ISD::REQUIRES_ZA_SAVE", SDTypeProfile<0, 0,[]>,
+ [SDNPHasChain, SDNPInGlue]>;
+def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
+
+//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
@@ -325,16 +356,6 @@ def : Pat<(AArch64_smstart (i32 svcr_op:$pstate)),
def : Pat<(AArch64_smstop (i32 svcr_op:$pstate)),
(MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-
-// Pseudo to insert cfi_offset/cfi_restore instructions. Used to save or restore
-// the streaming value of VG around streaming-mode changes in locally-streaming
-// functions.
-def VGSavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
-def : Pat<(AArch64VGSave), (VGSavePseudo)>;
-
-def VGRestorePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
-def : Pat<(AArch64VGRestore), (VGRestorePseudo)>;
-
//===----------------------------------------------------------------------===//
// SME2 Instructions
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 0c4b4f4..bc65af2 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -650,7 +650,7 @@ let Predicates = [HasSVE_or_SME, UseExperimentalZeroingPseudos] in {
let Predicates = [HasSVE_or_SME] in {
defm ADD_ZI : sve_int_arith_imm0<0b000, "add", add>;
- defm SUB_ZI : sve_int_arith_imm0<0b001, "sub", sub>;
+ defm SUB_ZI : sve_int_arith_imm0<0b001, "sub", sub, add>;
defm SUBR_ZI : sve_int_arith_imm0<0b011, "subr", AArch64subr>;
defm SQADD_ZI : sve_int_arith_imm0_ssat<0b100, "sqadd", saddsat, ssubsat>;
defm UQADD_ZI : sve_int_arith_imm0<0b101, "uqadd", uaddsat>;
@@ -1021,7 +1021,9 @@ let Predicates = [HasNonStreamingSVE_or_SME2p2] in {
let Predicates = [HasSVE_or_SME] in {
defm INSR_ZR : sve_int_perm_insrs<"insr", AArch64insr>;
defm INSR_ZV : sve_int_perm_insrv<"insr", AArch64insr>;
- defm EXT_ZZI : sve_int_perm_extract_i<"ext", AArch64ext>;
+ defm EXT_ZZI : sve_int_perm_extract_i<"ext", AArch64ext, "EXT_ZZI_CONSTRUCTIVE">;
+
+ def EXT_ZZI_CONSTRUCTIVE : UnpredRegImmPseudo<ZPR8, imm0_255>;
defm RBIT_ZPmZ : sve_int_perm_rev_rbit<"rbit", AArch64rbit_mt>;
defm REVB_ZPmZ : sve_int_perm_rev_revb<"revb", AArch64revb_mt>;
@@ -2131,21 +2133,37 @@ let Predicates = [HasSVE_or_SME] in {
(LASTB_VPZ_D (PTRUE_D 31), ZPR:$Z1), dsub))>;
// Splice with lane bigger or equal to 0
- foreach VT = [nxv16i8] in
+ foreach VT = [nxv16i8] in {
def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_255 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
+ let AddedComplexity = 1 in
+ def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_255 i32:$index)))),
+ (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>;
+ }
- foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in
+ foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in {
def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_127 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
+ let AddedComplexity = 1 in
+ def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_127 i32:$index)))),
+ (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>;
+ }
- foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in
+ foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in {
def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_63 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
+ let AddedComplexity = 1 in
+ def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_63 i32:$index)))),
+ (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>;
+ }
- foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in
+ foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in {
def : Pat<(VT (vector_splice VT:$Z1, VT:$Z2, (i64 (sve_ext_imm_0_31 i32:$index)))),
(EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>;
+ let AddedComplexity = 1 in
+ def : Pat<(VT (vector_splice VT:$Z1, VT:$Z1, (i64 (sve_ext_imm_0_31 i32:$index)))),
+ (EXT_ZZI_CONSTRUCTIVE ZPR:$Z1, imm0_255:$index)>;
+ }
defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>;
defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>;
@@ -4390,7 +4408,7 @@ def : InstAlias<"pfalse\t$Pd", (PFALSE PPRorPNR8:$Pd), 0>;
// Non-widening BFloat16 to BFloat16 instructions
//===----------------------------------------------------------------------===//
-let Predicates = [HasSVEB16B16] in {
+let Predicates = [HasSVEB16B16, HasNonStreamingSVE_or_SME2] in {
defm BFADD_ZZZ : sve_fp_3op_u_zd_bfloat<0b000, "bfadd", AArch64fadd>;
defm BFSUB_ZZZ : sve_fp_3op_u_zd_bfloat<0b001, "bfsub", AArch64fsub>;
defm BFMUL_ZZZ : sve_fp_3op_u_zd_bfloat<0b010, "bfmul", AArch64fmul>;
@@ -4423,9 +4441,9 @@ defm BFMLS_ZZZI : sve_fp_fma_by_indexed_elem_bfloat<"bfmls", 0b11, AArch64fmlsid
defm BFMUL_ZZZI : sve_fp_fmul_by_indexed_elem_bfloat<"bfmul", AArch64fmulidx>;
defm BFCLAMP_ZZZ : sve_fp_clamp_bfloat<"bfclamp", AArch64fclamp>;
-} // End HasSVEB16B16
+} // End HasSVEB16B16, HasNonStreamingSVE_or_SME2
-let Predicates = [HasSVEB16B16, UseExperimentalZeroingPseudos] in {
+let Predicates = [HasSVEB16B16, HasNonStreamingSVE_or_SME2, UseExperimentalZeroingPseudos] in {
defm BFADD_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fadd>;
defm BFSUB_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fsub>;
defm BFMUL_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmul>;
@@ -4433,7 +4451,7 @@ defm BFMAXNM_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmaxnm>;
defm BFMINNM_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fminnm>;
defm BFMIN_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmin>;
defm BFMAX_ZPZZ : sve_fp_2op_p_zds_zeroing_bfloat<int_aarch64_sve_fmax>;
-} // HasSVEB16B16, UseExperimentalZeroingPseudos
+} // HasSVEB16B16, HasNonStreamingSVE_or_SME2, UseExperimentalZeroingPseudos
let Predicates = [HasSVEBFSCALE] in {
def BFSCALE_ZPZZ : sve_fp_2op_p_zds_bfscale<0b1001, "bfscale", DestructiveBinary>;
diff --git a/llvm/lib/Target/AArch64/AArch64SchedA320.td b/llvm/lib/Target/AArch64/AArch64SchedA320.td
index 89ed1338..5ec95c7 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedA320.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedA320.td
@@ -847,7 +847,7 @@ def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instregex "^[SU]XTB_ZPmZ
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_B)>;
+def : InstRW<[CortexA320Write<3, CortexA320UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>;
// Extract narrow saturating
def : InstRW<[CortexA320Write<4, CortexA320UnitVALU>], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]",
diff --git a/llvm/lib/Target/AArch64/AArch64SchedA510.td b/llvm/lib/Target/AArch64/AArch64SchedA510.td
index 9456878..356e3fa 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedA510.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedA510.td
@@ -825,7 +825,7 @@ def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instregex "^[SU]XTB_ZPmZ
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_B)>;
+def : InstRW<[CortexA510Write<3, CortexA510UnitVALU>], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>;
// Extract narrow saturating
def : InstRW<[CortexA510Write<4, CortexA510UnitVALU>], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]",
@@ -1016,7 +1016,7 @@ def : InstRW<[CortexA510MCWrite<16, 13, CortexA510UnitVALU>], (instrs FADDA_VPZ_
def : InstRW<[CortexA510MCWrite<8, 5, CortexA510UnitVALU>], (instrs FADDA_VPZ_D)>;
// Floating point compare
-def : InstRW<[CortexA510Write<4, CortexA510UnitVALU>], (instregex "^FACG[ET]_PPzZZ_[HSD]",
+def : InstRW<[CortexA510MCWrite<4, 2, CortexA510UnitVALU>], (instregex "^FACG[ET]_PPzZZ_[HSD]",
"^FCM(EQ|GE|GT|NE)_PPzZ[0Z]_[HSD]",
"^FCM(LE|LT)_PPzZ0_[HSD]",
"^FCMUO_PPzZZ_[HSD]")>;
diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td
index 91a7079..e798222 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN2.td
@@ -1785,7 +1785,7 @@ def : InstRW<[N2Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]",
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[N2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>;
+def : InstRW<[N2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>;
// Extract narrow saturating
def : InstRW<[N2Write_4c_1V1], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]$",
diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td
index ecfb124..e44d40f 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseN3.td
@@ -1757,7 +1757,7 @@ def : InstRW<[N3Write_2c_1V], (instregex "^[SU]XTB_ZPmZ_[HSD]",
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[N3Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>;
+def : InstRW<[N3Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>;
// Extract narrow saturating
def : InstRW<[N3Write_4c_1V1], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]$",
diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
index 3686654..44625a2 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV1.td
@@ -1575,7 +1575,7 @@ def : InstRW<[V1Write_2c_1V1], (instregex "^[SU]XTB_ZPmZ_[HSD]",
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[V1Write_2c_1V01], (instrs EXT_ZZI)>;
+def : InstRW<[V1Write_2c_1V01], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE)>;
// Extract/insert operation, SIMD and FP scalar form
def : InstRW<[V1Write_3c_1V1], (instregex "^LAST[AB]_VPZ_[BHSD]$",
diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
index b2c3da0..6261220 100644
--- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
+++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td
@@ -2272,7 +2272,7 @@ def : InstRW<[V2Write_2c_1V13], (instregex "^[SU]XTB_ZPmZ_[HSD]",
"^[SU]XTW_ZPmZ_[D]")>;
// Extract
-def : InstRW<[V2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_B)>;
+def : InstRW<[V2Write_2c_1V], (instrs EXT_ZZI, EXT_ZZI_CONSTRUCTIVE, EXT_ZZI_B)>;
// Extract narrow saturating
def : InstRW<[V2Write_4c_1V13], (instregex "^[SU]QXTN[BT]_ZZ_[BHS]",
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 8a5b5ba..d3b1aa6 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -182,37 +182,25 @@ SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
const AArch64TargetLowering *TLI = STI.getTargetLowering();
- TargetLowering::ArgListEntry DstEntry;
- DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
- DstEntry.Node = Dst;
TargetLowering::ArgListTy Args;
- Args.push_back(DstEntry);
+ Args.emplace_back(Dst, PointerType::getUnqual(*DAG.getContext()));
RTLIB::Libcall NewLC;
switch (LC) {
case RTLIB::MEMCPY: {
NewLC = RTLIB::SC_MEMCPY;
- TargetLowering::ArgListEntry Entry;
- Entry.Ty = PointerType::getUnqual(*DAG.getContext());
- Entry.Node = Src;
- Args.push_back(Entry);
+ Args.emplace_back(Src, PointerType::getUnqual(*DAG.getContext()));
break;
}
case RTLIB::MEMMOVE: {
NewLC = RTLIB::SC_MEMMOVE;
- TargetLowering::ArgListEntry Entry;
- Entry.Ty = PointerType::getUnqual(*DAG.getContext());
- Entry.Node = Src;
- Args.push_back(Entry);
+ Args.emplace_back(Src, PointerType::getUnqual(*DAG.getContext()));
break;
}
case RTLIB::MEMSET: {
NewLC = RTLIB::SC_MEMSET;
- TargetLowering::ArgListEntry Entry;
- Entry.Ty = Type::getInt32Ty(*DAG.getContext());
- Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
- Entry.Node = Src;
- Args.push_back(Entry);
+ Args.emplace_back(DAG.getZExtOrTrunc(Src, DL, MVT::i32),
+ Type::getInt32Ty(*DAG.getContext()));
break;
}
default:
@@ -221,10 +209,7 @@ SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());
SDValue Symbol = DAG.getExternalSymbol(TLI->getLibcallName(NewLC), PointerVT);
- TargetLowering::ArgListEntry SizeEntry;
- SizeEntry.Node = Size;
- SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
- Args.push_back(SizeEntry);
+ Args.emplace_back(Size, DAG.getDataLayout().getIntPtrType(*DAG.getContext()));
TargetLowering::CallLoweringInfo CLI(DAG);
PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
index f136a184..a67bd42 100644
--- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
+++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
@@ -585,8 +585,7 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) {
ClMaxLifetimes);
if (StandardLifetime) {
IntrinsicInst *Start = Info.LifetimeStart[0];
- uint64_t Size =
- cast<ConstantInt>(Start->getArgOperand(0))->getZExtValue();
+ uint64_t Size = *Info.AI->getAllocationSize(*DL);
Size = alignTo(Size, kTagGranuleSize);
tagAlloca(AI, Start->getNextNode(), TagPCall, Size);
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 061ed61..671df35 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -212,6 +212,13 @@ public:
return hasSVE() || isStreamingSVEAvailable();
}
+ /// Returns true if the target has access to either the full range of SVE
+ /// instructions, or the streaming-compatible subset of SVE instructions
+ /// available to SME2.
+ bool isNonStreamingSVEorSME2Available() const {
+ return isSVEAvailable() || (isSVEorStreamingSVEAvailable() && hasSME2());
+ }
+
unsigned getMinVectorRegisterBitWidth() const {
// Don't assume any minimum vector size when PSTATE.SM may not be 0, because
// we don't yet support streaming-compatible codegen support that we trust
@@ -239,8 +246,8 @@ public:
/// Return true if the CPU supports any kind of instruction fusion.
bool hasFusion() const {
return hasArithmeticBccFusion() || hasArithmeticCbzFusion() ||
- hasFuseAES() || hasFuseArithmeticLogic() || hasFuseCCSelect() ||
- hasFuseAdrpAdd() || hasFuseLiterals();
+ hasFuseAES() || hasFuseArithmeticLogic() || hasFuseCmpCSel() ||
+ hasFuseCmpCSet() || hasFuseAdrpAdd() || hasFuseLiterals();
}
unsigned getEpilogueVectorizationMinVF() const {
@@ -451,12 +458,6 @@ public:
return "__chkstk";
}
- const char* getSecurityCheckCookieName() const {
- if (isWindowsArm64EC())
- return "#__security_check_cookie_arm64ec";
- return "__security_check_cookie";
- }
-
/// Choose a method of checking LR before performing a tail call.
AArch64PAuth::AuthCheckMethod
getAuthenticatedLRCheckMethod(const MachineFunction &MF) const;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 95eab16..e67bd58 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -224,6 +224,11 @@ static cl::opt<bool>
cl::desc("Enable Machine Pipeliner for AArch64"),
cl::init(false), cl::Hidden);
+static cl::opt<bool>
+ EnableNewSMEABILowering("aarch64-new-sme-abi",
+ cl::desc("Enable new lowering for the SME ABI"),
+ cl::init(false), cl::Hidden);
+
extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
LLVMInitializeAArch64Target() {
// Register the target.
@@ -263,6 +268,7 @@ LLVMInitializeAArch64Target() {
initializeLDTLSCleanupPass(PR);
initializeKCFIPass(PR);
initializeSMEABIPass(PR);
+ initializeMachineSMEABIPass(PR);
initializeSMEPeepholeOptPass(PR);
initializeSVEIntrinsicOptsPass(PR);
initializeAArch64SpeculationHardeningPass(PR);
@@ -367,7 +373,8 @@ AArch64TargetMachine::AArch64TargetMachine(const Target &T, const Triple &TT,
computeDefaultCPU(TT, CPU), FS, Options,
getEffectiveRelocModel(TT, RM),
getEffectiveAArch64CodeModel(TT, CM, JIT), OL),
- TLOF(createTLOF(getTargetTriple())), isLittle(LittleEndian) {
+ TLOF(createTLOF(getTargetTriple())), isLittle(LittleEndian),
+ UseNewSMEABILowering(EnableNewSMEABILowering) {
initAsmInfo();
if (TT.isOSBinFormatMachO()) {
@@ -668,10 +675,12 @@ void AArch64PassConfig::addIRPasses() {
addPass(createInterleavedAccessPass());
}
- // Expand any functions marked with SME attributes which require special
- // changes for the calling convention or that require the lazy-saving
- // mechanism specified in the SME ABI.
- addPass(createSMEABIPass());
+ if (!EnableNewSMEABILowering) {
+ // Expand any functions marked with SME attributes which require special
+ // changes for the calling convention or that require the lazy-saving
+ // mechanism specified in the SME ABI.
+ addPass(createSMEABIPass());
+ }
// Add Control Flow Guard checks.
if (TM->getTargetTriple().isOSWindows()) {
@@ -782,6 +791,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
}
void AArch64PassConfig::addMachineSSAOptimization() {
+ if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None)
+ addPass(createMachineSMEABIPass());
+
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
addPass(createSMEPeepholeOptPass());
@@ -812,6 +824,9 @@ bool AArch64PassConfig::addILPOpts() {
}
void AArch64PassConfig::addPreRegAlloc() {
+ if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering)
+ addPass(createMachineSMEABIPass());
+
// Change dead register definitions to refer to the zero register.
if (TM->getOptLevel() != CodeGenOptLevel::None &&
EnableDeadRegisterElimination)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.h b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
index b9e522d..0dd5d95 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
@@ -79,8 +79,12 @@ public:
size_t clearLinkerOptimizationHints(
const SmallPtrSetImpl<MachineInstr *> &MIs) const override;
+ /// Returns true if the new SME ABI lowering should be used.
+ bool useNewSMEABILowering() const { return UseNewSMEABILowering; }
+
private:
bool isLittle;
+ bool UseNewSMEABILowering;
};
// AArch64 little endian target machine.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e1adc0b..922da10 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -220,20 +220,17 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
static cl::opt<bool> EnableScalableAutovecInStreamingMode(
"enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
-static bool isSMEABIRoutineCall(const CallInst &CI) {
+static bool isSMEABIRoutineCall(const CallInst &CI,
+ const AArch64TargetLowering &TLI) {
const auto *F = CI.getCalledFunction();
- return F && StringSwitch<bool>(F->getName())
- .Case("__arm_sme_state", true)
- .Case("__arm_tpidr2_save", true)
- .Case("__arm_tpidr2_restore", true)
- .Case("__arm_za_disable", true)
- .Default(false);
+ return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
}
/// Returns true if the function has explicit operations that can only be
/// lowered using incompatible instructions for the selected mode. This also
/// returns true if the function F may use or modify ZA state.
-static bool hasPossibleIncompatibleOps(const Function *F) {
+static bool hasPossibleIncompatibleOps(const Function *F,
+ const AArch64TargetLowering &TLI) {
for (const BasicBlock &BB : *F) {
for (const Instruction &I : BB) {
// Be conservative for now and assume that any call to inline asm or to
@@ -242,7 +239,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
// all native LLVM instructions can be lowered to compatible instructions.
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
- isSMEABIRoutineCall(cast<CallInst>(I))))
+ isSMEABIRoutineCall(cast<CallInst>(I), TLI)))
return true;
}
}
@@ -290,7 +287,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState()) {
- if (hasPossibleIncompatibleOps(Callee))
+ if (hasPossibleIncompatibleOps(Callee, *getTLI()))
return false;
}
@@ -357,7 +354,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// change only once and avoid inlining of G into F.
SMEAttrs FAttrs(*F);
- SMECallAttrs CallAttrs(Call);
+ SMECallAttrs CallAttrs(Call, getTLI());
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
@@ -554,7 +551,17 @@ static bool isUnpackedVectorVT(EVT VecVT) {
VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock;
}
-static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) {
+static InstructionCost getHistogramCost(const AArch64Subtarget *ST,
+ const IntrinsicCostAttributes &ICA) {
+ // We need to know at least the number of elements in the vector of buckets
+ // and the size of each element to update.
+ if (ICA.getArgTypes().size() < 2)
+ return InstructionCost::getInvalid();
+
+ // Only interested in costing for the hardware instruction from SVE2.
+ if (!ST->hasSVE2())
+ return InstructionCost::getInvalid();
+
Type *BucketPtrsTy = ICA.getArgTypes()[0]; // Type of vector of pointers
Type *EltTy = ICA.getArgTypes()[1]; // Type of bucket elements
unsigned TotalHistCnts = 1;
@@ -579,9 +586,11 @@ static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) {
unsigned NaturalVectorWidth = AArch64::SVEBitsPerBlock / LegalEltSize;
TotalHistCnts = EC / NaturalVectorWidth;
+
+ return InstructionCost(BaseHistCntCost * TotalHistCnts);
}
- return InstructionCost(BaseHistCntCost * TotalHistCnts);
+ return InstructionCost::getInvalid();
}
InstructionCost
@@ -597,10 +606,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
return InstructionCost::getInvalid();
switch (ICA.getID()) {
- case Intrinsic::experimental_vector_histogram_add:
- if (!ST->hasSVE2())
- return InstructionCost::getInvalid();
- return getHistogramCost(ICA);
+ case Intrinsic::experimental_vector_histogram_add: {
+ InstructionCost HistCost = getHistogramCost(ST, ICA);
+ // If the cost isn't valid, we may still be able to scalarize
+ if (HistCost.isValid())
+ return HistCost;
+ break;
+ }
case Intrinsic::umin:
case Intrinsic::umax:
case Intrinsic::smin:
@@ -631,6 +643,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
return LT.first * Instrs;
+
+ TypeSize TS = getDataLayout().getTypeSizeInBits(RetTy);
+ uint64_t VectorSize = TS.getKnownMinValue();
+
+ if (ST->isSVEAvailable() && VectorSize >= 128 && isPowerOf2_64(VectorSize))
+ return LT.first * Instrs;
+
break;
}
case Intrinsic::abs: {
@@ -651,6 +670,16 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
return LT.first;
break;
}
+ case Intrinsic::fma:
+ case Intrinsic::fmuladd: {
+ // Given a fma or fmuladd, cost it the same as a fmul instruction which are
+ // usually the same for costs. TODO: Add fp16 and bf16 expansion costs.
+ Type *EltTy = RetTy->getScalarType();
+ if (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
+ (EltTy->isHalfTy() && ST->hasFullFP16()))
+ return getArithmeticInstrCost(Instruction::FMul, RetTy, CostKind);
+ break;
+ }
case Intrinsic::stepvector: {
InstructionCost Cost = 1; // Cost of the `index' instruction
auto LT = getTypeLegalizationCost(RetTy);
@@ -2072,6 +2101,20 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
: std::nullopt;
}
+static std::optional<Instruction *>
+instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts,
+ const AArch64Subtarget *ST) {
+ if (!ST->isStreaming())
+ return std::nullopt;
+
+ // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt
+ // with SVEPredPattern::all
+ Value *Cnt = IC.Builder.CreateElementCount(
+ II.getType(), ElementCount::getScalable(NumElts));
+ Cnt->takeName(&II);
+ return IC.replaceInstUsesWith(II, Cnt);
+}
+
static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
IntrinsicInst &II) {
Value *PgVal = II.getArgOperand(0);
@@ -2781,6 +2824,14 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
return instCombineSVECntElts(IC, II, 8);
case Intrinsic::aarch64_sve_cntb:
return instCombineSVECntElts(IC, II, 16);
+ case Intrinsic::aarch64_sme_cntsd:
+ return instCombineSMECntsElts(IC, II, 2, ST);
+ case Intrinsic::aarch64_sme_cntsw:
+ return instCombineSMECntsElts(IC, II, 4, ST);
+ case Intrinsic::aarch64_sme_cntsh:
+ return instCombineSMECntsElts(IC, II, 8, ST);
+ case Intrinsic::aarch64_sme_cntsb:
+ return instCombineSMECntsElts(IC, II, 16, ST);
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
@@ -3092,6 +3143,13 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return AdjustCost(
BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
+ // For the moment we do not have lowering for SVE1-only fptrunc f64->bf16 as
+ // we use fcvtx under SVE2. Give them invalid costs.
+ if (!ST->hasSVE2() && !ST->isStreamingSVEAvailable() &&
+ ISD == ISD::FP_ROUND && SrcTy.isScalableVector() &&
+ DstTy.getScalarType() == MVT::bf16 && SrcTy.getScalarType() == MVT::f64)
+ return InstructionCost::getInvalid();
+
static const TypeConversionCostTblEntry BF16Tbl[] = {
{ISD::FP_ROUND, MVT::bf16, MVT::f32, 1}, // bfcvt
{ISD::FP_ROUND, MVT::bf16, MVT::f64, 1}, // bfcvt
@@ -3100,6 +3158,12 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
{ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn
{ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn
{ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn
+ {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f32, 1}, // bfcvt
+ {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f32, 1}, // bfcvt
+ {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f32, 3}, // bfcvt+bfcvt+uzp1
+ {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f64, 2}, // fcvtx+bfcvt
+ {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f64, 5}, // 2*fcvtx+2*bfcvt+uzp1
+ {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f64, 11}, // 4*fcvt+4*bfcvt+3*uzp
};
if (ST->hasBF16())
@@ -3508,11 +3572,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
{ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1},
{ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3},
+ // Truncate from nxvmf32 to nxvmbf16.
+ {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f32, 8},
+ {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f32, 8},
+ {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f32, 17},
+
// Truncate from nxvmf64 to nxvmf16.
{ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1},
{ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3},
{ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7},
+ // Truncate from nxvmf64 to nxvmbf16.
+ {ISD::FP_ROUND, MVT::nxv2bf16, MVT::nxv2f64, 9},
+ {ISD::FP_ROUND, MVT::nxv4bf16, MVT::nxv4f64, 19},
+ {ISD::FP_ROUND, MVT::nxv8bf16, MVT::nxv8f64, 39},
+
// Truncate from nxvmf64 to nxvmf32.
{ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1},
{ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3},
@@ -3523,11 +3597,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
{ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
{ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
+ // Extend from nxvmbf16 to nxvmf32.
+ {ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2bf16, 1}, // lsl
+ {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4bf16, 1}, // lsl
+ {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8bf16, 4}, // unpck+unpck+lsl+lsl
+
// Extend from nxvmf16 to nxvmf64.
{ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
{ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
{ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
+ // Extend from nxvmbf16 to nxvmf64.
+ {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2bf16, 2}, // lsl+fcvt
+ {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4bf16, 6}, // 2*unpck+2*lsl+2*fcvt
+ {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8bf16, 14}, // 6*unpck+4*lsl+4*fcvt
+
// Extend from nxvmf32 to nxvmf64.
{ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
{ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
@@ -3928,6 +4012,24 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I);
}
+InstructionCost
+AArch64TTIImpl::getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
+ TTI::TargetCostKind CostKind,
+ unsigned Index) const {
+ if (isa<FixedVectorType>(Val))
+ return BaseT::getIndexedVectorInstrCostFromEnd(Opcode, Val, CostKind,
+ Index);
+
+ // This typically requires both while and lastb instructions in order
+ // to extract the last element. If this is in a loop the while
+ // instruction can at least be hoisted out, although it will consume a
+ // predicate register. The cost should be more expensive than the base
+ // extract cost, which is 2 for most CPUs.
+ return CostKind == TTI::TCK_CodeSize
+ ? 2
+ : ST->getVectorInsertExtractBaseCost() + 1;
+}
+
InstructionCost AArch64TTIImpl::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind, bool ForPoisonSrc,
@@ -3942,6 +4044,27 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
return DemandedElts.popcount() * (Insert + Extract) * VecInstCost;
}
+std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost(
+ Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
+ TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
+ std::function<InstructionCost(Type *)> InstCost) const {
+ if (!Ty->getScalarType()->isHalfTy() && !Ty->getScalarType()->isBFloatTy())
+ return std::nullopt;
+ if (Ty->getScalarType()->isHalfTy() && ST->hasFullFP16())
+ return std::nullopt;
+
+ Type *PromotedTy = Ty->getWithNewType(Type::getFloatTy(Ty->getContext()));
+ InstructionCost Cost = getCastInstrCost(Instruction::FPExt, PromotedTy, Ty,
+ TTI::CastContextHint::None, CostKind);
+ if (!Op1Info.isConstant() && !Op2Info.isConstant())
+ Cost *= 2;
+ Cost += InstCost(PromotedTy);
+ if (IncludeTrunc)
+ Cost += getCastInstrCost(Instruction::FPTrunc, Ty, PromotedTy,
+ TTI::CastContextHint::None, CostKind);
+ return Cost;
+}
+
InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -3964,6 +4087,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
int ISD = TLI->InstructionOpcodeToISD(Opcode);
+ // Increase the cost for half and bfloat types if not architecturally
+ // supported.
+ if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
+ ISD == ISD::FDIV || ISD == ISD::FREM)
+ if (auto PromotedCost = getFP16BF16PromoteCost(
+ Ty, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/true,
+ [&](Type *PromotedTy) {
+ return getArithmeticInstrCost(Opcode, PromotedTy, CostKind,
+ Op1Info, Op2Info);
+ }))
+ return *PromotedCost;
+
switch (ISD) {
default:
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -4232,11 +4367,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
[[fallthrough]];
case ISD::FADD:
case ISD::FSUB:
- // Increase the cost for half and bfloat types if not architecturally
- // supported.
- if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
- (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
- return 2 * LT.first;
if (!Ty->getScalarType()->isFP128Ty())
return LT.first;
[[fallthrough]];
@@ -4260,8 +4390,9 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
}
InstructionCost
-AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
- const SCEV *Ptr) const {
+AArch64TTIImpl::getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE,
+ const SCEV *Ptr,
+ TTI::TargetCostKind CostKind) const {
// Address computations in vectorized code with non-consecutive addresses will
// likely result in more instructions compared to scalar code where the
// computation can more often be merged into the index mode. The resulting
@@ -4269,7 +4400,7 @@ AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
int MaxMergeDistance = 64;
- if (Ty->isVectorTy() && SE &&
+ if (PtrTy->isVectorTy() && SE &&
!BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
return NumVectorInstToHideOverhead;
@@ -4282,10 +4413,9 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
TTI::OperandValueInfo Op2Info, const Instruction *I) const {
- int ISD = TLI->InstructionOpcodeToISD(Opcode);
// We don't lower some vector selects well that are wider than the register
// width. TODO: Improve this with different cost kinds.
- if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
+ if (isa<FixedVectorType>(ValTy) && Opcode == Instruction::Select) {
// We would need this many instructions to hide the scalarization happening.
const int AmortizationCost = 20;
@@ -4315,55 +4445,68 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
return LT.first;
}
- static const TypeConversionCostTblEntry
- VectorSelectTbl[] = {
- { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
- { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
- { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
- { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
- { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
- { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
- { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
- { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
- { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
- { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
- { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
- };
+ static const TypeConversionCostTblEntry VectorSelectTbl[] = {
+ {Instruction::Select, MVT::v2i1, MVT::v2f32, 2},
+ {Instruction::Select, MVT::v2i1, MVT::v2f64, 2},
+ {Instruction::Select, MVT::v4i1, MVT::v4f32, 2},
+ {Instruction::Select, MVT::v4i1, MVT::v4f16, 2},
+ {Instruction::Select, MVT::v8i1, MVT::v8f16, 2},
+ {Instruction::Select, MVT::v16i1, MVT::v16i16, 16},
+ {Instruction::Select, MVT::v8i1, MVT::v8i32, 8},
+ {Instruction::Select, MVT::v16i1, MVT::v16i32, 16},
+ {Instruction::Select, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost},
+ {Instruction::Select, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost},
+ {Instruction::Select, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost}};
EVT SelCondTy = TLI->getValueType(DL, CondTy);
EVT SelValTy = TLI->getValueType(DL, ValTy);
if (SelCondTy.isSimple() && SelValTy.isSimple()) {
- if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
+ if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, Opcode,
SelCondTy.getSimpleVT(),
SelValTy.getSimpleVT()))
return Entry->Cost;
}
}
- if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
- Type *ValScalarTy = ValTy->getScalarType();
- if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
- ValScalarTy->isBFloatTy()) {
- auto *ValVTy = cast<FixedVectorType>(ValTy);
-
- // Without dedicated instructions we promote [b]f16 compares to f32.
- auto *PromotedTy =
- VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
-
- InstructionCost Cost = 0;
- // Promote operands to float vectors.
- Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
- TTI::CastContextHint::None, CostKind);
- // Compare float vectors.
- Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
- Op1Info, Op2Info);
- // During codegen we'll truncate the vector result from i32 to i16.
- Cost +=
- getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
- VectorType::getInteger(PromotedTy),
- TTI::CastContextHint::None, CostKind);
- return Cost;
- }
+ if (Opcode == Instruction::FCmp) {
+ if (auto PromotedCost = getFP16BF16PromoteCost(
+ ValTy, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/false,
+ [&](Type *PromotedTy) {
+ InstructionCost Cost =
+ getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred,
+ CostKind, Op1Info, Op2Info);
+ if (isa<VectorType>(PromotedTy))
+ Cost += getCastInstrCost(
+ Instruction::Trunc,
+ VectorType::getInteger(cast<VectorType>(ValTy)),
+ VectorType::getInteger(cast<VectorType>(PromotedTy)),
+ TTI::CastContextHint::None, CostKind);
+ return Cost;
+ }))
+ return *PromotedCost;
+
+ auto LT = getTypeLegalizationCost(ValTy);
+ // Model unknown fp compares as a libcall.
+ if (LT.second.getScalarType() != MVT::f64 &&
+ LT.second.getScalarType() != MVT::f32 &&
+ LT.second.getScalarType() != MVT::f16)
+ return LT.first * getCallInstrCost(/*Function*/ nullptr, ValTy,
+ {ValTy, ValTy}, CostKind);
+
+ // Some comparison operators require expanding to multiple compares + or.
+ unsigned Factor = 1;
+ if (!CondTy->isVectorTy() &&
+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ))
+ Factor = 2; // fcmp with 2 selects
+ else if (isa<FixedVectorType>(ValTy) &&
+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ ||
+ VecPred == FCmpInst::FCMP_ORD || VecPred == FCmpInst::FCMP_UNO))
+ Factor = 3; // fcmxx+fcmyy+or
+ else if (isa<ScalableVectorType>(ValTy) &&
+ (VecPred == FCmpInst::FCMP_ONE || VecPred == FCmpInst::FCMP_UEQ))
+ Factor = 3; // fcmxx+fcmyy+or
+
+ return Factor * (CostKind == TTI::TCK_Latency ? 2 : LT.first);
}
// Treat the icmp in icmp(and, 0) or icmp(and, -1/1) when it can be folded to
@@ -4371,7 +4514,7 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
// comparison is not unsigned. FIXME: Enable for non-throughput cost kinds
// providing it will not cause performance regressions.
if (CostKind == TTI::TCK_RecipThroughput && ValTy->isIntegerTy() &&
- ISD == ISD::SETCC && I && !CmpInst::isUnsigned(VecPred) &&
+ Opcode == Instruction::ICmp && I && !CmpInst::isUnsigned(VecPred) &&
TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
match(I->getOperand(0), m_And(m_Value(), m_Value()))) {
if (match(I->getOperand(1), m_Zero()))
@@ -4809,32 +4952,18 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
// Limit to loops with trip counts that are cheap to expand.
UP.SCEVExpansionBudget = 1;
- // Try to unroll small, single block loops, if they have load/store
- // dependencies, to expose more parallel memory access streams.
+ // Try to unroll small loops, of few-blocks with low budget, if they have
+ // load/store dependencies, to expose more parallel memory access streams,
+ // or if they do little work inside a block (i.e. load -> X -> store pattern).
BasicBlock *Header = L->getHeader();
- if (Header == L->getLoopLatch()) {
+ BasicBlock *Latch = L->getLoopLatch();
+ if (Header == Latch) {
// Estimate the size of the loop.
unsigned Size;
- if (!isLoopSizeWithinBudget(L, TTI, 8, &Size))
+ unsigned Width = 10;
+ if (!isLoopSizeWithinBudget(L, TTI, Width, &Size))
return;
- SmallPtrSet<Value *, 8> LoadedValues;
- SmallVector<StoreInst *> Stores;
- for (auto *BB : L->blocks()) {
- for (auto &I : *BB) {
- Value *Ptr = getLoadStorePointerOperand(&I);
- if (!Ptr)
- continue;
- const SCEV *PtrSCEV = SE.getSCEV(Ptr);
- if (SE.isLoopInvariant(PtrSCEV, L))
- continue;
- if (isa<LoadInst>(&I))
- LoadedValues.insert(&I);
- else
- Stores.push_back(cast<StoreInst>(&I));
- }
- }
-
// Try to find an unroll count that maximizes the use of the instruction
// window, i.e. trying to fetch as many instructions per cycle as possible.
unsigned MaxInstsPerLine = 16;
@@ -4853,8 +4982,32 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
UC++;
}
- if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
- return LoadedValues.contains(SI->getOperand(0));
+ if (BestUC == 1)
+ return;
+
+ SmallPtrSet<Value *, 8> LoadedValuesPlus;
+ SmallVector<StoreInst *> Stores;
+ for (auto *BB : L->blocks()) {
+ for (auto &I : *BB) {
+ Value *Ptr = getLoadStorePointerOperand(&I);
+ if (!Ptr)
+ continue;
+ const SCEV *PtrSCEV = SE.getSCEV(Ptr);
+ if (SE.isLoopInvariant(PtrSCEV, L))
+ continue;
+ if (isa<LoadInst>(&I)) {
+ LoadedValuesPlus.insert(&I);
+ // Include in-loop 1st users of loaded values.
+ for (auto *U : I.users())
+ if (L->contains(cast<Instruction>(U)))
+ LoadedValuesPlus.insert(U);
+ } else
+ Stores.push_back(cast<StoreInst>(&I));
+ }
+ }
+
+ if (none_of(Stores, [&LoadedValuesPlus](StoreInst *SI) {
+ return LoadedValuesPlus.contains(SI->getOperand(0));
}))
return;
@@ -4866,7 +5019,6 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
// Try to runtime-unroll loops with early-continues depending on loop-varying
// loads; this helps with branch-prediction for the early-continues.
auto *Term = dyn_cast<BranchInst>(Header->getTerminator());
- auto *Latch = L->getLoopLatch();
SmallVector<BasicBlock *> Preds(predecessors(Latch));
if (!Term || !Term->isConditional() || Preds.size() == 1 ||
!llvm::is_contained(Preds, Header) ||
@@ -5102,6 +5254,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
return false;
switch (RdxDesc.getRecurrenceKind()) {
+ case RecurKind::Sub:
+ case RecurKind::AddChainWithSubs:
case RecurKind::Add:
case RecurKind::FAdd:
case RecurKind::And:
@@ -5332,13 +5486,14 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
}
InstructionCost
-AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
- VectorType *VecTy,
+AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
+ Type *ResTy, VectorType *VecTy,
TTI::TargetCostKind CostKind) const {
EVT VecVT = TLI->getValueType(DL, VecTy);
EVT ResVT = TLI->getValueType(DL, ResTy);
- if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
+ if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple() &&
+ RedOpcode == Instruction::Add) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
// The legal cases with dotprod are
@@ -5349,7 +5504,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return LT.first + 2;
}
- return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
+ return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, VecTy,
+ CostKind);
}
InstructionCost
@@ -6235,10 +6391,17 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
}
- auto ShouldSinkCondition = [](Value *Cond) -> bool {
+ auto ShouldSinkCondition = [](Value *Cond,
+ SmallVectorImpl<Use *> &Ops) -> bool {
+ if (!isa<IntrinsicInst>(Cond))
+ return false;
auto *II = dyn_cast<IntrinsicInst>(Cond);
- return II && II->getIntrinsicID() == Intrinsic::vector_reduce_or &&
- isa<ScalableVectorType>(II->getOperand(0)->getType());
+ if (II->getIntrinsicID() != Intrinsic::vector_reduce_or ||
+ !isa<ScalableVectorType>(II->getOperand(0)->getType()))
+ return false;
+ if (isa<CmpInst>(II->getOperand(0)))
+ Ops.push_back(&II->getOperandUse(0));
+ return true;
};
switch (I->getOpcode()) {
@@ -6254,7 +6417,7 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
break;
case Instruction::Select: {
- if (!ShouldSinkCondition(I->getOperand(0)))
+ if (!ShouldSinkCondition(I->getOperand(0), Ops))
return false;
Ops.push_back(&I->getOperandUse(0));
@@ -6264,7 +6427,7 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
if (cast<BranchInst>(I)->isUnconditional())
return false;
- if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition()))
+ if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition(), Ops))
return false;
Ops.push_back(&I->getOperandUse(0));
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 7f45177..b994ca7 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -221,6 +221,11 @@ public:
unsigned Index) const override;
InstructionCost
+ getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
+ TTI::TargetCostKind CostKind,
+ unsigned Index) const override;
+
+ InstructionCost
getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF,
TTI::TargetCostKind CostKind) const override;
@@ -238,8 +243,9 @@ public:
ArrayRef<const Value *> Args = {},
const Instruction *CxtI = nullptr) const override;
- InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
- const SCEV *Ptr) const override;
+ InstructionCost
+ getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, const SCEV *Ptr,
+ TTI::TargetCostKind CostKind) const override;
InstructionCost getCmpSelInstrCost(
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
@@ -435,6 +441,14 @@ public:
bool preferPredicatedReductionSelect() const override { return ST->hasSVE(); }
+ /// FP16 and BF16 operations are lowered to fptrunc(op(fpext, fpext) if the
+ /// architecture features are not present.
+ std::optional<InstructionCost>
+ getFP16BF16PromoteCost(Type *Ty, TTI::TargetCostKind CostKind,
+ TTI::OperandValueInfo Op1Info,
+ TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
+ std::function<InstructionCost(Type *)> InstCost) const;
+
InstructionCost
getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
std::optional<FastMathFlags> FMF,
@@ -446,7 +460,7 @@ public:
TTI::TargetCostKind CostKind) const override;
InstructionCost getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
InstructionCost
diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
index 1ca61f5..3641e22 100644
--- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
+++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
@@ -7909,9 +7909,11 @@ bool AArch64AsmParser::parseDirectiveSEHSavePReg(SMLoc L) {
}
bool AArch64AsmParser::parseDirectiveAeabiSubSectionHeader(SMLoc L) {
- // Expecting 3 AsmToken::Identifier after '.aeabi_subsection', a name and 2
- // parameters, e.g.: .aeabi_subsection (1)aeabi_feature_and_bits, (2)optional,
- // (3)uleb128 separated by 2 commas.
+ // Handle parsing of .aeabi_subsection directives
+ // - On first declaration of a subsection, expect exactly three identifiers
+ // after `.aeabi_subsection`: the subsection name and two parameters.
+ // - When switching to an existing subsection, it is valid to provide only
+ // the subsection name, or the name together with the two parameters.
MCAsmParser &Parser = getParser();
// Consume the name (subsection name)
@@ -7925,16 +7927,38 @@ bool AArch64AsmParser::parseDirectiveAeabiSubSectionHeader(SMLoc L) {
return true;
}
Parser.Lex();
- // consume a comma
+
+ std::unique_ptr<MCELFStreamer::AttributeSubSection> SubsectionExists =
+ getTargetStreamer().getAttributesSubsectionByName(SubsectionName);
+ // Check whether only the subsection name was provided.
+ // If so, the user is trying to switch to a subsection that should have been
+ // declared before.
+ if (Parser.getTok().is(llvm::AsmToken::EndOfStatement)) {
+ if (SubsectionExists) {
+ getTargetStreamer().emitAttributesSubsection(
+ SubsectionName,
+ static_cast<AArch64BuildAttributes::SubsectionOptional>(
+ SubsectionExists->IsOptional),
+ static_cast<AArch64BuildAttributes::SubsectionType>(
+ SubsectionExists->ParameterType));
+ return false;
+ }
+ // If subsection does not exists, report error.
+ else {
+ Error(Parser.getTok().getLoc(),
+ "Could not switch to subsection '" + SubsectionName +
+ "' using subsection name, subsection has not been defined");
+ return true;
+ }
+ }
+
+ // Otherwise, expecting 2 more parameters: consume a comma
// parseComma() return *false* on success, and call Lex(), no need to call
// Lex() again.
if (Parser.parseComma()) {
return true;
}
- std::unique_ptr<MCELFStreamer::AttributeSubSection> SubsectionExists =
- getTargetStreamer().getAttributesSubsectionByName(SubsectionName);
-
// Consume the first parameter (optionality parameter)
AArch64BuildAttributes::SubsectionOptional IsOptional;
// options: optional/required
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 66136a4..803943f 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -89,6 +89,7 @@ add_llvm_target(AArch64CodeGen
SMEABIPass.cpp
SMEPeepholeOpt.cpp
SVEIntrinsicOpts.cpp
+ MachineSMEABIPass.cpp
AArch64SIMDInstrOpt.cpp
DEPENDS
diff --git a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
index ae984be..23e46b8 100644
--- a/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
+++ b/llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
@@ -15,6 +15,7 @@
#include "MCTargetDesc/AArch64MCTargetDesc.h"
#include "TargetInfo/AArch64TargetInfo.h"
#include "Utils/AArch64BaseInfo.h"
+#include "llvm/MC/MCDecoder.h"
#include "llvm/MC/MCDecoderOps.h"
#include "llvm/MC/MCDisassembler/MCRelocationInfo.h"
#include "llvm/MC/MCInst.h"
@@ -27,314 +28,21 @@
#include <memory>
using namespace llvm;
+using namespace llvm::MCD;
#define DEBUG_TYPE "aarch64-disassembler"
// Pull DecodeStatus and its enum values into the global namespace.
using DecodeStatus = MCDisassembler::DecodeStatus;
-// Forward declare these because the autogenerated code will reference them.
-// Definitions are further down.
-template <unsigned RegClassID, unsigned FirstReg, unsigned NumRegsInClass>
-static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeGPR64x8ClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address,
- const MCDisassembler *Decoder);
-template <unsigned Min, unsigned Max>
-static DecodeStatus DecodeZPRMul2_MinMax(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeZK(MCInst &Inst, unsigned RegNo, uint64_t Address,
- const MCDisassembler *Decoder);
-template <unsigned Min, unsigned Max>
-static DecodeStatus DecodeZPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const void *Decoder);
-static DecodeStatus DecodeZPR4Mul4RegisterClass(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const void *Decoder);
-template <unsigned NumBitsForTile>
-static DecodeStatus DecodeMatrixTile(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeMatrixTileListRegisterClass(MCInst &Inst, unsigned RegMask,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo,
- uint64_t Address,
- const void *Decoder);
-
-static DecodeStatus DecodeFixedPointScaleImm32(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeFixedPointScaleImm64(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePCRelLabel16(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePCRelLabel19(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePCRelLabel9(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeMemExtend(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeMRSSystemRegister(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeMSRSystemRegister(MCInst &Inst, unsigned Imm,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeThreeAddrSRegInstruction(MCInst &Inst, uint32_t insn, uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeMoveImmInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeUnsignedLdStInstruction(MCInst &Inst, uint32_t insn, uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeSignedLdStInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeExclusiveLdStInstruction(MCInst &Inst, uint32_t insn, uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePairLdStInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeAuthLoadInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeAddSubERegInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeLogicalImmInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeModImmInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeModImmTiedInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeAdrInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeAddSubImmShift(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeUnconditionalBranch(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeSystemPStateImm0_15Instruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeSystemPStateImm0_1Instruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeTestAndBranch(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-
-static DecodeStatus DecodeFMOVLaneInstruction(MCInst &Inst, unsigned Insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR64Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR64ImmNarrow(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR32Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR32ImmNarrow(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR16Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR16ImmNarrow(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftR8Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftL64Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftL32Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftL16Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeVecShiftL8Imm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeWSeqPairsClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeXSeqPairsClassRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeSyspXzrInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus
-DecodeSVELogicalImmInstruction(MCInst &Inst, uint32_t insn, uint64_t Address,
- const MCDisassembler *Decoder);
template <int Bits>
static DecodeStatus DecodeSImm(MCInst &Inst, uint64_t Imm, uint64_t Address,
const MCDisassembler *Decoder);
-template <int ElementWidth>
-static DecodeStatus DecodeImm8OptLsl(MCInst &Inst, unsigned Imm, uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeSVEIncDecImm(MCInst &Inst, unsigned Imm,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeSVCROp(MCInst &Inst, unsigned Imm, uint64_t Address,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeCPYMemOpInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodeSETMemOpInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Addr,
- const MCDisassembler *Decoder);
-static DecodeStatus DecodePRFMRegInstruction(MCInst &Inst, uint32_t insn,
- uint64_t Address,
- const MCDisassembler *Decoder);
-
-#include "AArch64GenDisassemblerTables.inc"
-#include "AArch64GenInstrInfo.inc"
#define Success MCDisassembler::Success
#define Fail MCDisassembler::Fail
#define SoftFail MCDisassembler::SoftFail
-static MCDisassembler *createAArch64Disassembler(const Target &T,
- const MCSubtargetInfo &STI,
- MCContext &Ctx) {
-
- return new AArch64Disassembler(STI, Ctx, T.createMCInstrInfo());
-}
-
-DecodeStatus AArch64Disassembler::getInstruction(MCInst &MI, uint64_t &Size,
- ArrayRef<uint8_t> Bytes,
- uint64_t Address,
- raw_ostream &CS) const {
- CommentStream = &CS;
-
- Size = 0;
- // We want to read exactly 4 bytes of data.
- if (Bytes.size() < 4)
- return Fail;
- Size = 4;
-
- // Encoded as a small-endian 32-bit word in the stream.
- uint32_t Insn =
- (Bytes[3] << 24) | (Bytes[2] << 16) | (Bytes[1] << 8) | (Bytes[0] << 0);
-
- const uint8_t *Tables[] = {DecoderTable32, DecoderTableFallback32};
-
- for (const auto *Table : Tables) {
- DecodeStatus Result =
- decodeInstruction(Table, MI, Insn, Address, this, STI);
-
- const MCInstrDesc &Desc = MCII->get(MI.getOpcode());
-
- // For Scalable Matrix Extension (SME) instructions that have an implicit
- // operand for the accumulator (ZA) or implicit immediate zero which isn't
- // encoded, manually insert operand.
- for (unsigned i = 0; i < Desc.getNumOperands(); i++) {
- if (Desc.operands()[i].OperandType == MCOI::OPERAND_REGISTER) {
- switch (Desc.operands()[i].RegClass) {
- default:
- break;
- case AArch64::MPRRegClassID:
- MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZA));
- break;
- case AArch64::MPR8RegClassID:
- MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZAB0));
- break;
- case AArch64::ZTRRegClassID:
- MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZT0));
- break;
- }
- } else if (Desc.operands()[i].OperandType ==
- AArch64::OPERAND_IMPLICIT_IMM_0) {
- MI.insert(MI.begin() + i, MCOperand::createImm(0));
- }
- }
-
- if (MI.getOpcode() == AArch64::LDR_ZA ||
- MI.getOpcode() == AArch64::STR_ZA) {
- // Spill and fill instructions have a single immediate used for both
- // the vector select offset and optional memory offset. Replicate
- // the decoded immediate.
- const MCOperand &Imm4Op = MI.getOperand(2);
- assert(Imm4Op.isImm() && "Unexpected operand type!");
- MI.addOperand(Imm4Op);
- }
-
- if (Result != MCDisassembler::Fail)
- return Result;
- }
-
- return MCDisassembler::Fail;
-}
-
-uint64_t AArch64Disassembler::suggestBytesToSkip(ArrayRef<uint8_t> Bytes,
- uint64_t Address) const {
- // AArch64 instructions are always 4 bytes wide, so there's no point
- // in skipping any smaller number of bytes if an instruction can't
- // be decoded.
- return 4;
-}
-
-static MCSymbolizer *
-createAArch64ExternalSymbolizer(const Triple &TT, LLVMOpInfoCallback GetOpInfo,
- LLVMSymbolLookupCallback SymbolLookUp,
- void *DisInfo, MCContext *Ctx,
- std::unique_ptr<MCRelocationInfo> &&RelInfo) {
- return new AArch64ExternalSymbolizer(*Ctx, std::move(RelInfo), GetOpInfo,
- SymbolLookUp, DisInfo);
-}
-
-extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
-LLVMInitializeAArch64Disassembler() {
- TargetRegistry::RegisterMCDisassembler(getTheAArch64leTarget(),
- createAArch64Disassembler);
- TargetRegistry::RegisterMCDisassembler(getTheAArch64beTarget(),
- createAArch64Disassembler);
- TargetRegistry::RegisterMCSymbolizer(getTheAArch64leTarget(),
- createAArch64ExternalSymbolizer);
- TargetRegistry::RegisterMCSymbolizer(getTheAArch64beTarget(),
- createAArch64ExternalSymbolizer);
- TargetRegistry::RegisterMCDisassembler(getTheAArch64_32Target(),
- createAArch64Disassembler);
- TargetRegistry::RegisterMCSymbolizer(getTheAArch64_32Target(),
- createAArch64ExternalSymbolizer);
-
- TargetRegistry::RegisterMCDisassembler(getTheARM64Target(),
- createAArch64Disassembler);
- TargetRegistry::RegisterMCSymbolizer(getTheARM64Target(),
- createAArch64ExternalSymbolizer);
- TargetRegistry::RegisterMCDisassembler(getTheARM64_32Target(),
- createAArch64Disassembler);
- TargetRegistry::RegisterMCSymbolizer(getTheARM64_32Target(),
- createAArch64ExternalSymbolizer);
-}
-
template <unsigned RegClassID, unsigned FirstReg, unsigned NumRegsInClass>
static DecodeStatus DecodeSimpleRegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Address,
@@ -1854,3 +1562,120 @@ static DecodeStatus DecodePRFMRegInstruction(MCInst &Inst, uint32_t insn,
return Success;
}
+
+#include "AArch64GenDisassemblerTables.inc"
+#include "AArch64GenInstrInfo.inc"
+
+static MCDisassembler *createAArch64Disassembler(const Target &T,
+ const MCSubtargetInfo &STI,
+ MCContext &Ctx) {
+
+ return new AArch64Disassembler(STI, Ctx, T.createMCInstrInfo());
+}
+
+DecodeStatus AArch64Disassembler::getInstruction(MCInst &MI, uint64_t &Size,
+ ArrayRef<uint8_t> Bytes,
+ uint64_t Address,
+ raw_ostream &CS) const {
+ CommentStream = &CS;
+
+ Size = 0;
+ // We want to read exactly 4 bytes of data.
+ if (Bytes.size() < 4)
+ return Fail;
+ Size = 4;
+
+ // Encoded as a small-endian 32-bit word in the stream.
+ uint32_t Insn =
+ (Bytes[3] << 24) | (Bytes[2] << 16) | (Bytes[1] << 8) | (Bytes[0] << 0);
+
+ const uint8_t *Tables[] = {DecoderTable32, DecoderTableFallback32};
+
+ for (const auto *Table : Tables) {
+ DecodeStatus Result =
+ decodeInstruction(Table, MI, Insn, Address, this, STI);
+
+ const MCInstrDesc &Desc = MCII->get(MI.getOpcode());
+
+ // For Scalable Matrix Extension (SME) instructions that have an implicit
+ // operand for the accumulator (ZA) or implicit immediate zero which isn't
+ // encoded, manually insert operand.
+ for (unsigned i = 0; i < Desc.getNumOperands(); i++) {
+ if (Desc.operands()[i].OperandType == MCOI::OPERAND_REGISTER) {
+ switch (Desc.operands()[i].RegClass) {
+ default:
+ break;
+ case AArch64::MPRRegClassID:
+ MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZA));
+ break;
+ case AArch64::MPR8RegClassID:
+ MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZAB0));
+ break;
+ case AArch64::ZTRRegClassID:
+ MI.insert(MI.begin() + i, MCOperand::createReg(AArch64::ZT0));
+ break;
+ }
+ } else if (Desc.operands()[i].OperandType ==
+ AArch64::OPERAND_IMPLICIT_IMM_0) {
+ MI.insert(MI.begin() + i, MCOperand::createImm(0));
+ }
+ }
+
+ if (MI.getOpcode() == AArch64::LDR_ZA ||
+ MI.getOpcode() == AArch64::STR_ZA) {
+ // Spill and fill instructions have a single immediate used for both
+ // the vector select offset and optional memory offset. Replicate
+ // the decoded immediate.
+ const MCOperand &Imm4Op = MI.getOperand(2);
+ assert(Imm4Op.isImm() && "Unexpected operand type!");
+ MI.addOperand(Imm4Op);
+ }
+
+ if (Result != MCDisassembler::Fail)
+ return Result;
+ }
+
+ return MCDisassembler::Fail;
+}
+
+uint64_t AArch64Disassembler::suggestBytesToSkip(ArrayRef<uint8_t> Bytes,
+ uint64_t Address) const {
+ // AArch64 instructions are always 4 bytes wide, so there's no point
+ // in skipping any smaller number of bytes if an instruction can't
+ // be decoded.
+ return 4;
+}
+
+static MCSymbolizer *
+createAArch64ExternalSymbolizer(const Triple &TT, LLVMOpInfoCallback GetOpInfo,
+ LLVMSymbolLookupCallback SymbolLookUp,
+ void *DisInfo, MCContext *Ctx,
+ std::unique_ptr<MCRelocationInfo> &&RelInfo) {
+ return new AArch64ExternalSymbolizer(*Ctx, std::move(RelInfo), GetOpInfo,
+ SymbolLookUp, DisInfo);
+}
+
+extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void
+LLVMInitializeAArch64Disassembler() {
+ TargetRegistry::RegisterMCDisassembler(getTheAArch64leTarget(),
+ createAArch64Disassembler);
+ TargetRegistry::RegisterMCDisassembler(getTheAArch64beTarget(),
+ createAArch64Disassembler);
+ TargetRegistry::RegisterMCSymbolizer(getTheAArch64leTarget(),
+ createAArch64ExternalSymbolizer);
+ TargetRegistry::RegisterMCSymbolizer(getTheAArch64beTarget(),
+ createAArch64ExternalSymbolizer);
+ TargetRegistry::RegisterMCDisassembler(getTheAArch64_32Target(),
+ createAArch64Disassembler);
+ TargetRegistry::RegisterMCSymbolizer(getTheAArch64_32Target(),
+ createAArch64ExternalSymbolizer);
+
+ TargetRegistry::RegisterMCDisassembler(getTheARM64Target(),
+ createAArch64Disassembler);
+ TargetRegistry::RegisterMCSymbolizer(getTheARM64Target(),
+ createAArch64ExternalSymbolizer);
+ TargetRegistry::RegisterMCDisassembler(getTheARM64_32Target(),
+ createAArch64Disassembler);
+ TargetRegistry::RegisterMCSymbolizer(getTheARM64_32Target(),
+ createAArch64ExternalSymbolizer);
+}
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 010d0aaa..79bef76 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -125,12 +125,12 @@ struct AArch64OutgoingValueAssigner
bool UseVarArgsCCForFixed = IsCalleeWin && State.isVarArg();
bool Res;
- if (Info.IsFixed && !UseVarArgsCCForFixed) {
+ if (!Flags.isVarArg() && !UseVarArgsCCForFixed) {
if (!IsReturn)
applyStackPassedSmallTypeDAGHack(OrigVT, ValVT, LocVT);
- Res = AssignFn(ValNo, ValVT, LocVT, LocInfo, Flags, State);
+ Res = AssignFn(ValNo, ValVT, LocVT, LocInfo, Flags, Info.Ty, State);
} else
- Res = AssignFnVarArg(ValNo, ValVT, LocVT, LocInfo, Flags, State);
+ Res = AssignFnVarArg(ValNo, ValVT, LocVT, LocInfo, Flags, Info.Ty, State);
StackSize = State.getStackSize();
return Res;
@@ -361,7 +361,7 @@ struct OutgoingArgHandler : public CallLowering::OutgoingValueHandler {
unsigned MaxSize = MemTy.getSizeInBytes() * 8;
// For varargs, we always want to extend them to 8 bytes, in which case
// we disable setting a max.
- if (!Arg.IsFixed)
+ if (Arg.Flags[0].isVarArg())
MaxSize = 0;
Register ValVReg = Arg.Regs[RegIndex];
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index d905692..0bceb32 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -1102,82 +1102,6 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
return true;
}
-static unsigned selectFPConvOpc(unsigned GenericOpc, LLT DstTy, LLT SrcTy) {
- if (!DstTy.isScalar() || !SrcTy.isScalar())
- return GenericOpc;
-
- const unsigned DstSize = DstTy.getSizeInBits();
- const unsigned SrcSize = SrcTy.getSizeInBits();
-
- switch (DstSize) {
- case 32:
- switch (SrcSize) {
- case 32:
- switch (GenericOpc) {
- case TargetOpcode::G_SITOFP:
- return AArch64::SCVTFUWSri;
- case TargetOpcode::G_UITOFP:
- return AArch64::UCVTFUWSri;
- case TargetOpcode::G_FPTOSI:
- return AArch64::FCVTZSUWSr;
- case TargetOpcode::G_FPTOUI:
- return AArch64::FCVTZUUWSr;
- default:
- return GenericOpc;
- }
- case 64:
- switch (GenericOpc) {
- case TargetOpcode::G_SITOFP:
- return AArch64::SCVTFUXSri;
- case TargetOpcode::G_UITOFP:
- return AArch64::UCVTFUXSri;
- case TargetOpcode::G_FPTOSI:
- return AArch64::FCVTZSUWDr;
- case TargetOpcode::G_FPTOUI:
- return AArch64::FCVTZUUWDr;
- default:
- return GenericOpc;
- }
- default:
- return GenericOpc;
- }
- case 64:
- switch (SrcSize) {
- case 32:
- switch (GenericOpc) {
- case TargetOpcode::G_SITOFP:
- return AArch64::SCVTFUWDri;
- case TargetOpcode::G_UITOFP:
- return AArch64::UCVTFUWDri;
- case TargetOpcode::G_FPTOSI:
- return AArch64::FCVTZSUXSr;
- case TargetOpcode::G_FPTOUI:
- return AArch64::FCVTZUUXSr;
- default:
- return GenericOpc;
- }
- case 64:
- switch (GenericOpc) {
- case TargetOpcode::G_SITOFP:
- return AArch64::SCVTFUXDri;
- case TargetOpcode::G_UITOFP:
- return AArch64::UCVTFUXDri;
- case TargetOpcode::G_FPTOSI:
- return AArch64::FCVTZSUXDr;
- case TargetOpcode::G_FPTOUI:
- return AArch64::FCVTZUUXDr;
- default:
- return GenericOpc;
- }
- default:
- return GenericOpc;
- }
- default:
- return GenericOpc;
- };
- return GenericOpc;
-}
-
MachineInstr *
AArch64InstructionSelector::emitSelect(Register Dst, Register True,
Register False, AArch64CC::CondCode CC,
@@ -1349,7 +1273,9 @@ AArch64InstructionSelector::emitSelect(Register Dst, Register True,
return &*SelectInst;
}
-static AArch64CC::CondCode changeICMPPredToAArch64CC(CmpInst::Predicate P) {
+static AArch64CC::CondCode
+changeICMPPredToAArch64CC(CmpInst::Predicate P, Register RHS = {},
+ MachineRegisterInfo *MRI = nullptr) {
switch (P) {
default:
llvm_unreachable("Unknown condition code!");
@@ -1360,8 +1286,18 @@ static AArch64CC::CondCode changeICMPPredToAArch64CC(CmpInst::Predicate P) {
case CmpInst::ICMP_SGT:
return AArch64CC::GT;
case CmpInst::ICMP_SGE:
+ if (RHS && MRI) {
+ auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, *MRI);
+ if (ValAndVReg && ValAndVReg->Value == 0)
+ return AArch64CC::PL;
+ }
return AArch64CC::GE;
case CmpInst::ICMP_SLT:
+ if (RHS && MRI) {
+ auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, *MRI);
+ if (ValAndVReg && ValAndVReg->Value == 0)
+ return AArch64CC::MI;
+ }
return AArch64CC::LT;
case CmpInst::ICMP_SLE:
return AArch64CC::LE;
@@ -1697,7 +1633,7 @@ bool AArch64InstructionSelector::selectCompareBranchFedByFCmp(
emitFPCompare(FCmp.getOperand(2).getReg(), FCmp.getOperand(3).getReg(), MIB,
Pred);
AArch64CC::CondCode CC1, CC2;
- changeFCMPPredToAArch64CC(static_cast<CmpInst::Predicate>(Pred), CC1, CC2);
+ changeFCMPPredToAArch64CC(Pred, CC1, CC2);
MachineBasicBlock *DestMBB = I.getOperand(1).getMBB();
MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC1).addMBB(DestMBB);
if (CC2 != AArch64CC::AL)
@@ -1813,7 +1749,8 @@ bool AArch64InstructionSelector::selectCompareBranchFedByICmp(
auto &PredOp = ICmp.getOperand(1);
emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB);
const AArch64CC::CondCode CC = changeICMPPredToAArch64CC(
- static_cast<CmpInst::Predicate>(PredOp.getPredicate()));
+ static_cast<CmpInst::Predicate>(PredOp.getPredicate()),
+ ICmp.getOperand(3).getReg(), MIB.getMRI());
MIB.buildInstr(AArch64::Bcc, {}, {}).addImm(CC).addMBB(DestMBB);
I.eraseFromParent();
return true;
@@ -2510,8 +2447,8 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) {
emitIntegerCompare(/*LHS=*/Cmp->getOperand(2),
/*RHS=*/Cmp->getOperand(3), PredOp, MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
- const AArch64CC::CondCode InvCC =
- changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
+ const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC(
+ CmpInst::getInversePredicate(Pred), Cmp->getOperand(3).getReg(), &MRI);
emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB);
I.eraseFromParent();
return true;
@@ -3511,23 +3448,6 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
return true;
}
- case TargetOpcode::G_SITOFP:
- case TargetOpcode::G_UITOFP:
- case TargetOpcode::G_FPTOSI:
- case TargetOpcode::G_FPTOUI: {
- const LLT DstTy = MRI.getType(I.getOperand(0).getReg()),
- SrcTy = MRI.getType(I.getOperand(1).getReg());
- const unsigned NewOpc = selectFPConvOpc(Opcode, DstTy, SrcTy);
- if (NewOpc == Opcode)
- return false;
-
- I.setDesc(TII.get(NewOpc));
- constrainSelectedInstRegOperands(I, TII, TRI, RBI);
- I.setFlags(MachineInstr::NoFPExcept);
-
- return true;
- }
-
case TargetOpcode::G_FREEZE:
return selectCopy(I, TII, MRI, TRI, RBI);
@@ -3577,8 +3497,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
auto &PredOp = I.getOperand(1);
emitIntegerCompare(I.getOperand(2), I.getOperand(3), PredOp, MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
- const AArch64CC::CondCode InvCC =
- changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
+ const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC(
+ CmpInst::getInversePredicate(Pred), I.getOperand(3).getReg(), &MRI);
emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR,
/*Src2=*/AArch64::WZR, InvCC, MIB);
I.eraseFromParent();
@@ -4931,7 +4851,7 @@ MachineInstr *AArch64InstructionSelector::emitConjunctionRec(
if (Negate)
CC = CmpInst::getInversePredicate(CC);
if (isa<GICmp>(Cmp)) {
- OutCC = changeICMPPredToAArch64CC(CC);
+ OutCC = changeICMPPredToAArch64CC(CC, RHS, MIB.getMRI());
} else {
// Handle special FP cases.
AArch64CC::CondCode ExtraCC;
@@ -5101,7 +5021,8 @@ bool AArch64InstructionSelector::tryOptSelect(GSelect &I) {
emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), PredOp,
MIB);
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
- CondCode = changeICMPPredToAArch64CC(Pred);
+ CondCode =
+ changeICMPPredToAArch64CC(Pred, CondDef->getOperand(3).getReg(), &MRI);
} else {
// Get the condition code for the select.
auto Pred =
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index e0e1af7..82391f13 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -797,6 +797,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampMinNumElements(0, s16, 4)
.alwaysLegal();
+ getActionDefinitionsBuilder({G_TRUNC_SSAT_S, G_TRUNC_SSAT_U, G_TRUNC_USAT_U})
+ .legalFor({{v8s8, v8s16}, {v4s16, v4s32}, {v2s32, v2s64}});
+
getActionDefinitionsBuilder(G_SEXT_INREG)
.legalFor({s32, s64})
.legalFor(PackedVectorAllTypeList)
@@ -876,8 +879,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{v2s32, v2s32},
{v4s32, v4s32},
{v2s64, v2s64}})
- .legalFor(HasFP16,
- {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
+ .legalFor(
+ HasFP16,
+ {{s16, s16}, {s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
// Handle types larger than i64 by scalarizing/lowering.
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
@@ -965,6 +969,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{s128, s64}});
// Control-flow
+ getActionDefinitionsBuilder(G_BR).alwaysLegal();
getActionDefinitionsBuilder(G_BRCOND)
.legalFor({s32})
.clampScalar(0, s32, s32);
@@ -1146,7 +1151,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampMaxNumElements(1, s32, 4)
.clampMaxNumElements(1, s16, 8)
.clampMaxNumElements(1, s8, 16)
- .clampMaxNumElements(1, p0, 2);
+ .clampMaxNumElements(1, p0, 2)
+ .scalarizeIf(scalarOrEltWiderThan(1, 64), 1);
getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.legalIf(
@@ -1161,7 +1167,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampNumElements(0, v4s16, v8s16)
.clampNumElements(0, v2s32, v4s32)
.clampMaxNumElements(0, s64, 2)
- .clampMaxNumElements(0, p0, 2);
+ .clampMaxNumElements(0, p0, 2)
+ .scalarizeIf(scalarOrEltWiderThan(0, 64), 0);
getActionDefinitionsBuilder(G_BUILD_VECTOR)
.legalFor({{v8s8, s8},
@@ -1255,6 +1262,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, s64}});
+ getActionDefinitionsBuilder({G_TRAP, G_DEBUGTRAP, G_UBSANTRAP}).alwaysLegal();
+
getActionDefinitionsBuilder(G_DYN_STACKALLOC).custom();
getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).lower();
@@ -1644,6 +1653,11 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
MachineIRBuilder &MIB = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIB.getMRI();
+ auto LowerUnaryOp = [&MI, &MIB](unsigned Opcode) {
+ MIB.buildInstr(Opcode, {MI.getOperand(0)}, {MI.getOperand(2)});
+ MI.eraseFromParent();
+ return true;
+ };
auto LowerBinOp = [&MI, &MIB](unsigned Opcode) {
MIB.buildInstr(Opcode, {MI.getOperand(0)},
{MI.getOperand(2), MI.getOperand(3)});
@@ -1838,6 +1852,12 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
return LowerTriOp(AArch64::G_UDOT);
case Intrinsic::aarch64_neon_sdot:
return LowerTriOp(AArch64::G_SDOT);
+ case Intrinsic::aarch64_neon_sqxtn:
+ return LowerUnaryOp(TargetOpcode::G_TRUNC_SSAT_S);
+ case Intrinsic::aarch64_neon_sqxtun:
+ return LowerUnaryOp(TargetOpcode::G_TRUNC_SSAT_U);
+ case Intrinsic::aarch64_neon_uqxtn:
+ return LowerUnaryOp(TargetOpcode::G_TRUNC_USAT_U);
case Intrinsic::vector_reverse:
// TODO: Add support for vector_reverse
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 3ba08c8..6025f1c 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -614,8 +614,7 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
// x uge c => x ugt c - 1
//
// When c is not zero.
- if (C == 0)
- return std::nullopt;
+ assert(C != 0 && "C should not be zero here!");
P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
C -= 1;
break;
@@ -656,14 +655,13 @@ tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
if (isLegalArithImmed(C))
return {{C, P}};
- auto IsMaterializableInSingleInstruction = [=](uint64_t Imm) {
+ auto NumberOfInstrToLoadImm = [=](uint64_t Imm) {
SmallVector<AArch64_IMM::ImmInsnModel> Insn;
AArch64_IMM::expandMOVImm(Imm, 32, Insn);
- return Insn.size() == 1;
+ return Insn.size();
};
- if (!IsMaterializableInSingleInstruction(OriginalC) &&
- IsMaterializableInSingleInstruction(C))
+ if (NumberOfInstrToLoadImm(OriginalC) > NumberOfInstrToLoadImm(C))
return {{C, P}};
return std::nullopt;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 31954e7..cf391c4 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -13,6 +13,7 @@
#include "AArch64RegisterBankInfo.h"
#include "AArch64RegisterInfo.h"
+#include "AArch64Subtarget.h"
#include "MCTargetDesc/AArch64MCTargetDesc.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -492,7 +493,7 @@ static bool isFPIntrinsic(const MachineRegisterInfo &MRI,
bool AArch64RegisterBankInfo::isPHIWithFPConstraints(
const MachineInstr &MI, const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI, const unsigned Depth) const {
+ const AArch64RegisterInfo &TRI, const unsigned Depth) const {
if (!MI.isPHI() || Depth > MaxFPRSearchDepth)
return false;
@@ -506,7 +507,7 @@ bool AArch64RegisterBankInfo::isPHIWithFPConstraints(
bool AArch64RegisterBankInfo::hasFPConstraints(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI,
+ const AArch64RegisterInfo &TRI,
unsigned Depth) const {
unsigned Op = MI.getOpcode();
if (Op == TargetOpcode::G_INTRINSIC && isFPIntrinsic(MRI, MI))
@@ -544,7 +545,7 @@ bool AArch64RegisterBankInfo::hasFPConstraints(const MachineInstr &MI,
bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI,
+ const AArch64RegisterInfo &TRI,
unsigned Depth) const {
switch (MI.getOpcode()) {
case TargetOpcode::G_FPTOSI:
@@ -582,7 +583,7 @@ bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
bool AArch64RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI,
+ const AArch64RegisterInfo &TRI,
unsigned Depth) const {
switch (MI.getOpcode()) {
case AArch64::G_DUP:
@@ -618,6 +619,19 @@ bool AArch64RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
return hasFPConstraints(MI, MRI, TRI, Depth);
}
+bool AArch64RegisterBankInfo::prefersFPUse(const MachineInstr &MI,
+ const MachineRegisterInfo &MRI,
+ const AArch64RegisterInfo &TRI,
+ unsigned Depth) const {
+ switch (MI.getOpcode()) {
+ case TargetOpcode::G_SITOFP:
+ case TargetOpcode::G_UITOFP:
+ return MRI.getType(MI.getOperand(0).getReg()).getSizeInBits() ==
+ MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
+ }
+ return onlyDefinesFP(MI, MRI, TRI, Depth);
+}
+
bool AArch64RegisterBankInfo::isLoadFromFPType(const MachineInstr &MI) const {
// GMemOperation because we also want to match indexed loads.
auto *MemOp = cast<GMemOperation>(&MI);
@@ -671,8 +685,8 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
const MachineFunction &MF = *MI.getParent()->getParent();
const MachineRegisterInfo &MRI = MF.getRegInfo();
- const TargetSubtargetInfo &STI = MF.getSubtarget();
- const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
+ const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+ const AArch64RegisterInfo &TRI = *STI.getRegisterInfo();
switch (Opc) {
// G_{F|S|U}REM are not listed because they are not legal.
@@ -826,16 +840,28 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// Integer to FP conversions don't necessarily happen between GPR -> FPR
// regbanks. They can also be done within an FPR register.
Register SrcReg = MI.getOperand(1).getReg();
- if (getRegBank(SrcReg, MRI, TRI) == &AArch64::FPRRegBank)
+ if (getRegBank(SrcReg, MRI, TRI) == &AArch64::FPRRegBank &&
+ MRI.getType(SrcReg).getSizeInBits() ==
+ MRI.getType(MI.getOperand(0).getReg()).getSizeInBits())
OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
else
OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR};
break;
}
+ case TargetOpcode::G_FPTOSI_SAT:
+ case TargetOpcode::G_FPTOUI_SAT: {
+ LLT DstType = MRI.getType(MI.getOperand(0).getReg());
+ if (DstType.isVector())
+ break;
+ if (DstType == LLT::scalar(16)) {
+ OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
+ break;
+ }
+ OpRegBankIdx = {PMI_FirstGPR, PMI_FirstFPR};
+ break;
+ }
case TargetOpcode::G_FPTOSI:
case TargetOpcode::G_FPTOUI:
- case TargetOpcode::G_FPTOSI_SAT:
- case TargetOpcode::G_FPTOUI_SAT:
case TargetOpcode::G_INTRINSIC_LRINT:
case TargetOpcode::G_INTRINSIC_LLRINT:
if (MRI.getType(MI.getOperand(0).getReg()).isVector())
@@ -895,13 +921,13 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// instruction.
//
// Int->FP conversion operations are also captured in
- // onlyDefinesFP().
+ // prefersFPUse().
if (isPHIWithFPConstraints(UseMI, MRI, TRI))
return true;
return onlyUsesFP(UseMI, MRI, TRI) ||
- onlyDefinesFP(UseMI, MRI, TRI);
+ prefersFPUse(UseMI, MRI, TRI);
}))
OpRegBankIdx[0] = PMI_FirstFPR;
break;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
index 3abbc1b..b2de031 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
@@ -22,6 +22,7 @@
namespace llvm {
class TargetRegisterInfo;
+class AArch64RegisterInfo;
class AArch64GenRegisterBankInfo : public RegisterBankInfo {
protected:
@@ -123,21 +124,26 @@ class AArch64RegisterBankInfo final : public AArch64GenRegisterBankInfo {
/// \returns true if \p MI is a PHI that its def is used by
/// any instruction that onlyUsesFP.
bool isPHIWithFPConstraints(const MachineInstr &MI,
- const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI,
- unsigned Depth = 0) const;
+ const MachineRegisterInfo &MRI,
+ const AArch64RegisterInfo &TRI,
+ unsigned Depth = 0) const;
/// \returns true if \p MI only uses and defines FPRs.
bool hasFPConstraints(const MachineInstr &MI, const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI, unsigned Depth = 0) const;
+ const AArch64RegisterInfo &TRI,
+ unsigned Depth = 0) const;
/// \returns true if \p MI only uses FPRs.
bool onlyUsesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI, unsigned Depth = 0) const;
+ const AArch64RegisterInfo &TRI, unsigned Depth = 0) const;
/// \returns true if \p MI only defines FPRs.
bool onlyDefinesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI,
- const TargetRegisterInfo &TRI, unsigned Depth = 0) const;
+ const AArch64RegisterInfo &TRI, unsigned Depth = 0) const;
+
+ /// \returns true if \p MI can take both fpr and gpr uses, but prefers fp.
+ bool prefersFPUse(const MachineInstr &MI, const MachineRegisterInfo &MRI,
+ const AArch64RegisterInfo &TRI, unsigned Depth = 0) const;
/// \returns true if the load \p MI is likely loading from a floating-point
/// type.
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp
index 7618a57..a388216 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp
@@ -40,6 +40,7 @@ protected:
bool IsPCRel) const override;
bool needsRelocateWithSymbol(const MCValue &, unsigned Type) const override;
bool isNonILP32reloc(const MCFixup &Fixup, AArch64::Specifier RefKind) const;
+ void sortRelocs(std::vector<ELFRelocationEntry> &Relocs) override;
bool IsILP32;
};
@@ -96,8 +97,8 @@ unsigned AArch64ELFObjectWriter::getRelocType(const MCFixup &Fixup,
case AArch64::S_TPREL:
case AArch64::S_TLSDESC:
case AArch64::S_TLSDESC_AUTH:
- if (auto *SA = Target.getAddSym())
- cast<MCSymbolELF>(SA)->setType(ELF::STT_TLS);
+ if (auto *SA = const_cast<MCSymbol *>(Target.getAddSym()))
+ static_cast<MCSymbolELF *>(SA)->setType(ELF::STT_TLS);
break;
default:
break;
@@ -488,7 +489,8 @@ bool AArch64ELFObjectWriter::needsRelocateWithSymbol(const MCValue &Val,
// this global needs to be tagged. In addition, the linker needs to know
// whether to emit a special addend when relocating `end` symbols, and this
// can only be determined by the attributes of the symbol itself.
- if (Val.getAddSym() && cast<MCSymbolELF>(Val.getAddSym())->isMemtag())
+ if (Val.getAddSym() &&
+ static_cast<const MCSymbolELF *>(Val.getAddSym())->isMemtag())
return true;
if ((Val.getSpecifier() & AArch64::S_GOT) == AArch64::S_GOT)
@@ -497,6 +499,17 @@ bool AArch64ELFObjectWriter::needsRelocateWithSymbol(const MCValue &Val,
Val.getSpecifier());
}
+void AArch64ELFObjectWriter::sortRelocs(
+ std::vector<ELFRelocationEntry> &Relocs) {
+ // PATCHINST relocations should be applied last because they may overwrite the
+ // whole instruction and so should take precedence over other relocations that
+ // modify operands of the original instruction.
+ std::stable_partition(Relocs.begin(), Relocs.end(),
+ [](const ELFRelocationEntry &R) {
+ return R.Type != ELF::R_AARCH64_PATCHINST;
+ });
+}
+
std::unique_ptr<MCObjectTargetWriter>
llvm::createAArch64ELFObjectWriter(uint8_t OSABI, bool IsILP32) {
return std::make_unique<AArch64ELFObjectWriter>(OSABI, IsILP32);
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp
index 6257e99..917dbdf 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp
@@ -35,7 +35,6 @@
#include "llvm/MC/MCTargetOptions.h"
#include "llvm/MC/MCWinCOFFStreamer.h"
#include "llvm/Support/AArch64BuildAttributes.h"
-#include "llvm/Support/Casting.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/raw_ostream.h"
@@ -418,7 +417,8 @@ private:
}
MCSymbol *emitMappingSymbol(StringRef Name) {
- auto *Symbol = cast<MCSymbolELF>(getContext().createLocalSymbol(Name));
+ auto *Symbol =
+ static_cast<MCSymbolELF *>(getContext().createLocalSymbol(Name));
emitLabel(Symbol);
return Symbol;
}
@@ -455,7 +455,7 @@ void AArch64TargetELFStreamer::emitInst(uint32_t Inst) {
void AArch64TargetELFStreamer::emitDirectiveVariantPCS(MCSymbol *Symbol) {
getStreamer().getAssembler().registerSymbol(*Symbol);
- cast<MCSymbolELF>(Symbol)->setOther(ELF::STO_AARCH64_VARIANT_PCS);
+ static_cast<MCSymbolELF *>(Symbol)->setOther(ELF::STO_AARCH64_VARIANT_PCS);
}
void AArch64TargetELFStreamer::finish() {
@@ -541,7 +541,7 @@ void AArch64TargetELFStreamer::finish() {
MCSectionELF *MemtagSec = nullptr;
for (const MCSymbol &Symbol : Asm.symbols()) {
- const auto &Sym = cast<MCSymbolELF>(Symbol);
+ auto &Sym = static_cast<const MCSymbolELF &>(Symbol);
if (Sym.isMemtag()) {
MemtagSec = Ctx.getELFSection(".memtag.globals.static",
ELF::SHT_AARCH64_MEMTAG_GLOBALS_STATIC, 0);
@@ -556,7 +556,7 @@ void AArch64TargetELFStreamer::finish() {
S.switchSection(MemtagSec);
const auto *Zero = MCConstantExpr::create(0, Ctx);
for (const MCSymbol &Symbol : Asm.symbols()) {
- const auto &Sym = cast<MCSymbolELF>(Symbol);
+ auto &Sym = static_cast<const MCSymbolELF &>(Symbol);
if (!Sym.isMemtag())
continue;
auto *SRE = MCSymbolRefExpr::create(&Sym, Ctx);
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
index 3c8b571..54b58e9 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
@@ -1017,14 +1017,22 @@ bool AArch64InstPrinter::printSysAlias(const MCInst *MI,
else
return false;
+ StringRef Reg = getRegisterName(MI->getOperand(4).getReg());
+ bool NotXZR = Reg != "xzr";
+
+ // If a mandatory is not specified in the TableGen
+ // (i.e. no register operand should be present), and the register value
+ // is not xzr/x31, then disassemble to a SYS alias instead.
+ if (NotXZR && !NeedsReg)
+ return false;
+
std::string Str = Ins + Name;
llvm::transform(Str, Str.begin(), ::tolower);
O << '\t' << Str;
- if (NeedsReg) {
- O << ", ";
- printRegName(O, MI->getOperand(4).getReg());
- }
+
+ if (NeedsReg)
+ O << ", " << Reg;
return true;
}
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp
index 828c5c5..2b5cf34 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp
@@ -53,9 +53,9 @@ const MCAsmInfo::AtSpecifier MachOAtSpecifiers[] = {
{AArch64::S_MACHO_TLVPPAGEOFF, "TLVPPAGEOFF"},
};
-StringRef AArch64::getSpecifierName(const MCSpecifierExpr &Expr) {
+StringRef AArch64::getSpecifierName(AArch64::Specifier S) {
// clang-format off
- switch (static_cast<uint32_t>(Expr.getSpecifier())) {
+ switch (static_cast<uint32_t>(S)) {
case AArch64::S_CALL: return "";
case AArch64::S_LO12: return ":lo12:";
case AArch64::S_ABS_G3: return ":abs_g3:";
@@ -124,7 +124,7 @@ static bool evaluate(const MCSpecifierExpr &Expr, MCValue &Res,
if (!Expr.getSubExpr()->evaluateAsRelocatable(Res, Asm))
return false;
Res.setSpecifier(Expr.getSpecifier());
- return true;
+ return !Res.getSubSym();
}
AArch64MCAsmInfoDarwin::AArch64MCAsmInfoDarwin(bool IsILP32) {
@@ -183,7 +183,7 @@ void AArch64MCAsmInfoDarwin::printSpecifierExpr(
raw_ostream &OS, const MCSpecifierExpr &Expr) const {
if (auto *AE = dyn_cast<AArch64AuthMCExpr>(&Expr))
return AE->print(OS, this);
- OS << AArch64::getSpecifierName(Expr);
+ OS << AArch64::getSpecifierName(Expr.getSpecifier());
printExpr(OS, *Expr.getSubExpr());
}
@@ -232,7 +232,7 @@ void AArch64MCAsmInfoELF::printSpecifierExpr(
raw_ostream &OS, const MCSpecifierExpr &Expr) const {
if (auto *AE = dyn_cast<AArch64AuthMCExpr>(&Expr))
return AE->print(OS, this);
- OS << AArch64::getSpecifierName(Expr);
+ OS << AArch64::getSpecifierName(Expr.getSpecifier());
printExpr(OS, *Expr.getSubExpr());
}
@@ -262,7 +262,7 @@ AArch64MCAsmInfoMicrosoftCOFF::AArch64MCAsmInfoMicrosoftCOFF() {
void AArch64MCAsmInfoMicrosoftCOFF::printSpecifierExpr(
raw_ostream &OS, const MCSpecifierExpr &Expr) const {
- OS << AArch64::getSpecifierName(Expr);
+ OS << AArch64::getSpecifierName(Expr.getSpecifier());
printExpr(OS, *Expr.getSubExpr());
}
@@ -292,7 +292,7 @@ AArch64MCAsmInfoGNUCOFF::AArch64MCAsmInfoGNUCOFF() {
void AArch64MCAsmInfoGNUCOFF::printSpecifierExpr(
raw_ostream &OS, const MCSpecifierExpr &Expr) const {
- OS << AArch64::getSpecifierName(Expr);
+ OS << AArch64::getSpecifierName(Expr.getSpecifier());
printExpr(OS, *Expr.getSubExpr());
}
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h
index c28e925..0dfa61b 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h
@@ -181,7 +181,7 @@ enum {
/// Return the string representation of the ELF relocation specifier
/// (e.g. ":got:", ":lo12:").
-StringRef getSpecifierName(const MCSpecifierExpr &Expr);
+StringRef getSpecifierName(Specifier S);
inline Specifier getSymbolLoc(Specifier S) {
return static_cast<Specifier>(S & AArch64::S_SymLocBits);
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp
index a53b676..5fe9993 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64WinCOFFObjectWriter.cpp
@@ -73,9 +73,10 @@ unsigned AArch64WinCOFFObjectWriter::getRelocType(
// Supported
break;
default:
- Ctx.reportError(Fixup.getLoc(), "relocation specifier " +
- AArch64::getSpecifierName(*A64E) +
- " unsupported on COFF targets");
+ Ctx.reportError(Fixup.getLoc(),
+ "relocation specifier " +
+ AArch64::getSpecifierName(A64E->getSpecifier()) +
+ " unsupported on COFF targets");
return COFF::IMAGE_REL_ARM64_ABSOLUTE; // Dummy return value
}
}
@@ -83,9 +84,10 @@ unsigned AArch64WinCOFFObjectWriter::getRelocType(
switch (FixupKind) {
default: {
if (auto *A64E = dyn_cast<MCSpecifierExpr>(Expr)) {
- Ctx.reportError(Fixup.getLoc(), "relocation specifier " +
- AArch64::getSpecifierName(*A64E) +
- " unsupported on COFF targets");
+ Ctx.reportError(Fixup.getLoc(),
+ "relocation specifier " +
+ AArch64::getSpecifierName(A64E->getSpecifier()) +
+ " unsupported on COFF targets");
} else {
MCFixupKindInfo Info = MAB.getFixupKindInfo(Fixup.getKind());
Ctx.reportError(Fixup.getLoc(), Twine("relocation type ") + Info.Name +
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
new file mode 100644
index 0000000..5dfaa891
--- /dev/null
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -0,0 +1,696 @@
+//===- MachineSMEABIPass.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the SME ABI requirements for ZA state. This includes
+// implementing the lazy ZA state save schemes around calls.
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass works by collecting instructions that require ZA to be in a
+// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
+// transitions to ensure ZA is in the required state before instructions. State
+// transitions represent actions such as setting up or restoring a lazy save.
+// Certain points within a function may also have predefined states independent
+// of any instructions, for example, a "shared_za" function is always entered
+// and exited in the "ACTIVE" state.
+//
+// To handle ZA state across control flow, we make use of edge bundling. This
+// assigns each block an "incoming" and "outgoing" edge bundle (representing
+// incoming and outgoing edges). Initially, these are unique to each block;
+// then, in the process of forming bundles, the outgoing block of a block is
+// joined with the incoming bundle of all successors. The result is that each
+// bundle can be assigned a single ZA state, which ensures the state required by
+// all a blocks' successors is the same, and that each basic block will always
+// be entered with the same ZA state. This eliminates the need for splitting
+// edges to insert state transitions or "phi" nodes for ZA states.
+//
+// See below for a simple example of edge bundling.
+//
+// The following shows a conditionally executed basic block (BB1):
+//
+// if (cond)
+// BB1
+// BB2
+//
+// Initial Bundles Joined Bundles
+//
+// ┌──0──┐ ┌──0──┐
+// │ BB0 │ │ BB0 │
+// └──1──┘ └──1──┘
+// ├───────┐ ├───────┐
+// ▼ │ ▼ │
+// ┌──2──┐ │ ─────► ┌──1──┐ │
+// │ BB1 │ ▼ │ BB1 │ ▼
+// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
+// └───►4 BB2 │ └───►1 BB2 │
+// └──5──┘ └──2──┘
+//
+// On the left are the initial per-block bundles, and on the right are the
+// joined bundles (which are the result of the EdgeBundles analysis).
+
+#include "AArch64InstrInfo.h"
+#include "AArch64MachineFunctionInfo.h"
+#include "AArch64Subtarget.h"
+#include "MCTargetDesc/AArch64AddressingModes.h"
+#include "llvm/ADT/BitmaskEnum.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/EdgeBundles.h"
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-machine-sme-abi"
+
+namespace {
+
+enum ZAState {
+ // Any/unknown state (not valid)
+ ANY = 0,
+
+ // ZA is in use and active (i.e. within the accumulator)
+ ACTIVE,
+
+ // A ZA save has been set up or committed (i.e. ZA is dormant or off)
+ LOCAL_SAVED,
+
+ // ZA is off or a lazy save has been set up by the caller
+ CALLER_DORMANT,
+
+ // ZA is off
+ OFF,
+
+ // The number of ZA states (not a valid state)
+ NUM_ZA_STATE
+};
+
+/// A bitmask enum to record live physical registers that the "emit*" routines
+/// may need to preserve. Note: This only tracks registers we may clobber.
+enum LiveRegs : uint8_t {
+ None = 0,
+ NZCV = 1 << 0,
+ W0 = 1 << 1,
+ W0_HI = 1 << 2,
+ X0 = W0 | W0_HI,
+ LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
+};
+
+/// Holds the virtual registers live physical registers have been saved to.
+struct PhysRegSave {
+ LiveRegs PhysLiveRegs;
+ Register StatusFlags = AArch64::NoRegister;
+ Register X0Save = AArch64::NoRegister;
+};
+
+static bool isLegalEdgeBundleZAState(ZAState State) {
+ switch (State) {
+ case ZAState::ACTIVE:
+ case ZAState::LOCAL_SAVED:
+ return true;
+ default:
+ return false;
+ }
+}
+struct TPIDR2State {
+ int FrameIndex = -1;
+};
+
+StringRef getZAStateString(ZAState State) {
+#define MAKE_CASE(V) \
+ case V: \
+ return #V;
+ switch (State) {
+ MAKE_CASE(ZAState::ANY)
+ MAKE_CASE(ZAState::ACTIVE)
+ MAKE_CASE(ZAState::LOCAL_SAVED)
+ MAKE_CASE(ZAState::CALLER_DORMANT)
+ MAKE_CASE(ZAState::OFF)
+ default:
+ llvm_unreachable("Unexpected ZAState");
+ }
+#undef MAKE_CASE
+}
+
+static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
+ const MachineOperand &MO) {
+ if (!MO.isReg() || !MO.getReg().isPhysical())
+ return false;
+ return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
+ return AArch64::MPR128RegClass.contains(SR) ||
+ AArch64::ZTRRegClass.contains(SR);
+ });
+}
+
+/// Returns the required ZA state needed before \p MI and an iterator pointing
+/// to where any code required to change the ZA state should be inserted.
+static std::pair<ZAState, MachineBasicBlock::iterator>
+getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
+ bool ZAOffAtReturn) {
+ MachineBasicBlock::iterator InsertPt(MI);
+
+ if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
+ return {ZAState::ACTIVE, std::prev(InsertPt)};
+
+ if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
+ return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
+
+ if (MI.isReturn())
+ return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
+
+ for (auto &MO : MI.operands()) {
+ if (isZAorZTRegOp(TRI, MO))
+ return {ZAState::ACTIVE, InsertPt};
+ }
+
+ return {ZAState::ANY, InsertPt};
+}
+
+struct MachineSMEABI : public MachineFunctionPass {
+ inline static char ID = 0;
+
+ MachineSMEABI() : MachineFunctionPass(ID) {}
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ StringRef getPassName() const override { return "Machine SME ABI pass"; }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<EdgeBundlesWrapperLegacy>();
+ AU.addPreservedID(MachineLoopInfoID);
+ AU.addPreservedID(MachineDominatorsID);
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+
+ /// Collects the needed ZA state (and live registers) before each instruction
+ /// within the machine function.
+ void collectNeededZAStates(SMEAttrs);
+
+ /// Assigns each edge bundle a ZA state based on the needed states of blocks
+ /// that have incoming or outgoing edges in that bundle.
+ void assignBundleZAStates();
+
+ /// Inserts code to handle changes between ZA states within the function.
+ /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
+ void insertStateChanges();
+
+ // Emission routines for private and shared ZA functions (using lazy saves).
+ void emitNewZAPrologue(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI);
+ void emitRestoreLazySave(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI,
+ LiveRegs PhysLiveRegs);
+ void emitSetupLazySave(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI);
+ void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI);
+ void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+ bool ClearTPIDR2);
+
+ void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+ ZAState From, ZAState To, LiveRegs PhysLiveRegs);
+
+ /// Save live physical registers to virtual registers.
+ PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, DebugLoc DL);
+ /// Restore physical registers from a save of their previous values.
+ void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, DebugLoc DL);
+
+ /// Get or create a TPIDR2 block in this function.
+ TPIDR2State getTPIDR2Block();
+
+private:
+ /// Contains the needed ZA state (and live registers) at an instruction.
+ struct InstInfo {
+ ZAState NeededState{ZAState::ANY};
+ MachineBasicBlock::iterator InsertPt;
+ LiveRegs PhysLiveRegs = LiveRegs::None;
+ };
+
+ /// Contains the needed ZA state for each instruction in a block.
+ /// Instructions that do not require a ZA state are not recorded.
+ struct BlockInfo {
+ ZAState FixedEntryState{ZAState::ANY};
+ SmallVector<InstInfo> Insts;
+ LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
+ };
+
+ // All pass state that must be cleared between functions.
+ struct PassState {
+ SmallVector<BlockInfo> Blocks;
+ SmallVector<ZAState> BundleStates;
+ std::optional<TPIDR2State> TPIDR2Block;
+ } State;
+
+ MachineFunction *MF = nullptr;
+ EdgeBundles *Bundles = nullptr;
+ const AArch64Subtarget *Subtarget = nullptr;
+ const AArch64RegisterInfo *TRI = nullptr;
+ const TargetInstrInfo *TII = nullptr;
+ MachineRegisterInfo *MRI = nullptr;
+};
+
+void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
+ assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
+ "Expected function to have ZA/ZT0 state!");
+
+ State.Blocks.resize(MF->getNumBlockIDs());
+ for (MachineBasicBlock &MBB : *MF) {
+ BlockInfo &Block = State.Blocks[MBB.getNumber()];
+ if (MBB.isEntryBlock()) {
+ // Entry block:
+ Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
+ ? ZAState::CALLER_DORMANT
+ : ZAState::ACTIVE;
+ } else if (MBB.isEHPad()) {
+ // EH entry block:
+ Block.FixedEntryState = ZAState::LOCAL_SAVED;
+ }
+
+ LiveRegUnits LiveUnits(*TRI);
+ LiveUnits.addLiveOuts(MBB);
+
+ auto GetPhysLiveRegs = [&] {
+ LiveRegs PhysLiveRegs = LiveRegs::None;
+ if (!LiveUnits.available(AArch64::NZCV))
+ PhysLiveRegs |= LiveRegs::NZCV;
+ // We have to track W0 and X0 separately as otherwise things can get
+ // confused if we attempt to preserve X0 but only W0 was defined.
+ if (!LiveUnits.available(AArch64::W0))
+ PhysLiveRegs |= LiveRegs::W0;
+ if (!LiveUnits.available(AArch64::W0_HI))
+ PhysLiveRegs |= LiveRegs::W0_HI;
+ return PhysLiveRegs;
+ };
+
+ Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
+ auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
+ for (MachineInstr &MI : reverse(MBB)) {
+ MachineBasicBlock::iterator MBBI(MI);
+ LiveUnits.stepBackward(MI);
+ LiveRegs PhysLiveRegs = GetPhysLiveRegs();
+ auto [NeededState, InsertPt] = getZAStateBeforeInst(
+ *TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
+ assert((InsertPt == MBBI ||
+ InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
+ "Unexpected state change insertion point!");
+ // TODO: Do something to avoid state changes where NZCV is live.
+ if (MBBI == FirstTerminatorInsertPt)
+ Block.PhysLiveRegsAtExit = PhysLiveRegs;
+ if (NeededState != ZAState::ANY)
+ Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
+ }
+
+ // Reverse vector (as we had to iterate backwards for liveness).
+ std::reverse(Block.Insts.begin(), Block.Insts.end());
+ }
+}
+
+void MachineSMEABI::assignBundleZAStates() {
+ State.BundleStates.resize(Bundles->getNumBundles());
+ for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
+ LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
+
+ // Attempt to assign a ZA state for this bundle that minimizes state
+ // transitions. Edges within loops are given a higher weight as we assume
+ // they will be executed more than once.
+ // TODO: We should propagate desired incoming/outgoing states through blocks
+ // that have the "ANY" state first to make better global decisions.
+ int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
+ for (unsigned BlockID : Bundles->getBlocks(I)) {
+ LLVM_DEBUG(dbgs() << "- bb." << BlockID);
+
+ const BlockInfo &Block = State.Blocks[BlockID];
+ if (Block.Insts.empty()) {
+ LLVM_DEBUG(dbgs() << " (no state preference)\n");
+ continue;
+ }
+ bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
+ bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
+
+ ZAState DesiredIncomingState = Block.Insts.front().NeededState;
+ if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
+ EdgeStateCounts[DesiredIncomingState]++;
+ LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
+ << getZAStateString(DesiredIncomingState));
+ }
+ ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
+ if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
+ EdgeStateCounts[DesiredOutgoingState]++;
+ LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
+ << getZAStateString(DesiredOutgoingState));
+ }
+ LLVM_DEBUG(dbgs() << '\n');
+ }
+
+ ZAState BundleState =
+ ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
+
+ // Force ZA to be active in bundles that don't have a preferred state.
+ // TODO: Something better here (to avoid extra mode switches).
+ if (BundleState == ZAState::ANY)
+ BundleState = ZAState::ACTIVE;
+
+ LLVM_DEBUG({
+ dbgs() << "Chosen ZA state: " << getZAStateString(BundleState) << '\n'
+ << "Edge counts:";
+ for (auto [State, Count] : enumerate(EdgeStateCounts))
+ dbgs() << " " << getZAStateString(ZAState(State)) << ": " << Count;
+ dbgs() << "\n\n";
+ });
+
+ State.BundleStates[I] = BundleState;
+ }
+}
+
+void MachineSMEABI::insertStateChanges() {
+ for (MachineBasicBlock &MBB : *MF) {
+ const BlockInfo &Block = State.Blocks[MBB.getNumber()];
+ ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
+ /*Out=*/false)];
+
+ ZAState CurrentState = Block.FixedEntryState;
+ if (CurrentState == ZAState::ANY)
+ CurrentState = InState;
+
+ for (auto &Inst : Block.Insts) {
+ if (CurrentState != Inst.NeededState)
+ emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
+ Inst.PhysLiveRegs);
+ CurrentState = Inst.NeededState;
+ }
+
+ if (MBB.succ_empty())
+ continue;
+
+ ZAState OutState =
+ State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
+ if (CurrentState != OutState)
+ emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
+ Block.PhysLiveRegsAtExit);
+ }
+}
+
+TPIDR2State MachineSMEABI::getTPIDR2Block() {
+ if (State.TPIDR2Block)
+ return *State.TPIDR2Block;
+ MachineFrameInfo &MFI = MF->getFrameInfo();
+ State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)};
+ return *State.TPIDR2Block;
+}
+
+static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI) {
+ if (MBBI != MBB.end())
+ return MBBI->getDebugLoc();
+ return DebugLoc();
+}
+
+void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI) {
+ DebugLoc DL = getDebugLoc(MBB, MBBI);
+
+ // Get pointer to TPIDR2 block.
+ Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
+ Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
+ .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addImm(0)
+ .addImm(0);
+ BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
+ .addReg(TPIDR2);
+ // Set TPIDR2_EL0 to point to TPIDR2 block.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
+ .addImm(AArch64SysReg::TPIDR2_EL0)
+ .addReg(TPIDR2Ptr);
+}
+
+PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
+ MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI,
+ DebugLoc DL) {
+ PhysRegSave RegSave{PhysLiveRegs};
+ if (PhysLiveRegs & LiveRegs::NZCV) {
+ RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags)
+ .addImm(AArch64SysReg::NZCV)
+ .addReg(AArch64::NZCV, RegState::Implicit);
+ }
+ // Note: Preserving X0 is "free" as this is before register allocation, so
+ // the register allocator is still able to optimize these copies.
+ if (PhysLiveRegs & LiveRegs::W0) {
+ RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
+ ? &AArch64::GPR64RegClass
+ : &AArch64::GPR32RegClass);
+ BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save)
+ .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
+ }
+ return RegSave;
+}
+
+void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave,
+ MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI,
+ DebugLoc DL) {
+ if (RegSave.StatusFlags != AArch64::NoRegister)
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
+ .addImm(AArch64SysReg::NZCV)
+ .addReg(RegSave.StatusFlags)
+ .addReg(AArch64::NZCV, RegState::ImplicitDefine);
+
+ if (RegSave.X0Save != AArch64::NoRegister)
+ BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY),
+ RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
+ .addReg(RegSave.X0Save);
+}
+
+void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI,
+ LiveRegs PhysLiveRegs) {
+ auto *TLI = Subtarget->getTargetLowering();
+ DebugLoc DL = getDebugLoc(MBB, MBBI);
+ Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ Register TPIDR2 = AArch64::X0;
+
+ // TODO: Emit these within the restore MBB to prevent unnecessary saves.
+ PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
+
+ // Enable ZA.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
+ .addImm(AArch64SVCR::SVCRZA)
+ .addImm(1);
+ // Get current TPIDR2_EL0.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0)
+ .addImm(AArch64SysReg::TPIDR2_EL0);
+ // Get pointer to TPIDR2 block.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
+ .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addImm(0)
+ .addImm(0);
+ // (Conditionally) restore ZA state.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo))
+ .addReg(TPIDR2EL0)
+ .addReg(TPIDR2)
+ .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_RESTORE))
+ .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+ // Zero TPIDR2_EL0.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
+ .addImm(AArch64SysReg::TPIDR2_EL0)
+ .addReg(AArch64::XZR);
+
+ restorePhyRegSave(RegSave, MBB, MBBI, DL);
+}
+
+void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI,
+ bool ClearTPIDR2) {
+ DebugLoc DL = getDebugLoc(MBB, MBBI);
+
+ if (ClearTPIDR2)
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
+ .addImm(AArch64SysReg::TPIDR2_EL0)
+ .addReg(AArch64::XZR);
+
+ // Disable ZA.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
+ .addImm(AArch64SVCR::SVCRZA)
+ .addImm(0);
+}
+
+void MachineSMEABI::emitAllocateLazySaveBuffer(
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
+ MachineFrameInfo &MFI = MF->getFrameInfo();
+
+ DebugLoc DL = getDebugLoc(MBB, MBBI);
+ Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ Register Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+
+ // Calculate SVL.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1);
+
+ // 1. Allocate the lazy save buffer.
+ {
+ // TODO This function grows the stack with a subtraction, which doesn't work
+ // on Windows. Some refactoring to share the functionality in
+ // LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI
+ // supports SME
+ assert(!Subtarget->isTargetWindows() &&
+ "Lazy ZA save is not yet supported on Windows");
+ // Get original stack pointer.
+ BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP)
+ .addReg(AArch64::SP);
+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer)
+ .addReg(SVL)
+ .addReg(SVL)
+ .addReg(SP);
+ BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP)
+ .addReg(Buffer);
+ // We have just allocated a variable sized object, tell this to PEI.
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ }
+
+ // 2. Setup the TPIDR2 block.
+ {
+ // Note: This case just needs to do `SVL << 48`. It is not implemented as we
+ // generally don't support big-endian SVE/SME.
+ if (!Subtarget->isLittleEndian())
+ reportFatalInternalError(
+ "TPIDR2 block initialization is not supported on big-endian targets");
+
+ // Store buffer pointer and num_za_save_slices.
+ // Bytes 10-15 are implicitly zeroed.
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
+ .addReg(Buffer)
+ .addReg(SVL)
+ .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addImm(0);
+ }
+}
+
+void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI) {
+ auto *TLI = Subtarget->getTargetLowering();
+ DebugLoc DL = getDebugLoc(MBB, MBBI);
+
+ // Get current TPIDR2_EL0.
+ Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
+ .addReg(TPIDR2EL0, RegState::Define)
+ .addImm(AArch64SysReg::TPIDR2_EL0);
+ // If TPIDR2_EL0 is non-zero, commit the lazy save.
+ // NOTE: Functions that only use ZT0 don't need to zero ZA.
+ bool ZeroZA =
+ MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs().hasZAState();
+ auto CommitZASave =
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
+ .addReg(TPIDR2EL0)
+ .addImm(ZeroZA ? 1 : 0)
+ .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
+ .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+ if (ZeroZA)
+ CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+ // Enable ZA (as ZA could have previously been in the OFF state).
+ BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
+ .addImm(AArch64SVCR::SVCRZA)
+ .addImm(1);
+}
+
+void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator InsertPt,
+ ZAState From, ZAState To,
+ LiveRegs PhysLiveRegs) {
+
+ // ZA not used.
+ if (From == ZAState::ANY || To == ZAState::ANY)
+ return;
+
+ // If we're exiting from the CALLER_DORMANT state that means this new ZA
+ // function did not touch ZA (so ZA was never turned on).
+ if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF)
+ return;
+
+ // TODO: Avoid setting up the save buffer if there's no transition to
+ // LOCAL_SAVED.
+ if (From == ZAState::CALLER_DORMANT) {
+ assert(MBB.getParent()
+ ->getInfo<AArch64FunctionInfo>()
+ ->getSMEFnAttrs()
+ .hasPrivateZAInterface() &&
+ "CALLER_DORMANT state requires private ZA interface");
+ assert(&MBB == &MBB.getParent()->front() &&
+ "CALLER_DORMANT state only valid in entry block");
+ emitNewZAPrologue(MBB, MBB.getFirstNonPHI());
+ if (To == ZAState::ACTIVE)
+ return; // Nothing more to do (ZA is active after the prologue).
+
+ // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
+ // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
+ // case by changing the placement of the zero instruction.
+ From = ZAState::ACTIVE;
+ }
+
+ if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
+ emitSetupLazySave(MBB, InsertPt);
+ else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
+ emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
+ else if (To == ZAState::OFF) {
+ assert(From != ZAState::CALLER_DORMANT &&
+ "CALLER_DORMANT to OFF should have already been handled");
+ emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
+ } else {
+ dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
+ << getZAStateString(To) << '\n';
+ llvm_unreachable("Unimplemented state transition");
+ }
+}
+
+} // end anonymous namespace
+
+INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
+ false, false)
+
+bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
+ if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
+ return false;
+
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+ SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
+ if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
+ return false;
+
+ assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
+
+ // Reset pass state.
+ State = PassState{};
+ this->MF = &MF;
+ Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
+ Subtarget = &MF.getSubtarget<AArch64Subtarget>();
+ TII = Subtarget->getInstrInfo();
+ TRI = Subtarget->getRegisterInfo();
+ MRI = &MF.getRegInfo();
+
+ collectNeededZAStates(SMEFnAttrs);
+ assignBundleZAStates();
+ insertStateChanges();
+
+ // Allocate save buffer (if needed).
+ if (State.TPIDR2Block) {
+ MachineBasicBlock &EntryBlock = MF.front();
+ emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
+ }
+
+ return true;
+}
+
+FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); }
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 4af4d49..2008516 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -15,11 +15,16 @@
#include "AArch64.h"
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/RuntimeLibcalls.h"
+#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace llvm;
@@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass {
bool runOnFunction(Function &F) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetPassConfig>();
+ }
+
private:
bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
- SMEAttrs FnAttrs);
+ SMEAttrs FnAttrs, const TargetLowering &TLI);
};
} // end anonymous namespace
@@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
//===----------------------------------------------------------------------===//
// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
-void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
+void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI,
+ bool ZT0IsUndef = false) {
auto &Ctx = M->getContext();
auto *TPIDR2SaveTy =
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
auto Attrs =
AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE;
FunctionCallee Callee =
- M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
+ M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs);
CallInst *Call = Builder.CreateCall(Callee);
// If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
@@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
if (ZT0IsUndef)
Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
- Call->setCallingConv(
- CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
+ Call->setCallingConv(TLI.getLibcallCallingConv(LC));
// A save to TPIDR2 should be followed by clearing TPIDR2_EL0.
Function *WriteIntr =
@@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
/// interface if it does not share ZA or ZT0.
///
bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
- IRBuilder<> &Builder, SMEAttrs FnAttrs) {
+ IRBuilder<> &Builder, SMEAttrs FnAttrs,
+ const TargetLowering &TLI) {
LLVMContext &Context = F->getContext();
BasicBlock *OrigBB = &F->getEntryBlock();
Builder.SetInsertPoint(&OrigBB->front());
@@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
- emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
+ emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
@@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) {
if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za"))
return false;
+ const TargetMachine &TM =
+ getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+ const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering();
+
bool Changed = false;
SMEAttrs FnAttrs(F);
if (FnAttrs.isNewZA() || FnAttrs.isNewZT0())
- Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
+ Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI);
return Changed;
}
diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
index bd28716..85cca1d 100644
--- a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
@@ -80,16 +80,10 @@ static bool isMatchingStartStopPair(const MachineInstr *MI1,
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
return false;
- // This optimisation is unlikely to happen in practice for conditional
- // smstart/smstop pairs as the virtual registers for pstate.sm will always
- // be different.
- // TODO: For this optimisation to apply to conditional smstart/smstop,
- // this pass will need to do more work to remove redundant calls to
- // __arm_sme_state.
-
// Only consider conditional start/stop pairs which read the same register
- // holding the original value of pstate.sm, as some conditional start/stops
- // require the state on entry to the function.
+ // holding the original value of pstate.sm. This is somewhat over conservative
+ // as all conditional streaming mode changes only look at the state on entry
+ // to the function.
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
Register Reg1 = MI1->getOperand(3).getReg();
Register Reg2 = MI2->getOperand(3).getReg();
@@ -134,13 +128,6 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(
bool Changed = false;
MachineInstr *Prev = nullptr;
- SmallVector<MachineInstr *, 4> ToBeRemoved;
-
- // Convenience function to reset the matching of a sequence.
- auto Reset = [&]() {
- Prev = nullptr;
- ToBeRemoved.clear();
- };
// Walk through instructions in the block trying to find pairs of smstart
// and smstop nodes that cancel each other out. We only permit a limited
@@ -162,14 +149,10 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(
// that we marked for deletion in between.
Prev->eraseFromParent();
MI.eraseFromParent();
- for (MachineInstr *TBR : ToBeRemoved)
- TBR->eraseFromParent();
- ToBeRemoved.clear();
Prev = nullptr;
Changed = true;
NumSMChangesRemoved += 2;
} else {
- Reset();
Prev = &MI;
}
continue;
@@ -185,7 +168,7 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(
// of streaming mode. If not, the algorithm should reset.
switch (MI.getOpcode()) {
default:
- Reset();
+ Prev = nullptr;
break;
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
@@ -199,7 +182,7 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(
// concrete example/test-case.
if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
isSVERegOp(TRI, MRI, MI.getOperand(1)))
- Reset();
+ Prev = nullptr;
break;
case AArch64::ADJCALLSTACKDOWN:
case AArch64::ADJCALLSTACKUP:
@@ -207,12 +190,6 @@ bool SMEPeepholeOpt::optimizeStartStopPairs(
case AArch64::ADDXri:
// We permit these as they don't generate SVE/NEON instructions.
break;
- case AArch64::VGRestorePseudo:
- case AArch64::VGSavePseudo:
- // When the smstart/smstop are removed, we should also remove
- // the pseudos that save/restore the VG value for CFI info.
- ToBeRemoved.push_back(&MI);
- break;
case AArch64::MSRpstatesvcrImm1:
case AArch64::MSRpstatePseudo:
llvm_unreachable("Should have been handled");
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index a0320f9..74e4a7f 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -315,10 +315,16 @@ def addsub_imm8_opt_lsl_i16 : imm8_opt_lsl<16, "uint16_t", SVEAddSubImmOperand16
def addsub_imm8_opt_lsl_i32 : imm8_opt_lsl<32, "uint32_t", SVEAddSubImmOperand32>;
def addsub_imm8_opt_lsl_i64 : imm8_opt_lsl<64, "uint64_t", SVEAddSubImmOperand64>;
-def SVEAddSubImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8>", []>;
-def SVEAddSubImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16>", []>;
-def SVEAddSubImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32>", []>;
-def SVEAddSubImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64>", []>;
+let Complexity = 1 in {
+def SVEAddSubImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8, false>", []>;
+def SVEAddSubImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16, false>", []>;
+def SVEAddSubImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32, false>", []>;
+def SVEAddSubImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64, false>", []>;
+
+def SVEAddSubNegImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i8, true>", []>;
+def SVEAddSubNegImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i16, true>", []>;
+def SVEAddSubNegImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubImm<MVT::i32, true>", []>;
+def SVEAddSubNegImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubImm<MVT::i64, true>", []>;
def SVEAddSubSSatNegImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i8, true>", []>;
def SVEAddSubSSatNegImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i16, true>", []>;
@@ -329,6 +335,7 @@ def SVEAddSubSSatPosImm8Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MV
def SVEAddSubSSatPosImm16Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i16, false>", []>;
def SVEAddSubSSatPosImm32Pat : ComplexPattern<i32, 2, "SelectSVEAddSubSSatImm<MVT::i32, false>", []>;
def SVEAddSubSSatPosImm64Pat : ComplexPattern<i64, 2, "SelectSVEAddSubSSatImm<MVT::i64, false>", []>;
+} // Complexity = 1
def SVECpyDupImm8Pat : ComplexPattern<i32, 2, "SelectSVECpyDupImm<MVT::i8>", []>;
def SVECpyDupImm16Pat : ComplexPattern<i32, 2, "SelectSVECpyDupImm<MVT::i16>", []>;
@@ -809,6 +816,11 @@ let hasNoSchedulingInfo = 1 in {
Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2, zprty:$Zs3), []> {
let FalseLanes = flags;
}
+
+ class UnpredRegImmPseudo<ZPRRegOp zprty, Operand immty>
+ : SVEPseudo2Instr<NAME, 0>,
+ Pseudo<(outs zprty:$Zd), (ins zprty:$Zs, immty:$imm), []> {
+ }
}
//
@@ -1885,13 +1897,14 @@ class sve_int_perm_extract_i<string asm>
let Inst{4-0} = Zdn;
let Constraints = "$Zdn = $_Zdn";
- let DestructiveInstType = DestructiveOther;
+ let DestructiveInstType = Destructive2xRegImmUnpred;
let ElementSize = ElementSizeNone;
let hasSideEffects = 0;
}
-multiclass sve_int_perm_extract_i<string asm, SDPatternOperator op> {
- def NAME : sve_int_perm_extract_i<asm>;
+multiclass sve_int_perm_extract_i<string asm, SDPatternOperator op, string Ps> {
+ def NAME : sve_int_perm_extract_i<asm>,
+ SVEPseudo2Instr<Ps, 1>;
def : SVE_3_Op_Imm_Pat<nxv16i8, op, nxv16i8, nxv16i8, i32, imm0_255,
!cast<Instruction>(NAME)>;
@@ -5148,11 +5161,14 @@ multiclass sve_int_dup_imm<string asm> {
(!cast<Instruction>(NAME # _D) ZPR64:$Zd, cpy_imm8_opt_lsl_i64:$imm), 1>;
def : InstAlias<"fmov $Zd, #0.0",
- (!cast<Instruction>(NAME # _H) ZPR16:$Zd, 0, 0), 1>;
+ (!cast<Instruction>(NAME # _H) ZPR16:$Zd,
+ (cpy_imm8_opt_lsl_i16 0, 0)), 1>;
def : InstAlias<"fmov $Zd, #0.0",
- (!cast<Instruction>(NAME # _S) ZPR32:$Zd, 0, 0), 1>;
+ (!cast<Instruction>(NAME # _S) ZPR32:$Zd,
+ (cpy_imm8_opt_lsl_i32 0, 0)), 1>;
def : InstAlias<"fmov $Zd, #0.0",
- (!cast<Instruction>(NAME # _D) ZPR64:$Zd, 0, 0), 1>;
+ (!cast<Instruction>(NAME # _D) ZPR64:$Zd,
+ (cpy_imm8_opt_lsl_i64 0, 0)), 1>;
}
class sve_int_dup_fpimm<bits<2> sz8_64, Operand fpimmtype,
@@ -5212,7 +5228,8 @@ class sve_int_arith_imm0<bits<2> sz8_64, bits<3> opc, string asm,
let hasSideEffects = 0;
}
-multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op> {
+multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op,
+ SDPatternOperator inv_op = null_frag> {
def _B : sve_int_arith_imm0<0b00, opc, asm, ZPR8, addsub_imm8_opt_lsl_i8>;
def _H : sve_int_arith_imm0<0b01, opc, asm, ZPR16, addsub_imm8_opt_lsl_i16>;
def _S : sve_int_arith_imm0<0b10, opc, asm, ZPR32, addsub_imm8_opt_lsl_i32>;
@@ -5222,6 +5239,12 @@ multiclass sve_int_arith_imm0<bits<3> opc, string asm, SDPatternOperator op> {
def : SVE_1_Op_Imm_OptLsl_Pat<nxv8i16, op, ZPR16, i32, SVEAddSubImm16Pat, !cast<Instruction>(NAME # _H)>;
def : SVE_1_Op_Imm_OptLsl_Pat<nxv4i32, op, ZPR32, i32, SVEAddSubImm32Pat, !cast<Instruction>(NAME # _S)>;
def : SVE_1_Op_Imm_OptLsl_Pat<nxv2i64, op, ZPR64, i64, SVEAddSubImm64Pat, !cast<Instruction>(NAME # _D)>;
+
+ // Extra patterns for add(x, splat(-ve)) -> sub(x, +ve). There is no i8
+ // pattern as all i8 constants can be handled by an add.
+ def : SVE_1_Op_Imm_OptLsl_Pat<nxv8i16, inv_op, ZPR16, i32, SVEAddSubNegImm16Pat, !cast<Instruction>(NAME # _H)>;
+ def : SVE_1_Op_Imm_OptLsl_Pat<nxv4i32, inv_op, ZPR32, i32, SVEAddSubNegImm32Pat, !cast<Instruction>(NAME # _S)>;
+ def : SVE_1_Op_Imm_OptLsl_Pat<nxv2i64, inv_op, ZPR64, i64, SVEAddSubNegImm64Pat, !cast<Instruction>(NAME # _D)>;
}
multiclass sve_int_arith_imm0_ssat<bits<3> opc, string asm, SDPatternOperator op,
@@ -5543,11 +5566,14 @@ multiclass sve_int_dup_imm_pred_merge<string asm, SDPatternOperator op> {
nxv2i64, nxv2i1, i64, SVECpyDupImm64Pat>;
def : InstAlias<"fmov $Zd, $Pg/m, #0.0",
- (!cast<Instruction>(NAME # _H) ZPR16:$Zd, PPRAny:$Pg, 0, 0), 0>;
+ (!cast<Instruction>(NAME # _H) ZPR16:$Zd, PPRAny:$Pg,
+ (cpy_imm8_opt_lsl_i16 0, 0)), 0>;
def : InstAlias<"fmov $Zd, $Pg/m, #0.0",
- (!cast<Instruction>(NAME # _S) ZPR32:$Zd, PPRAny:$Pg, 0, 0), 0>;
+ (!cast<Instruction>(NAME # _S) ZPR32:$Zd, PPRAny:$Pg,
+ (cpy_imm8_opt_lsl_i32 0, 0)), 0>;
def : InstAlias<"fmov $Zd, $Pg/m, #0.0",
- (!cast<Instruction>(NAME # _D) ZPR64:$Zd, PPRAny:$Pg, 0, 0), 0>;
+ (!cast<Instruction>(NAME # _D) ZPR64:$Zd, PPRAny:$Pg,
+ (cpy_imm8_opt_lsl_i64 0, 0)), 0>;
def : Pat<(vselect PPRAny:$Pg, (SVEDup0), (nxv8f16 ZPR:$Zd)),
(!cast<Instruction>(NAME # _H) $Zd, $Pg, 0, 0)>;
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 271094f..dd6fa16 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -7,17 +7,14 @@
//===----------------------------------------------------------------------===//
#include "AArch64SMEAttributes.h"
+#include "AArch64ISelLowering.h"
#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/RuntimeLibcalls.h"
#include <cassert>
using namespace llvm;
-void SMEAttrs::set(unsigned M, bool Enable) {
- if (Enable)
- Bitmask |= M;
- else
- Bitmask &= ~M;
-
+void SMEAttrs::validate() const {
// Streaming Mode Attrs
assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) &&
"SM_Enabled and SM_Compatible are mutually exclusive");
@@ -77,19 +74,36 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= encodeZT0State(StateValue::New);
}
-void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
+void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName,
+ const AArch64TargetLowering &TLI) {
+ RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName);
+ if (Impl == RTLIB::Unsupported)
+ return;
unsigned KnownAttrs = SMEAttrs::Normal;
- if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
- KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
- if (FuncName == "__arm_tpidr2_restore")
+ RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl);
+ switch (LC) {
+ case RTLIB::SMEABI_SME_STATE:
+ case RTLIB::SMEABI_TPIDR2_SAVE:
+ case RTLIB::SMEABI_GET_CURRENT_VG:
+ case RTLIB::SMEABI_SME_STATE_SIZE:
+ case RTLIB::SMEABI_SME_SAVE:
+ case RTLIB::SMEABI_SME_RESTORE:
+ KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
+ break;
+ case RTLIB::SMEABI_ZA_DISABLE:
+ case RTLIB::SMEABI_TPIDR2_RESTORE:
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
SMEAttrs::SME_ABI_Routine;
- if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
- FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
+ break;
+ case RTLIB::SC_MEMCPY:
+ case RTLIB::SC_MEMMOVE:
+ case RTLIB::SC_MEMSET:
+ case RTLIB::SC_MEMCHR:
KnownAttrs |= SMEAttrs::SM_Compatible;
- if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
- FuncName == "__arm_sme_state_size")
- KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
+ break;
+ default:
+ break;
+ }
set(KnownAttrs);
}
@@ -110,11 +124,11 @@ bool SMECallAttrs::requiresSMChange() const {
return true;
}
-SMECallAttrs::SMECallAttrs(const CallBase &CB)
+SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI)
: CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal),
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
if (auto *CalledFunction = CB.getCalledFunction())
- CalledFn = SMEAttrs(*CalledFunction, SMEAttrs::InferAttrsFromName::Yes);
+ CalledFn = SMEAttrs(*CalledFunction, TLI);
// FIXME: We probably should not allow SME attributes on direct calls but
// clang duplicates streaming mode attributes at each callsite.
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index f1be0ecb..d26e3cd 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -13,6 +13,8 @@
namespace llvm {
+class AArch64TargetLowering;
+
class Function;
class CallBase;
class AttributeList;
@@ -48,19 +50,27 @@ public:
CallSiteFlags_Mask = ZT0_Undef
};
- enum class InferAttrsFromName { No, Yes };
-
SMEAttrs() = default;
SMEAttrs(unsigned Mask) { set(Mask); }
- SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No)
+ SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr)
: SMEAttrs(F.getAttributes()) {
- if (Infer == InferAttrsFromName::Yes)
- addKnownFunctionAttrs(F.getName());
+ if (TLI)
+ addKnownFunctionAttrs(F.getName(), *TLI);
}
SMEAttrs(const AttributeList &L);
- SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
+ SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) {
+ addKnownFunctionAttrs(FuncName, TLI);
+ };
- void set(unsigned M, bool Enable = true);
+ void set(unsigned M, bool Enable = true) {
+ if (Enable)
+ Bitmask |= M;
+ else
+ Bitmask &= ~M;
+#ifndef NDEBUG
+ validate();
+#endif
+ }
// Interfaces to query PSTATE.SM
bool hasStreamingBody() const { return Bitmask & SM_Body; }
@@ -146,7 +156,9 @@ public:
}
private:
- void addKnownFunctionAttrs(StringRef FuncName);
+ void addKnownFunctionAttrs(StringRef FuncName,
+ const AArch64TargetLowering &TLI);
+ void validate() const;
};
/// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
@@ -163,7 +175,7 @@ public:
SMEAttrs Callsite = SMEAttrs::Normal)
: CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
- SMECallAttrs(const CallBase &CB);
+ SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI);
SMEAttrs &caller() { return CallerFn; }
SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
@@ -194,7 +206,7 @@ public:
}
bool requiresEnablingZAAfterCall() const {
- return requiresLazySave() || requiresDisablingZABeforeCall();
+ return requiresDisablingZABeforeCall();
}
bool requiresPreservingAllZAState() const {