diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 170 |
1 files changed, 131 insertions, 39 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f026726..7b49754 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -164,6 +164,9 @@ static cl::opt<bool> UseFEATCPACodegen( /// Value type used for condition codes. static const MVT MVT_CC = MVT::i32; +/// Value type used for NZCV flags. +static constexpr MVT FlagsVT = MVT::i32; + static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2, AArch64::X3, AArch64::X4, AArch64::X5, AArch64::X6, AArch64::X7}; @@ -3098,6 +3101,83 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, return BB; } +// Helper function to find the instruction that defined a virtual register. +// If unable to find such instruction, returns nullptr. +static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI, + Register Reg) { + while (Reg.isVirtual()) { + MachineInstr *DefMI = MRI.getVRegDef(Reg); + assert(DefMI && "Virtual register definition not found"); + unsigned Opcode = DefMI->getOpcode(); + + if (Opcode == AArch64::COPY) { + Reg = DefMI->getOperand(1).getReg(); + // Vreg is defined by copying from physreg. + if (Reg.isPhysical()) + return DefMI; + continue; + } + if (Opcode == AArch64::SUBREG_TO_REG) { + Reg = DefMI->getOperand(2).getReg(); + continue; + } + + return DefMI; + } + return nullptr; +} + +void AArch64TargetLowering::fixupPtrauthDiscriminator( + MachineInstr &MI, MachineBasicBlock *BB, MachineOperand &IntDiscOp, + MachineOperand &AddrDiscOp, const TargetRegisterClass *AddrDiscRC) const { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + const DebugLoc &DL = MI.getDebugLoc(); + + Register AddrDisc = AddrDiscOp.getReg(); + int64_t IntDisc = IntDiscOp.getImm(); + assert(IntDisc == 0 && "Blend components are already expanded"); + + const MachineInstr *DiscMI = stripVRegCopies(MRI, AddrDisc); + if (DiscMI) { + switch (DiscMI->getOpcode()) { + case AArch64::MOVKXi: + // blend(addr, imm) which is lowered as "MOVK addr, #imm, #48". + // #imm should be an immediate and not a global symbol, for example. + if (DiscMI->getOperand(2).isImm() && + DiscMI->getOperand(3).getImm() == 48) { + AddrDisc = DiscMI->getOperand(1).getReg(); + IntDisc = DiscMI->getOperand(2).getImm(); + } + break; + case AArch64::MOVi32imm: + case AArch64::MOVi64imm: + // Small immediate integer constant passed via VReg. + if (DiscMI->getOperand(1).isImm() && + isUInt<16>(DiscMI->getOperand(1).getImm())) { + AddrDisc = AArch64::NoRegister; + IntDisc = DiscMI->getOperand(1).getImm(); + } + break; + } + } + + // For uniformity, always use NoRegister, as XZR is not necessarily contained + // in the requested register class. + if (AddrDisc == AArch64::XZR) + AddrDisc = AArch64::NoRegister; + + // Make sure AddrDisc operand respects the register class imposed by MI. + if (AddrDisc && MRI.getRegClass(AddrDisc) != AddrDiscRC) { + Register TmpReg = MRI.createVirtualRegister(AddrDiscRC); + BuildMI(*BB, MI, DL, TII->get(AArch64::COPY), TmpReg).addReg(AddrDisc); + AddrDisc = TmpReg; + } + + AddrDiscOp.setReg(AddrDisc); + IntDiscOp.setImm(IntDisc); +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3196,6 +3276,11 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true); case AArch64::MOVT_TIZ_PSEUDO: return EmitZTInstr(MI, BB, AArch64::MOVT_TIZ, /*Op0IsDef=*/true); + + case AArch64::PAC: + fixupPtrauthDiscriminator(MI, BB, MI.getOperand(3), MI.getOperand(4), + &AArch64::GPR64noipRegClass); + return BB; } } @@ -3451,7 +3536,7 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL, } unsigned Opcode = IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP; - return DAG.getNode(Opcode, DL, {MVT::i32, MVT::Other}, {Chain, LHS, RHS}); + return DAG.getNode(Opcode, DL, {FlagsVT, MVT::Other}, {Chain, LHS, RHS}); } static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, @@ -3465,7 +3550,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS); RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS); } - return DAG.getNode(AArch64ISD::FCMP, DL, MVT::i32, LHS, RHS); + return DAG.getNode(AArch64ISD::FCMP, DL, FlagsVT, LHS, RHS); } // The CMP instruction is just an alias for SUBS, and representing it as @@ -3490,7 +3575,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, // (a.k.a. ANDS) except that the flags are only guaranteed to work for one // of the signed comparisons. const SDValue ANDSNode = - DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(VT, MVT_CC), + DAG.getNode(AArch64ISD::ANDS, DL, DAG.getVTList(VT, FlagsVT), LHS.getOperand(0), LHS.getOperand(1)); // Replace all users of (and X, Y) with newly generated (ands X, Y) DAG.ReplaceAllUsesWith(LHS, ANDSNode); @@ -3501,7 +3586,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, } } - return DAG.getNode(Opcode, DL, DAG.getVTList(VT, MVT_CC), LHS, RHS) + return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS) .getValue(1); } @@ -3597,7 +3682,7 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC); unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC); SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32); - return DAG.getNode(Opcode, DL, MVT_CC, LHS, RHS, NZCVOp, Condition, CCOp); + return DAG.getNode(Opcode, DL, FlagsVT, LHS, RHS, NZCVOp, Condition, CCOp); } /// Returns true if @p Val is a tree of AND/OR/SETCC operations that can be @@ -4036,7 +4121,7 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul); // Check that the result fits into a 32-bit integer. - SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); if (IsSigned) { // cmp xreg, wreg, sxtw SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value); @@ -4059,12 +4144,12 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { DAG.getConstant(63, DL, MVT::i64)); // It is important that LowerBits is last, otherwise the arithmetic // shift will not be folded into the compare (SUBS). - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits) .getValue(1); } else { SDValue UpperBits = DAG.getNode(ISD::MULHU, DL, MVT::i64, LHS, RHS); - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + SDVTList VTs = DAG.getVTList(MVT::i64, FlagsVT); Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, DAG.getConstant(0, DL, MVT::i64), @@ -4075,7 +4160,7 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { } // switch (...) if (Opc) { - SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32); + SDVTList VTs = DAG.getVTList(Op->getValueType(0), FlagsVT); // Emit the AArch64 operation with overflow check. Value = DAG.getNode(Opc, DL, VTs, LHS, RHS); @@ -4177,7 +4262,7 @@ static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) { SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value; SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT); SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1); + DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), Op0, Op1); return Cmp.getValue(1); } @@ -4220,16 +4305,15 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG, SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry); SDLoc DL(Op); - SDVTList VTs = DAG.getVTList(VT0, VT1); - SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS, + SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, FlagsVT), OpLHS, OpRHS, OpCarryIn); SDValue OutFlag = IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG) : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry); - return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag); + return DAG.getMergeValues({Sum, OutFlag}, DL); } static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { @@ -4254,8 +4338,7 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { Overflow = DAG.getNode(AArch64ISD::CSEL, DL, MVT::i32, FVal, TVal, CCVal, Overflow); - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); - return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Value, Overflow); + return DAG.getMergeValues({Value, Overflow}, DL); } // Prefetch operands are: @@ -6813,7 +6896,8 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64)); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other), - {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, + {StoreNode->getChain(), DAG.getBitcast(MVT::v2i64, Lo), + DAG.getBitcast(MVT::v2i64, Hi), StoreNode->getBasePtr()}, StoreNode->getMemoryVT(), StoreNode->getMemOperand()); return Result; } @@ -7037,9 +7121,8 @@ SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Op.getOperand(0)); // Generate SUBS & CSEL. - SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32), - Op.getOperand(0), DAG.getConstant(0, DL, VT)); + SDValue Cmp = DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT), + Op.getOperand(0), DAG.getConstant(0, DL, VT)); return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg, DAG.getConstant(AArch64CC::PL, DL, MVT::i32), Cmp.getValue(1)); @@ -11108,7 +11191,7 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op, SDValue Carry = Op.getOperand(2); // SBCS uses a carry not a borrow so the carry flag should be inverted first. SDValue InvCarry = valueToCarryFlag(Carry, DAG, true); - SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, MVT::Glue), + SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS, InvCarry); EVT OpVT = Op.getValueType(); @@ -12441,10 +12524,10 @@ SDValue AArch64TargetLowering::LowerAsmOutputForConstraint( // Get NZCV register. Only update chain when copyfrom is glued. if (Glue.getNode()) { - Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32, Glue); + Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, FlagsVT, Glue); Chain = Glue.getValue(1); } else - Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32); + Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, FlagsVT); // Extract CC code. SDValue CC = getSETCC(Cond, Glue, DL, DAG); @@ -17343,12 +17426,17 @@ bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) { /// %sub.v1 = shuffle <32 x i32> %v0, <32 x i32> v1, <32, 33, 34, 35> /// %sub.v2 = shuffle <32 x i32> %v0, <32 x i32> v1, <16, 17, 18, 19> /// call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr) -bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, +bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, + Value *LaneMask, ShuffleVectorInst *SVI, unsigned Factor) const { assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && "Invalid interleave factor"); + auto *SI = dyn_cast<StoreInst>(Store); + if (!SI) + return false; + assert(!LaneMask && "Unexpected mask on store"); auto *VecTy = cast<FixedVectorType>(SVI->getType()); assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); @@ -18015,11 +18103,14 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask( unsigned ShlAmt = C2->getZExtValue(); if (auto ShouldADD = *N->user_begin(); ShouldADD->getOpcode() == ISD::ADD && ShouldADD->hasOneUse()) { - if (auto ShouldLOAD = dyn_cast<LoadSDNode>(*ShouldADD->user_begin())) { - unsigned ByteVT = ShouldLOAD->getMemoryVT().getSizeInBits() / 8; - if ((1ULL << ShlAmt) == ByteVT && - isIndexedLoadLegal(ISD::PRE_INC, ShouldLOAD->getMemoryVT())) - return false; + if (auto Load = dyn_cast<LoadSDNode>(*ShouldADD->user_begin())) { + EVT MemVT = Load->getMemoryVT(); + + if (Load->getValueType(0).isScalableVector()) + return (8ULL << ShlAmt) != MemVT.getScalarSizeInBits(); + + if (isIndexedLoadLegal(ISD::PRE_INC, MemVT)) + return (8ULL << ShlAmt) != MemVT.getFixedSizeInBits(); } } } @@ -18588,7 +18679,7 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, Created.push_back(And.getNode()); } else { SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); - SDVTList VTs = DAG.getVTList(VT, MVT::i32); + SDVTList VTs = DAG.getVTList(VT, FlagsVT); SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); @@ -19477,10 +19568,10 @@ static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) { // can select to CCMN to avoid the extra mov SDValue AbsOp1 = DAG.getConstant(Op1->getAPIntValue().abs(), DL, Op1->getValueType(0)); - CCmp = DAG.getNode(AArch64ISD::CCMN, DL, MVT_CC, Cmp1.getOperand(0), AbsOp1, - NZCVOp, Condition, Cmp0); + CCmp = DAG.getNode(AArch64ISD::CCMN, DL, FlagsVT, Cmp1.getOperand(0), + AbsOp1, NZCVOp, Condition, Cmp0); } else { - CCmp = DAG.getNode(AArch64ISD::CCMP, DL, MVT_CC, Cmp1.getOperand(0), + CCmp = DAG.getNode(AArch64ISD::CCMP, DL, FlagsVT, Cmp1.getOperand(0), Cmp1.getOperand(1), NZCVOp, Condition, Cmp0); } return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0), @@ -25129,8 +25220,9 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) { if (!TReassocOp && !FReassocOp) return SDValue(); - SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode), - DAG.getVTList(VT, MVT_CC), CmpOpOther, SubsOp); + SDValue NewCmp = + DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode), + DAG.getVTList(VT, FlagsVT), CmpOpOther, SubsOp); auto Reassociate = [&](SDValue ReassocOp, unsigned OpNum) { if (!ReassocOp) @@ -27156,7 +27248,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, : AArch64SysReg::RNDRRS); SDLoc DL(N); SDValue A = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::i32, MVT::Other), + AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, FlagsVT, MVT::Other), N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32)); SDValue B = DAG.getNode( AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), @@ -27902,16 +27994,16 @@ void AArch64TargetLowering::ReplaceNodeResults( MemVT.getScalarSizeInBits() == 32u || MemVT.getScalarSizeInBits() == 64u)) { + EVT HalfVT = MemVT.getHalfNumVectorElementsVT(*DAG.getContext()); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::LDNP, SDLoc(N), - DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - MVT::Other}), + DAG.getVTList({MVT::v2i64, MVT::v2i64, MVT::Other}), {LoadNode->getChain(), LoadNode->getBasePtr()}, LoadNode->getMemoryVT(), LoadNode->getMemOperand()); SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT, - Result.getValue(0), Result.getValue(1)); + DAG.getBitcast(HalfVT, Result.getValue(0)), + DAG.getBitcast(HalfVT, Result.getValue(1))); Results.append({Pair, Result.getValue(2) /* Chain */}); return; } |