aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp1079
1 files changed, 808 insertions, 271 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 40e6400..30eb190 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16,11 +16,11 @@
#include "AArch64MachineFunctionInfo.h"
#include "AArch64PerfectShuffle.h"
#include "AArch64RegisterInfo.h"
+#include "AArch64SMEAttributes.h"
#include "AArch64Subtarget.h"
#include "AArch64TargetMachine.h"
#include "MCTargetDesc/AArch64AddressingModes.h"
#include "Utils/AArch64BaseInfo.h"
-#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@@ -387,7 +387,7 @@ extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI)
- : TargetLowering(TM), Subtarget(&STI) {
+ : TargetLowering(TM, STI), Subtarget(&STI) {
// AArch64 doesn't have comparisons which set GPRs or setcc instructions, so
// we have to make something up. Arbitrarily, choose ZeroOrOne.
setBooleanContents(ZeroOrOneBooleanContent);
@@ -445,6 +445,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass);
+ // Add sve predicate as counter type
+ addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass);
+
// Add legal sve data types
addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass);
addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass);
@@ -473,15 +476,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
}
- if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
- addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass);
- setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1);
- setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1);
-
- setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom);
- setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand);
- }
-
// Compute derived properties from the register classes
computeRegisterProperties(Subtarget->getRegisterInfo());
@@ -536,7 +530,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FREM, MVT::f32, Expand);
setOperationAction(ISD::FREM, MVT::f64, Expand);
- setOperationAction(ISD::FREM, MVT::f80, Expand);
setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand);
@@ -1433,12 +1426,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
- setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
// ADDP custom lowering
for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
@@ -1518,6 +1523,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
+
+ for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64})
+ setOperationAction(ISD::FMA, VT, Custom);
}
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -1585,6 +1593,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::AVGCEILS, VT, Custom);
setOperationAction(ISD::AVGCEILU, VT, Custom);
+ setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, VT, Custom);
+ setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom);
+ setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom);
+
if (!Subtarget->isLittleEndian())
setOperationAction(ISD::BITCAST, VT, Custom);
@@ -1609,6 +1621,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);
+ // Promote predicate as counter load/stores to standard predicates.
+ setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1);
+ setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1);
+
+ // Predicate as counter legalization actions.
+ setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom);
+ setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand);
+
for (auto VT :
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
@@ -1769,17 +1789,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
+ }
- if (Subtarget->hasSVEB16B16() &&
- Subtarget->isNonStreamingSVEorSME2Available()) {
- setOperationAction(ISD::FADD, VT, Legal);
+ if (Subtarget->hasSVEB16B16() &&
+ Subtarget->isNonStreamingSVEorSME2Available()) {
+ // Note: Use SVE for bfloat16 operations when +sve-b16b16 is available.
+ for (auto VT : {MVT::v4bf16, MVT::v8bf16, MVT::nxv2bf16, MVT::nxv4bf16,
+ MVT::nxv8bf16}) {
+ setOperationAction(ISD::FADD, VT, Custom);
setOperationAction(ISD::FMA, VT, Custom);
setOperationAction(ISD::FMAXIMUM, VT, Custom);
setOperationAction(ISD::FMAXNUM, VT, Custom);
setOperationAction(ISD::FMINIMUM, VT, Custom);
setOperationAction(ISD::FMINNUM, VT, Custom);
- setOperationAction(ISD::FMUL, VT, Legal);
- setOperationAction(ISD::FSUB, VT, Legal);
+ setOperationAction(ISD::FMUL, VT, Custom);
+ setOperationAction(ISD::FSUB, VT, Custom);
}
}
@@ -1795,22 +1819,37 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (!Subtarget->hasSVEB16B16() ||
!Subtarget->isNonStreamingSVEorSME2Available()) {
- for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
- ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
- setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
- setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
- setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
+ for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+ MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
+ setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
+ setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
+
+ if (VT != MVT::nxv2bf16 && Subtarget->hasBF16())
+ setOperationAction(ISD::FMUL, VT, Custom);
+ else
+ setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
}
+
+ if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
+ setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
- // NEON doesn't support integer divides, but SVE does
+ // A number of operations like MULH and integer divides are not supported by
+ // NEON but are available in SVE.
for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
setOperationAction(ISD::SDIV, VT, Custom);
setOperationAction(ISD::UDIV, VT, Custom);
+ setOperationAction(ISD::MULHS, VT, Custom);
+ setOperationAction(ISD::MULHU, VT, Custom);
}
// NEON doesn't support 64-bit vector integer muls, but SVE does.
@@ -1847,10 +1886,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CTLZ, MVT::v1i64, Custom);
setOperationAction(ISD::CTLZ, MVT::v2i64, Custom);
setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
- setOperationAction(ISD::MULHS, MVT::v1i64, Custom);
- setOperationAction(ISD::MULHS, MVT::v2i64, Custom);
- setOperationAction(ISD::MULHU, MVT::v1i64, Custom);
- setOperationAction(ISD::MULHU, MVT::v2i64, Custom);
setOperationAction(ISD::SMAX, MVT::v1i64, Custom);
setOperationAction(ISD::SMAX, MVT::v2i64, Custom);
setOperationAction(ISD::SMIN, MVT::v1i64, Custom);
@@ -1872,8 +1907,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
- setOperationAction(ISD::MULHS, VT, Custom);
- setOperationAction(ISD::MULHU, VT, Custom);
}
// Use SVE for vectors with more than 2 elements.
@@ -1916,6 +1949,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
}
+
+ // Handle floating-point partial reduction
+ if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32,
+ MVT::nxv8f16, Legal);
+ // We can use SVE2p1 fdot to emulate the fixed-length variant.
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32,
+ MVT::v8f16, Custom);
+ }
}
// Handle non-aliasing elements mask
@@ -1951,10 +1993,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// We can lower types that have <vscale x {2|4}> elements to compact.
- for (auto VT :
- {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+ for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
+ MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
+ MVT::nxv4i32, MVT::nxv4f32}) {
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+ // Use a custom lowering for masked stores that could be a supported
+ // compressing store. Note: These types still use the normal (Legal)
+ // lowering for non-compressing masked stores.
+ setOperationAction(ISD::MSTORE, VT, Custom);
+ }
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
@@ -2283,6 +2330,11 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
}
+ if (Subtarget->hasSVE2p1() && VT.getVectorElementType() == MVT::f32) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT,
+ MVT::getVectorVT(MVT::f16, NumElts * 2), Custom);
+ }
+
// Lower fixed length vector operations to scalable equivalents.
setOperationAction(ISD::ABDS, VT, Default);
setOperationAction(ISD::ABDU, VT, Default);
@@ -2542,7 +2594,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant(
return false;
// Exit early if we demand all bits.
- if (DemandedBits.popcount() == Size)
+ if (DemandedBits.isAllOnes())
return false;
unsigned NewOpc;
@@ -3858,22 +3910,30 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
/// \param MustBeFirst Set to true if this subtree needs to be negated and we
/// cannot do the negation naturally. We are required to
/// emit the subtree first in this case.
+/// \param PreferFirst Set to true if processing this subtree first may
+/// result in more efficient code.
/// \param WillNegate Is true if are called when the result of this
/// subexpression must be negated. This happens when the
/// outer expression is an OR. We can use this fact to know
/// that we have a double negation (or (or ...) ...) that
/// can be implemented for free.
-static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
- bool &MustBeFirst, bool WillNegate,
+static bool canEmitConjunction(SelectionDAG &DAG, const SDValue Val,
+ bool &CanNegate, bool &MustBeFirst,
+ bool &PreferFirst, bool WillNegate,
unsigned Depth = 0) {
if (!Val.hasOneUse())
return false;
unsigned Opcode = Val->getOpcode();
if (Opcode == ISD::SETCC) {
- if (Val->getOperand(0).getValueType() == MVT::f128)
+ EVT VT = Val->getOperand(0).getValueType();
+ if (VT == MVT::f128)
return false;
CanNegate = true;
MustBeFirst = false;
+ // Designate this operation as a preferred first operation if the result
+ // of a SUB operation can be reused.
+ PreferFirst = DAG.doesNodeExist(ISD::SUB, DAG.getVTList(VT),
+ {Val->getOperand(0), Val->getOperand(1)});
return true;
}
// Protect against exponential runtime and stack overflow.
@@ -3885,11 +3945,15 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
SDValue O1 = Val->getOperand(1);
bool CanNegateL;
bool MustBeFirstL;
- if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1))
+ bool PreferFirstL;
+ if (!canEmitConjunction(DAG, O0, CanNegateL, MustBeFirstL, PreferFirstL,
+ IsOR, Depth + 1))
return false;
bool CanNegateR;
bool MustBeFirstR;
- if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1))
+ bool PreferFirstR;
+ if (!canEmitConjunction(DAG, O1, CanNegateR, MustBeFirstR, PreferFirstR,
+ IsOR, Depth + 1))
return false;
if (MustBeFirstL && MustBeFirstR)
@@ -3912,6 +3976,7 @@ static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
CanNegate = false;
MustBeFirst = MustBeFirstL || MustBeFirstR;
}
+ PreferFirst = PreferFirstL || PreferFirstR;
return true;
}
return false;
@@ -3973,19 +4038,25 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
SDValue LHS = Val->getOperand(0);
bool CanNegateL;
bool MustBeFirstL;
- bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR);
+ bool PreferFirstL;
+ bool ValidL = canEmitConjunction(DAG, LHS, CanNegateL, MustBeFirstL,
+ PreferFirstL, IsOR);
assert(ValidL && "Valid conjunction/disjunction tree");
(void)ValidL;
SDValue RHS = Val->getOperand(1);
bool CanNegateR;
bool MustBeFirstR;
- bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR);
+ bool PreferFirstR;
+ bool ValidR = canEmitConjunction(DAG, RHS, CanNegateR, MustBeFirstR,
+ PreferFirstR, IsOR);
assert(ValidR && "Valid conjunction/disjunction tree");
(void)ValidR;
- // Swap sub-tree that must come first to the right side.
- if (MustBeFirstL) {
+ bool ShouldFirstL = PreferFirstL && !PreferFirstR && !MustBeFirstR;
+
+ // Swap sub-tree that must or should come first to the right side.
+ if (MustBeFirstL || ShouldFirstL) {
assert(!MustBeFirstR && "Valid conjunction/disjunction tree");
std::swap(LHS, RHS);
std::swap(CanNegateL, CanNegateR);
@@ -4041,7 +4112,9 @@ static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val,
AArch64CC::CondCode &OutCC) {
bool DummyCanNegate;
bool DummyMustBeFirst;
- if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false))
+ bool DummyPreferFirst;
+ if (!canEmitConjunction(DAG, Val, DummyCanNegate, DummyMustBeFirst,
+ DummyPreferFirst, false))
return SDValue();
return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL);
@@ -4487,6 +4560,26 @@ static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG,
return DAG.getMergeValues({Sum, OutFlag}, DL);
}
+static SDValue lowerIntNeonIntrinsic(SDValue Op, unsigned Opcode,
+ SelectionDAG &DAG) {
+ SDLoc DL(Op);
+ auto getFloatVT = [](EVT VT) {
+ assert((VT == MVT::i32 || VT == MVT::i64) && "Unexpected VT");
+ return VT == MVT::i32 ? MVT::f32 : MVT::f64;
+ };
+ auto bitcastToFloat = [&](SDValue Val) {
+ return DAG.getBitcast(getFloatVT(Val.getValueType()), Val);
+ };
+ SmallVector<SDValue, 2> NewOps;
+ NewOps.reserve(Op.getNumOperands() - 1);
+
+ for (unsigned I = 1, E = Op.getNumOperands(); I < E; ++I)
+ NewOps.push_back(bitcastToFloat(Op.getOperand(I)));
+ EVT OrigVT = Op.getValueType();
+ SDValue OpNode = DAG.getNode(Opcode, DL, getFloatVT(OrigVT), NewOps);
+ return DAG.getBitcast(OrigVT, OpNode);
+}
+
static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
// Let legalize expand this if it isn't a legal type yet.
if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()))
@@ -5544,9 +5637,10 @@ SDValue AArch64TargetLowering::LowerGET_ROUNDING(SDValue Op,
SDLoc DL(Op);
SDValue Chain = Op.getOperand(0);
- SDValue FPCR_64 = DAG.getNode(
- ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other},
- {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)});
+ SDValue FPCR_64 =
+ DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other},
+ {Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL,
+ MVT::i64)});
Chain = FPCR_64.getValue(1);
SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPCR_64);
SDValue FltRounds = DAG.getNode(ISD::ADD, DL, MVT::i32, FPCR_32,
@@ -5632,7 +5726,8 @@ SDValue AArch64TargetLowering::LowerSET_FPMODE(SDValue Op,
// Set new value of FPCR.
SDValue Ops2[] = {
- Chain, DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), FPCR};
+ Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
+ FPCR};
return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
}
@@ -5655,9 +5750,9 @@ SDValue AArch64TargetLowering::LowerRESET_FPMODE(SDValue Op,
DAG.getConstant(AArch64::ReservedFPControlBits, DL, MVT::i64));
// Set new value of FPCR.
- SDValue Ops2[] = {Chain,
- DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
- FPSCRMasked};
+ SDValue Ops2[] = {
+ Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
+ FPSCRMasked};
return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
}
@@ -5735,8 +5830,10 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
if (VT.is64BitVector()) {
if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
isNullConstant(N0.getOperand(1)) &&
+ N0.getOperand(0).getValueType().is128BitVector() &&
N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
- isNullConstant(N1.getOperand(1))) {
+ isNullConstant(N1.getOperand(1)) &&
+ N1.getOperand(0).getValueType().is128BitVector()) {
N0 = N0.getOperand(0);
N1 = N1.getOperand(0);
VT = N0.getValueType();
@@ -6329,26 +6426,46 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getOperand(1).getValueType(),
Op.getOperand(1), Op.getOperand(2)));
return SDValue();
+ case Intrinsic::aarch64_neon_sqrshl:
+ if (Op.getValueType().isVector())
+ return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::SQRSHL, DAG);
+ case Intrinsic::aarch64_neon_sqshl:
+ if (Op.getValueType().isVector())
+ return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSHL, DAG);
+ case Intrinsic::aarch64_neon_uqrshl:
+ if (Op.getValueType().isVector())
+ return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::UQRSHL, DAG);
+ case Intrinsic::aarch64_neon_uqshl:
+ if (Op.getValueType().isVector())
+ return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSHL, DAG);
case Intrinsic::aarch64_neon_sqadd:
if (Op.getValueType().isVector())
return DAG.getNode(ISD::SADDSAT, DL, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
- return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::SQADD, DAG);
+
case Intrinsic::aarch64_neon_sqsub:
if (Op.getValueType().isVector())
return DAG.getNode(ISD::SSUBSAT, DL, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
- return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::SQSUB, DAG);
+
case Intrinsic::aarch64_neon_uqadd:
if (Op.getValueType().isVector())
return DAG.getNode(ISD::UADDSAT, DL, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
- return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::UQADD, DAG);
case Intrinsic::aarch64_neon_uqsub:
if (Op.getValueType().isVector())
return DAG.getNode(ISD::USUBSAT, DL, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
- return SDValue();
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::UQSUB, DAG);
+ case Intrinsic::aarch64_neon_sqdmulls_scalar:
+ return lowerIntNeonIntrinsic(Op, AArch64ISD::SQDMULL, DAG);
case Intrinsic::aarch64_sve_whilelt:
return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true,
/*IsEqual=*/false);
@@ -6382,9 +6499,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_lastb:
return DAG.getNode(AArch64ISD::LASTB, DL, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2));
- case Intrinsic::aarch64_sve_rev:
- return DAG.getNode(ISD::VECTOR_REVERSE, DL, Op.getValueType(),
- Op.getOperand(1));
case Intrinsic::aarch64_sve_tbl:
return DAG.getNode(AArch64ISD::TBL, DL, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
@@ -6710,8 +6824,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
}
+/// Helper function to check if a small vector load can be optimized.
+static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD,
+ const AArch64Subtarget &Subtarget) {
+ if (!Subtarget.isNeonAvailable())
+ return false;
+ if (LD->isVolatile())
+ return false;
+
+ EVT MemVT = LD->getMemoryVT();
+ if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16)
+ return false;
+
+ Align Alignment = LD->getAlign();
+ Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue());
+ if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment)
+ return false;
+
+ return true;
+}
+
bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
EVT ExtVT = ExtVal.getValueType();
+ // Small, illegal vectors can be extended inreg.
+ if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) {
+ if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 &&
+ isEligibleForSmallVectorLoadOpt(Load, *Subtarget))
+ return true;
+ }
if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
return false;
@@ -7170,12 +7310,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
return Result;
}
+/// Helper function to optimize loads of extended small vectors.
+/// These patterns would otherwise get scalarized into inefficient sequences.
+static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) {
+ const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget))
+ return SDValue();
+
+ EVT MemVT = Load->getMemoryVT();
+ EVT ResVT = Load->getValueType(0);
+ unsigned NumElts = ResVT.getVectorNumElements();
+ unsigned DstEltBits = ResVT.getScalarSizeInBits();
+ unsigned SrcEltBits = MemVT.getScalarSizeInBits();
+
+ unsigned ExtOpcode;
+ switch (Load->getExtensionType()) {
+ case ISD::EXTLOAD:
+ case ISD::ZEXTLOAD:
+ ExtOpcode = ISD::ZERO_EXTEND;
+ break;
+ case ISD::SEXTLOAD:
+ ExtOpcode = ISD::SIGN_EXTEND;
+ break;
+ case ISD::NON_EXTLOAD:
+ return SDValue();
+ }
+
+ SDLoc DL(Load);
+ SDValue Chain = Load->getChain();
+ SDValue BasePtr = Load->getBasePtr();
+ const MachinePointerInfo &PtrInfo = Load->getPointerInfo();
+ Align Alignment = Load->getAlign();
+
+ // Load the data as an FP scalar to avoid issues with integer loads.
+ unsigned LoadBits = MemVT.getStoreSizeInBits();
+ MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits);
+ SDValue ScalarLoad =
+ DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment);
+
+ MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits);
+ SDValue ScalarToVec =
+ DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad);
+ MVT BitcastTy =
+ MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits);
+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec);
+
+ SDValue Res = Bitcast;
+ unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits();
+ unsigned CurrentNumElts = Res.getValueType().getVectorNumElements();
+ while (CurrentEltBits < DstEltBits) {
+ if (Res.getValueSizeInBits() >= 128) {
+ CurrentNumElts = CurrentNumElts / 2;
+ MVT ExtractVT =
+ MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res,
+ DAG.getConstant(0, DL, MVT::i64));
+ }
+ CurrentEltBits = CurrentEltBits * 2;
+ MVT ExtVT =
+ MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
+ Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res);
+ }
+
+ if (CurrentNumElts != NumElts) {
+ MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts);
+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res,
+ DAG.getConstant(0, DL, MVT::i64));
+ }
+
+ return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL);
+}
+
SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
assert(LoadNode && "Expected custom lowering of a load node");
+ if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG))
+ return Result;
+
if (LoadNode->getMemoryVT() == MVT::i64x8) {
SmallVector<SDValue, 8> Ops;
SDValue Base = LoadNode->getBasePtr();
@@ -7194,37 +7408,38 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Loaded, Chain}, DL);
}
- // Custom lowering for extending v4i8 vector loads.
- EVT VT = Op->getValueType(0);
- assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
-
- if (LoadNode->getMemoryVT() != MVT::v4i8)
- return SDValue();
-
- // Avoid generating unaligned loads.
- if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
- return SDValue();
+ return SDValue();
+}
- unsigned ExtType;
- if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
- ExtType = ISD::SIGN_EXTEND;
- else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
- LoadNode->getExtensionType() == ISD::EXTLOAD)
- ExtType = ISD::ZERO_EXTEND;
- else
- return SDValue();
+// Convert to ContainerVT with no-op casts where possible.
+static SDValue convertToSVEContainerType(SDLoc DL, SDValue Vec, EVT ContainerVT,
+ SelectionDAG &DAG) {
+ EVT VecVT = Vec.getValueType();
+ if (VecVT.isFloatingPoint()) {
+ // Use no-op casts for floating-point types.
+ EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType());
+ Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedVT, Vec);
+ Vec = DAG.getNode(AArch64ISD::NVCAST, DL, ContainerVT, Vec);
+ } else {
+ // Extend integers (may not be a no-op).
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
+ }
+ return Vec;
+}
- SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
- LoadNode->getBasePtr(), MachinePointerInfo());
- SDValue Chain = Load.getValue(1);
- SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
- SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
- SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
- Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
- DAG.getConstant(0, DL, MVT::i64));
- if (VT == MVT::v4i32)
- Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
- return DAG.getMergeValues({Ext, Chain}, DL);
+// Convert to VecVT with no-op casts where possible.
+static SDValue convertFromSVEContainerType(SDLoc DL, SDValue Vec, EVT VecVT,
+ SelectionDAG &DAG) {
+ if (VecVT.isFloatingPoint()) {
+ // Use no-op casts for floating-point types.
+ EVT PackedVT = getPackedSVEVectorVT(VecVT.getScalarType());
+ Vec = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVT, Vec);
+ Vec = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VecVT, Vec);
+ } else {
+ // Truncate integers (may not be a no-op).
+ Vec = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Vec);
+ }
+ return Vec;
}
SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
@@ -7278,49 +7493,49 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
// Get legal type for compact instruction
EVT ContainerVT = getSVEContainerType(VecVT);
- EVT CastVT = VecVT.changeVectorElementTypeToInteger();
- // Convert to i32 or i64 for smaller types, as these are the only supported
+ // Convert to 32 or 64 bits for smaller types, as these are the only supported
// sizes for compact.
- if (ContainerVT != VecVT) {
- Vec = DAG.getBitcast(CastVT, Vec);
- Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
- }
+ Vec = convertToSVEContainerType(DL, Vec, ContainerVT, DAG);
SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
+ Vec);
// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
SDValue Offset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
- DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask,
+ Mask);
SDValue IndexMask = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MaskVT,
- DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
DAG.getConstant(0, DL, MVT::i64), Offset);
Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
}
+ // If we changed the element type before, we need to convert it back.
+ if (ElmtVT.isFloatingPoint())
+ Compressed = convertFromSVEContainerType(DL, Compressed, VecVT, DAG);
+
// Extracting from a legal SVE type before truncating produces better code.
if (IsFixedLength) {
- Compressed = DAG.getNode(
- ISD::EXTRACT_SUBVECTOR, DL,
- FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
- Compressed, DAG.getConstant(0, DL, MVT::i64));
- CastVT = FixedVecVT.changeVectorElementTypeToInteger();
+ EVT FixedSubVector = VecVT.isInteger()
+ ? FixedVecVT.changeVectorElementType(
+ ContainerVT.getVectorElementType())
+ : FixedVecVT;
+ Compressed = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedSubVector,
+ Compressed, DAG.getConstant(0, DL, MVT::i64));
VecVT = FixedVecVT;
}
- // If we changed the element type before, we need to convert it back.
- if (ContainerVT != VecVT) {
- Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
- Compressed = DAG.getBitcast(VecVT, Compressed);
- }
+ if (VecVT.isInteger())
+ Compressed = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed);
return Compressed;
}
@@ -7428,10 +7643,10 @@ static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
DAG.getUNDEF(ExpVT), Exp, Zero);
SDValue VPg = getPTrue(DAG, DL, XVT.changeVectorElementType(MVT::i1),
AArch64SVEPredPattern::all);
- SDValue FScale =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XVT,
- DAG.getConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64),
- VPg, VX, VExp);
+ SDValue FScale = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, XVT,
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64), VPg,
+ VX, VExp);
SDValue Final =
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
if (X.getValueType() != XScalarTy)
@@ -7518,6 +7733,117 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
EndOfTrmp);
}
+SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ if (VT.getScalarType() != MVT::bf16 ||
+ (Subtarget->hasSVEB16B16() &&
+ Subtarget->isNonStreamingSVEorSME2Available()))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+
+ assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
+ assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
+ "Unexpected FMUL VT");
+
+ auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
+ return [&, IID](EVT VT, auto... Ops) {
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(IID, DL, MVT::i32), Ops...);
+ };
+ };
+
+ auto Reinterpret = [&](SDValue Value, EVT VT) {
+ EVT SrcVT = Value.getValueType();
+ if (VT == SrcVT)
+ return Value;
+ if (SrcVT.isFixedLengthVector())
+ return convertToScalableVector(DAG, VT, Value);
+ if (VT.isFixedLengthVector())
+ return convertFromScalableVector(DAG, VT, Value);
+ return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
+ };
+
+ bool UseSVEBFMLAL = VT.isScalableVector();
+ auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
+ auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
+
+ // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
+ // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
+ auto BFMLALB =
+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
+ : Intrinsic::aarch64_neon_bfmlalb);
+ auto BFMLALT =
+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
+ : Intrinsic::aarch64_neon_bfmlalt);
+
+ EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
+ SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
+ SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
+
+ // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
+ // instructions. These result in two f32 vectors, which can be converted back
+ // to bf16 with FCVT and FCVTNT.
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+
+ // All SVE intrinsics expect to operate on full bf16 vector types.
+ if (UseSVEBFMLAL) {
+ LHS = Reinterpret(LHS, MVT::nxv8bf16);
+ RHS = Reinterpret(RHS, MVT::nxv8bf16);
+ }
+
+ SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
+ SDValue BottomBF16 =
+ FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
+ // Note: nxv4bf16 only uses even lanes.
+ if (VT == MVT::nxv4bf16)
+ return Reinterpret(BottomBF16, VT);
+
+ SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
+ SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
+ return Reinterpret(TopBF16, VT);
+}
+
+SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
+ SDValue OpA = Op->getOperand(0);
+ SDValue OpB = Op->getOperand(1);
+ SDValue OpC = Op->getOperand(2);
+ EVT VT = Op.getValueType();
+ SDLoc DL(Op);
+
+ assert(VT.isVector() && "Scalar fma lowering should be handled by patterns");
+
+ // Bail early if we're definitely not looking to merge FNEGs into the FMA.
+ if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64)
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+
+ if (OpC.getOpcode() != ISD::FNEG)
+ return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())
+ ? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED)
+ : Op; // Fallback to NEON lowering.
+
+ // Convert FMA/FNEG nodes to SVE to enable the following patterns:
+ // fma(a, b, neg(c)) -> fnmls(a, b, c)
+ // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+ // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ auto ConvertToScalableFnegMt = [&](SDValue Op) {
+ if (Op.getOpcode() == ISD::FNEG)
+ Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
+ return convertToScalableVector(DAG, ContainerVT, Op);
+ };
+
+ OpA = ConvertToScalableFnegMt(OpA);
+ OpB = ConvertToScalableFnegMt(OpB);
+ OpC = ConvertToScalableFnegMt(OpC);
+
+ SDValue ScalableRes =
+ DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7592,9 +7918,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::FSUB:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
case ISD::FMUL:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
+ return LowerFMUL(Op, DAG);
case ISD::FMA:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+ return LowerFMA(Op, DAG);
case ISD::FDIV:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
case ISD::FNEG:
@@ -7639,6 +7965,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::BUILD_VECTOR:
return LowerBUILD_VECTOR(Op, DAG);
+ case ISD::ANY_EXTEND_VECTOR_INREG:
+ case ISD::SIGN_EXTEND_VECTOR_INREG:
+ return LowerEXTEND_VECTOR_INREG(Op, DAG);
case ISD::ZERO_EXTEND_VECTOR_INREG:
return LowerZERO_EXTEND_VECTOR_INREG(Op, DAG);
case ISD::VECTOR_SHUFFLE:
@@ -7720,7 +8049,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::MSTORE:
- return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
+ return LowerMSTORE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
@@ -7875,6 +8204,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -8094,7 +8424,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+ DAG.getTargetConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
// Copy the address of the TPIDR2 block into X0 before 'calling' the
// RESTORE_ZA pseudo.
SDValue Glue;
@@ -8109,7 +8439,7 @@ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
// Finally reset the TPIDR2_EL0 register to 0.
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+ DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
return Chain;
@@ -8426,7 +8756,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");
- uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
+ TypeSize PartSize = VA.getValVT().getStoreSize();
unsigned NumParts = 1;
if (Ins[i].Flags.isInConsecutiveRegs()) {
while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -8443,16 +8773,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
InVals.push_back(ArgValue);
NumParts--;
if (NumParts > 0) {
- SDValue BytesIncrement;
- if (PartLoad.isScalableVector()) {
- BytesIncrement = DAG.getVScale(
- DL, Ptr.getValueType(),
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
- } else {
- BytesIncrement = DAG.getConstant(
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
- Ptr.getValueType());
- }
+ SDValue BytesIncrement =
+ DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
ExtraArgLocs++;
@@ -8699,15 +9021,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
}
- if (getTM().useNewSMEABILowering()) {
- // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
- if (Attrs.isNewZT0())
- Chain = DAG.getNode(
- ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
- DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
- DAG.getTargetConstant(0, DL, MVT::i32));
- }
-
return Chain;
}
@@ -9430,6 +9743,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState())
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
+ else if (CallAttrs.requiresPreservingZT0())
+ ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
else if (CallAttrs.caller().hasZAState() ||
CallAttrs.caller().hasZT0State())
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
@@ -9517,7 +9832,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+ DAG.getTargetConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
TPIDR2ObjAddr);
OptimizationRemarkEmitter ORE(&MF.getFunction());
ORE.emit([&]() {
@@ -9549,7 +9864,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
- bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
+ bool ShouldPreserveZT0 =
+ !UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();
// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
@@ -9562,7 +9878,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
- bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
+ bool DisableZA =
+ !UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");
@@ -9581,8 +9898,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// using a chain can result in incorrect scheduling. The markers refer to
// the position just before the CALLSEQ_START (though occur after as
// CALLSEQ_START lacks in-glue).
- Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
- {Chain, Chain.getValue(1)});
+ Chain =
+ DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other, MVT::Glue),
+ {Chain, Chain.getValue(1)});
}
}
@@ -9663,8 +9981,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
assert((isScalable || Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");
- uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
- uint64_t PartSize = StoreSize;
+ TypeSize StoreSize = VA.getValVT().getStoreSize();
+ TypeSize PartSize = StoreSize;
unsigned NumParts = 1;
if (Outs[i].Flags.isInConsecutiveRegs()) {
while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -9675,7 +9993,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
MachineFrameInfo &MFI = MF.getFrameInfo();
- int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
+ int FI =
+ MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false);
if (isScalable) {
bool IsPred = VA.getValVT() == MVT::aarch64svcount ||
VA.getValVT().getVectorElementType() == MVT::i1;
@@ -9696,16 +10015,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
NumParts--;
if (NumParts > 0) {
- SDValue BytesIncrement;
- if (isScalable) {
- BytesIncrement = DAG.getVScale(
- DL, Ptr.getValueType(),
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
- } else {
- BytesIncrement = DAG.getConstant(
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
- Ptr.getValueType());
- }
+ SDValue BytesIncrement =
+ DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
MPI = MachinePointerInfo(MPI.getAddrSpace());
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
@@ -9998,6 +10309,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (InGlue.getNode())
Ops.push_back(InGlue);
+ if (CLI.DeactivationSymbol)
+ Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol));
+
// If we're doing a tall call, use a TC_RETURN here rather than an
// actual call instruction.
if (IsTailCall) {
@@ -10047,7 +10361,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
getSMToggleCondition(CallAttrs));
}
- if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
+ if (!UseNewSMEABILowering &&
+ (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -10587,16 +10902,41 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
const SDLoc &DL,
SelectionDAG &DAG) const {
EVT PtrVT = getPointerTy(DAG.getDataLayout());
+ auto &MF = DAG.getMachineFunction();
+ auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ SDValue Glue;
SDValue Chain = DAG.getEntryNode();
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
+ SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal);
+ bool RequiresSMChange = TLSCallAttrs.requiresSMChange();
+
+ auto ChainAndGlue = [](SDValue Chain) -> std::pair<SDValue, SDValue> {
+ return {Chain, Chain.getValue(1)};
+ };
+
+ if (RequiresSMChange)
+ std::tie(Chain, Glue) =
+ ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue,
+ getSMToggleCondition(TLSCallAttrs)));
+
unsigned Opcode =
DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT()
? AArch64ISD::TLSDESC_AUTH_CALLSEQ
: AArch64ISD::TLSDESC_CALLSEQ;
- Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr});
- SDValue Glue = Chain.getValue(1);
+ SDValue Ops[] = {Chain, SymAddr, Glue};
+ std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
+ Opcode, DL, NodeTys, Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back()));
+
+ if (TLSCallAttrs.requiresLazySave())
+ std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
+ AArch64ISD::REQUIRES_ZA_SAVE, DL, NodeTys, {Chain, Chain.getValue(1)}));
+
+ if (RequiresSMChange)
+ std::tie(Chain, Glue) =
+ ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
+ getSMToggleCondition(TLSCallAttrs)));
return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
}
@@ -11505,7 +11845,12 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
}
if (LHS.getValueType().isInteger()) {
-
+ if (Subtarget->hasCSSC() && CC == ISD::SETNE && isNullConstant(RHS)) {
+ SDValue One = DAG.getConstant(1, DL, LHS.getValueType());
+ SDValue UMin = DAG.getNode(ISD::UMIN, DL, LHS.getValueType(), LHS, One);
+ SDValue Res = DAG.getZExtOrTrunc(UMin, DL, VT);
+ return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res;
+ }
simplifySetCCIntoEq(CC, LHS, RHS, DAG, DL);
SDValue CCVal;
@@ -13409,8 +13754,8 @@ SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) {
return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, VT,
- DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), SourceVec,
- MaskSourceVec);
+ DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32),
+ SourceVec, MaskSourceVec);
}
// Gather data to see if the operation can be modelled as a
@@ -14266,14 +14611,16 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst);
Shuffle = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
- DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
+ DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32),
+ V1Cst,
DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
} else {
if (IndexLen == 8) {
V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst);
Shuffle = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
- DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
+ DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32),
+ V1Cst,
DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
} else {
// FIXME: We cannot, for the moment, emit a TBL2 instruction because we
@@ -14284,8 +14631,8 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
// IndexLen));
Shuffle = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
- DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst,
- V2Cst,
+ DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32),
+ V1Cst, V2Cst,
DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
}
}
@@ -14453,6 +14800,40 @@ static SDValue tryToConvertShuffleOfTbl2ToTbl4(SDValue Op,
Tbl2->getOperand(1), Tbl2->getOperand(2), TBLMask});
}
+SDValue
+AArch64TargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ assert(VT.isScalableVector() && "Unexpected result type!");
+
+ bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
+ unsigned UnpackOpcode = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
+
+ // Repeatedly unpack Val until the result is of the desired type.
+ SDValue Val = Op.getOperand(0);
+ switch (Val.getSimpleValueType().SimpleTy) {
+ default:
+ return SDValue();
+ case MVT::nxv16i8:
+ Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv8i16, Val);
+ if (VT == MVT::nxv8i16)
+ break;
+ [[fallthrough]];
+ case MVT::nxv8i16:
+ Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv4i32, Val);
+ if (VT == MVT::nxv4i32)
+ break;
+ [[fallthrough]];
+ case MVT::nxv4i32:
+ Val = DAG.getNode(UnpackOpcode, DL, MVT::nxv2i64, Val);
+ assert(VT == MVT::nxv2i64 && "Unexpected result type!");
+ break;
+ }
+
+ return Val;
+}
+
// Baseline legalization for ZERO_EXTEND_VECTOR_INREG will blend-in zeros,
// but we don't have an appropriate instruction,
// so custom-lower it as ZIP1-with-zeros.
@@ -14461,6 +14842,10 @@ AArch64TargetLowering::LowerZERO_EXTEND_VECTOR_INREG(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
+
+ if (VT.isScalableVector())
+ return LowerEXTEND_VECTOR_INREG(Op, DAG);
+
SDValue SrcOp = Op.getOperand(0);
EVT SrcVT = SrcOp.getValueType();
assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
@@ -14570,17 +14955,20 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
}
unsigned WhichResult;
- if (isZIPMask(ShuffleMask, NumElts, WhichResult)) {
+ unsigned OperandOrder;
+ if (isZIPMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) {
unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
- return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2);
+ return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2,
+ OperandOrder == 0 ? V2 : V1);
}
if (isUZPMask(ShuffleMask, NumElts, WhichResult)) {
unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2);
}
- if (isTRNMask(ShuffleMask, NumElts, WhichResult)) {
+ if (isTRNMask(ShuffleMask, NumElts, WhichResult, OperandOrder)) {
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
- return DAG.getNode(Opc, DL, V1.getValueType(), V1, V2);
+ return DAG.getNode(Opc, DL, V1.getValueType(), OperandOrder == 0 ? V1 : V2,
+ OperandOrder == 0 ? V2 : V1);
}
if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
@@ -16292,9 +16680,9 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
isREVMask(M, EltSize, NumElts, 16) ||
isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
isSingletonEXTMask(M, VT, DummyUnsigned) ||
- isTRNMask(M, NumElts, DummyUnsigned) ||
+ isTRNMask(M, NumElts, DummyUnsigned, DummyUnsigned) ||
isUZPMask(M, NumElts, DummyUnsigned) ||
- isZIPMask(M, NumElts, DummyUnsigned) ||
+ isZIPMask(M, NumElts, DummyUnsigned, DummyUnsigned) ||
isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
@@ -16438,10 +16826,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
DAG.getTargetConstant(Cnt, DL, MVT::i32));
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
- DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL,
- MVT::i32),
- Op.getOperand(0), Op.getOperand(1));
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getTargetConstant(Intrinsic::aarch64_neon_ushl, DL, MVT::i32),
+ Op.getOperand(0), Op.getOperand(1));
case ISD::SRA:
case ISD::SRL:
if (VT.isScalableVector() &&
@@ -16943,7 +17331,7 @@ SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
template <unsigned NumVecs>
static bool
setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
- AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) {
+ AArch64TargetLowering::IntrinsicInfo &Info, const CallBase &CI) {
Info.opc = ISD::INTRINSIC_VOID;
// Retrieve EC from first vector argument.
const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType());
@@ -16968,7 +17356,7 @@ setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
/// MemIntrinsicNodes. The associated MachineMemOperands record the alignment
/// specified in the intrinsic calls.
bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
- const CallInst &I,
+ const CallBase &I,
MachineFunction &MF,
unsigned Intrinsic) const {
auto &DL = I.getDataLayout();
@@ -18537,7 +18925,7 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
case MVT::f64:
return true;
case MVT::bf16:
- return VT.isScalableVector() && Subtarget->hasSVEB16B16() &&
+ return VT.isScalableVector() && Subtarget->hasBF16() &&
Subtarget->isNonStreamingSVEorSME2Available();
default:
break;
@@ -18720,6 +19108,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
return (Index == 0 || Index == ResVT.getVectorMinNumElements());
}
+bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits(
+ LLVMContext &Context, EVT VT) const {
+ if (getTypeAction(Context, VT) != TypeExpandInteger)
+ return false;
+
+ EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2);
+ return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal;
+}
+
/// Turn vector tests of the signbit in the form of:
/// xor (sra X, elt_size(X)-1), -1
/// into:
@@ -19282,20 +19679,37 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor,
return CSNeg;
}
-static std::optional<unsigned> IsSVECntIntrinsic(SDValue S) {
+static bool IsSVECntIntrinsic(SDValue S) {
switch(getIntrinsicID(S.getNode())) {
default:
break;
case Intrinsic::aarch64_sve_cntb:
- return 8;
case Intrinsic::aarch64_sve_cnth:
- return 16;
case Intrinsic::aarch64_sve_cntw:
- return 32;
case Intrinsic::aarch64_sve_cntd:
- return 64;
+ return true;
+ }
+ return false;
+}
+
+// Returns the maximum (scalable) value that can be returned by an SVE count
+// intrinsic. Returns std::nullopt if \p Op is not aarch64_sve_cnt*.
+static std::optional<ElementCount> getMaxValueForSVECntIntrinsic(SDValue Op) {
+ Intrinsic::ID IID = getIntrinsicID(Op.getNode());
+ if (IID == Intrinsic::aarch64_sve_cntp)
+ return Op.getOperand(1).getValueType().getVectorElementCount();
+ switch (IID) {
+ case Intrinsic::aarch64_sve_cntd:
+ return ElementCount::getScalable(2);
+ case Intrinsic::aarch64_sve_cntw:
+ return ElementCount::getScalable(4);
+ case Intrinsic::aarch64_sve_cnth:
+ return ElementCount::getScalable(8);
+ case Intrinsic::aarch64_sve_cntb:
+ return ElementCount::getScalable(16);
+ default:
+ return std::nullopt;
}
- return {};
}
/// Calculates what the pre-extend type is, based on the extension
@@ -19939,7 +20353,9 @@ static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
return Res;
EVT VT = N->getValueType(0);
- if (VT != MVT::f32 && VT != MVT::f64)
+ if (VT != MVT::f16 && VT != MVT::f32 && VT != MVT::f64)
+ return SDValue();
+ if (VT == MVT::f16 && !Subtarget->hasFullFP16())
return SDValue();
// Only optimize when the source and destination types have the same width.
@@ -20037,7 +20453,7 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
: Intrinsic::aarch64_neon_vcvtfp2fxu;
SDValue FixConv =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy,
- DAG.getConstant(IntrinsicOpcode, DL, MVT::i32),
+ DAG.getTargetConstant(IntrinsicOpcode, DL, MVT::i32),
Op->getOperand(0), DAG.getTargetConstant(C, DL, MVT::i32));
// We can handle smaller integers by generating an extra trunc.
if (IntBits < FloatBits)
@@ -21591,9 +22007,8 @@ static SDValue performBuildVectorCombine(SDNode *N,
SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0);
if (LowLanesSrcVec.getValueType() == MVT::v2f64) {
SDValue HighLanes;
- if (Elt2->getOpcode() == ISD::UNDEF &&
- Elt3->getOpcode() == ISD::UNDEF) {
- HighLanes = DAG.getUNDEF(MVT::v2f32);
+ if (Elt2->isUndef() && Elt3->isUndef()) {
+ HighLanes = DAG.getPOISON(MVT::v2f32);
} else if (Elt2->getOpcode() == ISD::FP_ROUND &&
Elt3->getOpcode() == ISD::FP_ROUND &&
isa<ConstantSDNode>(Elt2->getOperand(1)) &&
@@ -22296,6 +22711,69 @@ static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift);
}
+// Attempt to combine the following patterns:
+// SUB x, (CSET LO, (CMP a, b)) -> SBC x, 0, (CMP a, b)
+// SUB (SUB x, y), (CSET LO, (CMP a, b)) -> SBC x, y, (CMP a, b)
+// The CSET may be preceded by a ZEXT.
+static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) {
+ if (N->getOpcode() != ISD::SUB)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+
+ SDValue N1 = N->getOperand(1);
+ if (N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse())
+ N1 = N1.getOperand(0);
+ if (!N1.hasOneUse() || getCSETCondCode(N1) != AArch64CC::LO)
+ return SDValue();
+
+ SDValue Flags = N1.getOperand(3);
+ if (Flags.getOpcode() != AArch64ISD::SUBS)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue N0 = N->getOperand(0);
+ if (N0->getOpcode() == ISD::SUB)
+ return DAG.getNode(AArch64ISD::SBC, DL, VT, N0.getOperand(0),
+ N0.getOperand(1), Flags);
+ return DAG.getNode(AArch64ISD::SBC, DL, VT, N0, DAG.getConstant(0, DL, VT),
+ Flags);
+}
+
+// add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW
+// ->
+// X = trunc(ashr(A, C)); add(x, lshr(X, BW-1)
+// The original converts into ashr+lshr+xtn+xtn+add. The second becomes
+// ashr+xtn+usra. The first form has less total latency due to more parallelism,
+// but more micro-ops and seems to be slower in practice.
+static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) {
+ using namespace llvm::SDPatternMatch;
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8)
+ return SDValue();
+
+ SDValue AShr, LShr;
+ if (!sd_match(N, m_Add(m_Trunc(m_Value(AShr)), m_Trunc(m_Value(LShr)))))
+ return SDValue();
+ if (AShr.getOpcode() != AArch64ISD::VASHR)
+ std::swap(AShr, LShr);
+ if (AShr.getOpcode() != AArch64ISD::VASHR ||
+ LShr.getOpcode() != AArch64ISD::VLSHR ||
+ AShr.getOperand(0) != LShr.getOperand(0) ||
+ AShr.getConstantOperandVal(1) < VT.getScalarSizeInBits() ||
+ LShr.getConstantOperandVal(1) != VT.getScalarSizeInBits() * 2 - 1)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, AShr);
+ SDValue Shift = DAG.getNode(
+ AArch64ISD::VLSHR, DL, VT, Trunc,
+ DAG.getTargetConstant(VT.getScalarSizeInBits() - 1, DL, MVT::i32));
+ return DAG.getNode(ISD::ADD, DL, VT, Trunc, Shift);
+}
+
static SDValue performAddSubCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Try to change sum of two reductions.
@@ -22317,6 +22795,10 @@ static SDValue performAddSubCombine(SDNode *N,
return Val;
if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG))
return Val;
+ if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG))
+ return Val;
+ if (SDValue Val = performAddTruncShiftCombine(N, DCI.DAG))
+ return Val;
if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG))
return Val;
@@ -22968,11 +23450,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
return DAG.getNode(ISD::OR, SDLoc(N), N->getValueType(0), N->getOperand(2),
N->getOperand(3));
case Intrinsic::aarch64_sve_sabd_u:
- return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
- N->getOperand(2), N->getOperand(3));
+ if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDS, DAG, true))
+ return V;
+ return DAG.getNode(AArch64ISD::ABDS_PRED, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_uabd_u:
- return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
- N->getOperand(2), N->getOperand(3));
+ if (SDValue V = convertMergedOpToPredOp(N, ISD::ABDU, DAG, true))
+ return V;
+ return DAG.getNode(AArch64ISD::ABDU_PRED, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
case Intrinsic::aarch64_sve_sdiv_u:
return DAG.getNode(AArch64ISD::SDIV_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -23895,7 +24381,7 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
// uzp1(x, undef) -> concat(truncate(x), undef)
- if (Op1.getOpcode() == ISD::UNDEF) {
+ if (Op1.isUndef()) {
EVT BCVT = MVT::Other, HalfVT = MVT::Other;
switch (ResVT.getSimpleVT().SimpleTy) {
default:
@@ -26038,7 +26524,7 @@ static SDValue performCSELCombine(SDNode *N,
// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
if (SDValue Folded = foldCSELofCTTZ(N, DAG))
- return Folded;
+ return Folded;
// CSEL a, b, cc, SUBS(x, y) -> CSEL a, b, swapped(cc), SUBS(y, x)
// if SUB(y, x) already exists and we can produce a swapped predicate for cc.
@@ -26063,29 +26549,6 @@ static SDValue performCSELCombine(SDNode *N,
}
}
- // CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't
- // use overflow flags, to avoid the comparison with zero. In case of success,
- // this also replaces the original SUB(x,y) with the newly created SUBS(x,y).
- // NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB
- // nodes with their SUBS equivalent as is already done for other flag-setting
- // operators, in which case doing the replacement here becomes redundant.
- if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) &&
- isNullConstant(Cond.getOperand(1))) {
- SDValue Sub = Cond.getOperand(0);
- AArch64CC::CondCode CC =
- static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
- if (Sub.getOpcode() == ISD::SUB &&
- (CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI ||
- CC == AArch64CC::PL)) {
- SDLoc DL(N);
- SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(),
- Sub.getOperand(0), Sub.getOperand(1));
- DCI.CombineTo(Sub.getNode(), Subs);
- DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1));
- return SDValue(N, 0);
- }
- }
-
// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;
@@ -26364,8 +26827,7 @@ performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SDValue L1 = LHS->getOperand(1);
SDValue L2 = LHS->getOperand(2);
- if (L0.getOpcode() == ISD::UNDEF && isNullConstant(L2) &&
- isSignExtInReg(L1)) {
+ if (L0.isUndef() && isNullConstant(L2) && isSignExtInReg(L1)) {
SDLoc DL(N);
SDValue Shl = L1.getOperand(0);
SDValue NewLHS = DAG.getNode(ISD::INSERT_SUBVECTOR, DL,
@@ -26629,22 +27091,25 @@ static SDValue performSelectCombine(SDNode *N,
assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) &&
"Scalar-SETCC feeding SELECT has unexpected result type!");
- // If NumMaskElts == 0, the comparison is larger than select result. The
- // largest real NEON comparison is 64-bits per lane, which means the result is
- // at most 32-bits and an illegal vector. Just bail out for now.
- EVT SrcVT = N0.getOperand(0).getValueType();
-
// Don't try to do this optimization when the setcc itself has i1 operands.
// There are no legal vectors of i1, so this would be pointless. v1f16 is
// ruled out to prevent the creation of setcc that need to be scalarized.
+ EVT SrcVT = N0.getOperand(0).getValueType();
if (SrcVT == MVT::i1 ||
(SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16))
return SDValue();
- int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
+ // If NumMaskElts == 0, the comparison is larger than select result. The
+ // largest real NEON comparison is 64-bits per lane, which means the result is
+ // at most 32-bits and an illegal vector. Just bail out for now.
+ unsigned NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
if (!ResVT.isVector() || NumMaskElts == 0)
return SDValue();
+ // Avoid creating vectors with excessive VFs before legalization.
+ if (DCI.isBeforeLegalize() && NumMaskElts != ResVT.getVectorNumElements())
+ return SDValue();
+
SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts);
EVT CCVT = SrcVT.changeVectorElementTypeToInteger();
@@ -27293,8 +27758,8 @@ static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG,
// ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to
// `aarch64_sve_prfb_gather_uxtw_index`.
SDLoc DL(N);
- Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL,
- MVT::i64);
+ Ops[1] = DAG.getTargetConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index,
+ DL, MVT::i64);
return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
}
@@ -28567,7 +29032,8 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
if ((Index != 0) && (Index != ResEC.getKnownMinValue()))
return;
- unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
+ unsigned Opcode = (Index == 0) ? (unsigned)ISD::ANY_EXTEND_VECTOR_INREG
+ : (unsigned)AArch64ISD::UUNPKHI;
EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
@@ -29294,12 +29760,26 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
AI->getOperation() == AtomicRMWInst::FMinimum))
return AtomicExpansionKind::None;
- // Nand is not supported in LSE.
// Leave 128 bits to LLSC or CmpXChg.
- if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128 &&
- !AI->isFloatingPointOperation()) {
- if (Subtarget->hasLSE())
- return AtomicExpansionKind::None;
+ if (Size < 128 && !AI->isFloatingPointOperation()) {
+ if (Subtarget->hasLSE()) {
+ // Nand is not supported in LSE.
+ switch (AI->getOperation()) {
+ case AtomicRMWInst::Xchg:
+ case AtomicRMWInst::Add:
+ case AtomicRMWInst::Sub:
+ case AtomicRMWInst::And:
+ case AtomicRMWInst::Or:
+ case AtomicRMWInst::Xor:
+ case AtomicRMWInst::Max:
+ case AtomicRMWInst::Min:
+ case AtomicRMWInst::UMax:
+ case AtomicRMWInst::UMin:
+ return AtomicExpansionKind::None;
+ default:
+ break;
+ }
+ }
if (Subtarget->outlineAtomics()) {
// [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
// Don't outline them unless
@@ -29307,11 +29787,16 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
// (2) low level libgcc and compiler-rt support implemented by:
// min/max outline atomics helpers
- if (AI->getOperation() != AtomicRMWInst::Min &&
- AI->getOperation() != AtomicRMWInst::Max &&
- AI->getOperation() != AtomicRMWInst::UMin &&
- AI->getOperation() != AtomicRMWInst::UMax) {
+ switch (AI->getOperation()) {
+ case AtomicRMWInst::Xchg:
+ case AtomicRMWInst::Add:
+ case AtomicRMWInst::Sub:
+ case AtomicRMWInst::And:
+ case AtomicRMWInst::Or:
+ case AtomicRMWInst::Xor:
return AtomicExpansionKind::None;
+ default:
+ break;
}
}
}
@@ -30118,6 +30603,43 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}
+SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ auto *Store = cast<MaskedStoreSDNode>(Op);
+ EVT VT = Store->getValue().getValueType();
+ if (VT.isFixedLengthVector())
+ return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
+
+ if (!Store->isCompressingStore())
+ return SDValue();
+
+ EVT MaskVT = Store->getMask().getValueType();
+ EVT MaskExtVT = getPromotedVTForPredicate(MaskVT);
+ EVT MaskReduceVT = MaskExtVT.getScalarType();
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+
+ SDValue MaskExt =
+ DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask());
+ SDValue CntActive =
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt);
+ if (MaskReduceVT != MVT::i64)
+ CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive);
+
+ SDValue CompressedValue =
+ DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
+ Store->getMask(), DAG.getPOISON(VT));
+ SDValue CompressedMask =
+ DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);
+
+ return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
+ Store->getBasePtr(), Store->getOffset(),
+ CompressedMask, Store->getMemoryVT(),
+ Store->getMemOperand(), Store->getAddressingMode(),
+ Store->isTruncatingStore(),
+ /*isCompressing=*/false);
+}
+
SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
@@ -30132,7 +30654,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
return DAG.getMaskedStore(
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
Mask, Store->getMemoryVT(), Store->getMemOperand(),
- Store->getAddressingMode(), Store->isTruncatingStore());
+ Store->getAddressingMode(), Store->isTruncatingStore(),
+ Store->isCompressingStore());
}
SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
@@ -31159,10 +31682,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
SDValue Shuffle;
if (IsSingleOp)
- Shuffle =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
- DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
- Op1, SVEMask);
+ Shuffle = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32), Op1,
+ SVEMask);
else if (Subtarget.hasSVE2()) {
if (!MinMaxEqual) {
unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt;
@@ -31181,10 +31704,10 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
SVEMask = convertToScalableVector(
DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask);
}
- Shuffle =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
- DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
- Op1, Op2, SVEMask);
+ Shuffle = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32), Op1,
+ Op2, SVEMask);
}
Shuffle = convertFromScalableVector(DAG, VT, Shuffle);
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
@@ -31266,15 +31789,23 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
}
unsigned WhichResult;
- if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
+ unsigned OperandOrder;
+ if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult,
+ OperandOrder) &&
WhichResult == 0)
return convertFromScalableVector(
- DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2));
+ DAG, VT,
+ DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT,
+ OperandOrder == 0 ? Op1 : Op2,
+ OperandOrder == 0 ? Op2 : Op1));
- if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
+ if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult,
+ OperandOrder)) {
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
- return convertFromScalableVector(
- DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
+ SDValue TRN =
+ DAG.getNode(Opc, DL, ContainerVT, OperandOrder == 0 ? Op1 : Op2,
+ OperandOrder == 0 ? Op2 : Op1);
+ return convertFromScalableVector(DAG, VT, TRN);
}
if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0)
@@ -31314,10 +31845,14 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
return convertFromScalableVector(DAG, VT, Op);
}
- if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
+ if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult,
+ OperandOrder) &&
WhichResult != 0)
return convertFromScalableVector(
- DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2));
+ DAG, VT,
+ DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT,
+ OperandOrder == 0 ? Op1 : Op2,
+ OperandOrder == 0 ? Op2 : Op1));
if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
@@ -31344,8 +31879,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
unsigned SegmentElts = VT.getVectorNumElements() / Segments;
if (std::optional<unsigned> Lane =
isDUPQMask(ShuffleMask, Segments, SegmentElts)) {
- SDValue IID =
- DAG.getConstant(Intrinsic::aarch64_sve_dup_laneq, DL, MVT::i64);
+ SDValue IID = DAG.getTargetConstant(Intrinsic::aarch64_sve_dup_laneq,
+ DL, MVT::i64);
return convertFromScalableVector(
DAG, VT,
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
@@ -31492,22 +32027,24 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
return false;
}
case ISD::INTRINSIC_WO_CHAIN: {
- if (auto ElementSize = IsSVECntIntrinsic(Op)) {
- unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
- if (!MaxSVEVectorSizeInBits)
- MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector;
- unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize;
- // The SVE count intrinsics don't support the multiplier immediate so we
- // don't have to account for that here. The value returned may be slightly
- // over the true required bits, as this is based on the "ALL" pattern. The
- // other patterns are also exposed by these intrinsics, but they all
- // return a value that's strictly less than "ALL".
- unsigned RequiredBits = llvm::bit_width(MaxElements);
- unsigned BitWidth = Known.Zero.getBitWidth();
- if (RequiredBits < BitWidth)
- Known.Zero.setHighBits(BitWidth - RequiredBits);
+ std::optional<ElementCount> MaxCount = getMaxValueForSVECntIntrinsic(Op);
+ if (!MaxCount)
return false;
- }
+ unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
+ if (!MaxSVEVectorSizeInBits)
+ MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector;
+ unsigned VscaleMax = MaxSVEVectorSizeInBits / 128;
+ unsigned MaxValue = MaxCount->getKnownMinValue() * VscaleMax;
+ // The SVE count intrinsics don't support the multiplier immediate so we
+ // don't have to account for that here. The value returned may be slightly
+ // over the true required bits, as this is based on the "ALL" pattern. The
+ // other patterns are also exposed by these intrinsics, but they all
+ // return a value that's strictly less than "ALL".
+ unsigned RequiredBits = llvm::bit_width(MaxValue);
+ unsigned BitWidth = Known.Zero.getBitWidth();
+ if (RequiredBits < BitWidth)
+ Known.Zero.setHighBits(BitWidth - RequiredBits);
+ return false;
}
}