diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 177 |
1 files changed, 139 insertions, 38 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a6f8f47..3ad2905 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::v8bf16, Expand); } + // For bf16, fpextend is custom lowered to be optionally expanded into shifts. + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom); + auto LegalizeNarrowFP = [this](MVT ScalarVT) { for (auto Op : { ISD::SETCC, @@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(Op, MVT::f16, Legal); } - // Strict conversion to a larger type is legal - for (auto VT : {MVT::f32, MVT::f64}) - setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal); - setOperationAction(ISD::PREFETCH, MVT::Other, Custom); setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom); @@ -1183,6 +1187,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setMaxDivRemBitWidthSupported(128); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); + if (Subtarget->hasSME()) + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i1, Custom); if (Subtarget->isNeonAvailable()) { // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to @@ -2669,6 +2675,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::CSINC) MAKE_CASE(AArch64ISD::THREAD_POINTER) MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ) + MAKE_CASE(AArch64ISD::TLSDESC_AUTH_CALLSEQ) MAKE_CASE(AArch64ISD::PROBED_ALLOCA) MAKE_CASE(AArch64ISD::ABDS_PRED) MAKE_CASE(AArch64ISD::ABDU_PRED) @@ -4495,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPExtendToSVE(Op, DAG); + bool IsStrict = Op->isStrictFPOpcode(); + SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f64) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT == MVT::f32 || Op0VT == MVT::f16) + return Op; + // Split bf16->f64 extends into two fpextends. + if (Op0VT == MVT::bf16 && IsStrict) { + SDValue Ext1 = + DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other}, + {Op0, Op.getOperand(0)}); + return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other}, + {Ext1, Ext1.getValue(1)}); + } + if (Op0VT == MVT::bf16) + return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT, + DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0)); + return SDValue(); + } + + if (VT.getScalarType() == MVT::f32) { + // FP16->FP32 extends are legal for v32 and v4f32. + if (Op0VT.getScalarType() == MVT::f16) + return Op; + if (Op0VT.getScalarType() == MVT::bf16) { + SDLoc DL(Op); + EVT IVT = VT.changeTypeToInteger(); + if (!Op0VT.isVector()) { + Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0); + IVT = MVT::v4i32; + } + + EVT Op0IVT = Op0.getValueType().changeTypeToInteger(); + SDValue Ext = + DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0)); + SDValue Shift = + DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT)); + if (!Op0VT.isVector()) + Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift, + DAG.getConstant(0, DL, MVT::i64)); + Shift = DAG.getBitcast(VT, Shift); + return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL) + : Shift; + } + return SDValue(); + } + assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); return SDValue(); } @@ -7342,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); case ISD::FP_EXTEND: + case ISD::STRICT_FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::FRAMEADDR: return LowerFRAMEADDR(Op, DAG); @@ -10123,8 +10179,11 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr, SDValue Chain = DAG.getEntryNode(); SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); - Chain = - DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr}); + unsigned Opcode = + DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT() + ? AArch64ISD::TLSDESC_AUTH_CALLSEQ + : AArch64ISD::TLSDESC_CALLSEQ; + Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr}); SDValue Glue = Chain.getValue(1); return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue); @@ -10136,8 +10195,12 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, assert(Subtarget->isTargetELF() && "This function expects an ELF target"); const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op); + AArch64FunctionInfo *MFI = + DAG.getMachineFunction().getInfo<AArch64FunctionInfo>(); - TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal()); + TLSModel::Model Model = MFI->hasELFSignedGOT() + ? TLSModel::GeneralDynamic + : getTargetMachine().getTLSModel(GA->getGlobal()); if (!EnableAArch64ELFLocalDynamicTLSGeneration) { if (Model == TLSModel::LocalDynamic) @@ -10174,8 +10237,6 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, // calculation. // These accesses will need deduplicating if there's more than one. - AArch64FunctionInfo *MFI = - DAG.getMachineFunction().getInfo<AArch64FunctionInfo>(); MFI->incNumLocalDynamicTLSAccesses(); // The call needs a relocation too for linker relaxation. It doesn't make @@ -18424,7 +18485,7 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) { EVT VT = A.getValueType(); SDValue Op0 = A.getOperand(0); SDValue Op1 = A.getOperand(1); - if (Op0.getOpcode() != Op0.getOpcode() || + if (Op0.getOpcode() != Op1.getOpcode() || (Op0.getOpcode() != ISD::ZERO_EXTEND && Op0.getOpcode() != ISD::SIGN_EXTEND)) return SDValue(); @@ -21981,21 +22042,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, SDLoc DL(N); SDValue Op2 = N->getOperand(2); - if (Op2->getOpcode() != ISD::MUL || - !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) || - !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode())) - return SDValue(); + unsigned Op2Opcode = Op2->getOpcode(); + SDValue MulOpLHS, MulOpRHS; + bool MulOpLHSIsSigned, MulOpRHSIsSigned; + if (ISD::isExtOpcode(Op2Opcode)) { + MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND); + MulOpLHS = Op2->getOperand(0); + MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType()); + } else if (Op2Opcode == ISD::MUL) { + SDValue ExtMulOpLHS = Op2->getOperand(0); + SDValue ExtMulOpRHS = Op2->getOperand(1); + + unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode(); + unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode(); + if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) || + !ISD::isExtOpcode(ExtMulOpRHSOpcode)) + return SDValue(); - SDValue Acc = N->getOperand(1); - SDValue Mul = N->getOperand(2); - SDValue ExtMulOpLHS = Mul->getOperand(0); - SDValue ExtMulOpRHS = Mul->getOperand(1); + MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND; + MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND; + + MulOpLHS = ExtMulOpLHS->getOperand(0); + MulOpRHS = ExtMulOpRHS->getOperand(0); - SDValue MulOpLHS = ExtMulOpLHS->getOperand(0); - SDValue MulOpRHS = ExtMulOpRHS->getOperand(0); - if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + return SDValue(); + } else return SDValue(); + SDValue Acc = N->getOperand(1); EVT ReducedVT = N->getValueType(0); EVT MulSrcVT = MulOpLHS.getValueType(); @@ -22009,8 +22084,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) return SDValue(); - bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND; - bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND; // If the extensions are mixed, we should lower it to a usdot instead unsigned Opcode = 0; if (MulOpLHSIsSigned != MulOpRHSIsSigned) { @@ -22026,10 +22099,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // USDOT expects the signed operand to be last if (!MulOpRHSIsSigned) std::swap(MulOpLHS, MulOpRHS); - } else if (MulOpLHSIsSigned) - Opcode = AArch64ISD::SDOT; - else - Opcode = AArch64ISD::UDOT; + } else + Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT; // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot // product followed by a zero / sign extension @@ -27413,6 +27484,15 @@ void AArch64TargetLowering::ReplaceNodeResults( Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); return; } + case Intrinsic::aarch64_sme_in_streaming_mode: { + SDLoc DL(N); + SDValue Chain = DAG.getEntryNode(); + SDValue RuntimePStateSM = + getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0)); + Results.push_back( + DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM)); + return; + } case Intrinsic::experimental_vector_match: case Intrinsic::get_active_lane_mask: { if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1) @@ -29648,9 +29728,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported( if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) { unsigned ScalarWidth = ScalarTy->getScalarSizeInBits(); + + if (Operation == ComplexDeinterleavingOperation::CDot) + return ScalarWidth == 32 || ScalarWidth == 64; return 8 <= ScalarWidth && ScalarWidth <= 64; } + // CDot is not supported outside of scalable/sve scopes + if (Operation == ComplexDeinterleavingOperation::CDot) + return false; + return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) || ScalarTy->isFloatTy() || ScalarTy->isDoubleTy(); } @@ -29660,6 +29747,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator) const { VectorType *Ty = cast<VectorType>(InputA->getType()); + if (Accumulator == nullptr) + Accumulator = Constant::getNullValue(Ty); bool IsScalable = Ty->isScalableTy(); bool IsInt = Ty->getElementType()->isIntegerTy(); @@ -29671,6 +29760,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( if (TyWidth > 128) { int Stride = Ty->getElementCount().getKnownMinValue() / 2; + int AccStride = cast<VectorType>(Accumulator->getType()) + ->getElementCount() + .getKnownMinValue() / + 2; auto *HalfTy = VectorType::getHalfElementsVectorType(Ty); auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0)); auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0)); @@ -29680,25 +29773,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride)); Value *LowerSplitAcc = nullptr; Value *UpperSplitAcc = nullptr; - if (Accumulator) { - LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0)); - UpperSplitAcc = - B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride)); - } + Type *FullTy = Ty; + FullTy = Accumulator->getType(); + auto *HalfAccTy = VectorType::getHalfElementsVectorType( + cast<VectorType>(Accumulator->getType())); + LowerSplitAcc = + B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0)); + UpperSplitAcc = + B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride)); auto *LowerSplitInt = createComplexDeinterleavingIR( B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); auto *UpperSplitInt = createComplexDeinterleavingIR( B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); - auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt, - B.getInt64(0)); - return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride)); + auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), + LowerSplitInt, B.getInt64(0)); + return B.CreateInsertVector(FullTy, Result, UpperSplitInt, + B.getInt64(AccStride)); } if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { - if (Accumulator == nullptr) - Accumulator = Constant::getNullValue(Ty); - if (IsScalable) { if (IsInt) return B.CreateIntrinsic( @@ -29750,6 +29844,13 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR( return B.CreateIntrinsic(IntId, Ty, {InputA, InputB}); } + if (OperationType == ComplexDeinterleavingOperation::CDot && IsInt && + IsScalable) { + return B.CreateIntrinsic( + Intrinsic::aarch64_sve_cdot, Accumulator->getType(), + {Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)}); + } + return nullptr; } |