diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1079 |
1 files changed, 808 insertions, 271 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 40e6400..30eb190 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16,11 +16,11 @@ #include "AArch64MachineFunctionInfo.h" #include "AArch64PerfectShuffle.h" #include "AArch64RegisterInfo.h" +#include "AArch64SMEAttributes.h" #include "AArch64Subtarget.h" #include "AArch64TargetMachine.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "Utils/AArch64BaseInfo.h" -#include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -387,7 +387,7 @@ extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) { AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) - : TargetLowering(TM), Subtarget(&STI) { + : TargetLowering(TM, STI), Subtarget(&STI) { // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so // we have to make something up. Arbitrarily, choose ZeroOrOne. setBooleanContents(ZeroOrOneBooleanContent); @@ -445,6 +445,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass); + // Add sve predicate as counter type + addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass); + // Add legal sve data types addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass); @@ -473,15 +476,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } } - if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { - addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass); - setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1); - setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1); - - setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom); - setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand); - } - // Compute derived properties from the register classes computeRegisterProperties(Subtarget->getRegisterInfo()); @@ -536,7 +530,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FREM, MVT::f32, Expand); setOperationAction(ISD::FREM, MVT::f64, Expand); - setOperationAction(ISD::FREM, MVT::f80, Expand); setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand); @@ -1433,12 +1426,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::v2i16, Custom); setOperationAction(ISD::BITCAST, MVT::v4i8, Custom); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom); // ADDP custom lowering for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) @@ -1518,6 +1523,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32}) setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom); + + for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) + setOperationAction(ISD::FMA, VT, Custom); } if (Subtarget->isSVEorStreamingSVEAvailable()) { @@ -1585,6 +1593,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::AVGCEILS, VT, Custom); setOperationAction(ISD::AVGCEILU, VT, Custom); + setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom); + if (!Subtarget->isLittleEndian()) setOperationAction(ISD::BITCAST, VT, Custom); @@ -1609,6 +1621,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); + // Promote predicate as counter load/stores to standard predicates. + setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1); + setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1); + + // Predicate as counter legalization actions. + setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom); + setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); @@ -1769,17 +1789,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); + } - if (Subtarget->hasSVEB16B16() && - Subtarget->isNonStreamingSVEorSME2Available()) { - setOperationAction(ISD::FADD, VT, Legal); + if (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available()) { + // Note: Use SVE for bfloat16 operations when +sve-b16b16 is available. + for (auto VT : {MVT::v4bf16, MVT::v8bf16, MVT::nxv2bf16, MVT::nxv4bf16, + MVT::nxv8bf16}) { + setOperationAction(ISD::FADD, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMAXNUM, VT, Custom); setOperationAction(ISD::FMINIMUM, VT, Custom); setOperationAction(ISD::FMINNUM, VT, Custom); - setOperationAction(ISD::FMUL, VT, Legal); - setOperationAction(ISD::FSUB, VT, Legal); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); } } @@ -1795,22 +1819,37 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (!Subtarget->hasSVEB16B16() || !Subtarget->isNonStreamingSVEorSME2Available()) { - for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM, - ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) { - setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32); - setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); - setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); + for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { + MVT PromotedVT = VT.changeVectorElementType(MVT::f32); + setOperationPromotedToType(ISD::FADD, VT, PromotedVT); + setOperationPromotedToType(ISD::FMA, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT); + setOperationPromotedToType(ISD::FSUB, VT, PromotedVT); + + if (VT != MVT::nxv2bf16 && Subtarget->hasBF16()) + setOperationAction(ISD::FMUL, VT, Custom); + else + setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); } + + if (Subtarget->hasBF16() && Subtarget->isNeonAvailable()) + setOperationAction(ISD::FMUL, MVT::v8bf16, Custom); } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); - // NEON doesn't support integer divides, but SVE does + // A number of operations like MULH and integer divides are not supported by + // NEON but are available in SVE. for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) { setOperationAction(ISD::SDIV, VT, Custom); setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); } // NEON doesn't support 64-bit vector integer muls, but SVE does. @@ -1847,10 +1886,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTLZ, MVT::v1i64, Custom); setOperationAction(ISD::CTLZ, MVT::v2i64, Custom); setOperationAction(ISD::CTTZ, MVT::v1i64, Custom); - setOperationAction(ISD::MULHS, MVT::v1i64, Custom); - setOperationAction(ISD::MULHS, MVT::v2i64, Custom); - setOperationAction(ISD::MULHU, MVT::v1i64, Custom); - setOperationAction(ISD::MULHU, MVT::v2i64, Custom); setOperationAction(ISD::SMAX, MVT::v1i64, Custom); setOperationAction(ISD::SMAX, MVT::v2i64, Custom); setOperationAction(ISD::SMIN, MVT::v1i64, Custom); @@ -1872,8 +1907,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_AND, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); - setOperationAction(ISD::MULHS, VT, Custom); - setOperationAction(ISD::MULHU, VT, Custom); } // Use SVE for vectors with more than 2 elements. @@ -1916,6 +1949,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal); setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal); } + + // Handle floating-point partial reduction + if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32, + MVT::nxv8f16, Legal); + // We can use SVE2p1 fdot to emulate the fixed-length variant. + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32, + MVT::v8f16, Custom); + } } // Handle non-aliasing elements mask @@ -1951,10 +1993,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // We can lower types that have <vscale x {2|4}> elements to compact. - for (auto VT : - {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32, - MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32}) + for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, + MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, + MVT::nxv4i32, MVT::nxv4f32}) { setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom); + // Use a custom lowering for masked stores that could be a supported + // compressing store. Note: These types still use the normal (Legal) + // lowering for non-compressing masked stores. + setOperationAction(ISD::MSTORE, VT, Custom); + } // If we have SVE, we can use SVE logic for legal (or smaller than legal) // NEON vectors in the lowest bits of the SVE register. @@ -2283,6 +2330,11 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { MVT::getVectorVT(MVT::i8, NumElts * 8), Custom); } + if (Subtarget->hasSVE2p1() && VT.getVectorElementType() == MVT::f32) { + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT, + MVT::getVectorVT(MVT::f16, NumElts * 2), Custom); + } + // Lower fixed length vector operations to scalable equivalents. setOperationAction(ISD::ABDS, VT, Default); setOperationAction(ISD::ABDU, VT, Default); @@ -2542,7 +2594,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant( return false; // Exit early if we demand all bits. - if (DemandedBits.popcount() == Size) + if (DemandedBits.isAllOnes()) return false; unsigned NewOpc; @@ -3858,22 +3910,30 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS, /// \param MustBeFirst Set to true if this subtree needs to be negated and we /// cannot do the negation naturally. We are required to /// emit the subtree first in this case. +/// \param PreferFirst Set to true if processing this subtree first may +/// result in more efficient code. /// \param WillNegate Is true if are called when the result of this /// subexpression must be negated. This happens when the /// outer expression is an OR. We can use this fact to know /// that we have a double negation (or (or ...) ...) that /// can be implemented for free. -static bool canEmitConjunction(const SDValue Val, bool &CanNegate, - bool &MustBeFirst, bool WillNegate, +static bool canEmitConjunction(SelectionDAG &DAG, const SDValue Val, + bool &CanNegate, bool &MustBeFirst, + bool &PreferFirst, bool WillNegate, unsigned Depth = 0) { if (!Val.hasOneUse()) return false; unsigned Opcode = Val->getOpcode(); if (Opcode == ISD::SETCC) { - if (Val->getOperand(0).getValueType() == MVT::f128) + EVT VT = Val->getOperand(0).getValueType(); + if (VT == MVT::f128) return false; CanNegate = true; MustBeFirst = false; + // Designate this operation as a preferred first operation if the result + // of a SUB operation can be reused. + PreferFirst = DAG.doesNodeExist(ISD::SUB, DAG.getVTList(VT), + {Val->getOperand(0), Val->getOperand(1)}); return true; } // Protect against exponential runtime and stack overflow. @@ -3885,11 +3945,15 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate, SDValue O1 = Val->getOperand(1); bool CanNegateL; bool MustBeFirstL; - if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1)) + bool PreferFirstL; + if (!canEmitConjunction(DAG, O0, CanNegateL, MustBeFirstL, PreferFirstL, + IsOR, Depth + 1)) return false; bool CanNegateR; bool MustBeFirstR; - if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1)) + bool PreferFirstR; + if (!canEmitConjunction(DAG, O1, CanNegateR, MustBeFirstR, PreferFirstR, + IsOR, Depth + 1)) return false; if (MustBeFirstL && MustBeFirstR) @@ -3912,6 +3976,7 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate, CanNegate = false; MustBeFirst = MustBeFirstL || MustBeFirstR; } + PreferFirst = PreferFirstL || PreferFirstR; return true; } return false; @@ -3973,19 +4038,25 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val, SDValue LHS = Val->getOperand(0); bool CanNegateL; bool MustBeFirstL; - bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR); + bool PreferFirstL; + bool ValidL = canEmitConjunction(DAG, LHS, CanNegateL, MustBeFirstL, + PreferFirstL, IsOR); assert(ValidL && "Valid conjunction/disjunction tree"); (void)ValidL; SDValue RHS = Val->getOperand(1); bool CanNegateR; bool MustBeFirstR; - bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR); + bool PreferFirstR; + bool ValidR = canEmitConjunction(DAG, RHS, CanNegateR, MustBeFirstR, + PreferFirstR, IsOR); assert(ValidR && "Valid conjunction/disjunction tree"); (void)ValidR; - // Swap sub-tree that must come first to the right side. - if (MustBeFirstL) { + bool ShouldFirstL = PreferFirstL && !PreferFirstR && !MustBeFirstR; + + // Swap sub-tree that must or should come first to the right side. + if (MustBeFirstL || ShouldFirstL) { assert(!MustBeFirstR && "Valid conjunction/disjunction tree"); std::swap(LHS, RHS); std::swap(CanNegateL, CanNegateR); @@ -4041,7 +4112,9 @@ static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val, AArch64CC::CondCode &OutCC) { bool DummyCanNegate; bool DummyMustBeFirst; - if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false)) + bool DummyPreferFirst; + if (!canEmitConjunction(DAG, Val, DummyCanNegate, DummyMustBeFirst, + DummyPreferFirst, false)) return SDValue(); return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL); @@ -4487,6 +4560,26 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG, return DAG.getMergeValues({Sum, OutFlag}, DL); } +static SDValue lowerIntNeonIntrinsic(SDValue Op, unsigned Opcode, + SelectionDAG &DAG) { + SDLoc DL(Op); + auto getFloatVT = [](EVT VT) { + assert((VT == MVT::i32 || VT == MVT::i64) && "Unexpected VT"); + return VT == MVT::i32 ? MVT::f32 : MVT::f64; + }; + auto bitcastToFloat = [&](SDValue Val) { + return DAG.getBitcast(getFloatVT(Val.getValueType()), Val); + }; + SmallVector<SDValue, 2> NewOps; + NewOps.reserve(Op.getNumOperands() - 1); + + for (unsigned I = 1, E = Op.getNumOperands(); I < E; ++I) + NewOps.push_back(bitcastToFloat(Op.getOperand(I))); + EVT OrigVT = Op.getValueType(); + SDValue OpNode = DAG.getNode(Opcode, DL, getFloatVT(OrigVT), NewOps); + return DAG.getBitcast(OrigVT, OpNode); +} + static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { // Let legalize expand this if it isn't a legal type yet. if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType())) @@ -5544,9 +5637,10 @@ SDValue AArch64TargetLowering::LowerGET_ROUNDING(SDValue Op, SDLoc DL(Op); SDValue Chain = Op.getOperand(0); - SDValue FPCR_64 = DAG.getNode( - ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, - {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)}); + SDValue FPCR_64 = + DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, + {Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, + MVT::i64)}); Chain = FPCR_64.getValue(1); SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPCR_64); SDValue FltRounds = DAG.getNode(ISD::ADD, DL, MVT::i32, FPCR_32, @@ -5632,7 +5726,8 @@ SDValue AArch64TargetLowering::LowerSET_FPMODE(SDValue Op, // Set new value of FPCR. SDValue Ops2[] = { - Chain, DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), FPCR}; + Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), + FPCR}; return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } @@ -5655,9 +5750,9 @@ SDValue AArch64TargetLowering::LowerRESET_FPMODE(SDValue Op, DAG.getConstant(AArch64::ReservedFPControlBits, DL, MVT::i64)); // Set new value of FPCR. - SDValue Ops2[] = {Chain, - DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), - FPSCRMasked}; + SDValue Ops2[] = { + Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), + FPSCRMasked}; return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } @@ -5735,8 +5830,10 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { if (VT.is64BitVector()) { if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && isNullConstant(N0.getOperand(1)) && + N0.getOperand(0).getValueType().is128BitVector() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - isNullConstant(N1.getOperand(1))) { + isNullConstant(N1.getOperand(1)) && + N1.getOperand(0).getValueType().is128BitVector()) { N0 = N0.getOperand(0); N1 = N1.getOperand(0); VT = N0.getValueType(); @@ -6329,26 +6426,46 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(1).getValueType(), Op.getOperand(1), Op.getOperand(2))); return SDValue(); + case Intrinsic::aarch64_neon_sqrshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQRSHL, DAG); + case Intrinsic::aarch64_neon_sqshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSHL, DAG); + case Intrinsic::aarch64_neon_uqrshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQRSHL, DAG); + case Intrinsic::aarch64_neon_uqshl: + if (Op.getValueType().isVector()) + return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSHL, DAG); case Intrinsic::aarch64_neon_sqadd: if (Op.getValueType().isVector()) return DAG.getNode(ISD::SADDSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQADD, DAG); + case Intrinsic::aarch64_neon_sqsub: if (Op.getValueType().isVector()) return DAG.getNode(ISD::SSUBSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSUB, DAG); + case Intrinsic::aarch64_neon_uqadd: if (Op.getValueType().isVector()) return DAG.getNode(ISD::UADDSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQADD, DAG); case Intrinsic::aarch64_neon_uqsub: if (Op.getValueType().isVector()) return DAG.getNode(ISD::USUBSAT, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - return SDValue(); + return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSUB, DAG); + case Intrinsic::aarch64_neon_sqdmulls_scalar: + return lowerIntNeonIntrinsic(Op, AArch64ISD::SQDMULL, DAG); case Intrinsic::aarch64_sve_whilelt: return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true, /*IsEqual=*/false); @@ -6382,9 +6499,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_lastb: return DAG.getNode(AArch64ISD::LASTB, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - case Intrinsic::aarch64_sve_rev: - return DAG.getNode(ISD::VECTOR_REVERSE, DL, Op.getValueType(), - Op.getOperand(1)); case Intrinsic::aarch64_sve_tbl: return DAG.getNode(AArch64ISD::TBL, DL, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); @@ -6710,8 +6824,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend, return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2; } +/// Helper function to check if a small vector load can be optimized. +static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD, + const AArch64Subtarget &Subtarget) { + if (!Subtarget.isNeonAvailable()) + return false; + if (LD->isVolatile()) + return false; + + EVT MemVT = LD->getMemoryVT(); + if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16) + return false; + + Align Alignment = LD->getAlign(); + Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue()); + if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment) + return false; + + return true; +} + bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { EVT ExtVT = ExtVal.getValueType(); + // Small, illegal vectors can be extended inreg. + if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) { + if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 && + isEligibleForSmallVectorLoadOpt(Load, *Subtarget)) + return true; + } if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors()) return false; @@ -7170,12 +7310,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op, return Result; } +/// Helper function to optimize loads of extended small vectors. +/// These patterns would otherwise get scalarized into inefficient sequences. +static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) { + const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); + if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget)) + return SDValue(); + + EVT MemVT = Load->getMemoryVT(); + EVT ResVT = Load->getValueType(0); + unsigned NumElts = ResVT.getVectorNumElements(); + unsigned DstEltBits = ResVT.getScalarSizeInBits(); + unsigned SrcEltBits = MemVT.getScalarSizeInBits(); + + unsigned ExtOpcode; + switch (Load->getExtensionType()) { + case ISD::EXTLOAD: + case ISD::ZEXTLOAD: + ExtOpcode = ISD::ZERO_EXTEND; + break; + case ISD::SEXTLOAD: + ExtOpcode = ISD::SIGN_EXTEND; + break; + case ISD::NON_EXTLOAD: + return SDValue(); + } + + SDLoc DL(Load); + SDValue Chain = Load->getChain(); + SDValue BasePtr = Load->getBasePtr(); + const MachinePointerInfo &PtrInfo = Load->getPointerInfo(); + Align Alignment = Load->getAlign(); + + // Load the data as an FP scalar to avoid issues with integer loads. + unsigned LoadBits = MemVT.getStoreSizeInBits(); + MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits); + SDValue ScalarLoad = + DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment); + + MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits); + SDValue ScalarToVec = + DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad); + MVT BitcastTy = + MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits); + SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec); + + SDValue Res = Bitcast; + unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits(); + unsigned CurrentNumElts = Res.getValueType().getVectorNumElements(); + while (CurrentEltBits < DstEltBits) { + if (Res.getValueSizeInBits() >= 128) { + CurrentNumElts = CurrentNumElts / 2; + MVT ExtractVT = + MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res, + DAG.getConstant(0, DL, MVT::i64)); + } + CurrentEltBits = CurrentEltBits * 2; + MVT ExtVT = + MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts); + Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res); + } + + if (CurrentNumElts != NumElts) { + MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res, + DAG.getConstant(0, DL, MVT::i64)); + } + + return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL); +} + SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); LoadSDNode *LoadNode = cast<LoadSDNode>(Op); assert(LoadNode && "Expected custom lowering of a load node"); + if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG)) + return Result; + if (LoadNode->getMemoryVT() == MVT::i64x8) { SmallVector<SDValue, 8> Ops; SDValue Base = LoadNode->getBasePtr(); @@ -7194,37 +7408,38 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, return DAG.getMergeValues({Loaded, Chain}, DL); } - // Custom lowering for extending v4i8 vector loads. - EVT VT = Op->getValueType(0); - assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32"); - - if (LoadNode->getMemoryVT() != MVT::v4i8) - return SDValue(); - - // Avoid generating unaligned loads. - if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4)) - return SDValue(); + return SDValue(); +} - unsigned ExtType; - if (LoadNode->getExtensionType() == ISD::SEXTLOAD) - ExtType = ISD::SIGN_EXTEND; - else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD || - LoadNode->getExtensionType() == ISD::EXTLOAD) - ExtType = ISD::ZERO_EXTEND; - else - return SDValue(); +// Convert to ContainerVT with no-op casts where possible. +static SDValue convertToSVEContainerType(SDLoc DL, SDValue Vec, EVT ContainerVT, + SelectionDAG &DAG) { + EVT VecVT = Vec.getValueType(); + if (VecVT.isFloatingPoint()) { + // Use no-op casts for floating-point types. + EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType()); + Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedVT, Vec); + Vec = DAG.getNode(AArch64ISD::NVCAST, DL, ContainerVT, Vec); + } else { + // Extend integers (may not be a no-op). + Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec); + } + return Vec; +} - SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(), - LoadNode->getBasePtr(), MachinePointerInfo()); - SDValue Chain = Load.getValue(1); - SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load); - SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec); - SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC); - Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext, - DAG.getConstant(0, DL, MVT::i64)); - if (VT == MVT::v4i32) - Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext); - return DAG.getMergeValues({Ext, Chain}, DL); +// Convert to VecVT with no-op casts where possible. +static SDValue convertFromSVEContainerType(SDLoc DL, SDValue Vec, EVT VecVT, + SelectionDAG &DAG) { + if (VecVT.isFloatingPoint()) { + // Use no-op casts for floating-point types. + EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType()); + Vec = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVT, Vec); + Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VecVT, Vec); + } else { + // Truncate integers (may not be a no-op). + Vec = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Vec); + } + return Vec; } SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op, @@ -7278,49 +7493,49 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op, // Get legal type for compact instruction EVT ContainerVT = getSVEContainerType(VecVT); - EVT CastVT = VecVT.changeVectorElementTypeToInteger(); - // Convert to i32 or i64 for smaller types, as these are the only supported + // Convert to 32 or 64 bits for smaller types, as these are the only supported // sizes for compact. - if (ContainerVT != VecVT) { - Vec = DAG.getBitcast(CastVT, Vec); - Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec); - } + Vec = convertToSVEContainerType(DL, Vec, ContainerVT, DAG); SDValue Compressed = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(), - DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec); + DAG.getTargetConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, + Vec); // compact fills with 0s, so if our passthru is all 0s, do nothing here. if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) { SDValue Offset = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, - DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask); + DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, + Mask); SDValue IndexMask = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, MaskVT, - DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64), + DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64), DAG.getConstant(0, DL, MVT::i64), Offset); Compressed = DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru); } + // If we changed the element type before, we need to convert it back. + if (ElmtVT.isFloatingPoint()) + Compressed = convertFromSVEContainerType(DL, Compressed, VecVT, DAG); + // Extracting from a legal SVE type before truncating produces better code. if (IsFixedLength) { - Compressed = DAG.getNode( - ISD::EXTRACT_SUBVECTOR, DL, - FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()), - Compressed, DAG.getConstant(0, DL, MVT::i64)); - CastVT = FixedVecVT.changeVectorElementTypeToInteger(); + EVT FixedSubVector = VecVT.isInteger() + ? FixedVecVT.changeVectorElementType( + ContainerVT.getVectorElementType()) + : FixedVecVT; + Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedSubVector, + Compressed, DAG.getConstant(0, DL, MVT::i64)); VecVT = FixedVecVT; } - // If we changed the element type before, we need to convert it back. - if (ContainerVT != VecVT) { - Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed); - Compressed = DAG.getBitcast(VecVT, Compressed); - } + if (VecVT.isInteger()) + Compressed = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed); return Compressed; } @@ -7428,10 +7643,10 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) { DAG.getUNDEF(ExpVT), Exp, Zero); SDValue VPg = getPTrue(DAG, DL, XVT.changeVectorElementType(MVT::i1), AArch64SVEPredPattern::all); - SDValue FScale = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XVT, - DAG.getConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64), - VPg, VX, VExp); + SDValue FScale = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, XVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64), VPg, + VX, VExp); SDValue Final = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero); if (X.getValueType() != XScalarTy) @@ -7518,6 +7733,117 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op, EndOfTrmp); } +SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + if (VT.getScalarType() != MVT::bf16 || + (Subtarget->hasSVEB16B16() && + Subtarget->isNonStreamingSVEorSME2Available())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + + assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); + assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) && + "Unexpected FMUL VT"); + + auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { + return [&, IID](EVT VT, auto... Ops) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Ops...); + }; + }; + + auto Reinterpret = [&](SDValue Value, EVT VT) { + EVT SrcVT = Value.getValueType(); + if (VT == SrcVT) + return Value; + if (SrcVT.isFixedLengthVector()) + return convertToScalableVector(DAG, VT, Value); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(DAG, VT, Value); + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value); + }; + + bool UseSVEBFMLAL = VT.isScalableVector(); + auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); + auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); + + // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant. + // This does not match BFCVTN[2], so we use SVE to convert back to bf16. + auto BFMLALB = + MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb + : Intrinsic::aarch64_neon_bfmlalb); + auto BFMLALT = + MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt + : Intrinsic::aarch64_neon_bfmlalt); + + EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32; + SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags()); + SDValue Pg = getPredicateForVector(DAG, DL, AccVT); + + // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom + // instructions. These result in two f32 vectors, which can be converted back + // to bf16 with FCVT and FCVTNT. + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + // All SVE intrinsics expect to operate on full bf16 vector types. + if (UseSVEBFMLAL) { + LHS = Reinterpret(LHS, MVT::nxv8bf16); + RHS = Reinterpret(RHS, MVT::nxv8bf16); + } + + SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32); + SDValue BottomBF16 = + FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32); + // Note: nxv4bf16 only uses even lanes. + if (VT == MVT::nxv4bf16) + return Reinterpret(BottomBF16, VT); + + SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32); + SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32); + return Reinterpret(TopBF16, VT); +} + +SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const { + SDValue OpA = Op->getOperand(0); + SDValue OpB = Op->getOperand(1); + SDValue OpC = Op->getOperand(2); + EVT VT = Op.getValueType(); + SDLoc DL(Op); + + assert(VT.isVector() && "Scalar fma lowering should be handled by patterns"); + + // Bail early if we're definitely not looking to merge FNEGs into the FMA. + if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); + + if (OpC.getOpcode() != ISD::FNEG) + return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) + ? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED) + : Op; // Fallback to NEON lowering. + + // Convert FMA/FNEG nodes to SVE to enable the following patterns: + // fma(a, b, neg(c)) -> fnmls(a, b, c) + // fma(neg(a), b, neg(c)) -> fnmla(a, b, c) + // fma(a, neg(b), neg(c)) -> fnmla(a, b, c) + SDValue Pg = getPredicateForVector(DAG, DL, VT); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + auto ConvertToScalableFnegMt = [&](SDValue Op) { + if (Op.getOpcode() == ISD::FNEG) + Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU); + return convertToScalableVector(DAG, ContainerVT, Op); + }; + + OpA = ConvertToScalableFnegMt(OpA); + OpB = ConvertToScalableFnegMt(OpB); + OpC = ConvertToScalableFnegMt(OpC); + + SDValue ScalableRes = + DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC); + return convertFromScalableVector(DAG, VT, ScalableRes); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -7592,9 +7918,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::FSUB: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); case ISD::FMUL: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); + return LowerFMUL(Op, DAG); case ISD::FMA: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); + return LowerFMA(Op, DAG); case ISD::FDIV: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); case ISD::FNEG: @@ -7639,6 +7965,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG); + case ISD::ANY_EXTEND_VECTOR_INREG: + case ISD::SIGN_EXTEND_VECTOR_INREG: + return LowerEXTEND_VECTOR_INREG(Op, DAG); case ISD::ZERO_EXTEND_VECTOR_INREG: return LowerZERO_EXTEND_VECTOR_INREG(Op, DAG); case ISD::VECTOR_SHUFFLE: @@ -7720,7 +8049,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STORE: return LowerSTORE(Op, DAG); case ISD::MSTORE: - return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + return LowerMSTORE(Op, DAG); case ISD::MGATHER: return LowerMGATHER(Op, DAG); case ISD::MSCATTER: @@ -7875,6 +8204,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: return LowerPARTIAL_REDUCE_MLA(Op, DAG); } } @@ -8094,7 +8424,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL, TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout())); SDValue TPIDR2_EL0 = DAG.getNode( ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); + DAG.getTargetConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); // Copy the address of the TPIDR2 block into X0 before 'calling' the // RESTORE_ZA pseudo. SDValue Glue; @@ -8109,7 +8439,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL, // Finally reset the TPIDR2_EL0 register to 0. Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), DAG.getConstant(0, DL, MVT::i64)); TPIDR2.Uses++; return Chain; @@ -8426,7 +8756,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments( Subtarget->isWindowsArm64EC()) && "Indirect arguments should be scalable on most subtargets"); - uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue(); + TypeSize PartSize = VA.getValVT().getStoreSize(); unsigned NumParts = 1; if (Ins[i].Flags.isInConsecutiveRegs()) { while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) @@ -8443,16 +8773,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments( InVals.push_back(ArgValue); NumParts--; if (NumParts > 0) { - SDValue BytesIncrement; - if (PartLoad.isScalableVector()) { - BytesIncrement = DAG.getVScale( - DL, Ptr.getValueType(), - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize)); - } else { - BytesIncrement = DAG.getConstant( - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL, - Ptr.getValueType()); - } + SDValue BytesIncrement = + DAG.getTypeSize(DL, Ptr.getValueType(), PartSize); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement, SDNodeFlags::NoUnsignedWrap); ExtraArgLocs++; @@ -8699,15 +9021,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } } - if (getTM().useNewSMEABILowering()) { - // Clear new ZT0 state. TODO: Move this to the SME ABI pass. - if (Attrs.isNewZT0()) - Chain = DAG.getNode( - ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32), - DAG.getTargetConstant(0, DL, MVT::i32)); - } - return Chain; } @@ -9430,6 +9743,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState()) ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE; + else if (CallAttrs.requiresPreservingZT0()) + ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE; else if (CallAttrs.caller().hasZAState() || CallAttrs.caller().hasZT0State()) ZAMarkerNode = AArch64ISD::INOUT_ZA_USE; @@ -9517,7 +9832,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), TPIDR2ObjAddr); OptimizationRemarkEmitter ORE(&MF.getFunction()); ORE.emit([&]() { @@ -9549,7 +9864,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, SDValue ZTFrameIdx; MachineFrameInfo &MFI = MF.getFrameInfo(); - bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0(); + bool ShouldPreserveZT0 = + !UseNewSMEABILowering && CallAttrs.requiresPreservingZT0(); // If the caller has ZT0 state which will not be preserved by the callee, // spill ZT0 before the call. @@ -9562,7 +9878,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // If caller shares ZT0 but the callee is not shared ZA, we need to stop // PSTATE.ZA before the call if there is no lazy-save active. - bool DisableZA = CallAttrs.requiresDisablingZABeforeCall(); + bool DisableZA = + !UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall(); assert((!DisableZA || !RequiresLazySave) && "Lazy-save should have PSTATE.SM=1 on entry to the function"); @@ -9581,8 +9898,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // using a chain can result in incorrect scheduling. The markers refer to // the position just before the CALLSEQ_START (though occur after as // CALLSEQ_START lacks in-glue). - Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other), - {Chain, Chain.getValue(1)}); + Chain = + DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other, MVT::Glue), + {Chain, Chain.getValue(1)}); } } @@ -9663,8 +9981,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, assert((isScalable || Subtarget->isWindowsArm64EC()) && "Indirect arguments should be scalable on most subtargets"); - uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue(); - uint64_t PartSize = StoreSize; + TypeSize StoreSize = VA.getValVT().getStoreSize(); + TypeSize PartSize = StoreSize; unsigned NumParts = 1; if (Outs[i].Flags.isInConsecutiveRegs()) { while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) @@ -9675,7 +9993,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); MachineFrameInfo &MFI = MF.getFrameInfo(); - int FI = MFI.CreateStackObject(StoreSize, Alignment, false); + int FI = + MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false); if (isScalable) { bool IsPred = VA.getValVT() == MVT::aarch64svcount || VA.getValVT().getVectorElementType() == MVT::i1; @@ -9696,16 +10015,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, NumParts--; if (NumParts > 0) { - SDValue BytesIncrement; - if (isScalable) { - BytesIncrement = DAG.getVScale( - DL, Ptr.getValueType(), - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize)); - } else { - BytesIncrement = DAG.getConstant( - APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL, - Ptr.getValueType()); - } + SDValue BytesIncrement = + DAG.getTypeSize(DL, Ptr.getValueType(), PartSize); MPI = MachinePointerInfo(MPI.getAddrSpace()); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement, SDNodeFlags::NoUnsignedWrap); @@ -9998,6 +10309,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (InGlue.getNode()) Ops.push_back(InGlue); + if (CLI.DeactivationSymbol) + Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol)); + // If we're doing a tall call, use a TC_RETURN here rather than an // actual call instruction. if (IsTailCall) { @@ -10047,7 +10361,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, getSMToggleCondition(CallAttrs)); } - if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()) + if (!UseNewSMEABILowering && + (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())) // Unconditionally resume ZA. Result = DAG.getNode( AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result, @@ -10587,16 +10902,41 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr, const SDLoc &DL, SelectionDAG &DAG) const { EVT PtrVT = getPointerTy(DAG.getDataLayout()); + auto &MF = DAG.getMachineFunction(); + auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); + SDValue Glue; SDValue Chain = DAG.getEntryNode(); SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); + SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal); + bool RequiresSMChange = TLSCallAttrs.requiresSMChange(); + + auto ChainAndGlue = [](SDValue Chain) -> std::pair<SDValue, SDValue> { + return {Chain, Chain.getValue(1)}; + }; + + if (RequiresSMChange) + std::tie(Chain, Glue) = + ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue, + getSMToggleCondition(TLSCallAttrs))); + unsigned Opcode = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT() ? AArch64ISD::TLSDESC_AUTH_CALLSEQ : AArch64ISD::TLSDESC_CALLSEQ; - Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr}); - SDValue Glue = Chain.getValue(1); + SDValue Ops[] = {Chain, SymAddr, Glue}; + std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode( + Opcode, DL, NodeTys, Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back())); + + if (TLSCallAttrs.requiresLazySave()) + std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode( + AArch64ISD::REQUIRES_ZA_SAVE, DL, NodeTys, {Chain, Chain.getValue(1)})); + + if (RequiresSMChange) + std::tie(Chain, Glue) = + ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue, + getSMToggleCondition(TLSCallAttrs))); return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue); } @@ -11505,7 +11845,12 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { } if (LHS.getValueType().isInteger()) { - + if (Subtarget->hasCSSC() && CC == ISD::SETNE && isNullConstant(RHS)) { + SDValue One = DAG.getConstant(1, DL, LHS.getValueType()); + SDValue UMin = DAG.getNode(ISD::UMIN, DL, LHS.getValueType(), LHS, One); + SDValue Res = DAG.getZExtOrTrunc(UMin, DL, VT); + return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res; + } simplifySetCCIntoEq(CC, LHS, RHS, DAG, DL); SDValue CCVal; @@ -13409,8 +13754,8 @@ SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) { return DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, VT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), SourceVec, - MaskSourceVec); + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + SourceVec, MaskSourceVec); } // Gather data to see if the operation can be modelled as a @@ -14266,14 +14611,16 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask, V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + V1Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } else { if (IndexLen == 8) { V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), + V1Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } else { // FIXME: We cannot, for the moment, emit a TBL2 instruction because we @@ -14284,8 +14631,8 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask, // IndexLen)); Shuffle = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, - DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst, - V2Cst, + DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), + V1Cst, V2Cst, DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen))); } } @@ -14453,6 +14800,40 @@ static SDValue tryToConvertShuffleOfTbl2ToTbl4(SDValue Op, Tbl2->getOperand(1), Tbl2->getOperand(2), TBLMask}); } +SDValue +AArch64TargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + assert(VT.isScalableVector() && "Unexpected result type!"); + + bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG; + unsigned UnpackOpcode = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + + // Repeatedly unpack Val until the result is of the desired type. + SDValue Val = Op.getOperand(0); + switch (Val.getSimpleValueType().SimpleTy) { + default: + return SDValue(); + case MVT::nxv16i8: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv8i16, Val); + if (VT == MVT::nxv8i16) + break; + [[fallthrough]]; + case MVT::nxv8i16: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv4i32, Val); + if (VT == MVT::nxv4i32) + break; + [[fallthrough]]; + case MVT::nxv4i32: + Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv2i64, Val); + assert(VT == MVT::nxv2i64 && "Unexpected result type!"); + break; + } + + return Val; +} + // Baseline legalization for ZERO_EXTEND_VECTOR_INREG will blend-in zeros, // but we don't have an appropriate instruction, // so custom-lower it as ZIP1-with-zeros. @@ -14461,6 +14842,10 @@ AArch64TargetLowering::LowerZERO_EXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); EVT VT = Op.getValueType(); + + if (VT.isScalableVector()) + return LowerEXTEND_VECTOR_INREG(Op, DAG); + SDValue SrcOp = Op.getOperand(0); EVT SrcVT = SrcOp.getValueType(); assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 && @@ -14570,17 +14955,20 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, } unsigned WhichResult; - if (isZIPMask(ShuffleMask, NumElts, WhichResult)) { + unsigned OperandOrder; + if (isZIPMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2; - return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); + return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2, + OperandOrder == 0 ? V2 : V1); } if (isUZPMask(ShuffleMask, NumElts, WhichResult)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2; return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); } - if (isTRNMask(ShuffleMask, NumElts, WhichResult)) { + if (isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2; - return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2); + return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2, + OperandOrder == 0 ? V2 : V1); } if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) { @@ -16292,9 +16680,9 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { isREVMask(M, EltSize, NumElts, 16) || isEXTMask(M, VT, DummyBool, DummyUnsigned) || isSingletonEXTMask(M, VT, DummyUnsigned) || - isTRNMask(M, NumElts, DummyUnsigned) || + isTRNMask(M, NumElts, DummyUnsigned, DummyUnsigned) || isUZPMask(M, NumElts, DummyUnsigned) || - isZIPMask(M, NumElts, DummyUnsigned) || + isZIPMask(M, NumElts, DummyUnsigned, DummyUnsigned) || isTRN_v_undef_Mask(M, VT, DummyUnsigned) || isUZP_v_undef_Mask(M, VT, DummyUnsigned) || isZIP_v_undef_Mask(M, VT, DummyUnsigned) || @@ -16438,10 +16826,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0), DAG.getTargetConstant(Cnt, DL, MVT::i32)); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, - DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL, - MVT::i32), - Op.getOperand(0), Op.getOperand(1)); + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getTargetConstant(Intrinsic::aarch64_neon_ushl, DL, MVT::i32), + Op.getOperand(0), Op.getOperand(1)); case ISD::SRA: case ISD::SRL: if (VT.isScalableVector() && @@ -16943,7 +17331,7 @@ SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op, template <unsigned NumVecs> static bool setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL, - AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) { + AArch64TargetLowering::IntrinsicInfo &Info, const CallBase &CI) { Info.opc = ISD::INTRINSIC_VOID; // Retrieve EC from first vector argument. const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType()); @@ -16968,7 +17356,7 @@ setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL, /// MemIntrinsicNodes. The associated MachineMemOperands record the alignment /// specified in the intrinsic calls. bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, - const CallInst &I, + const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const { auto &DL = I.getDataLayout(); @@ -18537,7 +18925,7 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( case MVT::f64: return true; case MVT::bf16: - return VT.isScalableVector() && Subtarget->hasSVEB16B16() && + return VT.isScalableVector() && Subtarget->hasBF16() && Subtarget->isNonStreamingSVEorSME2Available(); default: break; @@ -18720,6 +19108,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, return (Index == 0 || Index == ResVT.getVectorMinNumElements()); } +bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits( + LLVMContext &Context, EVT VT) const { + if (getTypeAction(Context, VT) != TypeExpandInteger) + return false; + + EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2); + return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal; +} + /// Turn vector tests of the signbit in the form of: /// xor (sra X, elt_size(X)-1), -1 /// into: @@ -19282,20 +19679,37 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, return CSNeg; } -static std::optional<unsigned> IsSVECntIntrinsic(SDValue S) { +static bool IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: break; case Intrinsic::aarch64_sve_cntb: - return 8; case Intrinsic::aarch64_sve_cnth: - return 16; case Intrinsic::aarch64_sve_cntw: - return 32; case Intrinsic::aarch64_sve_cntd: - return 64; + return true; + } + return false; +} + +// Returns the maximum (scalable) value that can be returned by an SVE count +// intrinsic. Returns std::nullopt if \p Op is not aarch64_sve_cnt*. +static std::optional<ElementCount> getMaxValueForSVECntIntrinsic(SDValue Op) { + Intrinsic::ID IID = getIntrinsicID(Op.getNode()); + if (IID == Intrinsic::aarch64_sve_cntp) + return Op.getOperand(1).getValueType().getVectorElementCount(); + switch (IID) { + case Intrinsic::aarch64_sve_cntd: + return ElementCount::getScalable(2); + case Intrinsic::aarch64_sve_cntw: + return ElementCount::getScalable(4); + case Intrinsic::aarch64_sve_cnth: + return ElementCount::getScalable(8); + case Intrinsic::aarch64_sve_cntb: + return ElementCount::getScalable(16); + default: + return std::nullopt; } - return {}; } /// Calculates what the pre-extend type is, based on the extension @@ -19939,7 +20353,9 @@ static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG, return Res; EVT VT = N->getValueType(0); - if (VT != MVT::f32 && VT != MVT::f64) + if (VT != MVT::f16 && VT != MVT::f32 && VT != MVT::f64) + return SDValue(); + if (VT == MVT::f16 && !Subtarget->hasFullFP16()) return SDValue(); // Only optimize when the source and destination types have the same width. @@ -20037,7 +20453,7 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG, : Intrinsic::aarch64_neon_vcvtfp2fxu; SDValue FixConv = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy, - DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), + DAG.getTargetConstant(IntrinsicOpcode, DL, MVT::i32), Op->getOperand(0), DAG.getTargetConstant(C, DL, MVT::i32)); // We can handle smaller integers by generating an extra trunc. if (IntBits < FloatBits) @@ -21591,9 +22007,8 @@ static SDValue performBuildVectorCombine(SDNode *N, SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0); if (LowLanesSrcVec.getValueType() == MVT::v2f64) { SDValue HighLanes; - if (Elt2->getOpcode() == ISD::UNDEF && - Elt3->getOpcode() == ISD::UNDEF) { - HighLanes = DAG.getUNDEF(MVT::v2f32); + if (Elt2->isUndef() && Elt3->isUndef()) { + HighLanes = DAG.getPOISON(MVT::v2f32); } else if (Elt2->getOpcode() == ISD::FP_ROUND && Elt3->getOpcode() == ISD::FP_ROUND && isa<ConstantSDNode>(Elt2->getOperand(1)) && @@ -22296,6 +22711,69 @@ static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift); } +// Attempt to combine the following patterns: +// SUB x, (CSET LO, (CMP a, b)) -> SBC x, 0, (CMP a, b) +// SUB (SUB x, y), (CSET LO, (CMP a, b)) -> SBC x, y, (CMP a, b) +// The CSET may be preceded by a ZEXT. +static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + + EVT VT = N->getValueType(0); + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + SDValue N1 = N->getOperand(1); + if (N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) + N1 = N1.getOperand(0); + if (!N1.hasOneUse() || getCSETCondCode(N1) != AArch64CC::LO) + return SDValue(); + + SDValue Flags = N1.getOperand(3); + if (Flags.getOpcode() != AArch64ISD::SUBS) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + if (N0->getOpcode() == ISD::SUB) + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0.getOperand(0), + N0.getOperand(1), Flags); + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0, DAG.getConstant(0, DL, VT), + Flags); +} + +// add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW +// -> +// X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) +// The original converts into ashr+lshr+xtn+xtn+add. The second becomes +// ashr+xtn+usra. The first form has less total latency due to more parallelism, +// but more micro-ops and seems to be slower in practice. +static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; + EVT VT = N->getValueType(0); + if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8) + return SDValue(); + + SDValue AShr, LShr; + if (!sd_match(N, m_Add(m_Trunc(m_Value(AShr)), m_Trunc(m_Value(LShr))))) + return SDValue(); + if (AShr.getOpcode() != AArch64ISD::VASHR) + std::swap(AShr, LShr); + if (AShr.getOpcode() != AArch64ISD::VASHR || + LShr.getOpcode() != AArch64ISD::VLSHR || + AShr.getOperand(0) != LShr.getOperand(0) || + AShr.getConstantOperandVal(1) < VT.getScalarSizeInBits() || + LShr.getConstantOperandVal(1) != VT.getScalarSizeInBits() * 2 - 1) + return SDValue(); + + SDLoc DL(N); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, AShr); + SDValue Shift = DAG.getNode( + AArch64ISD::VLSHR, DL, VT, Trunc, + DAG.getTargetConstant(VT.getScalarSizeInBits() - 1, DL, MVT::i32)); + return DAG.getNode(ISD::ADD, DL, VT, Trunc, Shift); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Try to change sum of two reductions. @@ -22317,6 +22795,10 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG)) return Val; + if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG)) + return Val; + if (SDValue Val = performAddTruncShiftCombine(N, DCI.DAG)) + return Val; if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG)) return Val; @@ -22968,11 +23450,15 @@ static SDValue performIntrinsicCombine(SDNode *N, return DAG.getNode(ISD::OR, SDLoc(N), N->getValueType(0), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_sabd_u: - return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0), - N->getOperand(2), N->getOperand(3)); + if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDS, DAG, true)) + return V; + return DAG.getNode(AArch64ISD::ABDS_PRED, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_uabd_u: - return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0), - N->getOperand(2), N->getOperand(3)); + if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDU, DAG, true)) + return V; + return DAG.getNode(AArch64ISD::ABDU_PRED, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2), N->getOperand(3)); case Intrinsic::aarch64_sve_sdiv_u: return DAG.getNode(AArch64ISD::SDIV_PRED, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); @@ -23895,7 +24381,7 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); // uzp1(x, undef) -> concat(truncate(x), undef) - if (Op1.getOpcode() == ISD::UNDEF) { + if (Op1.isUndef()) { EVT BCVT = MVT::Other, HalfVT = MVT::Other; switch (ResVT.getSimpleVT().SimpleTy) { default: @@ -26038,7 +26524,7 @@ static SDValue performCSELCombine(SDNode *N, // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1 // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1 if (SDValue Folded = foldCSELofCTTZ(N, DAG)) - return Folded; + return Folded; // CSEL a, b, cc, SUBS(x, y) -> CSEL a, b, swapped(cc), SUBS(y, x) // if SUB(y, x) already exists and we can produce a swapped predicate for cc. @@ -26063,29 +26549,6 @@ static SDValue performCSELCombine(SDNode *N, } } - // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't - // use overflow flags, to avoid the comparison with zero. In case of success, - // this also replaces the original SUB(x,y) with the newly created SUBS(x,y). - // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB - // nodes with their SUBS equivalent as is already done for other flag-setting - // operators, in which case doing the replacement here becomes redundant. - if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) && - isNullConstant(Cond.getOperand(1))) { - SDValue Sub = Cond.getOperand(0); - AArch64CC::CondCode CC = - static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2)); - if (Sub.getOpcode() == ISD::SUB && - (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI || - CC == AArch64CC::PL)) { - SDLoc DL(N); - SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(), - Sub.getOperand(0), Sub.getOperand(1)); - DCI.CombineTo(Sub.getNode(), Subs); - DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1)); - return SDValue(N, 0); - } - } - // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z if (SDValue CondLast = foldCSELofLASTB(N, DAG)) return CondLast; @@ -26364,8 +26827,7 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDValue L1 = LHS->getOperand(1); SDValue L2 = LHS->getOperand(2); - if (L0.getOpcode() == ISD::UNDEF && isNullConstant(L2) && - isSignExtInReg(L1)) { + if (L0.isUndef() && isNullConstant(L2) && isSignExtInReg(L1)) { SDLoc DL(N); SDValue Shl = L1.getOperand(0); SDValue NewLHS = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, @@ -26629,22 +27091,25 @@ static SDValue performSelectCombine(SDNode *N, assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) && "Scalar-SETCC feeding SELECT has unexpected result type!"); - // If NumMaskElts == 0, the comparison is larger than select result. The - // largest real NEON comparison is 64-bits per lane, which means the result is - // at most 32-bits and an illegal vector. Just bail out for now. - EVT SrcVT = N0.getOperand(0).getValueType(); - // Don't try to do this optimization when the setcc itself has i1 operands. // There are no legal vectors of i1, so this would be pointless. v1f16 is // ruled out to prevent the creation of setcc that need to be scalarized. + EVT SrcVT = N0.getOperand(0).getValueType(); if (SrcVT == MVT::i1 || (SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16)) return SDValue(); - int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits(); + // If NumMaskElts == 0, the comparison is larger than select result. The + // largest real NEON comparison is 64-bits per lane, which means the result is + // at most 32-bits and an illegal vector. Just bail out for now. + unsigned NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits(); if (!ResVT.isVector() || NumMaskElts == 0) return SDValue(); + // Avoid creating vectors with excessive VFs before legalization. + if (DCI.isBeforeLegalize() && NumMaskElts != ResVT.getVectorNumElements()) + return SDValue(); + SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts); EVT CCVT = SrcVT.changeVectorElementTypeToInteger(); @@ -27293,8 +27758,8 @@ static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG, // ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to // `aarch64_sve_prfb_gather_uxtw_index`. SDLoc DL(N); - Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL, - MVT::i64); + Ops[1] = DAG.getTargetConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, + DL, MVT::i64); return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops); } @@ -28567,7 +29032,8 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( if ((Index != 0) && (Index != ResEC.getKnownMinValue())) return; - unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI; + unsigned Opcode = (Index == 0) ? (unsigned)ISD::ANY_EXTEND_VECTOR_INREG + : (unsigned)AArch64ISD::UUNPKHI; EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext()); SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0)); @@ -29294,12 +29760,26 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { AI->getOperation() == AtomicRMWInst::FMinimum)) return AtomicExpansionKind::None; - // Nand is not supported in LSE. // Leave 128 bits to LLSC or CmpXChg. - if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128 && - !AI->isFloatingPointOperation()) { - if (Subtarget->hasLSE()) - return AtomicExpansionKind::None; + if (Size < 128 && !AI->isFloatingPointOperation()) { + if (Subtarget->hasLSE()) { + // Nand is not supported in LSE. + switch (AI->getOperation()) { + case AtomicRMWInst::Xchg: + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::And: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: + case AtomicRMWInst::Max: + case AtomicRMWInst::Min: + case AtomicRMWInst::UMax: + case AtomicRMWInst::UMin: + return AtomicExpansionKind::None; + default: + break; + } + } if (Subtarget->outlineAtomics()) { // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far. // Don't outline them unless @@ -29307,11 +29787,16 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf // (2) low level libgcc and compiler-rt support implemented by: // min/max outline atomics helpers - if (AI->getOperation() != AtomicRMWInst::Min && - AI->getOperation() != AtomicRMWInst::Max && - AI->getOperation() != AtomicRMWInst::UMin && - AI->getOperation() != AtomicRMWInst::UMax) { + switch (AI->getOperation()) { + case AtomicRMWInst::Xchg: + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::And: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: return AtomicExpansionKind::None; + default: + break; } } } @@ -30118,6 +30603,43 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( Store->isTruncatingStore()); } +SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + auto *Store = cast<MaskedStoreSDNode>(Op); + EVT VT = Store->getValue().getValueType(); + if (VT.isFixedLengthVector()) + return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + + if (!Store->isCompressingStore()) + return SDValue(); + + EVT MaskVT = Store->getMask().getValueType(); + EVT MaskExtVT = getPromotedVTForPredicate(MaskVT); + EVT MaskReduceVT = MaskExtVT.getScalarType(); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + + SDValue MaskExt = + DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask()); + SDValue CntActive = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt); + if (MaskReduceVT != MVT::i64) + CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive); + + SDValue CompressedValue = + DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(), + Store->getMask(), DAG.getPOISON(VT)); + SDValue CompressedMask = + DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive); + + return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue, + Store->getBasePtr(), Store->getOffset(), + CompressedMask, Store->getMemoryVT(), + Store->getMemOperand(), Store->getAddressingMode(), + Store->isTruncatingStore(), + /*isCompressing=*/false); +} + SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( SDValue Op, SelectionDAG &DAG) const { auto *Store = cast<MaskedStoreSDNode>(Op); @@ -30132,7 +30654,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( return DAG.getMaskedStore( Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(), Mask, Store->getMemoryVT(), Store->getMemOperand(), - Store->getAddressingMode(), Store->isTruncatingStore()); + Store->getAddressingMode(), Store->isTruncatingStore(), + Store->isCompressingStore()); } SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( @@ -31159,10 +31682,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2, SDValue Shuffle; if (IsSingleOp) - Shuffle = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, - DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32), - Op1, SVEMask); + Shuffle = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32), Op1, + SVEMask); else if (Subtarget.hasSVE2()) { if (!MinMaxEqual) { unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt; @@ -31181,10 +31704,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2, SVEMask = convertToScalableVector( DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask); } - Shuffle = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, - DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32), - Op1, Op2, SVEMask); + Shuffle = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, + DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32), Op1, + Op2, SVEMask); } Shuffle = convertFromScalableVector(DAG, VT, Shuffle); return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle); @@ -31266,15 +31789,23 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( } unsigned WhichResult; - if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) && + unsigned OperandOrder; + if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder) && WhichResult == 0) return convertFromScalableVector( - DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2)); + DAG, VT, + DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, + OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1)); - if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) { + if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2; - return convertFromScalableVector( - DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2)); + SDValue TRN = + DAG.getNode(Opc, DL, ContainerVT, OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1); + return convertFromScalableVector(DAG, VT, TRN); } if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0) @@ -31314,10 +31845,14 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( return convertFromScalableVector(DAG, VT, Op); } - if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) && + if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult, + OperandOrder) && WhichResult != 0) return convertFromScalableVector( - DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2)); + DAG, VT, + DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, + OperandOrder == 0 ? Op1 : Op2, + OperandOrder == 0 ? Op2 : Op1)); if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) { unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2; @@ -31344,8 +31879,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( unsigned SegmentElts = VT.getVectorNumElements() / Segments; if (std::optional<unsigned> Lane = isDUPQMask(ShuffleMask, Segments, SegmentElts)) { - SDValue IID = - DAG.getConstant(Intrinsic::aarch64_sve_dup_laneq, DL, MVT::i64); + SDValue IID = DAG.getTargetConstant(Intrinsic::aarch64_sve_dup_laneq, + DL, MVT::i64); return convertFromScalableVector( DAG, VT, DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT, @@ -31492,22 +32027,24 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( return false; } case ISD::INTRINSIC_WO_CHAIN: { - if (auto ElementSize = IsSVECntIntrinsic(Op)) { - unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); - if (!MaxSVEVectorSizeInBits) - MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; - unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize; - // The SVE count intrinsics don't support the multiplier immediate so we - // don't have to account for that here. The value returned may be slightly - // over the true required bits, as this is based on the "ALL" pattern. The - // other patterns are also exposed by these intrinsics, but they all - // return a value that's strictly less than "ALL". - unsigned RequiredBits = llvm::bit_width(MaxElements); - unsigned BitWidth = Known.Zero.getBitWidth(); - if (RequiredBits < BitWidth) - Known.Zero.setHighBits(BitWidth - RequiredBits); + std::optional<ElementCount> MaxCount = getMaxValueForSVECntIntrinsic(Op); + if (!MaxCount) return false; - } + unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); + if (!MaxSVEVectorSizeInBits) + MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; + unsigned VscaleMax = MaxSVEVectorSizeInBits / 128; + unsigned MaxValue = MaxCount->getKnownMinValue() * VscaleMax; + // The SVE count intrinsics don't support the multiplier immediate so we + // don't have to account for that here. The value returned may be slightly + // over the true required bits, as this is based on the "ALL" pattern. The + // other patterns are also exposed by these intrinsics, but they all + // return a value that's strictly less than "ALL". + unsigned RequiredBits = llvm::bit_width(MaxValue); + unsigned BitWidth = Known.Zero.getBitWidth(); + if (RequiredBits < BitWidth) + Known.Zero.setHighBits(BitWidth - RequiredBits); + return false; } } |
