diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 953 |
1 files changed, 436 insertions, 517 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index f2c2f46..15f45a1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -382,6 +382,54 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, } } +// We return an EVT that can hold N VTs +// If the VT is a vector, the resulting EVT is a flat vector with the same +// element type as VT's element type. +static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) { + if (N == 1) + return VT; + + return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(), + VT.getVectorNumElements() * N) + : EVT::getVectorVT(C, VT, N); +} + +static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT, + const SDLoc &dl, SelectionDAG &DAG) { + if (V.getValueType() == VT) { + assert(I == 0 && "Index must be 0 for scalar value"); + return V; + } + + if (!VT.isVector()) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V, + DAG.getVectorIdxConstant(I, dl)); + + return DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, V, + DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl)); +} + +template <typename T> +static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl, + SelectionDAG &DAG, T GetElement) { + if (N == 1) + return GetElement(0); + + SmallVector<SDValue, 8> Values; + for (const unsigned I : llvm::seq(N)) { + SDValue Val = GetElement(I); + if (Val.getValueType().isVector()) + DAG.ExtractVectorElements(Val, Values); + else + Values.push_back(Val); + } + + EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(), + Values.size()); + return DAG.getBuildVector(VT, dl, Values); +} + /// PromoteScalarIntegerPTX /// Used to make sure the arguments/returns are suitable for passing /// and promote them to a larger size if they're not. @@ -420,9 +468,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) { // parameter starting at index Idx using a single vectorized op of // size AccessSize. If so, it returns the number of param pieces // covered by the vector op. Otherwise, it returns 1. -static unsigned CanMergeParamLoadStoresStartingAt( +template <typename T> +static unsigned canMergeParamLoadStoresStartingAt( unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs, - const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) { + const SmallVectorImpl<T> &Offsets, Align ParamAlignment) { // Can't vectorize if param alignment is not sufficient. if (ParamAlignment < AccessSize) @@ -472,10 +521,11 @@ static unsigned CanMergeParamLoadStoresStartingAt( // of the same size as ValueVTs indicating how each piece should be // loaded/stored (i.e. as a scalar, or as part of a vector // load/store). +template <typename T> static SmallVector<unsigned, 16> VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, - const SmallVectorImpl<uint64_t> &Offsets, - Align ParamAlignment, bool IsVAArg = false) { + const SmallVectorImpl<T> &Offsets, Align ParamAlignment, + bool IsVAArg = false) { // Set vector size to match ValueVTs and mark all elements as // scalars by default. @@ -486,7 +536,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, const auto GetNumElts = [&](unsigned I) -> unsigned { for (const unsigned AccessSize : {16, 8, 4, 2}) { - const unsigned NumElts = CanMergeParamLoadStoresStartingAt( + const unsigned NumElts = canMergeParamLoadStoresStartingAt( I, AccessSize, ValueVTs, Offsets, ParamAlignment); assert((NumElts == 1 || NumElts == 2 || NumElts == 4) && "Unexpected vectorization size"); @@ -843,7 +893,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE}); + ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -952,10 +1002,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // promoted to f32. v2f16 is expanded to f16, which is then promoted // to f32. for (const auto &Op : - {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) { + {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) { setOperationAction(Op, MVT::f16, Promote); setOperationAction(Op, MVT::f32, Legal); - setOperationAction(Op, MVT::f64, Legal); + // only div/rem/sqrt are legal for f64 + if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) { + setOperationAction(Op, MVT::f64, Legal); + } setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand); setOperationAction(Op, MVT::bf16, Promote); AddPromotedToType(Op, MVT::bf16, MVT::f32); @@ -1072,12 +1125,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::DeclareArrayParam) MAKE_CASE(NVPTXISD::DeclareScalarParam) MAKE_CASE(NVPTXISD::CALL) - MAKE_CASE(NVPTXISD::LoadParam) - MAKE_CASE(NVPTXISD::LoadParamV2) - MAKE_CASE(NVPTXISD::LoadParamV4) - MAKE_CASE(NVPTXISD::StoreParam) - MAKE_CASE(NVPTXISD::StoreParamV2) - MAKE_CASE(NVPTXISD::StoreParamV4) MAKE_CASE(NVPTXISD::MoveParam) MAKE_CASE(NVPTXISD::UNPACK_VECTOR) MAKE_CASE(NVPTXISD::BUILD_VECTOR) @@ -1315,105 +1362,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, return DL.getABITypeAlign(Ty); } -static bool adjustElementType(EVT &ElementType) { - switch (ElementType.getSimpleVT().SimpleTy) { - default: - return false; - case MVT::f16: - case MVT::bf16: - ElementType = MVT::i16; - return true; - case MVT::f32: - case MVT::v2f16: - case MVT::v2bf16: - ElementType = MVT::i32; - return true; - case MVT::f64: - ElementType = MVT::i64; - return true; - } -} - -// Use byte-store when the param address of the argument value is unaligned. -// This may happen when the return value is a field of a packed structure. -// -// This is called in LowerCall() when passing the param values. -static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain, - uint64_t Offset, EVT ElementType, - SDValue StVal, SDValue &InGlue, - unsigned ArgID, const SDLoc &dl) { - // Bit logic only works on integer types - if (adjustElementType(ElementType)) - StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal); - - // Store each byte - SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue); - for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { - // Shift the byte to the last byte position - SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal, - DAG.getConstant(i * 8, dl, MVT::i32)); - SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32), - DAG.getConstant(Offset + i, dl, MVT::i32), - ShiftVal, InGlue}; - // Trunc store only the last byte by using - // st.param.b8 - // The register type can be larger than b8. - Chain = DAG.getMemIntrinsicNode( - NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8, - MachinePointerInfo(), Align(1), MachineMemOperand::MOStore); - InGlue = Chain.getValue(1); - } - return Chain; -} - -// Use byte-load when the param adress of the returned value is unaligned. -// This may happen when the returned value is a field of a packed structure. -static SDValue -LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset, - EVT ElementType, SDValue &InGlue, - SmallVectorImpl<SDValue> &TempProxyRegOps, - const SDLoc &dl) { - // Bit logic only works on integer types - EVT MergedType = ElementType; - adjustElementType(MergedType); - - // Load each byte and construct the whole value. Initial value to 0 - SDValue RetVal = DAG.getConstant(0, dl, MergedType); - // LoadParamMemI8 loads into i16 register only - SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue); - for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { - SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(Offset + i, dl, MVT::i32), - InGlue}; - // This will be selected to LoadParamMemI8 - SDValue LdVal = - DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands, - MVT::i8, MachinePointerInfo(), Align(1)); - SDValue TmpLdVal = LdVal.getValue(0); - Chain = LdVal.getValue(1); - InGlue = LdVal.getValue(2); - - TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl, - TmpLdVal.getSimpleValueType(), TmpLdVal); - TempProxyRegOps.push_back(TmpLdVal); - - SDValue CMask = DAG.getConstant(255, dl, MergedType); - SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32); - // Need to extend the i16 register to the whole width. - TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal); - // Mask off the high bits. Leave only the lower 8bits. - // Do this because we are using loadparam.b8. - TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask); - // Shift and merge - TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift); - RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal); - } - if (ElementType != MergedType) - RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal); - - return RetVal; -} - static bool shouldConvertToIndirectCall(const CallBase *CB, const GlobalAddressSDNode *Func) { if (!Func) @@ -1480,19 +1428,48 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SelectionDAG &DAG = CLI.DAG; SDLoc dl = CLI.DL; - SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; - SDValue Chain = CLI.Chain; + const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; SDValue Callee = CLI.Callee; - bool &isTailCall = CLI.IsTailCall; ArgListTy &Args = CLI.getArgs(); Type *RetTy = CLI.RetTy; const CallBase *CB = CLI.CB; const DataLayout &DL = DAG.getDataLayout(); + LLVMContext &Ctx = *DAG.getContext(); const auto GetI32 = [&](const unsigned I) { return DAG.getConstant(I, dl, MVT::i32); }; + const unsigned UniqueCallSite = GlobalUniqueCallSite++; + const SDValue CallChain = CLI.Chain; + const SDValue StartChain = + DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl); + SDValue DeclareGlue = StartChain.getValue(1); + + SmallVector<SDValue, 16> CallPrereqs{StartChain}; + + const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) { + // PTX ABI requires integral types to be at least 32 bits in size. FP16 is + // loaded/stored using i16, so it's handled here as well. + const unsigned SizeBits = promoteScalarArgumentSize(Size * 8); + SDValue Declare = + DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue}, + {StartChain, Symbol, GetI32(SizeBits), DeclareGlue}); + CallPrereqs.push_back(Declare); + DeclareGlue = Declare.getValue(1); + return Declare; + }; + + const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align, + unsigned Size) { + SDValue Declare = DAG.getNode( + NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue}, + {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue}); + CallPrereqs.push_back(Declare); + DeclareGlue = Declare.getValue(1); + return Declare; + }; + // Variadic arguments. // // Normally, for each argument, we declare a param scalar or a param @@ -1508,15 +1485,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // // After all vararg is processed, 'VAOffset' holds the size of the // vararg byte array. + assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) && + "Non-VarArg function with extra arguments"); - SDValue VADeclareParam; // vararg byte array const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic - unsigned VAOffset = 0; // current offset in the param array + unsigned VAOffset = 0; // current offset in the param array - const unsigned UniqueCallSite = GlobalUniqueCallSite++; - SDValue TempChain = Chain; - Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl); - SDValue InGlue = Chain.getValue(1); + const SDValue VADeclareParam = + CLI.Args.size() > FirstVAArg + ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32), + Align(STI.getMaxRequiredAlignment()), 0) + : SDValue(); // Args.size() and Outs.size() need not match. // Outs.size() will be larger @@ -1548,15 +1527,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const SDValue ParamSymbol = getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32); - SmallVector<EVT, 16> VTs; - SmallVector<uint64_t, 16> Offsets; - assert((!IsByVal || Arg.IndirectType) && "byval arg must have indirect type"); Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty); - ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset); - assert(VTs.size() == Offsets.size() && "Size mismatch"); - assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch"); const Align ArgAlign = [&]() { if (IsByVal) { @@ -1564,202 +1537,152 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // so we don't need to worry whether it's naturally aligned or not. // See TargetLowering::LowerCallTo(). const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign(); - const Align ByValAlign = getFunctionByValParamAlign( - CB->getCalledFunction(), ETy, InitialAlign, DL); - if (IsVAArg) - VAOffset = alignTo(VAOffset, ByValAlign); - return ByValAlign; + return getFunctionByValParamAlign(CB->getCalledFunction(), ETy, + InitialAlign, DL); } return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL); }(); - const unsigned TypeSize = DL.getTypeAllocSize(ETy); - assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) && + const unsigned TySize = DL.getTypeAllocSize(ETy); + assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) && "type size mismatch"); - const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> { - if (IsVAArg) { - if (ArgI == FirstVAArg) { - VADeclareParam = DAG.getNode( - NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()), - GetI32(0), InGlue}); - return VADeclareParam; - } - return std::nullopt; - } - if (IsByVal || shouldPassAsArray(Arg.Ty)) { - // declare .param .align <align> .b8 .param<n>[<size>]; - return DAG.getNode(NVPTXISD::DeclareArrayParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(ArgAlign.value()), - GetI32(TypeSize), InGlue}); - } - assert(ArgOuts.size() == 1 && "We must pass only one value as non-array"); - // declare .param .b<size> .param<n>; - - // PTX ABI requires integral types to be at least 32 bits in - // size. FP16 is loaded/stored using i16, so it's handled - // here as well. - const unsigned PromotedSize = - (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) - ? promoteScalarArgumentSize(TypeSize * 8) - : TypeSize * 8; - - return DAG.getNode(NVPTXISD::DeclareScalarParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(PromotedSize), InGlue}); - }(); - if (ArgDeclare) { - Chain = ArgDeclare->getValue(0); - InGlue = ArgDeclare->getValue(1); - } + const SDValue ArgDeclare = [&]() { + if (IsVAArg) + return VADeclareParam; - // PTX Interoperability Guide 3.3(A): [Integer] Values shorter - // than 32-bits are sign extended or zero extended, depending on - // whether they are signed or unsigned types. This case applies - // only to scalar parameters and not to aggregate values. - const bool ExtendIntegerParam = - Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32; + if (IsByVal || shouldPassAsArray(Arg.Ty)) + return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize); - const auto GetStoredValue = [&](const unsigned I, EVT EltVT, - const Align PartAlign) { - SDValue StVal; - if (IsByVal) { - SDValue Ptr = ArgOutVals[0]; - auto MPI = refinePtrAS(Ptr, DAG, DL, *this); - SDValue SrcAddr = - DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I])); - - StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign); - } else { - StVal = ArgOutVals[I]; + assert(ArgOuts.size() == 1 && "We must pass only one value as non-array"); + assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) && + "Only int and float types are supported as non-array arguments"); - auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType()); - if (PromotedVT != StVal.getValueType()) { - StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT, - StVal); - } - } + return MakeDeclareScalarParam(ParamSymbol, TySize); + }(); - if (ExtendIntegerParam) { - assert(VTs.size() == 1 && "Scalar can't have multiple parts."); - // zext/sext to i32 - StVal = - DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal); - } else if (EltVT.getSizeInBits() < 16) { - // Use 16-bit registers for small stores as it's the - // smallest general purpose register size supported by NVPTX. - StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); + if (IsByVal) { + assert(ArgOutVals.size() == 1 && "We must pass only one value as byval"); + SDValue SrcPtr = ArgOutVals[0]; + const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this); + const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign(); + + if (IsVAArg) + VAOffset = alignTo(VAOffset, ArgAlign); + + SmallVector<EVT, 4> ValueVTs, MemVTs; + SmallVector<TypeSize, 4> Offsets; + ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets); + + unsigned J = 0; + const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg); + for (const unsigned NumElts : VI) { + EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx); + Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]); + SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]); + SDValue SrcLoad = + DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign); + + TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset); + Align ParamAlign = commonAlignment(ArgAlign, ParamOffset); + SDValue ParamAddr = + DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset); + SDValue StoreParam = + DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign); + CallPrereqs.push_back(StoreParam); + + J += NumElts; } - return StVal; - }; - - const auto VectorInfo = - VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); - - unsigned J = 0; - for (const unsigned NumElts : VectorInfo) { - const int CurOffset = Offsets[J]; - EVT EltVT = promoteScalarIntegerPTX(VTs[J]); - const Align PartAlign = commonAlignment(ArgAlign, CurOffset); - - // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a - // scalar store. In such cases, fall back to byte stores. - if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) { - - SDValue StVal = GetStoredValue(J, EltVT, PartAlign); - Chain = LowerUnalignedStoreParam(DAG, Chain, - CurOffset + (IsByVal ? VAOffset : 0), - EltVT, StVal, InGlue, ArgI, dl); + if (IsVAArg) + VAOffset += TySize; + } else { + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset); + assert(VTs.size() == Offsets.size() && "Size mismatch"); + assert(VTs.size() == ArgOuts.size() && "Size mismatch"); + + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter + // than 32-bits are sign extended or zero extended, depending on + // whether they are signed or unsigned types. This case applies + // only to scalar parameters and not to aggregate values. + const bool ExtendIntegerParam = + Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32; + + const auto GetStoredValue = [&](const unsigned I) { + SDValue StVal = ArgOutVals[I]; + assert(promoteScalarIntegerPTX(StVal.getValueType()) == + StVal.getValueType() && + "OutVal type should always be legal"); + + const EVT VTI = promoteScalarIntegerPTX(VTs[I]); + const EVT StoreVT = + ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); + + return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl); + }; + + unsigned J = 0; + const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); + for (const unsigned NumElts : VI) { + const EVT EltVT = promoteScalarIntegerPTX(VTs[J]); + + unsigned Offset; + if (IsVAArg) { + // TODO: We may need to support vector types that can be passed + // as scalars in variadic arguments. + assert(NumElts == 1 && + "Vectorization should be disabled for vaargs."); + + // Align each part of the variadic argument to their type. + VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT)); + Offset = VAOffset; + + const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; + VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx)); + } else { + assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args"); + Offset = Offsets[J]; + } - // LowerUnalignedStoreParam took care of inserting the necessary nodes - // into the SDAG, so just move on to the next element. - J++; - continue; - } + SDValue Ptr = + DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset)); - if (IsVAArg && !IsByVal) - // Align each part of the variadic argument to their type. - VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT)); + const MaybeAlign CurrentAlign = ExtendIntegerParam + ? MaybeAlign(std::nullopt) + : commonAlignment(ArgAlign, Offset); - assert((IsVAArg || VAOffset == 0) && - "VAOffset must be 0 for non-VA args"); - SmallVector<SDValue, 6> StoreOperands{ - Chain, GetI32(IsVAArg ? FirstVAArg : ArgI), - GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))}; + SDValue Val = + getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) { + return GetStoredValue(J + K); + }); - // Record the values to store. - for (const unsigned K : llvm::seq(NumElts)) - StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign)); - StoreOperands.push_back(InGlue); + SDValue StoreParam = + DAG.getStore(ArgDeclare, dl, Val, Ptr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); + CallPrereqs.push_back(StoreParam); - NVPTXISD::NodeType Op; - switch (NumElts) { - case 1: - Op = NVPTXISD::StoreParam; - break; - case 2: - Op = NVPTXISD::StoreParamV2; - break; - case 4: - Op = NVPTXISD::StoreParamV4; - break; - default: - llvm_unreachable("Invalid vector info."); - } - // Adjust type of the store op if we've extended the scalar - // return value. - EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; - - Chain = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, - TheStoreType, MachinePointerInfo(), PartAlign, - MachineMemOperand::MOStore); - InGlue = Chain.getValue(1); - - // TODO: We may need to support vector types that can be passed - // as scalars in variadic arguments. - if (IsVAArg && !IsByVal) { - assert(NumElts == 1 && - "Vectorization is expected to be disabled for variadics."); - VAOffset += - DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext())); + J += NumElts; } - - J += NumElts; } - if (IsVAArg && IsByVal) - VAOffset += TypeSize; } - GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); - // Handle Result if (!Ins.empty()) { - const SDValue RetDeclare = [&]() { - const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); - const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy); - if (shouldPassAsArray(RetTy)) { - const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); - return DAG.getNode(NVPTXISD::DeclareArrayParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, RetSymbol, GetI32(RetAlign.value()), - GetI32(ResultSize / 8), InGlue}); - } - const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize); - return DAG.getNode( - NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue}, - {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue}); - }(); - Chain = RetDeclare.getValue(0); - InGlue = RetDeclare.getValue(1); + const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); + const unsigned ResultSize = DL.getTypeAllocSize(RetTy); + if (shouldPassAsArray(RetTy)) { + const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); + MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize); + } else { + MakeDeclareScalarParam(RetSymbol, ResultSize); + } } - const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs); // Set the size of the vararg param byte array if the callee is a variadic // function and the variadic part is not empty. - if (HasVAArgs) { + if (VADeclareParam) { SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0), VADeclareParam.getOperand(1), VADeclareParam.getOperand(2), GetI32(VAOffset), @@ -1768,6 +1691,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, VADeclareParam->getVTList(), DeclareParamOps); } + const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); // If the type of the callsite does not match that of the function, convert // the callsite to an indirect call. const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func); @@ -1797,57 +1721,39 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // instruction. // The prototype is embedded in a string and put as the operand for a // CallPrototype SDNode which will print out to the value of the string. + const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs); std::string Proto = getPrototype(DL, RetTy, Args, CLI.Outs, HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB, UniqueCallSite); const char *ProtoStr = nvTM->getStrPool().save(Proto).data(); - Chain = DAG.getNode( - NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue}, - {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue}); - InGlue = Chain.getValue(1); - } - - if (ConvertToIndirectCall) { - // Copy the function ptr to a ptx register and use the register to call the - // function. - const MVT DestVT = Callee.getValueType().getSimpleVT(); - MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT)); - auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee); - Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT); + const SDValue PrototypeDeclare = DAG.getNode( + NVPTXISD::CallPrototype, dl, MVT::Other, + {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)}); + CallPrereqs.push_back(PrototypeDeclare); } const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0; const unsigned NumArgs = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size()); /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, - /// NumParams, Callee, Proto, InGlue) - Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue}, - {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall), - GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, - GetI32(Proto), InGlue}); - InGlue = Chain.getValue(1); - + /// NumParams, Callee, Proto) + const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs); + const SDValue Call = DAG.getNode( + NVPTXISD::CALL, dl, MVT::Other, + {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall), + GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)}); + + SmallVector<SDValue, 16> LoadChains{Call}; SmallVector<SDValue, 16> ProxyRegOps; - // An item of the vector is filled if the element does not need a ProxyReg - // operation on it and should be added to InVals as is. ProxyRegOps and - // ProxyRegTruncates contain empty/none items at the same index. - SmallVector<SDValue, 16> RetElts; - // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()` - // to use the values of `LoadParam`s and to be replaced later then - // `CALLSEQ_END` is added. - SmallVector<SDValue, 16> TempProxyRegOps; - - // Generate loads from param memory/moves from registers for result if (!Ins.empty()) { SmallVector<EVT, 16> VTs; SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0); + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); assert(VTs.size() == Ins.size() && "Bad value decomposition"); const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); + const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than // 32-bits are sign extended or zero extended, depending on whether @@ -1855,106 +1761,49 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const bool ExtendIntegerRetVal = RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); unsigned I = 0; - for (const unsigned VectorizedSize : VectorInfo) { - EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]); - EVT EltType = Ins[I].VT; - const Align EltAlign = commonAlignment(RetAlign, Offsets[I]); - - if (TheLoadType != VTs[I]) - EltType = TheLoadType; - - if (ExtendIntegerRetVal) { - TheLoadType = MVT::i32; - EltType = MVT::i32; - } else if (TheLoadType.getSizeInBits() < 16) { - EltType = MVT::i16; - } - - // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a - // scalar load. In such cases, fall back to byte loads. - if (VectorizedSize == 1 && RetTy->isAggregateType() && - EltAlign < DAG.getEVTAlign(TheLoadType)) { - SDValue Ret = LowerUnalignedLoadRetParam( - DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl); - ProxyRegOps.push_back(SDValue()); - RetElts.resize(I); - RetElts.push_back(Ret); - - I++; - continue; - } - - SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType); - LoadVTs.append({MVT::Other, MVT::Glue}); - - NVPTXISD::NodeType Op; - switch (VectorizedSize) { - case 1: - Op = NVPTXISD::LoadParam; - break; - case 2: - Op = NVPTXISD::LoadParamV2; - break; - case 4: - Op = NVPTXISD::LoadParamV4; - break; - default: - llvm_unreachable("Invalid vector info."); - } - - SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue}; - SDValue RetVal = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType, - MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad); - - for (const unsigned J : llvm::seq(VectorizedSize)) { - ProxyRegOps.push_back(RetVal.getValue(J)); - } - - Chain = RetVal.getValue(VectorizedSize); - InGlue = RetVal.getValue(VectorizedSize + 1); - - I += VectorizedSize; + const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign); + for (const unsigned NumElts : VI) { + const MaybeAlign CurrentAlign = + ExtendIntegerRetVal ? MaybeAlign(std::nullopt) + : commonAlignment(RetAlign, Offsets[I]); + + const EVT VTI = promoteScalarIntegerPTX(VTs[I]); + const EVT LoadVT = + ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); + const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx); + SDValue Ptr = + DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I])); + + SDValue R = + DAG.getLoad(VecVT, dl, Call, Ptr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); + + LoadChains.push_back(R.getValue(1)); + for (const unsigned J : llvm::seq(NumElts)) + ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG)); + I += NumElts; } } - Chain = - DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl); - InGlue = Chain.getValue(1); + const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains); + const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite, + UniqueCallSite + 1, SDValue(), dl); // Append ProxyReg instructions to the chain to make sure that `callseq_end` // will not get lost. Otherwise, during libcalls expansion, the nodes can become // dangling. - for (const unsigned I : llvm::seq(ProxyRegOps.size())) { - if (I < RetElts.size() && RetElts[I]) { - InVals.push_back(RetElts[I]); - continue; - } - - SDValue Ret = - DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(), - {Chain, ProxyRegOps[I]}); - - const EVT ExpectedVT = Ins[I].VT; - if (!Ret.getValueType().bitsEq(ExpectedVT)) { - Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret); - } + for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) { + SDValue Proxy = + DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg}); + SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl); InVals.push_back(Ret); } - for (SDValue &T : TempProxyRegOps) { - SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(), - {Chain, T.getOperand(0)}); - DAG.ReplaceAllUsesWith(T, Repl); - DAG.RemoveDeadNode(T.getNode()); - } - - // set isTailCall to false for now, until we figure out how to express + // set IsTailCall to false for now, until we figure out how to express // tail call optimization in PTX - isTailCall = false; - return Chain; + CLI.IsTailCall = false; + return CallEnd; } SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, @@ -3407,11 +3256,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const { - MachineFunction &MF = DAG.getMachineFunction(); const DataLayout &DL = DAG.getDataLayout(); auto PtrVT = getPointerTy(DAG.getDataLayout()); - const Function *F = &MF.getFunction(); + const Function &F = DAG.getMachineFunction().getFunction(); SDValue Root = DAG.getRoot(); SmallVector<SDValue, 16> OutChains; @@ -3427,7 +3275,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // See similar issue in LowerCall. auto AllIns = ArrayRef(Ins); - for (const auto &Arg : F->args()) { + for (const auto &Arg : F.args()) { const auto ArgIns = AllIns.take_while( [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); }); AllIns = AllIns.drop_front(ArgIns.size()); @@ -3467,7 +3315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer"); SDValue P; - if (isKernelFunction(*F)) { + if (isKernelFunction(F)) { P = ArgSymbol; P.getNode()->setIROrder(Arg.getArgNo() + 1); } else { @@ -3485,43 +3333,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( assert(VTs.size() == Offsets.size() && "Size mismatch"); const Align ArgAlign = getFunctionArgumentAlignment( - F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); + &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); unsigned I = 0; - for (const unsigned NumElts : VectorInfo) { + const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); + for (const unsigned NumElts : VI) { // i1 is loaded/stored as i8 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I]; - // If the element is a packed type (ex. v2f16, v4i8, etc) holding - // multiple elements. - const unsigned PackingAmt = - LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1; - - const EVT VecVT = - NumElts == 1 - ? LoadVT - : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(), - NumElts * PackingAmt); + const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext()); SDValue VecAddr = DAG.getObjectPtrOffset( dl, ArgSymbol, TypeSize::getFixed(Offsets[I])); - const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]); + const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]); SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr, MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign, MachineMemOperand::MODereferenceable | MachineMemOperand::MOInvariant); - if (P.getNode()) - P.getNode()->setIROrder(Arg.getArgNo() + 1); + P.getNode()->setIROrder(Arg.getArgNo() + 1); for (const unsigned J : llvm::seq(NumElts)) { - SDValue Elt = - NumElts == 1 - ? P - : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR - : ISD::EXTRACT_VECTOR_ELT, - dl, LoadVT, P, - DAG.getVectorIdxConstant(J * PackingAmt, dl)); + SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG); Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags, DAG, dl); @@ -3544,9 +3376,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, const SmallVectorImpl<ISD::OutputArg> &Outs, const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl, SelectionDAG &DAG) const { - const MachineFunction &MF = DAG.getMachineFunction(); - const Function &F = MF.getFunction(); - Type *RetTy = MF.getFunction().getReturnType(); + const Function &F = DAG.getMachineFunction().getFunction(); + Type *RetTy = F.getReturnType(); if (RetTy->isVoidTy()) { assert(OutVals.empty() && Outs.empty() && "Return value expected for void"); @@ -3554,10 +3385,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, } const DataLayout &DL = DAG.getDataLayout(); - SmallVector<EVT, 16> VTs; - SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); - assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + + const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32); + const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL); // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than // 32-bits are sign extended or zero extended, depending on whether @@ -3565,6 +3395,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, const bool ExtendIntegerRetVal = RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); + assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + const auto GetRetVal = [&](unsigned I) -> SDValue { SDValue RetVal = OutVals[I]; assert(promoteScalarIntegerPTX(RetVal.getValueType()) == @@ -3577,33 +3412,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl); }; - const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL); - const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); unsigned I = 0; - for (const unsigned NumElts : VectorInfo) { + const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign); + for (const unsigned NumElts : VI) { const MaybeAlign CurrentAlign = ExtendIntegerRetVal ? MaybeAlign(std::nullopt) : commonAlignment(RetAlign, Offsets[I]); - SDValue Val; - if (NumElts == 1) { - Val = GetRetVal(I); - } else { - SmallVector<SDValue, 4> StoreVals; - for (const unsigned J : llvm::seq(NumElts)) { - SDValue ValJ = GetRetVal(I + J); - if (ValJ.getValueType().isVector()) - DAG.ExtractVectorElements(ValJ, StoreVals); - else - StoreVals.push_back(ValJ); - } - - EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(), - StoreVals.size()); - Val = DAG.getBuildVector(VT, dl, StoreVals); - } + SDValue Val = getBuildVectorizedValue( + NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); }); - const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32); SDValue Ptr = DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I])); @@ -5097,7 +4915,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return SDValue(); auto *LD = cast<MemSDNode>(N); - EVT MemVT = LD->getMemoryVT(); SDLoc DL(LD); // the new opcode after we double the number of operands @@ -5114,10 +4931,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { Operands.push_back(DCI.DAG.getIntPtrConstant( cast<LoadSDNode>(LD)->getExtensionType(), DL)); break; - case NVPTXISD::LoadParamV2: - OldNumOutputs = 2; - Opcode = NVPTXISD::LoadParamV4; - break; case NVPTXISD::LoadV2: OldNumOutputs = 2; Opcode = NVPTXISD::LoadV4; @@ -5142,9 +4955,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end()); // Create the new load - SDValue NewLoad = - DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs), - Operands, MemVT, LD->getMemOperand()); + SDValue NewLoad = DCI.DAG.getMemIntrinsicNode( + Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(), + LD->getMemOperand()); // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep // the outputs the same. These nodes will be optimized away in later @@ -5186,7 +4999,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, return SDValue(); auto *ST = cast<MemSDNode>(N); - EVT MemVT = ElementVT.getVectorElementType(); // The new opcode after we double the number of operands. NVPTXISD::NodeType Opcode; @@ -5195,17 +5007,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Any packed type is legal, so the legalizer will not have lowered // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do // it here. - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV2; break; - case NVPTXISD::StoreParam: - Opcode = NVPTXISD::StoreParamV2; - break; - case NVPTXISD::StoreParamV2: - Opcode = NVPTXISD::StoreParamV4; - break; case NVPTXISD::StoreV2: - MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV4; break; case NVPTXISD::StoreV4: @@ -5215,7 +5019,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, return SDValue(); Opcode = NVPTXISD::StoreV8; break; - case NVPTXISD::StoreParamV4: case NVPTXISD::StoreV8: // PTX doesn't support the next doubling of operands return SDValue(); @@ -5257,19 +5060,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N, // Now we replace the store return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands, - MemVT, ST->getMemOperand()); -} - -static SDValue PerformStoreCombineHelper(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - unsigned Front, unsigned Back) { - if (all_of(N->ops().drop_front(Front).drop_back(Back), - [](const SDUse &U) { return U.get()->isUndef(); })) - // Operand 0 is the previous value in the chain. Cannot return EntryToken - // as the previous value will become unused and eliminated later. - return N->getOperand(0); - - return combinePackingMovIntoStore(N, DCI, Front, Back); + ST->getMemoryVT(), ST->getMemOperand()); } static SDValue PerformStoreCombine(SDNode *N, @@ -5277,13 +5068,6 @@ static SDValue PerformStoreCombine(SDNode *N, return combinePackingMovIntoStore(N, DCI, 1, 2); } -static SDValue PerformStoreParamCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { - // Operands from the 3rd to the 2nd last one are the values to be stored. - // {Chain, ArgID, Offset, Val, Glue} - return PerformStoreCombineHelper(N, DCI, 3, 1); -} - /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD. /// static SDValue PerformADDCombine(SDNode *N, @@ -5429,6 +5213,42 @@ static SDValue PerformREMCombine(SDNode *N, return SDValue(); } +// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y) +static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + CodeGenOptLevel OptLevel) { + if (OptLevel == CodeGenOptLevel::None) + return SDValue(); + + SDValue Op = N->getOperand(0); + if (!Op.hasOneUse()) + return SDValue(); + EVT ToVT = N->getValueType(0); + EVT FromVT = Op.getValueType(); + if (!((ToVT == MVT::i32 && FromVT == MVT::i16) || + (ToVT == MVT::i64 && FromVT == MVT::i32))) + return SDValue(); + if (!(Op.getOpcode() == ISD::MUL || + (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1))))) + return SDValue(); + + SDLoc DL(N); + unsigned ExtOpcode = N->getOpcode(); + unsigned Opcode = 0; + if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_SIGNED; + else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_UNSIGNED; + else + return SDValue(); + SDValue RHS = Op.getOperand(1); + if (Op.getOpcode() == ISD::SHL) { + const auto ShiftAmt = Op.getConstantOperandVal(1); + const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt; + RHS = DCI.DAG.getConstant(MulVal, DL, ToVT); + } + return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS); +} + enum OperandSignedness { Signed = 0, Unsigned, @@ -5939,6 +5759,86 @@ static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, N->getConstantOperandAPInt(2), N->getConstantOperandVal(3)), SDLoc(N), N->getValueType(0)); + return SDValue(); +} + +// During call lowering we wrap the return values in a ProxyReg node which +// depend on the chain value produced by the completed call. This ensures that +// the full call is emitted in cases where libcalls are used to legalize +// operations. To improve the functioning of other DAG combines we pull all +// operations we can through one of these nodes, ensuring that the ProxyReg +// directly wraps a load. That is: +// +// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0))) +// +static SDValue sinkProxyReg(SDValue R, SDValue Chain, + TargetLowering::DAGCombinerInfo &DCI) { + switch (R.getOpcode()) { + case ISD::TRUNCATE: + case ISD::ANY_EXTEND: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + case ISD::BITCAST: { + if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI)) + return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V); + return SDValue(); + } + case ISD::SHL: + case ISD::SRL: + case ISD::SRA: + case ISD::OR: { + if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI)) + if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI)) + return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B); + return SDValue(); + } + case ISD::Constant: + return R; + case ISD::LOAD: + case NVPTXISD::LoadV2: + case NVPTXISD::LoadV4: { + return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(), + {Chain, R}); + } + case ISD::BUILD_VECTOR: { + if (DCI.isBeforeLegalize()) + return SDValue(); + + SmallVector<SDValue, 16> Ops; + for (auto &Op : R->ops()) { + SDValue V = sinkProxyReg(Op, Chain, DCI); + if (!V) + return SDValue(); + Ops.push_back(V); + } + return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops); + } + case ISD::EXTRACT_VECTOR_ELT: { + if (DCI.isBeforeLegalize()) + return SDValue(); + + if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI)) + return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R), + R.getValueType(), V, R.getOperand(1)); + return SDValue(); + } + default: + return SDValue(); + } +} + +static SDValue combineProxyReg(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + + SDValue Chain = N->getOperand(0); + SDValue Reg = N->getOperand(1); + + // If the ProxyReg is not wrapping a load, try to pull the operations through + // the ProxyReg. + if (Reg.getOpcode() != ISD::LOAD) { + if (SDValue V = sinkProxyReg(Reg, Chain, DCI)) + return V; + } return SDValue(); } @@ -5955,6 +5855,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return combineADDRSPACECAST(N, DCI); case ISD::AND: return PerformANDCombine(N, DCI); + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + return combineMulWide(N, DCI, OptLevel); case ISD::BUILD_VECTOR: return PerformBUILD_VECTORCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: @@ -5962,7 +5865,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::FADD: return PerformFADDCombine(N, DCI, OptLevel); case ISD::LOAD: - case NVPTXISD::LoadParamV2: case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: return combineUnpackingMovIntoLoad(N, DCI); @@ -5970,6 +5872,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformMULCombine(N, DCI, OptLevel); case NVPTXISD::PRMT: return combinePRMT(N, DCI, OptLevel); + case NVPTXISD::ProxyReg: + return combineProxyReg(N, DCI); case ISD::SETCC: return PerformSETCCCombine(N, DCI, STI.getSmVersion()); case ISD::SHL: @@ -5977,10 +5881,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::SREM: case ISD::UREM: return PerformREMCombine(N, DCI, OptLevel); - case NVPTXISD::StoreParam: - case NVPTXISD::StoreParamV2: - case NVPTXISD::StoreParamV4: - return PerformStoreParamCombine(N, DCI); case ISD::STORE: case NVPTXISD::StoreV2: case NVPTXISD::StoreV4: @@ -6329,6 +6229,22 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, Results.push_back(NewValue.getValue(3)); } +static void replaceProxyReg(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI, + SmallVectorImpl<SDValue> &Results) { + SDValue Chain = N->getOperand(0); + SDValue Reg = N->getOperand(1); + + MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType()); + + SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT); + SDValue NewProxy = + DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg}); + SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0)); + + Results.push_back(Res); +} + void NVPTXTargetLowering::ReplaceNodeResults( SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { switch (N->getOpcode()) { @@ -6346,6 +6262,9 @@ void NVPTXTargetLowering::ReplaceNodeResults( case ISD::CopyFromReg: ReplaceCopyFromReg_128(N, DAG, Results); return; + case NVPTXISD::ProxyReg: + replaceProxyReg(N, DAG, *this, Results); + return; } } |