diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 93 |
1 files changed, 88 insertions, 5 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ef3e8c8..7b49754 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3101,6 +3101,83 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, return BB; } +// Helper function to find the instruction that defined a virtual register. +// If unable to find such instruction, returns nullptr. +static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI, + Register Reg) { + while (Reg.isVirtual()) { + MachineInstr *DefMI = MRI.getVRegDef(Reg); + assert(DefMI && "Virtual register definition not found"); + unsigned Opcode = DefMI->getOpcode(); + + if (Opcode == AArch64::COPY) { + Reg = DefMI->getOperand(1).getReg(); + // Vreg is defined by copying from physreg. + if (Reg.isPhysical()) + return DefMI; + continue; + } + if (Opcode == AArch64::SUBREG_TO_REG) { + Reg = DefMI->getOperand(2).getReg(); + continue; + } + + return DefMI; + } + return nullptr; +} + +void AArch64TargetLowering::fixupPtrauthDiscriminator( + MachineInstr &MI, MachineBasicBlock *BB, MachineOperand &IntDiscOp, + MachineOperand &AddrDiscOp, const TargetRegisterClass *AddrDiscRC) const { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + const DebugLoc &DL = MI.getDebugLoc(); + + Register AddrDisc = AddrDiscOp.getReg(); + int64_t IntDisc = IntDiscOp.getImm(); + assert(IntDisc == 0 && "Blend components are already expanded"); + + const MachineInstr *DiscMI = stripVRegCopies(MRI, AddrDisc); + if (DiscMI) { + switch (DiscMI->getOpcode()) { + case AArch64::MOVKXi: + // blend(addr, imm) which is lowered as "MOVK addr, #imm, #48". + // #imm should be an immediate and not a global symbol, for example. + if (DiscMI->getOperand(2).isImm() && + DiscMI->getOperand(3).getImm() == 48) { + AddrDisc = DiscMI->getOperand(1).getReg(); + IntDisc = DiscMI->getOperand(2).getImm(); + } + break; + case AArch64::MOVi32imm: + case AArch64::MOVi64imm: + // Small immediate integer constant passed via VReg. + if (DiscMI->getOperand(1).isImm() && + isUInt<16>(DiscMI->getOperand(1).getImm())) { + AddrDisc = AArch64::NoRegister; + IntDisc = DiscMI->getOperand(1).getImm(); + } + break; + } + } + + // For uniformity, always use NoRegister, as XZR is not necessarily contained + // in the requested register class. + if (AddrDisc == AArch64::XZR) + AddrDisc = AArch64::NoRegister; + + // Make sure AddrDisc operand respects the register class imposed by MI. + if (AddrDisc && MRI.getRegClass(AddrDisc) != AddrDiscRC) { + Register TmpReg = MRI.createVirtualRegister(AddrDiscRC); + BuildMI(*BB, MI, DL, TII->get(AArch64::COPY), TmpReg).addReg(AddrDisc); + AddrDisc = TmpReg; + } + + AddrDiscOp.setReg(AddrDisc); + IntDiscOp.setImm(IntDisc); +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3199,6 +3276,11 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true); case AArch64::MOVT_TIZ_PSEUDO: return EmitZTInstr(MI, BB, AArch64::MOVT_TIZ, /*Op0IsDef=*/true); + + case AArch64::PAC: + fixupPtrauthDiscriminator(MI, BB, MI.getOperand(3), MI.getOperand(4), + &AArch64::GPR64noipRegClass); + return BB; } } @@ -6814,7 +6896,8 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64)); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other), - {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, + {StoreNode->getChain(), DAG.getBitcast(MVT::v2i64, Lo), + DAG.getBitcast(MVT::v2i64, Hi), StoreNode->getBasePtr()}, StoreNode->getMemoryVT(), StoreNode->getMemOperand()); return Result; } @@ -27911,16 +27994,16 @@ void AArch64TargetLowering::ReplaceNodeResults( MemVT.getScalarSizeInBits() == 32u || MemVT.getScalarSizeInBits() == 64u)) { + EVT HalfVT = MemVT.getHalfNumVectorElementsVT(*DAG.getContext()); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::LDNP, SDLoc(N), - DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MVT::Other}), + DAG.getVTList({MVT::v2i64, MVT::v2i64, MVT::Other}), {LoadNode->getChain(), LoadNode->getBasePtr()}, LoadNode->getMemoryVT(), LoadNode->getMemOperand()); SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT, - Result.getValue(0), Result.getValue(1)); + DAG.getBitcast(HalfVT, Result.getValue(0)), + DAG.getBitcast(HalfVT, Result.getValue(1))); Results.append({Pair, Result.getValue(2) /* Chain */}); return; } |