aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp953
1 files changed, 436 insertions, 517 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f2c2f46..15f45a1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -382,6 +382,54 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
}
+// We return an EVT that can hold N VTs
+// If the VT is a vector, the resulting EVT is a flat vector with the same
+// element type as VT's element type.
+static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
+ if (N == 1)
+ return VT;
+
+ return VT.isVector() ? EVT::getVectorVT(C, VT.getScalarType(),
+ VT.getVectorNumElements() * N)
+ : EVT::getVectorVT(C, VT, N);
+}
+
+static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
+ const SDLoc &dl, SelectionDAG &DAG) {
+ if (V.getValueType() == VT) {
+ assert(I == 0 && "Index must be 0 for scalar value");
+ return V;
+ }
+
+ if (!VT.isVector())
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, V,
+ DAG.getVectorIdxConstant(I, dl));
+
+ return DAG.getNode(
+ ISD::EXTRACT_SUBVECTOR, dl, VT, V,
+ DAG.getVectorIdxConstant(I * VT.getVectorNumElements(), dl));
+}
+
+template <typename T>
+static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
+ SelectionDAG &DAG, T GetElement) {
+ if (N == 1)
+ return GetElement(0);
+
+ SmallVector<SDValue, 8> Values;
+ for (const unsigned I : llvm::seq(N)) {
+ SDValue Val = GetElement(I);
+ if (Val.getValueType().isVector())
+ DAG.ExtractVectorElements(Val, Values);
+ else
+ Values.push_back(Val);
+ }
+
+ EVT VT = EVT::getVectorVT(*DAG.getContext(), Values[0].getValueType(),
+ Values.size());
+ return DAG.getBuildVector(VT, dl, Values);
+}
+
/// PromoteScalarIntegerPTX
/// Used to make sure the arguments/returns are suitable for passing
/// and promote them to a larger size if they're not.
@@ -420,9 +468,10 @@ static EVT promoteScalarIntegerPTX(const EVT VT) {
// parameter starting at index Idx using a single vectorized op of
// size AccessSize. If so, it returns the number of param pieces
// covered by the vector op. Otherwise, it returns 1.
-static unsigned CanMergeParamLoadStoresStartingAt(
+template <typename T>
+static unsigned canMergeParamLoadStoresStartingAt(
unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
// Can't vectorize if param alignment is not sufficient.
if (ParamAlignment < AccessSize)
@@ -472,10 +521,11 @@ static unsigned CanMergeParamLoadStoresStartingAt(
// of the same size as ValueVTs indicating how each piece should be
// loaded/stored (i.e. as a scalar, or as part of a vector
// load/store).
+template <typename T>
static SmallVector<unsigned, 16>
VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
- const SmallVectorImpl<uint64_t> &Offsets,
- Align ParamAlignment, bool IsVAArg = false) {
+ const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
+ bool IsVAArg = false) {
// Set vector size to match ValueVTs and mark all elements as
// scalars by default.
@@ -486,7 +536,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
const auto GetNumElts = [&](unsigned I) -> unsigned {
for (const unsigned AccessSize : {16, 8, 4, 2}) {
- const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+ const unsigned NumElts = canMergeParamLoadStoresStartingAt(
I, AccessSize, ValueVTs, Offsets, ParamAlignment);
assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
"Unexpected vectorization size");
@@ -843,7 +893,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE});
+ ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -952,10 +1002,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// promoted to f32. v2f16 is expanded to f16, which is then promoted
// to f32.
for (const auto &Op :
- {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
+ {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::f32, Legal);
- setOperationAction(Op, MVT::f64, Legal);
+ // only div/rem/sqrt are legal for f64
+ if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
+ setOperationAction(Op, MVT::f64, Legal);
+ }
setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
setOperationAction(Op, MVT::bf16, Promote);
AddPromotedToType(Op, MVT::bf16, MVT::f32);
@@ -1072,12 +1125,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::DeclareArrayParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
MAKE_CASE(NVPTXISD::CALL)
- MAKE_CASE(NVPTXISD::LoadParam)
- MAKE_CASE(NVPTXISD::LoadParamV2)
- MAKE_CASE(NVPTXISD::LoadParamV4)
- MAKE_CASE(NVPTXISD::StoreParam)
- MAKE_CASE(NVPTXISD::StoreParamV2)
- MAKE_CASE(NVPTXISD::StoreParamV4)
MAKE_CASE(NVPTXISD::MoveParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1315,105 +1362,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}
-static bool adjustElementType(EVT &ElementType) {
- switch (ElementType.getSimpleVT().SimpleTy) {
- default:
- return false;
- case MVT::f16:
- case MVT::bf16:
- ElementType = MVT::i16;
- return true;
- case MVT::f32:
- case MVT::v2f16:
- case MVT::v2bf16:
- ElementType = MVT::i32;
- return true;
- case MVT::f64:
- ElementType = MVT::i64;
- return true;
- }
-}
-
-// Use byte-store when the param address of the argument value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-//
-// This is called in LowerCall() when passing the param values.
-static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
- uint64_t Offset, EVT ElementType,
- SDValue StVal, SDValue &InGlue,
- unsigned ArgID, const SDLoc &dl) {
- // Bit logic only works on integer types
- if (adjustElementType(ElementType))
- StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
-
- // Store each byte
- SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
- // Shift the byte to the last byte position
- SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
- DAG.getConstant(i * 8, dl, MVT::i32));
- SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
- DAG.getConstant(Offset + i, dl, MVT::i32),
- ShiftVal, InGlue};
- // Trunc store only the last byte by using
- // st.param.b8
- // The register type can be larger than b8.
- Chain = DAG.getMemIntrinsicNode(
- NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
- MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
- InGlue = Chain.getValue(1);
- }
- return Chain;
-}
-
-// Use byte-load when the param adress of the returned value is unaligned.
-// This may happen when the returned value is a field of a packed structure.
-static SDValue
-LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
- EVT ElementType, SDValue &InGlue,
- SmallVectorImpl<SDValue> &TempProxyRegOps,
- const SDLoc &dl) {
- // Bit logic only works on integer types
- EVT MergedType = ElementType;
- adjustElementType(MergedType);
-
- // Load each byte and construct the whole value. Initial value to 0
- SDValue RetVal = DAG.getConstant(0, dl, MergedType);
- // LoadParamMemI8 loads into i16 register only
- SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
- for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
- SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Offset + i, dl, MVT::i32),
- InGlue};
- // This will be selected to LoadParamMemI8
- SDValue LdVal =
- DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
- MVT::i8, MachinePointerInfo(), Align(1));
- SDValue TmpLdVal = LdVal.getValue(0);
- Chain = LdVal.getValue(1);
- InGlue = LdVal.getValue(2);
-
- TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
- TmpLdVal.getSimpleValueType(), TmpLdVal);
- TempProxyRegOps.push_back(TmpLdVal);
-
- SDValue CMask = DAG.getConstant(255, dl, MergedType);
- SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
- // Need to extend the i16 register to the whole width.
- TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
- // Mask off the high bits. Leave only the lower 8bits.
- // Do this because we are using loadparam.b8.
- TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
- // Shift and merge
- TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
- RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
- }
- if (ElementType != MergedType)
- RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
- return RetVal;
-}
-
static bool shouldConvertToIndirectCall(const CallBase *CB,
const GlobalAddressSDNode *Func) {
if (!Func)
@@ -1480,19 +1428,48 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SelectionDAG &DAG = CLI.DAG;
SDLoc dl = CLI.DL;
- SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
- SDValue Chain = CLI.Chain;
+ const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
SDValue Callee = CLI.Callee;
- bool &isTailCall = CLI.IsTailCall;
ArgListTy &Args = CLI.getArgs();
Type *RetTy = CLI.RetTy;
const CallBase *CB = CLI.CB;
const DataLayout &DL = DAG.getDataLayout();
+ LLVMContext &Ctx = *DAG.getContext();
const auto GetI32 = [&](const unsigned I) {
return DAG.getConstant(I, dl, MVT::i32);
};
+ const unsigned UniqueCallSite = GlobalUniqueCallSite++;
+ const SDValue CallChain = CLI.Chain;
+ const SDValue StartChain =
+ DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
+ SDValue DeclareGlue = StartChain.getValue(1);
+
+ SmallVector<SDValue, 16> CallPrereqs{StartChain};
+
+ const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
+ // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
+ // loaded/stored using i16, so it's handled here as well.
+ const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
+ SDValue Declare =
+ DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+ {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
+ CallPrereqs.push_back(Declare);
+ DeclareGlue = Declare.getValue(1);
+ return Declare;
+ };
+
+ const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
+ unsigned Size) {
+ SDValue Declare = DAG.getNode(
+ NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+ {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
+ CallPrereqs.push_back(Declare);
+ DeclareGlue = Declare.getValue(1);
+ return Declare;
+ };
+
// Variadic arguments.
//
// Normally, for each argument, we declare a param scalar or a param
@@ -1508,15 +1485,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
//
// After all vararg is processed, 'VAOffset' holds the size of the
// vararg byte array.
+ assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
+ "Non-VarArg function with extra arguments");
- SDValue VADeclareParam; // vararg byte array
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
- unsigned VAOffset = 0; // current offset in the param array
+ unsigned VAOffset = 0; // current offset in the param array
- const unsigned UniqueCallSite = GlobalUniqueCallSite++;
- SDValue TempChain = Chain;
- Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
- SDValue InGlue = Chain.getValue(1);
+ const SDValue VADeclareParam =
+ CLI.Args.size() > FirstVAArg
+ ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
+ Align(STI.getMaxRequiredAlignment()), 0)
+ : SDValue();
// Args.size() and Outs.size() need not match.
// Outs.size() will be larger
@@ -1548,15 +1527,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const SDValue ParamSymbol =
getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
-
assert((!IsByVal || Arg.IndirectType) &&
"byval arg must have indirect type");
Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
- ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
- assert(VTs.size() == Offsets.size() && "Size mismatch");
- assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
const Align ArgAlign = [&]() {
if (IsByVal) {
@@ -1564,202 +1537,152 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- const Align ByValAlign = getFunctionByValParamAlign(
- CB->getCalledFunction(), ETy, InitialAlign, DL);
- if (IsVAArg)
- VAOffset = alignTo(VAOffset, ByValAlign);
- return ByValAlign;
+ return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+ InitialAlign, DL);
}
return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
}();
- const unsigned TypeSize = DL.getTypeAllocSize(ETy);
- assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+ const unsigned TySize = DL.getTypeAllocSize(ETy);
+ assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
- const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
- if (IsVAArg) {
- if (ArgI == FirstVAArg) {
- VADeclareParam = DAG.getNode(
- NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
- GetI32(0), InGlue});
- return VADeclareParam;
- }
- return std::nullopt;
- }
- if (IsByVal || shouldPassAsArray(Arg.Ty)) {
- // declare .param .align <align> .b8 .param<n>[<size>];
- return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(ArgAlign.value()),
- GetI32(TypeSize), InGlue});
- }
- assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
- // declare .param .b<size> .param<n>;
-
- // PTX ABI requires integral types to be at least 32 bits in
- // size. FP16 is loaded/stored using i16, so it's handled
- // here as well.
- const unsigned PromotedSize =
- (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
- ? promoteScalarArgumentSize(TypeSize * 8)
- : TypeSize * 8;
-
- return DAG.getNode(NVPTXISD::DeclareScalarParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
- }();
- if (ArgDeclare) {
- Chain = ArgDeclare->getValue(0);
- InGlue = ArgDeclare->getValue(1);
- }
+ const SDValue ArgDeclare = [&]() {
+ if (IsVAArg)
+ return VADeclareParam;
- // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
- // than 32-bits are sign extended or zero extended, depending on
- // whether they are signed or unsigned types. This case applies
- // only to scalar parameters and not to aggregate values.
- const bool ExtendIntegerParam =
- Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+ if (IsByVal || shouldPassAsArray(Arg.Ty))
+ return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
- const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
- const Align PartAlign) {
- SDValue StVal;
- if (IsByVal) {
- SDValue Ptr = ArgOutVals[0];
- auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
- SDValue SrcAddr =
- DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
- StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
- } else {
- StVal = ArgOutVals[I];
+ assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
+ assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
+ "Only int and float types are supported as non-array arguments");
- auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType());
- if (PromotedVT != StVal.getValueType()) {
- StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT,
- StVal);
- }
- }
+ return MakeDeclareScalarParam(ParamSymbol, TySize);
+ }();
- if (ExtendIntegerParam) {
- assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
- // zext/sext to i32
- StVal =
- DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal);
- } else if (EltVT.getSizeInBits() < 16) {
- // Use 16-bit registers for small stores as it's the
- // smallest general purpose register size supported by NVPTX.
- StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
+ if (IsByVal) {
+ assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+ SDValue SrcPtr = ArgOutVals[0];
+ const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+ const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+
+ if (IsVAArg)
+ VAOffset = alignTo(VAOffset, ArgAlign);
+
+ SmallVector<EVT, 4> ValueVTs, MemVTs;
+ SmallVector<TypeSize, 4> Offsets;
+ ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+ Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+ SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+ SDValue SrcLoad =
+ DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+ TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+ Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+ SDValue ParamAddr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+ CallPrereqs.push_back(StoreParam);
+
+ J += NumElts;
}
- return StVal;
- };
-
- const auto VectorInfo =
- VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
- unsigned J = 0;
- for (const unsigned NumElts : VectorInfo) {
- const int CurOffset = Offsets[J];
- EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
- const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
-
- // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
- // scalar store. In such cases, fall back to byte stores.
- if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
-
- SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
- Chain = LowerUnalignedStoreParam(DAG, Chain,
- CurOffset + (IsByVal ? VAOffset : 0),
- EltVT, StVal, InGlue, ArgI, dl);
+ if (IsVAArg)
+ VAOffset += TySize;
+ } else {
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+ assert(VTs.size() == Offsets.size() && "Size mismatch");
+ assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+ // than 32-bits are sign extended or zero extended, depending on
+ // whether they are signed or unsigned types. This case applies
+ // only to scalar parameters and not to aggregate values.
+ const bool ExtendIntegerParam =
+ Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+ const auto GetStoredValue = [&](const unsigned I) {
+ SDValue StVal = ArgOutVals[I];
+ assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+ StVal.getValueType() &&
+ "OutVal type should always be legal");
+
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT StoreVT =
+ ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+ return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+ };
+
+ unsigned J = 0;
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+ for (const unsigned NumElts : VI) {
+ const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+ unsigned Offset;
+ if (IsVAArg) {
+ // TODO: We may need to support vector types that can be passed
+ // as scalars in variadic arguments.
+ assert(NumElts == 1 &&
+ "Vectorization should be disabled for vaargs.");
+
+ // Align each part of the variadic argument to their type.
+ VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+ Offset = VAOffset;
+
+ const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+ VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+ } else {
+ assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+ Offset = Offsets[J];
+ }
- // LowerUnalignedStoreParam took care of inserting the necessary nodes
- // into the SDAG, so just move on to the next element.
- J++;
- continue;
- }
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
- if (IsVAArg && !IsByVal)
- // Align each part of the variadic argument to their type.
- VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+ const MaybeAlign CurrentAlign = ExtendIntegerParam
+ ? MaybeAlign(std::nullopt)
+ : commonAlignment(ArgAlign, Offset);
- assert((IsVAArg || VAOffset == 0) &&
- "VAOffset must be 0 for non-VA args");
- SmallVector<SDValue, 6> StoreOperands{
- Chain, GetI32(IsVAArg ? FirstVAArg : ArgI),
- GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))};
+ SDValue Val =
+ getBuildVectorizedValue(NumElts, dl, DAG, [&](unsigned K) {
+ return GetStoredValue(J + K);
+ });
- // Record the values to store.
- for (const unsigned K : llvm::seq(NumElts))
- StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign));
- StoreOperands.push_back(InGlue);
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, Val, Ptr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ CallPrereqs.push_back(StoreParam);
- NVPTXISD::NodeType Op;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::StoreParam;
- break;
- case 2:
- Op = NVPTXISD::StoreParamV2;
- break;
- case 4:
- Op = NVPTXISD::StoreParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
- // Adjust type of the store op if we've extended the scalar
- // return value.
- EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
-
- Chain = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
- TheStoreType, MachinePointerInfo(), PartAlign,
- MachineMemOperand::MOStore);
- InGlue = Chain.getValue(1);
-
- // TODO: We may need to support vector types that can be passed
- // as scalars in variadic arguments.
- if (IsVAArg && !IsByVal) {
- assert(NumElts == 1 &&
- "Vectorization is expected to be disabled for variadics.");
- VAOffset +=
- DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
+ J += NumElts;
}
-
- J += NumElts;
}
- if (IsVAArg && IsByVal)
- VAOffset += TypeSize;
}
- GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
-
// Handle Result
if (!Ins.empty()) {
- const SDValue RetDeclare = [&]() {
- const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
- const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
- if (shouldPassAsArray(RetTy)) {
- const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
- return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, RetSymbol, GetI32(RetAlign.value()),
- GetI32(ResultSize / 8), InGlue});
- }
- const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize);
- return DAG.getNode(
- NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
- {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
- }();
- Chain = RetDeclare.getValue(0);
- InGlue = RetDeclare.getValue(1);
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+ const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
+ if (shouldPassAsArray(RetTy)) {
+ const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
+ } else {
+ MakeDeclareScalarParam(RetSymbol, ResultSize);
+ }
}
- const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
// Set the size of the vararg param byte array if the callee is a variadic
// function and the variadic part is not empty.
- if (HasVAArgs) {
+ if (VADeclareParam) {
SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
VADeclareParam.getOperand(1),
VADeclareParam.getOperand(2), GetI32(VAOffset),
@@ -1768,6 +1691,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
VADeclareParam->getVTList(), DeclareParamOps);
}
+ const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
// If the type of the callsite does not match that of the function, convert
// the callsite to an indirect call.
const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
@@ -1797,57 +1721,39 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// instruction.
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
+ const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
std::string Proto =
getPrototype(DL, RetTy, Args, CLI.Outs,
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
- Chain = DAG.getNode(
- NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
- {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
- InGlue = Chain.getValue(1);
- }
-
- if (ConvertToIndirectCall) {
- // Copy the function ptr to a ptx register and use the register to call the
- // function.
- const MVT DestVT = Callee.getValueType().getSimpleVT();
- MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
- auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
- Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
+ const SDValue PrototypeDeclare = DAG.getNode(
+ NVPTXISD::CallPrototype, dl, MVT::Other,
+ {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)});
+ CallPrereqs.push_back(PrototypeDeclare);
}
const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
const unsigned NumArgs =
std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
- /// NumParams, Callee, Proto, InGlue)
- Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
- GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
- GetI32(Proto), InGlue});
- InGlue = Chain.getValue(1);
-
+ /// NumParams, Callee, Proto)
+ const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs);
+ const SDValue Call = DAG.getNode(
+ NVPTXISD::CALL, dl, MVT::Other,
+ {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
+ GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
+
+ SmallVector<SDValue, 16> LoadChains{Call};
SmallVector<SDValue, 16> ProxyRegOps;
- // An item of the vector is filled if the element does not need a ProxyReg
- // operation on it and should be added to InVals as is. ProxyRegOps and
- // ProxyRegTruncates contain empty/none items at the same index.
- SmallVector<SDValue, 16> RetElts;
- // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
- // to use the values of `LoadParam`s and to be replaced later then
- // `CALLSEQ_END` is added.
- SmallVector<SDValue, 16> TempProxyRegOps;
-
- // Generate loads from param memory/moves from registers for result
if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -1855,106 +1761,49 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
- for (const unsigned VectorizedSize : VectorInfo) {
- EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]);
- EVT EltType = Ins[I].VT;
- const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
-
- if (TheLoadType != VTs[I])
- EltType = TheLoadType;
-
- if (ExtendIntegerRetVal) {
- TheLoadType = MVT::i32;
- EltType = MVT::i32;
- } else if (TheLoadType.getSizeInBits() < 16) {
- EltType = MVT::i16;
- }
-
- // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
- // scalar load. In such cases, fall back to byte loads.
- if (VectorizedSize == 1 && RetTy->isAggregateType() &&
- EltAlign < DAG.getEVTAlign(TheLoadType)) {
- SDValue Ret = LowerUnalignedLoadRetParam(
- DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
- ProxyRegOps.push_back(SDValue());
- RetElts.resize(I);
- RetElts.push_back(Ret);
-
- I++;
- continue;
- }
-
- SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType);
- LoadVTs.append({MVT::Other, MVT::Glue});
-
- NVPTXISD::NodeType Op;
- switch (VectorizedSize) {
- case 1:
- Op = NVPTXISD::LoadParam;
- break;
- case 2:
- Op = NVPTXISD::LoadParamV2;
- break;
- case 4:
- Op = NVPTXISD::LoadParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
-
- SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue};
- SDValue RetVal = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
- MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad);
-
- for (const unsigned J : llvm::seq(VectorizedSize)) {
- ProxyRegOps.push_back(RetVal.getValue(J));
- }
-
- Chain = RetVal.getValue(VectorizedSize);
- InGlue = RetVal.getValue(VectorizedSize + 1);
-
- I += VectorizedSize;
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ for (const unsigned NumElts : VI) {
+ const MaybeAlign CurrentAlign =
+ ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
+ : commonAlignment(RetAlign, Offsets[I]);
+
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT LoadVT =
+ ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+ const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
+
+ SDValue R =
+ DAG.getLoad(VecVT, dl, Call, Ptr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+
+ LoadChains.push_back(R.getValue(1));
+ for (const unsigned J : llvm::seq(NumElts))
+ ProxyRegOps.push_back(getExtractVectorizedValue(R, J, LoadVT, dl, DAG));
+ I += NumElts;
}
}
- Chain =
- DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
- InGlue = Chain.getValue(1);
+ const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains);
+ const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite,
+ UniqueCallSite + 1, SDValue(), dl);
// Append ProxyReg instructions to the chain to make sure that `callseq_end`
// will not get lost. Otherwise, during libcalls expansion, the nodes can become
// dangling.
- for (const unsigned I : llvm::seq(ProxyRegOps.size())) {
- if (I < RetElts.size() && RetElts[I]) {
- InVals.push_back(RetElts[I]);
- continue;
- }
-
- SDValue Ret =
- DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(),
- {Chain, ProxyRegOps[I]});
-
- const EVT ExpectedVT = Ins[I].VT;
- if (!Ret.getValueType().bitsEq(ExpectedVT)) {
- Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret);
- }
+ for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) {
+ SDValue Proxy =
+ DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg});
+ SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl);
InVals.push_back(Ret);
}
- for (SDValue &T : TempProxyRegOps) {
- SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(),
- {Chain, T.getOperand(0)});
- DAG.ReplaceAllUsesWith(T, Repl);
- DAG.RemoveDeadNode(T.getNode());
- }
-
- // set isTailCall to false for now, until we figure out how to express
+ // set IsTailCall to false for now, until we figure out how to express
// tail call optimization in PTX
- isTailCall = false;
- return Chain;
+ CLI.IsTailCall = false;
+ return CallEnd;
}
SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
@@ -3407,11 +3256,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
- MachineFunction &MF = DAG.getMachineFunction();
const DataLayout &DL = DAG.getDataLayout();
auto PtrVT = getPointerTy(DAG.getDataLayout());
- const Function *F = &MF.getFunction();
+ const Function &F = DAG.getMachineFunction().getFunction();
SDValue Root = DAG.getRoot();
SmallVector<SDValue, 16> OutChains;
@@ -3427,7 +3275,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// See similar issue in LowerCall.
auto AllIns = ArrayRef(Ins);
- for (const auto &Arg : F->args()) {
+ for (const auto &Arg : F.args()) {
const auto ArgIns = AllIns.take_while(
[&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
AllIns = AllIns.drop_front(ArgIns.size());
@@ -3467,7 +3315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
SDValue P;
- if (isKernelFunction(*F)) {
+ if (isKernelFunction(F)) {
P = ArgSymbol;
P.getNode()->setIROrder(Arg.getArgNo() + 1);
} else {
@@ -3485,43 +3333,27 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(VTs.size() == Offsets.size() && "Size mismatch");
const Align ArgAlign = getFunctionArgumentAlignment(
- F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
+ &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
unsigned I = 0;
- for (const unsigned NumElts : VectorInfo) {
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+ for (const unsigned NumElts : VI) {
// i1 is loaded/stored as i8
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
- // If the element is a packed type (ex. v2f16, v4i8, etc) holding
- // multiple elements.
- const unsigned PackingAmt =
- LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
-
- const EVT VecVT =
- NumElts == 1
- ? LoadVT
- : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
- NumElts * PackingAmt);
+ const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
- const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
+ const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
SDValue P =
DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
- if (P.getNode())
- P.getNode()->setIROrder(Arg.getArgNo() + 1);
+ P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt =
- NumElts == 1
- ? P
- : DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
- : ISD::EXTRACT_VECTOR_ELT,
- dl, LoadVT, P,
- DAG.getVectorIdxConstant(J * PackingAmt, dl));
+ SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
Elt = correctParamType(Elt, ArgIns[I + J].VT, ArgIns[I + J].Flags,
DAG, dl);
@@ -3544,9 +3376,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
const SDLoc &dl, SelectionDAG &DAG) const {
- const MachineFunction &MF = DAG.getMachineFunction();
- const Function &F = MF.getFunction();
- Type *RetTy = MF.getFunction().getReturnType();
+ const Function &F = DAG.getMachineFunction().getFunction();
+ Type *RetTy = F.getReturnType();
if (RetTy->isVoidTy()) {
assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
@@ -3554,10 +3385,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
}
const DataLayout &DL = DAG.getDataLayout();
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
- assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+ const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
+ const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
@@ -3565,6 +3395,11 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+ assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
const auto GetRetVal = [&](unsigned I) -> SDValue {
SDValue RetVal = OutVals[I];
assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
@@ -3577,33 +3412,16 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
return correctParamType(RetVal, StoreVT, Outs[I].Flags, DAG, dl);
};
- const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
- const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
- for (const unsigned NumElts : VectorInfo) {
+ const auto VI = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+ for (const unsigned NumElts : VI) {
const MaybeAlign CurrentAlign = ExtendIntegerRetVal
? MaybeAlign(std::nullopt)
: commonAlignment(RetAlign, Offsets[I]);
- SDValue Val;
- if (NumElts == 1) {
- Val = GetRetVal(I);
- } else {
- SmallVector<SDValue, 4> StoreVals;
- for (const unsigned J : llvm::seq(NumElts)) {
- SDValue ValJ = GetRetVal(I + J);
- if (ValJ.getValueType().isVector())
- DAG.ExtractVectorElements(ValJ, StoreVals);
- else
- StoreVals.push_back(ValJ);
- }
-
- EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
- StoreVals.size());
- Val = DAG.getBuildVector(VT, dl, StoreVals);
- }
+ SDValue Val = getBuildVectorizedValue(
+ NumElts, dl, DAG, [&](unsigned K) { return GetRetVal(I + K); });
- const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
@@ -5097,7 +4915,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return SDValue();
auto *LD = cast<MemSDNode>(N);
- EVT MemVT = LD->getMemoryVT();
SDLoc DL(LD);
// the new opcode after we double the number of operands
@@ -5114,10 +4931,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
- case NVPTXISD::LoadParamV2:
- OldNumOutputs = 2;
- Opcode = NVPTXISD::LoadParamV4;
- break;
case NVPTXISD::LoadV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadV4;
@@ -5142,9 +4955,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
// Create the new load
- SDValue NewLoad =
- DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
- Operands, MemVT, LD->getMemOperand());
+ SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
+ Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(),
+ LD->getMemOperand());
// Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
// the outputs the same. These nodes will be optimized away in later
@@ -5186,7 +4999,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
return SDValue();
auto *ST = cast<MemSDNode>(N);
- EVT MemVT = ElementVT.getVectorElementType();
// The new opcode after we double the number of operands.
NVPTXISD::NodeType Opcode;
@@ -5195,17 +5007,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
// Any packed type is legal, so the legalizer will not have lowered
// ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
// it here.
- MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV2;
break;
- case NVPTXISD::StoreParam:
- Opcode = NVPTXISD::StoreParamV2;
- break;
- case NVPTXISD::StoreParamV2:
- Opcode = NVPTXISD::StoreParamV4;
- break;
case NVPTXISD::StoreV2:
- MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
@@ -5215,7 +5019,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
return SDValue();
Opcode = NVPTXISD::StoreV8;
break;
- case NVPTXISD::StoreParamV4:
case NVPTXISD::StoreV8:
// PTX doesn't support the next doubling of operands
return SDValue();
@@ -5257,19 +5060,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
// Now we replace the store
return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
- MemVT, ST->getMemOperand());
-}
-
-static SDValue PerformStoreCombineHelper(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- unsigned Front, unsigned Back) {
- if (all_of(N->ops().drop_front(Front).drop_back(Back),
- [](const SDUse &U) { return U.get()->isUndef(); }))
- // Operand 0 is the previous value in the chain. Cannot return EntryToken
- // as the previous value will become unused and eliminated later.
- return N->getOperand(0);
-
- return combinePackingMovIntoStore(N, DCI, Front, Back);
+ ST->getMemoryVT(), ST->getMemOperand());
}
static SDValue PerformStoreCombine(SDNode *N,
@@ -5277,13 +5068,6 @@ static SDValue PerformStoreCombine(SDNode *N,
return combinePackingMovIntoStore(N, DCI, 1, 2);
}
-static SDValue PerformStoreParamCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
- // Operands from the 3rd to the 2nd last one are the values to be stored.
- // {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, DCI, 3, 1);
-}
-
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
@@ -5429,6 +5213,42 @@ static SDValue PerformREMCombine(SDNode *N,
return SDValue();
}
+// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
+static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op = N->getOperand(0);
+ if (!Op.hasOneUse())
+ return SDValue();
+ EVT ToVT = N->getValueType(0);
+ EVT FromVT = Op.getValueType();
+ if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
+ (ToVT == MVT::i64 && FromVT == MVT::i32)))
+ return SDValue();
+ if (!(Op.getOpcode() == ISD::MUL ||
+ (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1)))))
+ return SDValue();
+
+ SDLoc DL(N);
+ unsigned ExtOpcode = N->getOpcode();
+ unsigned Opcode = 0;
+ if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
+ Opcode = NVPTXISD::MUL_WIDE_SIGNED;
+ else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
+ Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
+ else
+ return SDValue();
+ SDValue RHS = Op.getOperand(1);
+ if (Op.getOpcode() == ISD::SHL) {
+ const auto ShiftAmt = Op.getConstantOperandVal(1);
+ const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt;
+ RHS = DCI.DAG.getConstant(MulVal, DL, ToVT);
+ }
+ return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS);
+}
+
enum OperandSignedness {
Signed = 0,
Unsigned,
@@ -5939,6 +5759,86 @@ static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
N->getConstantOperandAPInt(2),
N->getConstantOperandVal(3)),
SDLoc(N), N->getValueType(0));
+ return SDValue();
+}
+
+// During call lowering we wrap the return values in a ProxyReg node which
+// depend on the chain value produced by the completed call. This ensures that
+// the full call is emitted in cases where libcalls are used to legalize
+// operations. To improve the functioning of other DAG combines we pull all
+// operations we can through one of these nodes, ensuring that the ProxyReg
+// directly wraps a load. That is:
+//
+// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
+//
+static SDValue sinkProxyReg(SDValue R, SDValue Chain,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ switch (R.getOpcode()) {
+ case ISD::TRUNCATE:
+ case ISD::ANY_EXTEND:
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ case ISD::BITCAST: {
+ if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V);
+ return SDValue();
+ }
+ case ISD::SHL:
+ case ISD::SRL:
+ case ISD::SRA:
+ case ISD::OR: {
+ if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI))
+ return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B);
+ return SDValue();
+ }
+ case ISD::Constant:
+ return R;
+ case ISD::LOAD:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4: {
+ return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(),
+ {Chain, R});
+ }
+ case ISD::BUILD_VECTOR: {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ SmallVector<SDValue, 16> Ops;
+ for (auto &Op : R->ops()) {
+ SDValue V = sinkProxyReg(Op, Chain, DCI);
+ if (!V)
+ return SDValue();
+ Ops.push_back(V);
+ }
+ return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
+ }
+ case ISD::EXTRACT_VECTOR_ELT: {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R),
+ R.getValueType(), V, R.getOperand(1));
+ return SDValue();
+ }
+ default:
+ return SDValue();
+ }
+}
+
+static SDValue combineProxyReg(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Reg = N->getOperand(1);
+
+ // If the ProxyReg is not wrapping a load, try to pull the operations through
+ // the ProxyReg.
+ if (Reg.getOpcode() != ISD::LOAD) {
+ if (SDValue V = sinkProxyReg(Reg, Chain, DCI))
+ return V;
+ }
return SDValue();
}
@@ -5955,6 +5855,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ return combineMulWide(N, DCI, OptLevel);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
@@ -5962,7 +5865,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::LOAD:
- case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
@@ -5970,6 +5872,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformMULCombine(N, DCI, OptLevel);
case NVPTXISD::PRMT:
return combinePRMT(N, DCI, OptLevel);
+ case NVPTXISD::ProxyReg:
+ return combineProxyReg(N, DCI);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::SHL:
@@ -5977,10 +5881,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SREM:
case ISD::UREM:
return PerformREMCombine(N, DCI, OptLevel);
- case NVPTXISD::StoreParam:
- case NVPTXISD::StoreParamV2:
- case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N, DCI);
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
@@ -6329,6 +6229,22 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ SmallVectorImpl<SDValue> &Results) {
+ SDValue Chain = N->getOperand(0);
+ SDValue Reg = N->getOperand(1);
+
+ MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType());
+
+ SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT);
+ SDValue NewProxy =
+ DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg});
+ SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0));
+
+ Results.push_back(Res);
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -6346,6 +6262,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case NVPTXISD::ProxyReg:
+ replaceProxyReg(N, DAG, *this, Results);
+ return;
}
}