diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
24 files changed, 664 insertions, 288 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index c4b43e1..c52487a 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -176,6 +176,9 @@ public: std::optional<AArch64PACKey::ID> PACKey, uint64_t PACDisc, Register PACAddrDisc); + // Emit the sequence for PAC. + void emitPtrauthSign(const MachineInstr *MI); + // Emit the sequence to compute the discriminator. // // The returned register is either unmodified AddrDisc or ScratchReg. @@ -2175,6 +2178,37 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( OutStreamer->emitLabel(EndSym); } +void AArch64AsmPrinter::emitPtrauthSign(const MachineInstr *MI) { + Register Val = MI->getOperand(1).getReg(); + auto Key = (AArch64PACKey::ID)MI->getOperand(2).getImm(); + uint64_t Disc = MI->getOperand(3).getImm(); + Register AddrDisc = MI->getOperand(4).getReg(); + bool AddrDiscKilled = MI->getOperand(4).isKill(); + + // As long as at least one of Val and AddrDisc is in GPR64noip, a scratch + // register is available. + Register ScratchReg = Val == AArch64::X16 ? AArch64::X17 : AArch64::X16; + assert(ScratchReg != AddrDisc && + "Neither X16 nor X17 is available as a scratch register"); + + // Compute pac discriminator + assert(isUInt<16>(Disc)); + Register DiscReg = emitPtrauthDiscriminator( + Disc, AddrDisc, ScratchReg, /*MayUseAddrAsScratch=*/AddrDiscKilled); + bool IsZeroDisc = DiscReg == AArch64::XZR; + unsigned Opc = getPACOpcodeForKey(Key, IsZeroDisc); + + // paciza x16 ; if IsZeroDisc + // pacia x16, x17 ; if !IsZeroDisc + MCInst PACInst; + PACInst.setOpcode(Opc); + PACInst.addOperand(MCOperand::createReg(Val)); + PACInst.addOperand(MCOperand::createReg(Val)); + if (!IsZeroDisc) + PACInst.addOperand(MCOperand::createReg(DiscReg)); + EmitToStreamer(*OutStreamer, PACInst); +} + void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) { bool IsCall = MI->getOpcode() == AArch64::BLRA; unsigned BrTarget = MI->getOperand(0).getReg(); @@ -2890,6 +2924,10 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { MI->getOperand(4).getImm(), MI->getOperand(5).getReg()); return; + case AArch64::PAC: + emitPtrauthSign(MI); + return; + case AArch64::LOADauthptrstatic: LowerLOADauthptrstatic(*MI); return; diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index ca09598..99f0af5 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -39,8 +39,8 @@ let Predicates = [HasDotProd] in { def ext_addv_to_udot_addv : GICombineRule< (defs root:$root, ext_addv_to_udot_addv_matchinfo:$matchinfo), (match (wip_match_opcode G_VECREDUCE_ADD):$root, - [{ return matchExtAddvToUdotAddv(*${root}, MRI, STI, ${matchinfo}); }]), - (apply [{ applyExtAddvToUdotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }]) + [{ return matchExtAddvToDotAddv(*${root}, MRI, STI, ${matchinfo}); }]), + (apply [{ applyExtAddvToDotAddv(*${root}, MRI, B, Observer, STI, ${matchinfo}); }]) >; } @@ -62,8 +62,10 @@ class push_opcode_through_ext<Instruction opcode, Instruction extOpcode> : GICom def push_sub_through_zext : push_opcode_through_ext<G_SUB, G_ZEXT>; def push_add_through_zext : push_opcode_through_ext<G_ADD, G_ZEXT>; +def push_mul_through_zext : push_opcode_through_ext<G_MUL, G_ZEXT>; def push_sub_through_sext : push_opcode_through_ext<G_SUB, G_SEXT>; def push_add_through_sext : push_opcode_through_ext<G_ADD, G_SEXT>; +def push_mul_through_sext : push_opcode_through_ext<G_MUL, G_SEXT>; def AArch64PreLegalizerCombiner: GICombiner< "AArch64PreLegalizerCombinerImpl", [all_combines, @@ -75,8 +77,10 @@ def AArch64PreLegalizerCombiner: GICombiner< ext_uaddv_to_uaddlv, push_sub_through_zext, push_add_through_zext, + push_mul_through_zext, push_sub_through_sext, - push_add_through_sext]> { + push_add_through_sext, + push_mul_through_sext]> { let CombineAllMethodName = "tryCombineAllImpl"; } diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp index 7de66cc..201bfe0 100644 --- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -598,6 +598,9 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( llvm_unreachable("Unsupported ElementSize"); } + // Preserve undef state until DOP's reg is defined. + unsigned DOPRegState = MI.getOperand(DOPIdx).isUndef() ? RegState::Undef : 0; + // // Create the destructive operation (if required) // @@ -616,10 +619,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfxZero)) .addReg(DstReg, RegState::Define) .addReg(MI.getOperand(PredIdx).getReg()) - .addReg(MI.getOperand(DOPIdx).getReg()); + .addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState); // After the movprfx, the destructive operand is same as Dst DOPIdx = 0; + DOPRegState = 0; // Create the additional LSL to zero the lanes when the DstReg is not // unique. Zeros the lanes in z0 that aren't active in p0 with sequence @@ -638,8 +642,9 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( assert(DOPRegIsUnique && "The destructive operand should be unique"); PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfx)) .addReg(DstReg, RegState::Define) - .addReg(MI.getOperand(DOPIdx).getReg()); + .addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState); DOPIdx = 0; + DOPRegState = 0; } // @@ -647,10 +652,11 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( // DOP = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode)) .addReg(DstReg, RegState::Define | getDeadRegState(DstIsDead)); + DOPRegState = DOPRegState | RegState::Kill; switch (DType) { case AArch64::DestructiveUnaryPassthru: - DOP.addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + DOP.addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState) .add(MI.getOperand(PredIdx)) .add(MI.getOperand(SrcIdx)); break; @@ -659,12 +665,12 @@ bool AArch64ExpandPseudo::expand_DestructiveOp( case AArch64::DestructiveBinaryComm: case AArch64::DestructiveBinaryCommWithRev: DOP.add(MI.getOperand(PredIdx)) - .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) - .add(MI.getOperand(SrcIdx)); + .addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState) + .add(MI.getOperand(SrcIdx)); break; case AArch64::DestructiveTernaryCommWithRev: DOP.add(MI.getOperand(PredIdx)) - .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + .addReg(MI.getOperand(DOPIdx).getReg(), DOPRegState) .add(MI.getOperand(SrcIdx)) .add(MI.getOperand(Src2Idx)); break; @@ -1199,32 +1205,36 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, Register DstReg = MI.getOperand(0).getReg(); if (DstReg == MI.getOperand(3).getReg()) { // Expand to BIT - BuildMI(MBB, MBBI, MI.getDebugLoc(), - TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BITv8i8 - : AArch64::BITv16i8)) - .add(MI.getOperand(0)) - .add(MI.getOperand(3)) - .add(MI.getOperand(2)) - .add(MI.getOperand(1)); + auto I = BuildMI(MBB, MBBI, MI.getDebugLoc(), + TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BITv8i8 + : AArch64::BITv16i8)) + .add(MI.getOperand(0)) + .add(MI.getOperand(3)) + .add(MI.getOperand(2)) + .add(MI.getOperand(1)); + transferImpOps(MI, I, I); } else if (DstReg == MI.getOperand(2).getReg()) { // Expand to BIF - BuildMI(MBB, MBBI, MI.getDebugLoc(), - TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BIFv8i8 - : AArch64::BIFv16i8)) - .add(MI.getOperand(0)) - .add(MI.getOperand(2)) - .add(MI.getOperand(3)) - .add(MI.getOperand(1)); + auto I = BuildMI(MBB, MBBI, MI.getDebugLoc(), + TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BIFv8i8 + : AArch64::BIFv16i8)) + .add(MI.getOperand(0)) + .add(MI.getOperand(2)) + .add(MI.getOperand(3)) + .add(MI.getOperand(1)); + transferImpOps(MI, I, I); } else { // Expand to BSL, use additional move if required if (DstReg == MI.getOperand(1).getReg()) { - BuildMI(MBB, MBBI, MI.getDebugLoc(), - TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8 - : AArch64::BSLv16i8)) - .add(MI.getOperand(0)) - .add(MI.getOperand(1)) - .add(MI.getOperand(2)) - .add(MI.getOperand(3)); + auto I = + BuildMI(MBB, MBBI, MI.getDebugLoc(), + TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8 + : AArch64::BSLv16i8)) + .add(MI.getOperand(0)) + .add(MI.getOperand(1)) + .add(MI.getOperand(2)) + .add(MI.getOperand(3)); + transferImpOps(MI, I, I); } else { BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::ORRv8i8 @@ -1234,15 +1244,17 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, getRenamableRegState(MI.getOperand(0).isRenamable())) .add(MI.getOperand(1)) .add(MI.getOperand(1)); - BuildMI(MBB, MBBI, MI.getDebugLoc(), - TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8 - : AArch64::BSLv16i8)) - .add(MI.getOperand(0)) - .addReg(DstReg, - RegState::Kill | - getRenamableRegState(MI.getOperand(0).isRenamable())) - .add(MI.getOperand(2)) - .add(MI.getOperand(3)); + auto I2 = + BuildMI(MBB, MBBI, MI.getDebugLoc(), + TII->get(Opcode == AArch64::BSPv8i8 ? AArch64::BSLv8i8 + : AArch64::BSLv16i8)) + .add(MI.getOperand(0)) + .addReg(DstReg, + RegState::Kill | getRenamableRegState( + MI.getOperand(0).isRenamable())) + .add(MI.getOperand(2)) + .add(MI.getOperand(3)); + transferImpOps(MI, I2, I2); } } MI.eraseFromParent(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index eca7ca5..ad42f4b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -5296,7 +5296,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case Intrinsic::aarch64_sve_ld1_pn_x2: { if (VT == MVT::nxv16i8) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 2, 0, AArch64::LD1B_2Z_IMM_PSEUDO, AArch64::LD1B_2Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5307,7 +5307,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || VT == MVT::nxv8bf16) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 2, 1, AArch64::LD1H_2Z_IMM_PSEUDO, AArch64::LD1H_2Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5317,7 +5317,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 2, 2, AArch64::LD1W_2Z_IMM_PSEUDO, AArch64::LD1W_2Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5327,7 +5327,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 2, 3, AArch64::LD1D_2Z_IMM_PSEUDO, AArch64::LD1D_2Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5341,7 +5341,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case Intrinsic::aarch64_sve_ld1_pn_x4: { if (VT == MVT::nxv16i8) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 4, 0, AArch64::LD1B_4Z_IMM_PSEUDO, AArch64::LD1B_4Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5352,7 +5352,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || VT == MVT::nxv8bf16) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 4, 1, AArch64::LD1H_4Z_IMM_PSEUDO, AArch64::LD1H_4Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5362,7 +5362,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 4, 2, AArch64::LD1W_4Z_IMM_PSEUDO, AArch64::LD1W_4Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5372,7 +5372,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad( Node, 4, 3, AArch64::LD1D_4Z_IMM_PSEUDO, AArch64::LD1D_4Z_PSEUDO); else if (Subtarget->hasSVE2p1()) @@ -5386,7 +5386,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case Intrinsic::aarch64_sve_ldnt1_pn_x2: { if (VT == MVT::nxv16i8) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 2, 0, AArch64::LDNT1B_2Z_IMM_PSEUDO, AArch64::LDNT1B_2Z_PSEUDO); @@ -5398,7 +5398,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || VT == MVT::nxv8bf16) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 2, 1, AArch64::LDNT1H_2Z_IMM_PSEUDO, AArch64::LDNT1H_2Z_PSEUDO); @@ -5409,7 +5409,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 2, 2, AArch64::LDNT1W_2Z_IMM_PSEUDO, AArch64::LDNT1W_2Z_PSEUDO); @@ -5420,7 +5420,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 2, 3, AArch64::LDNT1D_2Z_IMM_PSEUDO, AArch64::LDNT1D_2Z_PSEUDO); @@ -5435,7 +5435,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case Intrinsic::aarch64_sve_ldnt1_pn_x4: { if (VT == MVT::nxv16i8) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 4, 0, AArch64::LDNT1B_4Z_IMM_PSEUDO, AArch64::LDNT1B_4Z_PSEUDO); @@ -5447,7 +5447,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || VT == MVT::nxv8bf16) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 4, 1, AArch64::LDNT1H_4Z_IMM_PSEUDO, AArch64::LDNT1H_4Z_PSEUDO); @@ -5458,7 +5458,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 4, 2, AArch64::LDNT1W_4Z_IMM_PSEUDO, AArch64::LDNT1W_4Z_PSEUDO); @@ -5469,7 +5469,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { break; return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - if (Subtarget->hasSME2()) + if (Subtarget->hasSME2() && Subtarget->isStreaming()) SelectContiguousMultiVectorLoad(Node, 4, 3, AArch64::LDNT1D_4Z_IMM_PSEUDO, AArch64::LDNT1D_4Z_PSEUDO); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 4f13a14..4f6e3dd 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -164,6 +164,9 @@ static cl::opt<bool> UseFEATCPACodegen( /// Value type used for condition codes. static const MVT MVT_CC = MVT::i32; +/// Value type used for NZCV flags. +static constexpr MVT FlagsVT = MVT::i32; + static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2, AArch64::X3, AArch64::X4, AArch64::X5, AArch64::X6, AArch64::X7}; @@ -3098,6 +3101,83 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, return BB; } +// Helper function to find the instruction that defined a virtual register. +// If unable to find such instruction, returns nullptr. +static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI, + Register Reg) { + while (Reg.isVirtual()) { + MachineInstr *DefMI = MRI.getVRegDef(Reg); + assert(DefMI && "Virtual register definition not found"); + unsigned Opcode = DefMI->getOpcode(); + + if (Opcode == AArch64::COPY) { + Reg = DefMI->getOperand(1).getReg(); + // Vreg is defined by copying from physreg. + if (Reg.isPhysical()) + return DefMI; + continue; + } + if (Opcode == AArch64::SUBREG_TO_REG) { + Reg = DefMI->getOperand(2).getReg(); + continue; + } + + return DefMI; + } + return nullptr; +} + +void AArch64TargetLowering::fixupPtrauthDiscriminator( + MachineInstr &MI, MachineBasicBlock *BB, MachineOperand &IntDiscOp, + MachineOperand &AddrDiscOp, const TargetRegisterClass *AddrDiscRC) const { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + const DebugLoc &DL = MI.getDebugLoc(); + + Register AddrDisc = AddrDiscOp.getReg(); + int64_t IntDisc = IntDiscOp.getImm(); + assert(IntDisc == 0 && "Blend components are already expanded"); + + const MachineInstr *DiscMI = stripVRegCopies(MRI, AddrDisc); + if (DiscMI) { + switch (DiscMI->getOpcode()) { + case AArch64::MOVKXi: + // blend(addr, imm) which is lowered as "MOVK addr, #imm, #48". + // #imm should be an immediate and not a global symbol, for example. + if (DiscMI->getOperand(2).isImm() && + DiscMI->getOperand(3).getImm() == 48) { + AddrDisc = DiscMI->getOperand(1).getReg(); + IntDisc = DiscMI->getOperand(2).getImm(); + } + break; + case AArch64::MOVi32imm: + case AArch64::MOVi64imm: + // Small immediate integer constant passed via VReg. + if (DiscMI->getOperand(1).isImm() && + isUInt<16>(DiscMI->getOperand(1).getImm())) { + AddrDisc = AArch64::NoRegister; + IntDisc = DiscMI->getOperand(1).getImm(); + } + break; + } + } + + // For uniformity, always use NoRegister, as XZR is not necessarily contained + // in the requested register class. + if (AddrDisc == AArch64::XZR) + AddrDisc = AArch64::NoRegister; + + // Make sure AddrDisc operand respects the register class imposed by MI. + if (AddrDisc && MRI.getRegClass(AddrDisc) != AddrDiscRC) { + Register TmpReg = MRI.createVirtualRegister(AddrDiscRC); + BuildMI(*BB, MI, DL, TII->get(AArch64::COPY), TmpReg).addReg(AddrDisc); + AddrDisc = TmpReg; + } + + AddrDiscOp.setReg(AddrDisc); + IntDiscOp.setImm(IntDisc); +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3196,6 +3276,11 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true); case AArch64::MOVT_TIZ_PSEUDO: return EmitZTInstr(MI, BB, AArch64::MOVT_TIZ, /*Op0IsDef=*/true); + + case AArch64::PAC: + fixupPtrauthDiscriminator(MI, BB, MI.getOperand(3), MI.getOperand(4), + &AArch64::GPR64noipRegClass); + return BB; } } @@ -3451,7 +3536,7 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL, } unsigned Opcode = IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP; - return DAG.getNode(Opcode, DL, {MVT::i32, MVT::Other}, {Chain, LHS, RHS}); + return DAG.getNode(Opcode, DL, {FlagsVT, MVT::Other}, {Chain, LHS, RHS}); } static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, @@ -3465,7 +3550,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS); RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS); } - return DAG.getNode(AArch64ISD::FCMP, DL, MVT::i32, LHS, RHS); + return DAG.getNode(AArch64ISD::FCMP, DL, FlagsVT, LHS, RHS); } // The CMP instruction is just an alias for SUBS, and representing it as @@ -3490,7 +3575,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, // (a.k.a. ANDS) except that the flags are only guaranteed to work for one // of the signed comparisons. const SDValue ANDSNode = - DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(VT, MVT_CC), + DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(VT, FlagsVT), LHS.getOperand(0), LHS.getOperand(1)); // Replace all users of (and X, Y) with newly generated (ands X, Y) DAG.ReplaceAllUsesWith(LHS, ANDSNode); @@ -3501,7 +3586,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, } } - return DAG.getNode(Opcode, DL, DAG.getVTList(VT, MVT_CC), LHS, RHS) + return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS) .getValue(1); } @@ -3597,7 +3682,7 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC); unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC); SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32); - return DAG.getNode(Opcode, DL, MVT_CC, LHS, RHS, NZCVOp, Condition, CCOp); + return DAG.getNode(Opcode, DL, FlagsVT, LHS, RHS, NZCVOp, Condition, CCOp); } /// Returns true if @p Val is a tree of AND/OR/SETCC operations that can be @@ -4036,7 +4121,7 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul); // Check that the result fits into a 32-bit integer. - SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); if (IsSigned) { // cmp xreg, wreg, sxtw SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value); @@ -4059,12 +4144,12 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { DAG.getConstant(63, DL, MVT::i64)); // It is important that LowerBits is last, otherwise the arithmetic // shift will not be folded into the compare (SUBS). - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits) .getValue(1); } else { SDValue UpperBits = DAG.getNode(ISD::MULHU, DL, MVT::i64, LHS, RHS); - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, DAG.getConstant(0, DL, MVT::i64), @@ -4075,7 +4160,7 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { } // switch (...) if (Opc) { - SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32); + SDVTList VTs = DAG.getVTList(Op->getValueType(0), FlagsVT); // Emit the AArch64 operation with overflow check. Value = DAG.getNode(Opc, DL, VTs, LHS, RHS); @@ -4177,7 +4262,7 @@ static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) { SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value; SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT); SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1); + DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), Op0, Op1); return Cmp.getValue(1); } @@ -4220,16 +4305,15 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG, SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry); SDLoc DL(Op); - SDVTList VTs = DAG.getVTList(VT0, VT1); - SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS, + SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, FlagsVT), OpLHS, OpRHS, OpCarryIn); SDValue OutFlag = IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG) : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry); - return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag); + return DAG.getMergeValues({Sum, OutFlag}, DL); } static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { @@ -4254,8 +4338,7 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { Overflow = DAG.getNode(AArch64ISD::CSEL, DL, MVT::i32, FVal, TVal, CCVal, Overflow); - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); - return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Value, Overflow); + return DAG.getMergeValues({Value, Overflow}, DL); } // Prefetch operands are: @@ -6439,7 +6522,9 @@ bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { } } - return true; + EVT PreExtScalarVT = ExtVal->getOperand(0).getValueType().getScalarType(); + return PreExtScalarVT == MVT::i8 || PreExtScalarVT == MVT::i16 || + PreExtScalarVT == MVT::i32 || PreExtScalarVT == MVT::i64; } unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { @@ -6811,7 +6896,8 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64)); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other), - {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, + {StoreNode->getChain(), DAG.getBitcast(MVT::v2i64, Lo), + DAG.getBitcast(MVT::v2i64, Hi), StoreNode->getBasePtr()}, StoreNode->getMemoryVT(), StoreNode->getMemOperand()); return Result; } @@ -7035,9 +7121,8 @@ SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Op.getOperand(0)); // Generate SUBS & CSEL. - SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32), - Op.getOperand(0), DAG.getConstant(0, DL, VT)); + SDValue Cmp = DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), + Op.getOperand(0), DAG.getConstant(0, DL, VT)); return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg, DAG.getConstant(AArch64CC::PL, DL, MVT::i32), Cmp.getValue(1)); @@ -8867,6 +8952,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, bool &IsTailCall = CLI.IsTailCall; CallingConv::ID &CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); MachineFunction::CallSiteInfo CSInfo; @@ -8906,6 +8992,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, *DAG.getContext()); RetCCInfo.AnalyzeCallResult(Ins, RetCC); + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); + // Check callee args/returns for SVE registers and set calling convention // accordingly. if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) { @@ -11106,7 +11196,7 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op, SDValue Carry = Op.getOperand(2); // SBCS uses a carry not a borrow so the carry flag should be inverted first. SDValue InvCarry = valueToCarryFlag(Carry, DAG, true); - SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, MVT::Glue), + SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS, InvCarry); EVT OpVT = Op.getValueType(); @@ -11240,7 +11330,7 @@ static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal, SDValue AArch64TargetLowering::LowerSELECT_CC( ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal, - iterator_range<SDNode::user_iterator> Users, bool HasNoNaNs, + iterator_range<SDNode::user_iterator> Users, SDNodeFlags Flags, const SDLoc &DL, SelectionDAG &DAG) const { // Handle f128 first, because it will result in a comparison of some RTLIB // call result against zero. @@ -11301,6 +11391,22 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( return DAG.getNode(ISD::AND, DL, VT, LHS, Shift); } + // Canonicalise absolute difference patterns: + // select_cc lhs, rhs, sub(lhs, rhs), sub(rhs, lhs), cc -> + // select_cc lhs, rhs, sub(lhs, rhs), neg(sub(lhs, rhs)), cc + // + // select_cc lhs, rhs, sub(rhs, lhs), sub(lhs, rhs), cc -> + // select_cc lhs, rhs, neg(sub(lhs, rhs)), sub(lhs, rhs), cc + // The second forms can be matched into subs+cneg. + if (TVal.getOpcode() == ISD::SUB && FVal.getOpcode() == ISD::SUB) { + if (TVal.getOperand(0) == LHS && TVal.getOperand(1) == RHS && + FVal.getOperand(0) == RHS && FVal.getOperand(1) == LHS) + FVal = DAG.getNegative(TVal, DL, TVal.getValueType()); + else if (TVal.getOperand(0) == RHS && TVal.getOperand(1) == LHS && + FVal.getOperand(0) == LHS && FVal.getOperand(1) == RHS) + TVal = DAG.getNegative(FVal, DL, FVal.getValueType()); + } + unsigned Opcode = AArch64ISD::CSEL; // If both the TVal and the FVal are constants, see if we can swap them in @@ -11438,7 +11544,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( return true; } })) { - bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs; + bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Flags.hasNoNaNs(); SDValue VectorCmp = emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, DL, DAG); if (VectorCmp) @@ -11452,7 +11558,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC( AArch64CC::CondCode CC1, CC2; changeFPCCToAArch64CC(CC, CC1, CC2); - if (DAG.getTarget().Options.UnsafeFPMath) { + if (Flags.hasNoSignedZeros()) { // Transform "a == 0.0 ? 0.0 : x" to "a == 0.0 ? a : x" and // "a != 0.0 ? x : 0.0" to "a != 0.0 ? x : a" to avoid materializing 0.0. ConstantFPSDNode *RHSVal = dyn_cast<ConstantFPSDNode>(RHS); @@ -11531,10 +11637,9 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op, SDValue RHS = Op.getOperand(1); SDValue TVal = Op.getOperand(2); SDValue FVal = Op.getOperand(3); - bool HasNoNans = Op->getFlags().hasNoNaNs(); + SDNodeFlags Flags = Op->getFlags(); SDLoc DL(Op); - return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, - DAG); + return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), Flags, DL, DAG); } SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, @@ -11542,7 +11647,6 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, SDValue CCVal = Op->getOperand(0); SDValue TVal = Op->getOperand(1); SDValue FVal = Op->getOperand(2); - bool HasNoNans = Op->getFlags().hasNoNaNs(); SDLoc DL(Op); EVT Ty = Op.getValueType(); @@ -11609,8 +11713,8 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, DAG.getUNDEF(MVT::f32), FVal); } - SDValue Res = - LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, DAG); + SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), + Op->getFlags(), DL, DAG); if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) { return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res); @@ -12207,7 +12311,9 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, SDLoc DL(Operand); EVT VT = Operand.getValueType(); - SDNodeFlags Flags = SDNodeFlags::AllowReassociation; + // Ensure nodes can be recognized by isAssociativeAndCommutative. + SDNodeFlags Flags = + SDNodeFlags::AllowReassociation | SDNodeFlags::NoSignedZeros; // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2) // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N) @@ -12439,10 +12545,10 @@ SDValue AArch64TargetLowering::LowerAsmOutputForConstraint( // Get NZCV register. Only update chain when copyfrom is glued. if (Glue.getNode()) { - Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32, Glue); + Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, FlagsVT, Glue); Chain = Glue.getValue(1); } else - Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32); + Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, FlagsVT); // Extract CC code. SDValue CC = getSETCC(Cond, Glue, DL, DAG); @@ -16589,7 +16695,7 @@ bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const { return !(isFMAFasterThanFMulAndFAdd(*F, Ty) && isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) && (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath)); + I->getFastMathFlags().allowContract())); } // All 32-bit GPR operations implicitly zero the high-half of the corresponding @@ -17155,7 +17261,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor, /// %vec0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0 /// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1 bool AArch64TargetLowering::lowerInterleavedLoad( - LoadInst *LI, ArrayRef<ShuffleVectorInst *> Shuffles, + Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles, ArrayRef<unsigned> Indices, unsigned Factor) const { assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); @@ -17163,6 +17269,11 @@ bool AArch64TargetLowering::lowerInterleavedLoad( assert(Shuffles.size() == Indices.size() && "Unmatched number of shufflevectors and indices"); + auto *LI = dyn_cast<LoadInst>(Load); + if (!LI) + return false; + assert(!Mask && "Unexpected mask on a load"); + const DataLayout &DL = LI->getDataLayout(); VectorType *VTy = Shuffles[0]->getType(); @@ -17336,12 +17447,17 @@ bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) { /// %sub.v1 = shuffle <32 x i32> %v0, <32 x i32> v1, <32, 33, 34, 35> /// %sub.v2 = shuffle <32 x i32> %v0, <32 x i32> v1, <16, 17, 18, 19> /// call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr) -bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, +bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, + Value *LaneMask, ShuffleVectorInst *SVI, unsigned Factor) const { assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); + auto *SI = dyn_cast<StoreInst>(Store); + if (!SI) + return false; + assert(!LaneMask && "Unexpected mask on store"); auto *VecTy = cast<FixedVectorType>(SVI->getType()); assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); @@ -17486,9 +17602,8 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, } bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( - Instruction *Load, Value *Mask, - ArrayRef<Value *> DeinterleavedValues) const { - unsigned Factor = DeinterleavedValues.size(); + Instruction *Load, Value *Mask, IntrinsicInst *DI) const { + const unsigned Factor = getDeinterleaveIntrinsicFactor(DI->getIntrinsicID()); if (Factor != 2 && Factor != 4) { LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n"); return false; @@ -17498,9 +17613,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( return false; assert(!Mask && "Unexpected mask on a load\n"); - Value *FirstActive = *llvm::find_if(DeinterleavedValues, - [](Value *V) { return V != nullptr; }); - VectorType *VTy = cast<VectorType>(FirstActive->getType()); + VectorType *VTy = getDeinterleavedVectorType(DI); const DataLayout &DL = LI->getModule()->getDataLayout(); bool UseScalable; @@ -17528,6 +17641,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( Builder.CreateVectorSplat(LdTy->getElementCount(), Builder.getTrue()); Value *BaseAddr = LI->getPointerOperand(); + Value *Result = nullptr; if (NumLoads > 1) { // Create multiple legal small ldN. SmallVector<Value *, 4> ExtractedLdValues(Factor, PoisonValue::get(VTy)); @@ -17548,35 +17662,35 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( } LLVM_DEBUG(dbgs() << "LdN4 res: "; LdN->dump()); } - // Replace output of deinterleave2 intrinsic by output of ldN2/ldN4 - for (unsigned J = 0; J < Factor; ++J) { - if (DeinterleavedValues[J]) - DeinterleavedValues[J]->replaceAllUsesWith(ExtractedLdValues[J]); - } + + // Merge the values from different factors. + Result = PoisonValue::get(DI->getType()); + for (unsigned J = 0; J < Factor; ++J) + Result = Builder.CreateInsertValue(Result, ExtractedLdValues[J], J); } else { - Value *Result; if (UseScalable) Result = Builder.CreateCall(LdNFunc, {Pred, BaseAddr}, "ldN"); else Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN"); - // Replace output of deinterleave2 intrinsic by output of ldN2/ldN4 - for (unsigned I = 0; I < Factor; I++) { - if (DeinterleavedValues[I]) { - Value *NewExtract = Builder.CreateExtractValue(Result, I); - DeinterleavedValues[I]->replaceAllUsesWith(NewExtract); - } - } } + + // Replace output of deinterleave2 intrinsic by output of ldN2/ldN4 + DI->replaceAllUsesWith(Result); return true; } bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore( - StoreInst *SI, ArrayRef<Value *> InterleavedValues) const { + Instruction *Store, Value *Mask, + ArrayRef<Value *> InterleavedValues) const { unsigned Factor = InterleavedValues.size(); if (Factor != 2 && Factor != 4) { LLVM_DEBUG(dbgs() << "Matching st2 and st4 patterns failed\n"); return false; } + StoreInst *SI = dyn_cast<StoreInst>(Store); + if (!SI) + return false; + assert(!Mask && "Unexpected mask on plain store"); VectorType *VTy = cast<VectorType>(InterleavedValues[0]->getType()); const DataLayout &DL = SI->getModule()->getDataLayout(); @@ -18010,11 +18124,14 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask( unsigned ShlAmt = C2->getZExtValue(); if (auto ShouldADD = *N->user_begin(); ShouldADD->getOpcode() == ISD::ADD && ShouldADD->hasOneUse()) { - if (auto ShouldLOAD = dyn_cast<LoadSDNode>(*ShouldADD->user_begin())) { - unsigned ByteVT = ShouldLOAD->getMemoryVT().getSizeInBits() / 8; - if ((1ULL << ShlAmt) == ByteVT && - isIndexedLoadLegal(ISD::PRE_INC, ShouldLOAD->getMemoryVT())) - return false; + if (auto Load = dyn_cast<LoadSDNode>(*ShouldADD->user_begin())) { + EVT MemVT = Load->getMemoryVT(); + + if (Load->getValueType(0).isScalableVector()) + return (8ULL << ShlAmt) != MemVT.getScalarSizeInBits(); + + if (isIndexedLoadLegal(ISD::PRE_INC, MemVT)) + return (8ULL << ShlAmt) != MemVT.getFixedSizeInBits(); } } } @@ -18583,7 +18700,7 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, Created.push_back(And.getNode()); } else { SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); - SDVTList VTs = DAG.getVTList(VT, MVT::i32); + SDVTList VTs = DAG.getVTList(VT, FlagsVT); SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); @@ -19472,10 +19589,10 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { // can select to CCMN to avoid the extra mov SDValue AbsOp1 = DAG.getConstant(Op1->getAPIntValue().abs(), DL, Op1->getValueType(0)); - CCmp = DAG.getNode(AArch64ISD::CCMN, DL, MVT_CC, Cmp1.getOperand(0), AbsOp1, - NZCVOp, Condition, Cmp0); + CCmp = DAG.getNode(AArch64ISD::CCMN, DL, FlagsVT, Cmp1.getOperand(0), + AbsOp1, NZCVOp, Condition, Cmp0); } else { - CCmp = DAG.getNode(AArch64ISD::CCMP, DL, MVT_CC, Cmp1.getOperand(0), + CCmp = DAG.getNode(AArch64ISD::CCMP, DL, FlagsVT, Cmp1.getOperand(0), Cmp1.getOperand(1), NZCVOp, Condition, Cmp0); } return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0), @@ -24016,6 +24133,60 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG, Store->getMemOperand()); } +// Combine store (fp_to_int X) to use vector semantics around the conversion +// when NEON is available. This allows us to store the in-vector result directly +// without transferring the result into a GPR in the process. +static SDValue combineStoreValueFPToInt(StoreSDNode *ST, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + // Limit to post-legalization in order to avoid peeling truncating stores. + if (DCI.isBeforeLegalize()) + return SDValue(); + if (!Subtarget->isNeonAvailable()) + return SDValue(); + // Source operand is already a vector. + SDValue Value = ST->getValue(); + if (Value.getValueType().isVector()) + return SDValue(); + + // Look through potential assertions. + while (Value->isAssert()) + Value = Value.getOperand(0); + + if (Value.getOpcode() != ISD::FP_TO_SINT && + Value.getOpcode() != ISD::FP_TO_UINT) + return SDValue(); + if (!Value->hasOneUse()) + return SDValue(); + + SDValue FPSrc = Value.getOperand(0); + EVT SrcVT = FPSrc.getValueType(); + if (SrcVT != MVT::f32 && SrcVT != MVT::f64) + return SDValue(); + + // No support for assignments such as i64 = fp_to_sint i32 + EVT VT = Value.getSimpleValueType(); + if (VT != SrcVT.changeTypeToInteger()) + return SDValue(); + + // Create a 128-bit element vector to avoid widening. The floating point + // conversion is transformed into a single element conversion via a pattern. + unsigned NumElements = 128 / SrcVT.getFixedSizeInBits(); + EVT VecSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumElements); + EVT VecDstVT = VecSrcVT.changeTypeToInteger(); + SDLoc DL(ST); + SDValue VecFP = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecSrcVT, FPSrc); + SDValue VecConv = DAG.getNode(Value.getOpcode(), DL, VecDstVT, VecFP); + + SDValue Zero = DAG.getVectorIdxConstant(0, DL); + SDValue Extracted = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecConv, Zero); + + DCI.CombineTo(ST->getValue().getNode(), Extracted); + return SDValue(ST, 0); +} + bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) { return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) || (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) || @@ -24098,6 +24269,9 @@ static SDValue performSTORECombine(SDNode *N, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDLoc DL(ST); + if (SDValue Res = combineStoreValueFPToInt(ST, DCI, DAG, Subtarget)) + return Res; + auto hasValidElementTypeForFPTruncStore = [](EVT VT) { EVT EltVT = VT.getVectorElementType(); return EltVT == MVT::f32 || EltVT == MVT::f64; @@ -25124,8 +25298,9 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) { if (!TReassocOp && !FReassocOp) return SDValue(); - SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode), - DAG.getVTList(VT, MVT_CC), CmpOpOther, SubsOp); + SDValue NewCmp = + DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode), + DAG.getVTList(VT, FlagsVT), CmpOpOther, SubsOp); auto Reassociate = [&](SDValue ReassocOp, unsigned OpNum) { if (!ReassocOp) @@ -26829,6 +27004,23 @@ static SDValue performSHLCombine(SDNode *N, return DAG.getNode(ISD::AND, DL, VT, NewShift, NewRHS); } +static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { + unsigned IntrinsicID = N->getConstantOperandVal(1); + auto Register = + (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR + : AArch64SysReg::RNDRRS); + SDLoc DL(N); + SDValue A = DAG.getNode( + AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, FlagsVT, MVT::Other), + N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32)); + SDValue B = DAG.getNode( + AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1)); + return DAG.getMergeValues( + {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -27144,22 +27336,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case Intrinsic::aarch64_sve_st1_scatter_scalar_offset: return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_IMM_PRED); case Intrinsic::aarch64_rndr: - case Intrinsic::aarch64_rndrrs: { - unsigned IntrinsicID = N->getConstantOperandVal(1); - auto Register = - (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR - : AArch64SysReg::RNDRRS); - SDLoc DL(N); - SDValue A = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::i32, MVT::Other), - N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32)); - SDValue B = DAG.getNode( - AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i32), - DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1)); - return DAG.getMergeValues( - {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); - } + case Intrinsic::aarch64_rndrrs: + return performRNDRCombine(N, DAG); case Intrinsic::aarch64_sme_ldr_zt: return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N), DAG.getVTList(MVT::Other), N->getOperand(0), @@ -27897,16 +28075,16 @@ void AArch64TargetLowering::ReplaceNodeResults( MemVT.getScalarSizeInBits() == 32u || MemVT.getScalarSizeInBits() == 64u)) { + EVT HalfVT = MemVT.getHalfNumVectorElementsVT(*DAG.getContext()); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::LDNP, SDLoc(N), - DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MVT::Other}), + DAG.getVTList({MVT::v2i64, MVT::v2i64, MVT::Other}), {LoadNode->getChain(), LoadNode->getBasePtr()}, LoadNode->getMemoryVT(), LoadNode->getMemOperand()); SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT, - Result.getValue(0), Result.getValue(1)); + DAG.getBitcast(HalfVT, Result.getValue(0)), + DAG.getBitcast(HalfVT, Result.getValue(1))); Results.append({Pair, Result.getValue(2) /* Chain */}); return; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 6afb3c3..ea63edd8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -182,6 +182,13 @@ public: MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI, MachineBasicBlock *BB) const; + /// Replace (0, vreg) discriminator components with the operands of blend + /// or with (immediate, NoRegister) when possible. + void fixupPtrauthDiscriminator(MachineInstr &MI, MachineBasicBlock *BB, + MachineOperand &IntDiscOp, + MachineOperand &AddrDiscOp, + const TargetRegisterClass *AddrDiscRC) const; + MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *MBB) const override; @@ -211,19 +218,20 @@ public: unsigned getMaxSupportedInterleaveFactor() const override { return 4; } - bool lowerInterleavedLoad(LoadInst *LI, + bool lowerInterleavedLoad(Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles, ArrayRef<unsigned> Indices, unsigned Factor) const override; - bool lowerInterleavedStore(StoreInst *SI, ShuffleVectorInst *SVI, + bool lowerInterleavedStore(Instruction *Store, Value *Mask, + ShuffleVectorInst *SVI, unsigned Factor) const override; - bool lowerDeinterleaveIntrinsicToLoad( - Instruction *Load, Value *Mask, - ArrayRef<Value *> DeinterleaveValues) const override; + bool lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask, + IntrinsicInst *DI) const override; bool lowerInterleaveIntrinsicToStore( - StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override; + Instruction *Store, Value *Mask, + ArrayRef<Value *> InterleaveValues) const override; bool isLegalAddImmediate(int64_t) const override; bool isLegalAddScalableImmediate(int64_t) const override; @@ -654,7 +662,7 @@ private: SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal, iterator_range<SDNode::user_iterator> Users, - bool HasNoNans, const SDLoc &dl, + SDNodeFlags Flags, const SDLoc &dl, SelectionDAG &DAG) const; SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 996b0ed..59d4fd2 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -531,8 +531,9 @@ bool AArch64InstrInfo::analyzeBranchPredicate(MachineBasicBlock &MBB, MBP.LHS = LastInst->getOperand(0); MBP.RHS = MachineOperand::CreateImm(0); - MBP.Predicate = LastOpc == AArch64::CBNZX ? MachineBranchPredicate::PRED_NE - : MachineBranchPredicate::PRED_EQ; + MBP.Predicate = (LastOpc == AArch64::CBNZX || LastOpc == AArch64::CBNZW) + ? MachineBranchPredicate::PRED_NE + : MachineBranchPredicate::PRED_EQ; return false; } @@ -6573,10 +6574,8 @@ static bool isCombineInstrCandidateFP(const MachineInstr &Inst) { TargetOptions Options = Inst.getParent()->getParent()->getTarget().Options; // We can fuse FADD/FSUB with FMUL, if fusion is either allowed globally by // the target options or if FADD/FSUB has the contract fast-math flag. - return Options.UnsafeFPMath || - Options.AllowFPOpFusion == FPOpFusion::Fast || + return Options.AllowFPOpFusion == FPOpFusion::Fast || Inst.getFlag(MachineInstr::FmContract); - return true; } return false; } @@ -6679,9 +6678,8 @@ bool AArch64InstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst, case AArch64::FMUL_ZZZ_H: case AArch64::FMUL_ZZZ_S: case AArch64::FMUL_ZZZ_D: - return Inst.getParent()->getParent()->getTarget().Options.UnsafeFPMath || - (Inst.getFlag(MachineInstr::MIFlag::FmReassoc) && - Inst.getFlag(MachineInstr::MIFlag::FmNsz)); + return Inst.getFlag(MachineInstr::MIFlag::FmReassoc) && + Inst.getFlag(MachineInstr::MIFlag::FmNsz); // == Integer types == // -- Base instructions -- diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 6c46b18..251fd44 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -430,26 +430,27 @@ def UseWzrToVecMove : Predicate<"Subtarget->useWzrToVecMove()">; def SDTBinaryArithWithFlagsOut : SDTypeProfile<2, 2, [SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, - SDTCisInt<0>, SDTCisVT<1, i32>]>; + SDTCisInt<0>, + SDTCisVT<1, FlagsVT>]>; // SDTBinaryArithWithFlagsIn - RES1, FLAGS = op LHS, RHS, FLAGS def SDTBinaryArithWithFlagsIn : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, - SDTCisVT<3, i32>]>; + SDTCisVT<3, FlagsVT>]>; // SDTBinaryArithWithFlagsInOut - RES1, FLAGS = op LHS, RHS, FLAGS def SDTBinaryArithWithFlagsInOut : SDTypeProfile<2, 3, [SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>, - SDTCisVT<1, i32>, - SDTCisVT<4, i32>]>; + SDTCisVT<1, FlagsVT>, + SDTCisVT<4, FlagsVT>]>; def SDT_AArch64Brcond : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisVT<1, i32>, - SDTCisVT<2, i32>]>; + SDTCisVT<2, FlagsVT>]>; def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>; def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, SDTCisVT<2, OtherVT>]>; @@ -458,22 +459,22 @@ def SDT_AArch64CSel : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<3>, - SDTCisVT<4, i32>]>; + SDTCisVT<4, FlagsVT>]>; def SDT_AArch64CCMP : SDTypeProfile<1, 5, - [SDTCisVT<0, i32>, + [SDTCisVT<0, FlagsVT>, SDTCisInt<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, SDTCisInt<4>, SDTCisVT<5, i32>]>; def SDT_AArch64FCCMP : SDTypeProfile<1, 5, - [SDTCisVT<0, i32>, + [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<1, 2>, SDTCisInt<3>, SDTCisInt<4>, SDTCisVT<5, i32>]>; -def SDT_AArch64FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, +def SDT_AArch64FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, FlagsVT>, SDTCisFP<1>, SDTCisSameAs<2, 1>]>; def SDT_AArch64Rev : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>]>; @@ -518,10 +519,10 @@ def SDT_AArch64uaddlp : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; def SDT_AArch64ldp : SDTypeProfile<2, 1, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; def SDT_AArch64ldiapp : SDTypeProfile<2, 1, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; -def SDT_AArch64ldnp : SDTypeProfile<2, 1, [SDTCisVT<0, v4i32>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; +def SDT_AArch64ldnp : SDTypeProfile<2, 1, [SDTCisVT<0, v2i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; def SDT_AArch64stp : SDTypeProfile<0, 3, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; def SDT_AArch64stilp : SDTypeProfile<0, 3, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; -def SDT_AArch64stnp : SDTypeProfile<0, 3, [SDTCisVT<0, v4i32>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; +def SDT_AArch64stnp : SDTypeProfile<0, 3, [SDTCisVT<0, v2i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; // Generates the general dynamic sequences, i.e. // adrp x0, :tlsdesc:var @@ -1053,13 +1054,6 @@ def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; def AArch64uaddlv : SDNode<"AArch64ISD::UADDLV", SDT_AArch64uaddlp>; def AArch64saddlv : SDNode<"AArch64ISD::SADDLV", SDT_AArch64uaddlp>; -def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), - [(abdu node:$lhs, node:$rhs), - (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; -def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), - [(abds node:$lhs, node:$rhs), - (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; - // Add Pairwise of two vectors def AArch64addp_n : SDNode<"AArch64ISD::ADDP", SDT_AArch64Zip>; // Add Long Pairwise @@ -1131,10 +1125,10 @@ def AArch64probedalloca SDTypeProfile<0, 1, [SDTCisPtrTy<0>]>, [SDNPHasChain, SDNPMayStore]>; -// MRS, also sets the flags via a glue. +// MRS, also sets the flags. def AArch64mrs : SDNode<"AArch64ISD::MRS", SDTypeProfile<2, 1, [SDTCisVT<0, i64>, - SDTCisVT<1, i32>, + SDTCisVT<1, FlagsVT>, SDTCisVT<2, i32>]>, [SDNPHasChain]>; @@ -2039,7 +2033,7 @@ let Predicates = [HasPAuth] in { def DZB : SignAuthZero<prefix_z, 0b11, !strconcat(asm, "dzb"), op>; } - defm PAC : SignAuth<0b000, 0b010, "pac", int_ptrauth_sign>; + defm PAC : SignAuth<0b000, 0b010, "pac", null_frag>; defm AUT : SignAuth<0b001, 0b011, "aut", null_frag>; def XPACI : ClearAuth<0, "xpaci">; @@ -2159,6 +2153,26 @@ let Predicates = [HasPAuth] in { let Uses = []; } + // PAC pseudo instruction. In AsmPrinter, it is expanded into an actual PAC* + // instruction immediately preceded by the discriminator computation. + // This enforces the expected immediate modifier is used for signing, even + // if an attacker is able to substitute AddrDisc. + def PAC : Pseudo<(outs GPR64:$SignedVal), + (ins GPR64:$Val, i32imm:$Key, i64imm:$Disc, GPR64noip:$AddrDisc), + [], "$SignedVal = $Val">, Sched<[WriteI, ReadI]> { + let isCodeGenOnly = 1; + let hasSideEffects = 0; + let mayStore = 0; + let mayLoad = 0; + let Size = 12; + let Defs = [X16, X17]; + let usesCustomInserter = 1; + } + + // A standalone pattern is used, so that literal 0 can be passed as $Disc. + def : Pat<(int_ptrauth_sign GPR64:$Val, timm:$Key, GPR64noip:$AddrDisc), + (PAC GPR64:$Val, $Key, 0, GPR64noip:$AddrDisc)>; + // AUT and re-PAC a value, using different keys/data. // This directly manipulates x16/x17, which are the only registers that // certain OSs guarantee are safe to use for sensitive operations. @@ -3941,6 +3955,26 @@ defm LDRSW : LoadUI<0b10, 0, 0b10, GPR64, uimm12s4, "ldrsw", def : Pat<(i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))), (SUBREG_TO_REG (i64 0), (LDRWui GPR64sp:$Rn, uimm12s4:$offset), sub_32)>; +// load zero-extended i32, bitcast to f64 +def : Pat <(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; + +// load zero-extended i16, bitcast to f64 +def : Pat <(f64 (bitconvert (i64 (zextloadi16 (am_indexed32 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; + +// load zero-extended i8, bitcast to f64 +def : Pat <(f64 (bitconvert (i64 (zextloadi8 (am_indexed32 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; + +// load zero-extended i16, bitcast to f32 +def : Pat <(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; + +// load zero-extended i8, bitcast to f32 +def : Pat <(f32 (bitconvert (i32 (zextloadi8 (am_indexed16 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; + // Pre-fetch. def PRFMui : PrefetchUI<0b11, 0, 0b10, "prfm", [(AArch64Prefetch timm:$Rt, @@ -5667,8 +5701,7 @@ let Predicates = [HasFullFP16] in { // Advanced SIMD two vector instructions. //===----------------------------------------------------------------------===// -defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", - AArch64uabd>; +defm UABDL : SIMDLongThreeVectorBHSabdl<1, 0b0111, "uabdl", abdu>; // Match UABDL in log2-shuffle patterns. def : Pat<(abs (v8i16 (sub (zext (v8i8 V64:$opA)), (zext (v8i8 V64:$opB))))), @@ -6018,8 +6051,8 @@ defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", null_frag>; defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>; defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>; defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", - TriOpFrag<(add node:$LHS, (AArch64sabd node:$MHS, node:$RHS))> >; -defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", AArch64sabd>; + TriOpFrag<(add node:$LHS, (abds node:$MHS, node:$RHS))> >; +defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", abds>; defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", avgfloors>; defm SHSUB : SIMDThreeSameVectorBHS<0,0b00100,"shsub", int_aarch64_neon_shsub>; defm SMAXP : SIMDThreeSameVectorBHS<0,0b10100,"smaxp", int_aarch64_neon_smaxp>; @@ -6037,8 +6070,8 @@ defm SRSHL : SIMDThreeSameVector<0,0b01010,"srshl", int_aarch64_neon_srshl>; defm SSHL : SIMDThreeSameVector<0,0b01000,"sshl", int_aarch64_neon_sshl>; defm SUB : SIMDThreeSameVector<1,0b10000,"sub", sub>; defm UABA : SIMDThreeSameVectorBHSTied<1, 0b01111, "uaba", - TriOpFrag<(add node:$LHS, (AArch64uabd node:$MHS, node:$RHS))> >; -defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", AArch64uabd>; + TriOpFrag<(add node:$LHS, (abdu node:$MHS, node:$RHS))> >; +defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", abdu>; defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", avgflooru>; defm UHSUB : SIMDThreeSameVectorBHS<1,0b00100,"uhsub", int_aarch64_neon_uhsub>; defm UMAXP : SIMDThreeSameVectorBHS<1,0b10100,"umaxp", int_aarch64_neon_umaxp>; @@ -6635,6 +6668,15 @@ def : Pat<(f16 (any_uint_to_fp (i32 (any_fp_to_uint f16:$Rn)))), (UCVTFv1i16 (f16 (FCVTZUv1f16 f16:$Rn)))>; } +def : Pat<(v4i32 (any_fp_to_sint (v4f32 (scalar_to_vector (f32 FPR32:$src))))), + (v4i32 (INSERT_SUBREG (IMPLICIT_DEF), (i32 (FCVTZSv1i32 (f32 FPR32:$src))), ssub))>; +def : Pat<(v4i32 (any_fp_to_uint (v4f32 (scalar_to_vector (f32 FPR32:$src))))), + (v4i32 (INSERT_SUBREG (IMPLICIT_DEF), (i32 (FCVTZUv1i32 (f32 FPR32:$src))), ssub))>; +def : Pat<(v2i64 (any_fp_to_sint (v2f64 (scalar_to_vector (f64 FPR64:$src))))), + (v2i64 (INSERT_SUBREG (IMPLICIT_DEF), (i64 (FCVTZSv1i64 (f64 FPR64:$src))), dsub))>; +def : Pat<(v2i64 (any_fp_to_uint (v2f64 (scalar_to_vector (f64 FPR64:$src))))), + (v2i64 (INSERT_SUBREG (IMPLICIT_DEF), (i64 (FCVTZUv1i64 (f64 FPR64:$src))), dsub))>; + // int -> float conversion of value in lane 0 of simd vector should use // correct cvtf variant to avoid costly fpr <-> gpr register transfers. def : Pat<(f32 (sint_to_fp (i32 (vector_extract (v4i32 FPR128:$Rn), (i64 0))))), @@ -6759,10 +6801,8 @@ defm SUBHN : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn> defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>; defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>; defm PMULL : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>; -defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", - AArch64sabd>; -defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", - AArch64sabd>; +defm SABAL : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal", abds>; +defm SABDL : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl", abds>; defm SADDL : SIMDLongThreeVectorBHS< 0, 0b0000, "saddl", BinOpFrag<(add (sext node:$LHS), (sext node:$RHS))>>; defm SADDW : SIMDWideThreeVectorBHS< 0, 0b0001, "saddw", @@ -6780,8 +6820,7 @@ defm SSUBL : SIMDLongThreeVectorBHS<0, 0b0010, "ssubl", BinOpFrag<(sub (sext node:$LHS), (sext node:$RHS))>>; defm SSUBW : SIMDWideThreeVectorBHS<0, 0b0011, "ssubw", BinOpFrag<(sub node:$LHS, (sext node:$RHS))>>; -defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", - AArch64uabd>; +defm UABAL : SIMDLongThreeVectorTiedBHSabal<1, 0b0101, "uabal", abdu>; defm UADDL : SIMDLongThreeVectorBHS<1, 0b0000, "uaddl", BinOpFrag<(add (zanyext node:$LHS), (zanyext node:$RHS))>>; defm UADDW : SIMDWideThreeVectorBHS<1, 0b0001, "uaddw", diff --git a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp index 0ddd17c..b97d622 100644 --- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -8,11 +8,11 @@ // // This pass performs below peephole optimizations on MIR level. // -// 1. MOVi32imm + ANDWrr ==> ANDWri + ANDWri -// MOVi64imm + ANDXrr ==> ANDXri + ANDXri +// 1. MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri +// MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri // // 2. MOVi32imm + ADDWrr ==> ADDWRi + ADDWRi -// MOVi64imm + ADDXrr ==> ANDXri + ANDXri +// MOVi64imm + ADDXrr ==> ADDXri + ADDXri // // 3. MOVi32imm + SUBWrr ==> SUBWRi + SUBWRi // MOVi64imm + SUBXrr ==> SUBXri + SUBXri @@ -125,8 +125,13 @@ struct AArch64MIPeepholeOpt : public MachineFunctionPass { template <typename T> bool visitADDSSUBS(OpcodePair PosOpcs, OpcodePair NegOpcs, MachineInstr &MI); + // Strategy used to split logical immediate bitmasks. + enum class SplitStrategy { + Intersect, + }; template <typename T> - bool visitAND(unsigned Opc, MachineInstr &MI); + bool trySplitLogicalImm(unsigned Opc, MachineInstr &MI, + SplitStrategy Strategy, unsigned OtherOpc = 0); bool visitORR(MachineInstr &MI); bool visitCSEL(MachineInstr &MI); bool visitINSERT(MachineInstr &MI); @@ -158,14 +163,6 @@ INITIALIZE_PASS(AArch64MIPeepholeOpt, "aarch64-mi-peephole-opt", template <typename T> static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) { T UImm = static_cast<T>(Imm); - if (AArch64_AM::isLogicalImmediate(UImm, RegSize)) - return false; - - // If this immediate can be handled by one instruction, do not split it. - SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn; - AArch64_IMM::expandMOVImm(UImm, RegSize, Insn); - if (Insn.size() == 1) - return false; // The bitmask immediate consists of consecutive ones. Let's say there is // constant 0b00000000001000000000010000000000 which does not consist of @@ -194,12 +191,13 @@ static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) { } template <typename T> -bool AArch64MIPeepholeOpt::visitAND( - unsigned Opc, MachineInstr &MI) { +bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI, + SplitStrategy Strategy, + unsigned OtherOpc) { // Try below transformation. // - // MOVi32imm + ANDWrr ==> ANDWri + ANDWri - // MOVi64imm + ANDXrr ==> ANDXri + ANDXri + // MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri + // MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri // // The mov pseudo instruction could be expanded to multiple mov instructions // later. Let's try to split the constant operand of mov instruction into two @@ -208,10 +206,27 @@ bool AArch64MIPeepholeOpt::visitAND( return splitTwoPartImm<T>( MI, - [Opc](T Imm, unsigned RegSize, T &Imm0, - T &Imm1) -> std::optional<OpcodePair> { - if (splitBitmaskImm(Imm, RegSize, Imm0, Imm1)) - return std::make_pair(Opc, Opc); + [Opc, Strategy, OtherOpc](T Imm, unsigned RegSize, T &Imm0, + T &Imm1) -> std::optional<OpcodePair> { + // If this immediate is already a suitable bitmask, don't split it. + // TODO: Should we just combine the two instructions in this case? + if (AArch64_AM::isLogicalImmediate(Imm, RegSize)) + return std::nullopt; + + // If this immediate can be handled by one instruction, don't split it. + SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn; + AArch64_IMM::expandMOVImm(Imm, RegSize, Insn); + if (Insn.size() == 1) + return std::nullopt; + + bool SplitSucc = false; + switch (Strategy) { + case SplitStrategy::Intersect: + SplitSucc = splitBitmaskImm(Imm, RegSize, Imm0, Imm1); + break; + } + if (SplitSucc) + return std::make_pair(Opc, !OtherOpc ? Opc : OtherOpc); return std::nullopt; }, [&TII = TII](MachineInstr &MI, OpcodePair Opcode, unsigned Imm0, @@ -859,10 +874,20 @@ bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { Changed |= visitINSERT(MI); break; case AArch64::ANDWrr: - Changed |= visitAND<uint32_t>(AArch64::ANDWri, MI); + Changed |= trySplitLogicalImm<uint32_t>(AArch64::ANDWri, MI, + SplitStrategy::Intersect); break; case AArch64::ANDXrr: - Changed |= visitAND<uint64_t>(AArch64::ANDXri, MI); + Changed |= trySplitLogicalImm<uint64_t>(AArch64::ANDXri, MI, + SplitStrategy::Intersect); + break; + case AArch64::ANDSWrr: + Changed |= trySplitLogicalImm<uint32_t>( + AArch64::ANDWri, MI, SplitStrategy::Intersect, AArch64::ANDSWri); + break; + case AArch64::ANDSXrr: + Changed |= trySplitLogicalImm<uint64_t>( + AArch64::ANDXri, MI, SplitStrategy::Intersect, AArch64::ANDSXri); break; case AArch64::ORRWrs: Changed |= visitORR(MI); diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td index 61bf87f..1a7609b 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -305,7 +305,8 @@ def GPR64pi48 : RegisterOperand<GPR64, "printPostIncOperand<48>">; def GPR64pi64 : RegisterOperand<GPR64, "printPostIncOperand<64>">; // Condition code regclass. -def CCR : RegisterClass<"AArch64", [i32], 32, (add NZCV)> { +defvar FlagsVT = i32; +def CCR : RegisterClass<"AArch64", [FlagsVT], 32, (add NZCV)> { let CopyCost = -1; // Don't allow copying of status registers. // CCR is not allocatable. diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp index bafb8d0..8a5b5ba 100644 --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -32,10 +32,29 @@ AArch64SelectionDAGInfo::AArch64SelectionDAGInfo() void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG, const SDNode *N) const { + SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N); + #ifndef NDEBUG + // Some additional checks not yet implemented by verifyTargetNode. + constexpr MVT FlagsVT = MVT::i32; switch (N->getOpcode()) { - default: - return SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N); + case AArch64ISD::SUBS: + assert(N->getValueType(1) == FlagsVT); + break; + case AArch64ISD::ADC: + case AArch64ISD::SBC: + assert(N->getOperand(2).getValueType() == FlagsVT); + break; + case AArch64ISD::ADCS: + case AArch64ISD::SBCS: + assert(N->getValueType(1) == FlagsVT); + assert(N->getOperand(2).getValueType() == FlagsVT); + break; + case AArch64ISD::CSEL: + case AArch64ISD::CSINC: + case AArch64ISD::BRCOND: + assert(N->getOperand(3).getValueType() == FlagsVT); + break; case AArch64ISD::SADDWT: case AArch64ISD::SADDWB: case AArch64ISD::UADDWT: diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp index 75c7dd9..f136a184 100644 --- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp +++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp @@ -581,7 +581,6 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) { // statement if return_twice functions are called. bool StandardLifetime = !SInfo.CallsReturnTwice && - SInfo.UnrecognizedLifetimes.empty() && memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, DT, LI, ClMaxLifetimes); if (StandardLifetime) { @@ -616,10 +615,5 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) { memtag::annotateDebugRecords(Info, Tag); } - // If we have instrumented at least one alloca, all unrecognized lifetime - // intrinsics have to go. - for (auto *I : SInfo.UnrecognizedLifetimes) - I->eraseFromParent(); - return true; } diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp index 2409cc8..0f4f012 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -534,7 +534,7 @@ unsigned AArch64Subtarget::classifyGlobalFunctionReference( } void AArch64Subtarget::overrideSchedPolicy(MachineSchedPolicy &Policy, - unsigned NumRegionInstrs) const { + const SchedRegion &Region) const { // LNT run (at least on Cyclone) showed reasonably significant gains for // bi-directional scheduling. 253.perlbmk. Policy.OnlyTopDown = false; diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h index 154db3c..061ed61 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -343,7 +343,8 @@ public: } void overrideSchedPolicy(MachineSchedPolicy &Policy, - unsigned NumRegionInstrs) const override; + const SchedRegion &Region) const override; + void adjustSchedDependency(SUnit *Def, int DefOpIdx, SUnit *Use, int UseOpIdx, SDep &Dep, const TargetSchedModel *SchedModel) const override; diff --git a/llvm/lib/Target/AArch64/AArch64TargetObjectFile.cpp b/llvm/lib/Target/AArch64/AArch64TargetObjectFile.cpp index c218831..85de2d5 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetObjectFile.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetObjectFile.cpp @@ -36,7 +36,7 @@ void AArch64_ELFTargetObjectFile::Initialize(MCContext &Ctx, // SHF_AARCH64_PURECODE flag set if the "+execute-only" target feature is // present. if (TM.getMCSubtargetInfo()->hasFeature(AArch64::FeatureExecuteOnly)) { - auto *Text = cast<MCSectionELF>(TextSection); + auto *Text = static_cast<MCSectionELF *>(TextSection); Text->setFlags(Text->getFlags() | ELF::SHF_AARCH64_PURECODE); } } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 90d3d92..18ca22f 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -249,7 +249,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) { return false; } -uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const { +APInt AArch64TTIImpl::getFeatureMask(const Function &F) const { StringRef AttributeStr = isMultiversionedFunction(F) ? "fmv-features" : "target-features"; StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString(); @@ -4905,14 +4905,17 @@ void AArch64TTIImpl::getUnrollingPreferences( // Disable partial & runtime unrolling on -Os. UP.PartialOptSizeThreshold = 0; - // No need to unroll auto-vectorized loops - if (findStringMetadataForLoop(L, "llvm.loop.isvectorized")) - return; - // Scan the loop: don't unroll loops with calls as this could prevent - // inlining. + // inlining. Don't unroll auto-vectorized loops either, though do allow + // unrolling of the scalar remainder. + bool IsVectorized = getBooleanLoopAttribute(L, "llvm.loop.isvectorized"); for (auto *BB : L->getBlocks()) { for (auto &I : *BB) { + // Both auto-vectorized loops and the scalar remainder have the + // isvectorized attribute, so differentiate between them by the presence + // of vector instructions. + if (IsVectorized && I.getType()->isVectorTy()) + return; if (isa<CallBase>(I)) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) if (const Function *F = cast<CallBase>(I).getCalledFunction()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index b27eb2e..7f45177 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -89,7 +89,7 @@ public: unsigned getInlineCallPenalty(const Function *F, const CallBase &Call, unsigned DefaultCallPenalty) const override; - uint64_t getFeatureMask(const Function &F) const override; + APInt getFeatureMask(const Function &F) const override; bool isMultiversionedFunction(const Function &F) const override; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp b/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp index 0b79850..1a15075 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp @@ -50,8 +50,10 @@ bool AArch64GISelUtils::isCMN(const MachineInstr *MaybeSub, // // %sub = G_SUB 0, %y // %cmp = G_ICMP eq/ne, %z, %sub + // or with signed comparisons with the no-signed-wrap flag set if (!MaybeSub || MaybeSub->getOpcode() != TargetOpcode::G_SUB || - !CmpInst::isEquality(Pred)) + (!CmpInst::isEquality(Pred) && + !(CmpInst::isSigned(Pred) && MaybeSub->getFlag(MachineInstr::NoSWrap)))) return false; auto MaybeZero = getIConstantVRegValWithLookThrough(MaybeSub->getOperand(1).getReg(), MRI); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 1381a9b..d905692 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -1810,7 +1810,7 @@ bool AArch64InstructionSelector::selectCompareBranchFedByICmp( // Couldn't optimize. Emit a compare + a Bcc. MachineBasicBlock *DestMBB = I.getOperand(1).getMBB(); - auto PredOp = ICmp.getOperand(1); + auto &PredOp = ICmp.getOperand(1); emitIntegerCompare(ICmp.getOperand(2), ICmp.getOperand(3), PredOp, MIB); const AArch64CC::CondCode CC = changeICMPPredToAArch64CC( static_cast<CmpInst::Predicate>(PredOp.getPredicate())); @@ -2506,12 +2506,12 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) { return false; } auto &PredOp = Cmp->getOperand(1); - auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); - const AArch64CC::CondCode InvCC = - changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); MIB.setInstrAndDebugLoc(I); emitIntegerCompare(/*LHS=*/Cmp->getOperand(2), /*RHS=*/Cmp->getOperand(3), PredOp, MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); + const AArch64CC::CondCode InvCC = + changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB); I.eraseFromParent(); return true; @@ -3574,10 +3574,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) { return false; } - auto Pred = static_cast<CmpInst::Predicate>(I.getOperand(1).getPredicate()); + auto &PredOp = I.getOperand(1); + emitIntegerCompare(I.getOperand(2), I.getOperand(3), PredOp, MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred)); - emitIntegerCompare(I.getOperand(2), I.getOperand(3), I.getOperand(1), MIB); emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR, /*Src2=*/AArch64::WZR, InvCC, MIB); I.eraseFromParent(); @@ -5096,11 +5097,11 @@ bool AArch64InstructionSelector::tryOptSelect(GSelect &I) { AArch64CC::CondCode CondCode; if (CondOpc == TargetOpcode::G_ICMP) { - auto Pred = - static_cast<CmpInst::Predicate>(CondDef->getOperand(1).getPredicate()); + auto &PredOp = CondDef->getOperand(1); + emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), PredOp, + MIB); + auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); CondCode = changeICMPPredToAArch64CC(Pred); - emitIntegerCompare(CondDef->getOperand(2), CondDef->getOperand(3), - CondDef->getOperand(1), MIB); } else { // Get the condition code for the select. auto Pred = @@ -5148,29 +5149,37 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare( MachineInstr *LHSDef = getDefIgnoringCopies(LHS.getReg(), MRI); MachineInstr *RHSDef = getDefIgnoringCopies(RHS.getReg(), MRI); auto P = static_cast<CmpInst::Predicate>(Predicate.getPredicate()); + // Given this: // // x = G_SUB 0, y - // G_ICMP x, z + // G_ICMP z, x // // Produce this: // - // cmn y, z - if (isCMN(LHSDef, P, MRI)) - return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder); + // cmn z, y + if (isCMN(RHSDef, P, MRI)) + return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder); - // Same idea here, but with the RHS of the compare instead: + // Same idea here, but with the LHS of the compare instead: // // Given this: // // x = G_SUB 0, y - // G_ICMP z, x + // G_ICMP x, z // // Produce this: // - // cmn z, y - if (isCMN(RHSDef, P, MRI)) - return emitCMN(LHS, RHSDef->getOperand(2), MIRBuilder); + // cmn y, z + // + // But be careful! We need to swap the predicate! + if (isCMN(LHSDef, P, MRI)) { + if (!CmpInst::isEquality(P)) { + P = CmpInst::getSwappedPredicate(P); + Predicate = MachineOperand::CreatePredicate(P); + } + return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder); + } // Given this: // diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 473ba5e..e0e1af7 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -287,6 +287,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) .moreElementsToNextPow2(0) .lower(); + getActionDefinitionsBuilder({G_ABDS, G_ABDU}) + .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32}) + .lower(); + getActionDefinitionsBuilder( {G_SADDE, G_SSUBE, G_UADDE, G_USUBE, G_SADDO, G_SSUBO, G_UADDO, G_USUBO}) .legalFor({{s32, s32}, {s64, s32}}) @@ -1646,6 +1650,12 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, MI.eraseFromParent(); return true; }; + auto LowerTriOp = [&MI, &MIB](unsigned Opcode) { + MIB.buildInstr(Opcode, {MI.getOperand(0)}, + {MI.getOperand(2), MI.getOperand(3), MI.getOperand(4)}); + MI.eraseFromParent(); + return true; + }; Intrinsic::ID IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID(); switch (IntrinsicID) { @@ -1794,6 +1804,10 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerBinOp(AArch64::G_SMULL); case Intrinsic::aarch64_neon_umull: return LowerBinOp(AArch64::G_UMULL); + case Intrinsic::aarch64_neon_sabd: + return LowerBinOp(TargetOpcode::G_ABDS); + case Intrinsic::aarch64_neon_uabd: + return LowerBinOp(TargetOpcode::G_ABDU); case Intrinsic::aarch64_neon_abs: { // Lower the intrinsic to G_ABS. MIB.buildInstr(TargetOpcode::G_ABS, {MI.getOperand(0)}, {MI.getOperand(2)}); @@ -1820,6 +1834,10 @@ bool AArch64LegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, return LowerBinOp(TargetOpcode::G_USUBSAT); break; } + case Intrinsic::aarch64_neon_udot: + return LowerTriOp(AArch64::G_UDOT); + case Intrinsic::aarch64_neon_sdot: + return LowerTriOp(AArch64::G_SDOT); case Intrinsic::vector_reverse: // TODO: Add support for vector_reverse diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp index 1cd9453..8c10673 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp @@ -228,12 +228,13 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI, B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset))); } -// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y)) -// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1)) +// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot(x, y)) +// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot(x, y)) +// Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot(x, 1)) // Similar to performVecReduceAddCombine in SelectionDAG -bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, - const AArch64Subtarget &STI, - std::tuple<Register, Register, bool> &MatchInfo) { +bool matchExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, + const AArch64Subtarget &STI, + std::tuple<Register, Register, bool> &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && "Expected a G_VECREDUCE_ADD instruction"); assert(STI.hasDotProd() && "Target should have Dot Product feature"); @@ -246,31 +247,57 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32) return false; - LLT SrcTy; - auto I1Opc = I1->getOpcode(); - if (I1Opc == TargetOpcode::G_MUL) { + // Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT + // then the ext's must match the same opcode. It is set to the ext opcode on + // output. + auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1, + Register &Out2, unsigned &I1Opc) { // If result of this has more than 1 use, then there is no point in creating - // udot instruction - if (!MRI.hasOneNonDBGUse(MidReg)) + // a dot instruction + if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) return false; MachineInstr *ExtMI1 = - getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI); + getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI); MachineInstr *ExtMI2 = - getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI); + getDefIgnoringCopies(MI->getOperand(2).getReg(), MRI); LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg()); LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg()); if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy) return false; + if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) && + I1Opc != ExtMI1->getOpcode()) + return false; + Out1 = ExtMI1->getOperand(1).getReg(); + Out2 = ExtMI2->getOperand(1).getReg(); I1Opc = ExtMI1->getOpcode(); - SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg()); - std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg(); - std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg(); + return true; + }; + + LLT SrcTy; + unsigned I1Opc = I1->getOpcode(); + if (I1Opc == TargetOpcode::G_MUL) { + Register Out1, Out2; + if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc)) + return false; + SrcTy = MRI.getType(Out1); + std::get<0>(MatchInfo) = Out1; + std::get<1>(MatchInfo) = Out2; } else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) { - SrcTy = MRI.getType(I1->getOperand(1).getReg()); - std::get<0>(MatchInfo) = I1->getOperand(1).getReg(); - std::get<1>(MatchInfo) = 0; + Register I1Op = I1->getOperand(1).getReg(); + MachineInstr *M = getDefIgnoringCopies(I1Op, MRI); + Register Out1, Out2; + if (M->getOpcode() == TargetOpcode::G_MUL && + tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) { + SrcTy = MRI.getType(Out1); + std::get<0>(MatchInfo) = Out1; + std::get<1>(MatchInfo) = Out2; + } else { + SrcTy = MRI.getType(I1Op); + std::get<0>(MatchInfo) = I1Op; + std::get<1>(MatchInfo) = 0; + } } else { return false; } @@ -288,11 +315,11 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, return true; } -void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineIRBuilder &Builder, - GISelChangeObserver &Observer, - const AArch64Subtarget &STI, - std::tuple<Register, Register, bool> &MatchInfo) { +void applyExtAddvToDotAddv(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &Builder, + GISelChangeObserver &Observer, + const AArch64Subtarget &STI, + std::tuple<Register, Register, bool> &MatchInfo) { assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD && "Expected a G_VECREDUCE_ADD instruction"); assert(STI.hasDotProd() && "Target should have Dot Product feature"); @@ -553,15 +580,15 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI, MI.eraseFromParent(); } -// Pushes ADD/SUB through extend instructions to decrease the number of extend -// instruction at the end by allowing selection of {s|u}addl sooner - +// Pushes ADD/SUB/MUL through extend instructions to decrease the number of +// extend instruction at the end by allowing selection of {s|u}addl sooner // i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8)) bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI, Register DstReg, Register SrcReg1, Register SrcReg2) { assert((MI.getOpcode() == TargetOpcode::G_ADD || - MI.getOpcode() == TargetOpcode::G_SUB) && - "Expected a G_ADD or G_SUB instruction\n"); + MI.getOpcode() == TargetOpcode::G_SUB || + MI.getOpcode() == TargetOpcode::G_MUL) && + "Expected a G_ADD, G_SUB or G_MUL instruction\n"); // Deal with vector types only LLT DstTy = MRI.getType(DstReg); @@ -594,9 +621,10 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI, B.buildInstr(MI.getOpcode(), {MidTy}, {Ext1Reg, Ext2Reg}).getReg(0); // G_SUB has to sign-extend the result. - // G_ADD needs to sext from sext and can sext or zext from zext, so the - // original opcode is used. - if (MI.getOpcode() == TargetOpcode::G_ADD) + // G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL + // needs to use the original opcode so the original opcode is used for both. + if (MI.getOpcode() == TargetOpcode::G_ADD || + MI.getOpcode() == TargetOpcode::G_MUL) B.buildInstr(Opc, {DstReg}, {AddReg}); else B.buildSExt(DstReg, AddReg); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp index 233f42b..6257e99 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFStreamer.cpp @@ -523,7 +523,8 @@ void AArch64TargetELFStreamer::finish() { // mark it execute-only if it is empty and there is at least one // execute-only section in the object. if (any_of(Asm, [](const MCSection &Sec) { - return cast<MCSectionELF>(Sec).getFlags() & ELF::SHF_AARCH64_PURECODE; + return static_cast<const MCSectionELF &>(Sec).getFlags() & + ELF::SHF_AARCH64_PURECODE; })) { auto *Text = static_cast<MCSectionELF *>(Ctx.getObjectFileInfo()->getTextSection()); @@ -559,8 +560,7 @@ void AArch64TargetELFStreamer::finish() { if (!Sym.isMemtag()) continue; auto *SRE = MCSymbolRefExpr::create(&Sym, Ctx); - (void)S.emitRelocDirective(*Zero, "BFD_RELOC_NONE", SRE, SMLoc(), - *Ctx.getSubtargetInfo()); + S.emitRelocDirective(*Zero, "BFD_RELOC_NONE", SRE); } } diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCExpr.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCExpr.cpp index 3d4a14b..1a9bce5 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCExpr.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCExpr.cpp @@ -9,8 +9,6 @@ #include "AArch64MCAsmInfo.h" #include "llvm/MC/MCContext.h" #include "llvm/MC/MCStreamer.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" using namespace llvm; diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MachObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MachObjectWriter.cpp index 1ac340a..a22a17a 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MachObjectWriter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MachObjectWriter.cpp @@ -132,7 +132,8 @@ static bool canUseLocalRelocation(const MCSectionMachO &Section, // But only if they don't point to a few forbidden sections. if (!Symbol.isInSection()) return true; - const MCSectionMachO &RefSec = cast<MCSectionMachO>(Symbol.getSection()); + const MCSectionMachO &RefSec = + static_cast<MCSectionMachO &>(Symbol.getSection()); if (RefSec.getType() == MachO::S_CSTRING_LITERALS) return false; |