aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp257
1 files changed, 237 insertions, 20 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7d2fe78..66a1010 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -47,6 +47,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
@@ -59,6 +60,7 @@
#include <cmath>
#include <cstdint>
#include <iterator>
+#include <optional>
#include <sstream>
#include <string>
#include <utility>
@@ -1529,6 +1531,105 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}
+static bool adjustElementType(EVT &ElementType) {
+ switch (ElementType.getSimpleVT().SimpleTy) {
+ default:
+ return false;
+ case MVT::f16:
+ case MVT::bf16:
+ ElementType = MVT::i16;
+ return true;
+ case MVT::f32:
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ ElementType = MVT::i32;
+ return true;
+ case MVT::f64:
+ ElementType = MVT::i64;
+ return true;
+ }
+}
+
+// Use byte-store when the param address of the argument value is unaligned.
+// This may happen when the return value is a field of a packed structure.
+//
+// This is called in LowerCall() when passing the param values.
+static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
+ uint64_t Offset, EVT ElementType,
+ SDValue StVal, SDValue &InGlue,
+ unsigned ArgID, const SDLoc &dl) {
+ // Bit logic only works on integer types
+ if (adjustElementType(ElementType))
+ StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
+
+ // Store each byte
+ SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
+ // Shift the byte to the last byte position
+ SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
+ DAG.getConstant(i * 8, dl, MVT::i32));
+ SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
+ DAG.getConstant(Offset + i, dl, MVT::i32),
+ ShiftVal, InGlue};
+ // Trunc store only the last byte by using
+ // st.param.b8
+ // The register type can be larger than b8.
+ Chain = DAG.getMemIntrinsicNode(
+ NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
+ MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
+ InGlue = Chain.getValue(1);
+ }
+ return Chain;
+}
+
+// Use byte-load when the param adress of the returned value is unaligned.
+// This may happen when the returned value is a field of a packed structure.
+static SDValue
+LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
+ EVT ElementType, SDValue &InGlue,
+ SmallVectorImpl<SDValue> &TempProxyRegOps,
+ const SDLoc &dl) {
+ // Bit logic only works on integer types
+ EVT MergedType = ElementType;
+ adjustElementType(MergedType);
+
+ // Load each byte and construct the whole value. Initial value to 0
+ SDValue RetVal = DAG.getConstant(0, dl, MergedType);
+ // LoadParamMemI8 loads into i16 register only
+ SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
+ for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
+ SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
+ DAG.getConstant(Offset + i, dl, MVT::i32),
+ InGlue};
+ // This will be selected to LoadParamMemI8
+ SDValue LdVal =
+ DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
+ MVT::i8, MachinePointerInfo(), Align(1));
+ SDValue TmpLdVal = LdVal.getValue(0);
+ Chain = LdVal.getValue(1);
+ InGlue = LdVal.getValue(2);
+
+ TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
+ TmpLdVal.getSimpleValueType(), TmpLdVal);
+ TempProxyRegOps.push_back(TmpLdVal);
+
+ SDValue CMask = DAG.getConstant(255, dl, MergedType);
+ SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
+ // Need to extend the i16 register to the whole width.
+ TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
+ // Mask off the high bits. Leave only the lower 8bits.
+ // Do this because we are using loadparam.b8.
+ TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
+ // Shift and merge
+ TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
+ RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
+ }
+ if (ElementType != MergedType)
+ RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
+
+ return RetVal;
+}
+
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1680,17 +1781,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (NeedAlign)
PartAlign = commonAlignment(ArgAlign, CurOffset);
- // New store.
- if (VectorInfo[j] & PVF_FIRST) {
- assert(StoreOperands.empty() && "Unfinished preceding store.");
- StoreOperands.push_back(Chain);
- StoreOperands.push_back(
- DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
- StoreOperands.push_back(DAG.getConstant(
- IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
- dl, MVT::i32));
- }
-
SDValue StVal = OutVals[OIdx];
MVT PromotedVT;
@@ -1723,6 +1813,35 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
}
+ // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
+ // scalar store. In such cases, fall back to byte stores.
+ if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
+ PartAlign.value() <
+ DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
+ assert(StoreOperands.empty() && "Unfinished preceeding store.");
+ Chain = LowerUnalignedStoreParam(
+ DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
+ StVal, InGlue, ParamCount, dl);
+
+ // LowerUnalignedStoreParam took care of inserting the necessary nodes
+ // into the SDAG, so just move on to the next element.
+ if (!IsByVal)
+ ++OIdx;
+ continue;
+ }
+
+ // New store.
+ if (VectorInfo[j] & PVF_FIRST) {
+ assert(StoreOperands.empty() && "Unfinished preceding store.");
+ StoreOperands.push_back(Chain);
+ StoreOperands.push_back(
+ DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
+
+ StoreOperands.push_back(DAG.getConstant(
+ IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
+ dl, MVT::i32));
+ }
+
// Record the value to store.
StoreOperands.push_back(StVal);
@@ -1923,6 +2042,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVector<SDValue, 16> ProxyRegOps;
SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
+ // An item of the vector is filled if the element does not need a ProxyReg
+ // operation on it and should be added to InVals as is. ProxyRegOps and
+ // ProxyRegTruncates contain empty/none items at the same index.
+ SmallVector<SDValue, 16> RetElts;
+ // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
+ // to use the values of `LoadParam`s and to be replaced later then
+ // `CALLSEQ_END` is added.
+ SmallVector<SDValue, 16> TempProxyRegOps;
// Generate loads from param memory/moves from registers for result
if (Ins.size() > 0) {
@@ -1966,6 +2093,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
EltType = MVT::i16;
}
+ // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
+ // scalar load. In such cases, fall back to byte loads.
+ if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
+ EltAlign < DL.getABITypeAlign(
+ TheLoadType.getTypeForEVT(*DAG.getContext()))) {
+ assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
+ SDValue Ret = LowerUnalignedLoadRetParam(
+ DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
+ ProxyRegOps.push_back(SDValue());
+ ProxyRegTruncates.push_back(std::optional<MVT>());
+ RetElts.resize(i);
+ RetElts.push_back(Ret);
+
+ continue;
+ }
+
// Record index of the very first element of the vector.
if (VectorInfo[i] & PVF_FIRST) {
assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
@@ -2028,6 +2171,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// will not get lost. Otherwise, during libcalls expansion, the nodes can become
// dangling.
for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
+ if (i < RetElts.size() && RetElts[i]) {
+ InVals.push_back(RetElts[i]);
+ continue;
+ }
+
SDValue Ret = DAG.getNode(
NVPTXISD::ProxyReg, dl,
DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
@@ -2044,6 +2192,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
InVals.push_back(Ret);
}
+ for (SDValue &T : TempProxyRegOps) {
+ SDValue Repl = DAG.getNode(
+ NVPTXISD::ProxyReg, dl,
+ DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
+ {Chain, T.getOperand(0), InGlue});
+ DAG.ReplaceAllUsesWith(T, Repl);
+ DAG.RemoveDeadNode(T.getNode());
+
+ Chain = Repl.getValue(1);
+ InGlue = Repl.getValue(2);
+ }
+
// set isTailCall to false for now, until we figure out how to express
// tail call optimization in PTX
isTailCall = false;
@@ -3045,9 +3205,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
Value *srcValue = Constant::getNullValue(PointerType::get(
EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
+
+ const MaybeAlign PartAlign = [&]() -> MaybeAlign {
+ if (aggregateIsPacked)
+ return Align(1);
+ if (NumElts != 1)
+ return std::nullopt;
+ Align PartAlign =
+ (Offsets[parti] == 0 && PAL.getParamAlignment(i))
+ ? PAL.getParamAlignment(i).value()
+ : DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
+ return commonAlignment(PartAlign, Offsets[parti]);
+ }();
SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
- MachinePointerInfo(srcValue),
- MaybeAlign(aggregateIsPacked ? 1 : 0),
+ MachinePointerInfo(srcValue), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
if (P.getNode())
@@ -3113,6 +3284,33 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
return Chain;
}
+// Use byte-store when the param adress of the return value is unaligned.
+// This may happen when the return value is a field of a packed structure.
+static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
+ uint64_t Offset, EVT ElementType,
+ SDValue RetVal, const SDLoc &dl) {
+ // Bit logic only works on integer types
+ if (adjustElementType(ElementType))
+ RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
+
+ // Store each byte
+ for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
+ // Shift the byte to the last byte position
+ SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
+ DAG.getConstant(i * 8, dl, MVT::i32));
+ SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
+ ShiftVal};
+ // Trunc store only the last byte by using
+ // st.param.b8
+ // The register type can be larger than b8.
+ Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
+ DAG.getVTList(MVT::Other), StoreOperands,
+ MVT::i8, MachinePointerInfo(), std::nullopt,
+ MachineMemOperand::MOStore);
+ }
+ return Chain;
+}
+
SDValue
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
bool isVarArg,
@@ -3162,13 +3360,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
SmallVector<SDValue, 6> StoreOperands;
for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
- // New load/store. Record chain and offset operands.
- if (VectorInfo[i] & PVF_FIRST) {
- assert(StoreOperands.empty() && "Orphaned operand list.");
- StoreOperands.push_back(Chain);
- StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
- }
-
SDValue OutVal = OutVals[i];
SDValue RetVal = PromotedOutVals[i];
@@ -3182,6 +3373,32 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
}
+ // If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
+ // for a scalar store. In such cases, fall back to byte stores.
+ if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
+ EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
+ Align ElementTypeAlign =
+ DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
+ Align ElementAlign =
+ commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
+ if (ElementAlign < ElementTypeAlign) {
+ assert(StoreOperands.empty() && "Orphaned operand list.");
+ Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
+ RetVal, dl);
+
+ // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
+ // into the graph, so just move on to the next element.
+ continue;
+ }
+ }
+
+ // New load/store. Record chain and offset operands.
+ if (VectorInfo[i] & PVF_FIRST) {
+ assert(StoreOperands.empty() && "Orphaned operand list.");
+ StoreOperands.push_back(Chain);
+ StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
+ }
+
// Record the value to return.
StoreOperands.push_back(RetVal);