aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp93
1 files changed, 88 insertions, 5 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ef3e8c8..7b49754 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3101,6 +3101,83 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
return BB;
}
+// Helper function to find the instruction that defined a virtual register.
+// If unable to find such instruction, returns nullptr.
+static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
+ Register Reg) {
+ while (Reg.isVirtual()) {
+ MachineInstr *DefMI = MRI.getVRegDef(Reg);
+ assert(DefMI && "Virtual register definition not found");
+ unsigned Opcode = DefMI->getOpcode();
+
+ if (Opcode == AArch64::COPY) {
+ Reg = DefMI->getOperand(1).getReg();
+ // Vreg is defined by copying from physreg.
+ if (Reg.isPhysical())
+ return DefMI;
+ continue;
+ }
+ if (Opcode == AArch64::SUBREG_TO_REG) {
+ Reg = DefMI->getOperand(2).getReg();
+ continue;
+ }
+
+ return DefMI;
+ }
+ return nullptr;
+}
+
+void AArch64TargetLowering::fixupPtrauthDiscriminator(
+ MachineInstr &MI, MachineBasicBlock *BB, MachineOperand &IntDiscOp,
+ MachineOperand &AddrDiscOp, const TargetRegisterClass *AddrDiscRC) const {
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ const DebugLoc &DL = MI.getDebugLoc();
+
+ Register AddrDisc = AddrDiscOp.getReg();
+ int64_t IntDisc = IntDiscOp.getImm();
+ assert(IntDisc == 0 && "Blend components are already expanded");
+
+ const MachineInstr *DiscMI = stripVRegCopies(MRI, AddrDisc);
+ if (DiscMI) {
+ switch (DiscMI->getOpcode()) {
+ case AArch64::MOVKXi:
+ // blend(addr, imm) which is lowered as "MOVK addr, #imm, #48".
+ // #imm should be an immediate and not a global symbol, for example.
+ if (DiscMI->getOperand(2).isImm() &&
+ DiscMI->getOperand(3).getImm() == 48) {
+ AddrDisc = DiscMI->getOperand(1).getReg();
+ IntDisc = DiscMI->getOperand(2).getImm();
+ }
+ break;
+ case AArch64::MOVi32imm:
+ case AArch64::MOVi64imm:
+ // Small immediate integer constant passed via VReg.
+ if (DiscMI->getOperand(1).isImm() &&
+ isUInt<16>(DiscMI->getOperand(1).getImm())) {
+ AddrDisc = AArch64::NoRegister;
+ IntDisc = DiscMI->getOperand(1).getImm();
+ }
+ break;
+ }
+ }
+
+ // For uniformity, always use NoRegister, as XZR is not necessarily contained
+ // in the requested register class.
+ if (AddrDisc == AArch64::XZR)
+ AddrDisc = AArch64::NoRegister;
+
+ // Make sure AddrDisc operand respects the register class imposed by MI.
+ if (AddrDisc && MRI.getRegClass(AddrDisc) != AddrDiscRC) {
+ Register TmpReg = MRI.createVirtualRegister(AddrDiscRC);
+ BuildMI(*BB, MI, DL, TII->get(AArch64::COPY), TmpReg).addReg(AddrDisc);
+ AddrDisc = TmpReg;
+ }
+
+ AddrDiscOp.setReg(AddrDisc);
+ IntDiscOp.setImm(IntDisc);
+}
+
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {
@@ -3199,6 +3276,11 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true);
case AArch64::MOVT_TIZ_PSEUDO:
return EmitZTInstr(MI, BB, AArch64::MOVT_TIZ, /*Op0IsDef=*/true);
+
+ case AArch64::PAC:
+ fixupPtrauthDiscriminator(MI, BB, MI.getOperand(3), MI.getOperand(4),
+ &AArch64::GPR64noipRegClass);
+ return BB;
}
}
@@ -6814,7 +6896,8 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64));
SDValue Result = DAG.getMemIntrinsicNode(
AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other),
- {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
+ {StoreNode->getChain(), DAG.getBitcast(MVT::v2i64, Lo),
+ DAG.getBitcast(MVT::v2i64, Hi), StoreNode->getBasePtr()},
StoreNode->getMemoryVT(), StoreNode->getMemOperand());
return Result;
}
@@ -27911,16 +27994,16 @@ void AArch64TargetLowering::ReplaceNodeResults(
MemVT.getScalarSizeInBits() == 32u ||
MemVT.getScalarSizeInBits() == 64u)) {
+ EVT HalfVT = MemVT.getHalfNumVectorElementsVT(*DAG.getContext());
SDValue Result = DAG.getMemIntrinsicNode(
AArch64ISD::LDNP, SDLoc(N),
- DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
- MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
- MVT::Other}),
+ DAG.getVTList({MVT::v2i64, MVT::v2i64, MVT::Other}),
{LoadNode->getChain(), LoadNode->getBasePtr()},
LoadNode->getMemoryVT(), LoadNode->getMemOperand());
SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT,
- Result.getValue(0), Result.getValue(1));
+ DAG.getBitcast(HalfVT, Result.getValue(0)),
+ DAG.getBitcast(HalfVT, Result.getValue(1)));
Results.append({Pair, Result.getValue(2) /* Chain */});
return;
}