aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp177
1 files changed, 139 insertions, 38 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a6f8f47..3ad2905 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::v8bf16, Expand);
}
+ // For bf16, fpextend is custom lowered to be optionally expanded into shifts.
+ setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
+ setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
+ setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom);
+ setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
+ setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
+ setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom);
+
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
for (auto Op : {
ISD::SETCC,
@@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::f16, Legal);
}
- // Strict conversion to a larger type is legal
- for (auto VT : {MVT::f32, MVT::f64})
- setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
-
setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
@@ -1183,6 +1187,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setMaxDivRemBitWidthSupported(128);
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
+ if (Subtarget->hasSME())
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i1, Custom);
if (Subtarget->isNeonAvailable()) {
// FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
@@ -2669,6 +2675,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CSINC)
MAKE_CASE(AArch64ISD::THREAD_POINTER)
MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
+ MAKE_CASE(AArch64ISD::TLSDESC_AUTH_CALLSEQ)
MAKE_CASE(AArch64ISD::PROBED_ALLOCA)
MAKE_CASE(AArch64ISD::ABDS_PRED)
MAKE_CASE(AArch64ISD::ABDU_PRED)
@@ -4495,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPExtendToSVE(Op, DAG);
+ bool IsStrict = Op->isStrictFPOpcode();
+ SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
+ EVT Op0VT = Op0.getValueType();
+ if (VT == MVT::f64) {
+ // FP16->FP32 extends are legal for v32 and v4f32.
+ if (Op0VT == MVT::f32 || Op0VT == MVT::f16)
+ return Op;
+ // Split bf16->f64 extends into two fpextends.
+ if (Op0VT == MVT::bf16 && IsStrict) {
+ SDValue Ext1 =
+ DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other},
+ {Op0, Op.getOperand(0)});
+ return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other},
+ {Ext1, Ext1.getValue(1)});
+ }
+ if (Op0VT == MVT::bf16)
+ return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT,
+ DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0));
+ return SDValue();
+ }
+
+ if (VT.getScalarType() == MVT::f32) {
+ // FP16->FP32 extends are legal for v32 and v4f32.
+ if (Op0VT.getScalarType() == MVT::f16)
+ return Op;
+ if (Op0VT.getScalarType() == MVT::bf16) {
+ SDLoc DL(Op);
+ EVT IVT = VT.changeTypeToInteger();
+ if (!Op0VT.isVector()) {
+ Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0);
+ IVT = MVT::v4i32;
+ }
+
+ EVT Op0IVT = Op0.getValueType().changeTypeToInteger();
+ SDValue Ext =
+ DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0));
+ SDValue Shift =
+ DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT));
+ if (!Op0VT.isVector())
+ Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift,
+ DAG.getConstant(0, DL, MVT::i64));
+ Shift = DAG.getBitcast(VT, Shift);
+ return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL)
+ : Shift;
+ }
+ return SDValue();
+ }
+
assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
return SDValue();
}
@@ -7342,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STRICT_FP_ROUND:
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
+ case ISD::STRICT_FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
case ISD::FRAMEADDR:
return LowerFRAMEADDR(Op, DAG);
@@ -10123,8 +10179,11 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
SDValue Chain = DAG.getEntryNode();
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
- Chain =
- DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr});
+ unsigned Opcode =
+ DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT()
+ ? AArch64ISD::TLSDESC_AUTH_CALLSEQ
+ : AArch64ISD::TLSDESC_CALLSEQ;
+ Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr});
SDValue Glue = Chain.getValue(1);
return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
@@ -10136,8 +10195,12 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op,
assert(Subtarget->isTargetELF() && "This function expects an ELF target");
const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
+ AArch64FunctionInfo *MFI =
+ DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
- TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal());
+ TLSModel::Model Model = MFI->hasELFSignedGOT()
+ ? TLSModel::GeneralDynamic
+ : getTargetMachine().getTLSModel(GA->getGlobal());
if (!EnableAArch64ELFLocalDynamicTLSGeneration) {
if (Model == TLSModel::LocalDynamic)
@@ -10174,8 +10237,6 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op,
// calculation.
// These accesses will need deduplicating if there's more than one.
- AArch64FunctionInfo *MFI =
- DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
MFI->incNumLocalDynamicTLSAccesses();
// The call needs a relocation too for linker relaxation. It doesn't make
@@ -18424,7 +18485,7 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
EVT VT = A.getValueType();
SDValue Op0 = A.getOperand(0);
SDValue Op1 = A.getOperand(1);
- if (Op0.getOpcode() != Op0.getOpcode() ||
+ if (Op0.getOpcode() != Op1.getOpcode() ||
(Op0.getOpcode() != ISD::ZERO_EXTEND &&
Op0.getOpcode() != ISD::SIGN_EXTEND))
return SDValue();
@@ -21981,21 +22042,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);
SDValue Op2 = N->getOperand(2);
- if (Op2->getOpcode() != ISD::MUL ||
- !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
- !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
- return SDValue();
+ unsigned Op2Opcode = Op2->getOpcode();
+ SDValue MulOpLHS, MulOpRHS;
+ bool MulOpLHSIsSigned, MulOpRHSIsSigned;
+ if (ISD::isExtOpcode(Op2Opcode)) {
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
+ MulOpLHS = Op2->getOperand(0);
+ MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
+ } else if (Op2Opcode == ISD::MUL) {
+ SDValue ExtMulOpLHS = Op2->getOperand(0);
+ SDValue ExtMulOpRHS = Op2->getOperand(1);
+
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
- SDValue Acc = N->getOperand(1);
- SDValue Mul = N->getOperand(2);
- SDValue ExtMulOpLHS = Mul->getOperand(0);
- SDValue ExtMulOpRHS = Mul->getOperand(1);
+ MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+
+ MulOpLHS = ExtMulOpLHS->getOperand(0);
+ MulOpRHS = ExtMulOpRHS->getOperand(0);
- SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
- SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+ if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+ return SDValue();
+ } else
return SDValue();
+ SDValue Acc = N->getOperand(1);
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();
@@ -22009,8 +22084,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
- bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
- bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
@@ -22026,10 +22099,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// USDOT expects the signed operand to be last
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
- } else if (MulOpLHSIsSigned)
- Opcode = AArch64ISD::SDOT;
- else
- Opcode = AArch64ISD::UDOT;
+ } else
+ Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
@@ -27413,6 +27484,15 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
+ case Intrinsic::aarch64_sme_in_streaming_mode: {
+ SDLoc DL(N);
+ SDValue Chain = DAG.getEntryNode();
+ SDValue RuntimePStateSM =
+ getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
+ Results.push_back(
+ DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM));
+ return;
+ }
case Intrinsic::experimental_vector_match:
case Intrinsic::get_active_lane_mask: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
@@ -29648,9 +29728,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
+
+ if (Operation == ComplexDeinterleavingOperation::CDot)
+ return ScalarWidth == 32 || ScalarWidth == 64;
return 8 <= ScalarWidth && ScalarWidth <= 64;
}
+ // CDot is not supported outside of scalable/sve scopes
+ if (Operation == ComplexDeinterleavingOperation::CDot)
+ return false;
+
return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
}
@@ -29660,6 +29747,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
Value *Accumulator) const {
VectorType *Ty = cast<VectorType>(InputA->getType());
+ if (Accumulator == nullptr)
+ Accumulator = Constant::getNullValue(Ty);
bool IsScalable = Ty->isScalableTy();
bool IsInt = Ty->getElementType()->isIntegerTy();
@@ -29671,6 +29760,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
if (TyWidth > 128) {
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
+ int AccStride = cast<VectorType>(Accumulator->getType())
+ ->getElementCount()
+ .getKnownMinValue() /
+ 2;
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29680,25 +29773,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
Value *LowerSplitAcc = nullptr;
Value *UpperSplitAcc = nullptr;
- if (Accumulator) {
- LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
- UpperSplitAcc =
- B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
- }
+ Type *FullTy = Ty;
+ FullTy = Accumulator->getType();
+ auto *HalfAccTy = VectorType::getHalfElementsVectorType(
+ cast<VectorType>(Accumulator->getType()));
+ LowerSplitAcc =
+ B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
+ UpperSplitAcc =
+ B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
auto *LowerSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
auto *UpperSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
- auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
- B.getInt64(0));
- return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
+ auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
+ LowerSplitInt, B.getInt64(0));
+ return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
+ B.getInt64(AccStride));
}
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
- if (Accumulator == nullptr)
- Accumulator = Constant::getNullValue(Ty);
-
if (IsScalable) {
if (IsInt)
return B.CreateIntrinsic(
@@ -29750,6 +29844,13 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
return B.CreateIntrinsic(IntId, Ty, {InputA, InputB});
}
+ if (OperationType == ComplexDeinterleavingOperation::CDot && IsInt &&
+ IsScalable) {
+ return B.CreateIntrinsic(
+ Intrinsic::aarch64_sve_cdot, Accumulator->getType(),
+ {Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
+ }
+
return nullptr;
}