aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp377
1 files changed, 260 insertions, 117 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 2f1a7ad..b88978a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -15,6 +15,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "NVPTXISelDAGToDAG.h"
+#include "NVPTXSelectionDAGInfo.h"
#include "NVPTXSubtarget.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXTargetObjectFile.h"
@@ -305,7 +306,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
uint64_t StartingOffset = 0) {
SmallVector<EVT, 16> TempVTs;
SmallVector<uint64_t, 16> TempOffsets;
- ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
+ ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets,
+ StartingOffset);
for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
@@ -512,7 +514,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
// NVPTXTargetLowering Constructor.
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
const NVPTXSubtarget &STI)
- : TargetLowering(TM), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
+ : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
// always lower memset, memcpy, and memmove intrinsics to load/store
// instructions, rather
// then generating calls to memset, mempcy or memmove.
@@ -711,8 +713,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
Custom);
}
- setOperationAction(ISD::BSWAP, MVT::i16, Expand);
-
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
@@ -766,10 +766,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Register custom handling for illegal type loads/stores. We'll try to custom
// lower almost all illegal types and logic in the lowering will discard cases
// we can't handle.
- setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
+ setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::i256, MVT::f128},
+ Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
+ Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -1104,97 +1106,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// * MVT::Other - internal.addrspace.wrap
setOperationAction(ISD::INTRINSIC_WO_CHAIN,
{MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
-}
-const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
-
-#define MAKE_CASE(V) \
- case V: \
- return #V;
-
- switch ((NVPTXISD::NodeType)Opcode) {
- case NVPTXISD::FIRST_NUMBER:
- break;
-
- MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128)
- MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128)
- MAKE_CASE(NVPTXISD::RET_GLUE)
- MAKE_CASE(NVPTXISD::DeclareArrayParam)
- MAKE_CASE(NVPTXISD::DeclareScalarParam)
- MAKE_CASE(NVPTXISD::CALL)
- MAKE_CASE(NVPTXISD::MoveParam)
- MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
- MAKE_CASE(NVPTXISD::BUILD_VECTOR)
- MAKE_CASE(NVPTXISD::CallPrototype)
- MAKE_CASE(NVPTXISD::ProxyReg)
- MAKE_CASE(NVPTXISD::LoadV2)
- MAKE_CASE(NVPTXISD::LoadV4)
- MAKE_CASE(NVPTXISD::LoadV8)
- MAKE_CASE(NVPTXISD::LDUV2)
- MAKE_CASE(NVPTXISD::LDUV4)
- MAKE_CASE(NVPTXISD::StoreV2)
- MAKE_CASE(NVPTXISD::StoreV4)
- MAKE_CASE(NVPTXISD::StoreV8)
- MAKE_CASE(NVPTXISD::FSHL_CLAMP)
- MAKE_CASE(NVPTXISD::FSHR_CLAMP)
- MAKE_CASE(NVPTXISD::BFI)
- MAKE_CASE(NVPTXISD::PRMT)
- MAKE_CASE(NVPTXISD::FCOPYSIGN)
- MAKE_CASE(NVPTXISD::FMAXNUM3)
- MAKE_CASE(NVPTXISD::FMINNUM3)
- MAKE_CASE(NVPTXISD::FMAXIMUM3)
- MAKE_CASE(NVPTXISD::FMINIMUM3)
- MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
- MAKE_CASE(NVPTXISD::STACKRESTORE)
- MAKE_CASE(NVPTXISD::STACKSAVE)
- MAKE_CASE(NVPTXISD::SETP_F16X2)
- MAKE_CASE(NVPTXISD::SETP_BF16X2)
- MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
- MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
- MAKE_CASE(NVPTXISD::BrxEnd)
- MAKE_CASE(NVPTXISD::BrxItem)
- MAKE_CASE(NVPTXISD::BrxStart)
- MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED)
- MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X)
- MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y)
- MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
- MAKE_CASE(
- NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
- MAKE_CASE(
- NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1)
- MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2)
- MAKE_CASE(
- NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
- MAKE_CASE(
- NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
- MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
- MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
- MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
- MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
- MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
- }
- return nullptr;
-
-#undef MAKE_CASE
+ // Custom lowering for bswap
+ setOperationAction(ISD::BSWAP, {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
+ Custom);
}
TargetLoweringBase::LegalizeTypeAction
@@ -2031,7 +1946,7 @@ static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
}
/// Get 3-input scalar reduction opcode
-static std::optional<NVPTXISD::NodeType>
+static std::optional<unsigned>
getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
switch (ReductionOpcode) {
case ISD::VECREDUCE_FMAX:
@@ -2659,6 +2574,44 @@ static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
return Tcgen05StNode;
}
+static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
+ SDLoc DL(Op);
+ SDValue Src = Op.getOperand(0);
+ EVT VT = Op.getValueType();
+
+ switch (VT.getSimpleVT().SimpleTy) {
+ case MVT::i16: {
+ SDValue Extended = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src);
+ SDValue Swapped =
+ getPRMT(Extended, DAG.getConstant(0, DL, MVT::i32), 0x7701, DL, DAG);
+ return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Swapped);
+ }
+ case MVT::i32: {
+ return getPRMT(Src, DAG.getConstant(0, DL, MVT::i32), 0x0123, DL, DAG);
+ }
+ case MVT::v2i16: {
+ SDValue Converted = DAG.getBitcast(MVT::i32, Src);
+ SDValue Swapped =
+ getPRMT(Converted, DAG.getConstant(0, DL, MVT::i32), 0x2301, DL, DAG);
+ return DAG.getNode(ISD::BITCAST, DL, MVT::v2i16, Swapped);
+ }
+ case MVT::i64: {
+ SDValue UnpackSrc =
+ DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, Src);
+ SDValue SwappedLow =
+ getPRMT(UnpackSrc.getValue(0), DAG.getConstant(0, DL, MVT::i32), 0x0123,
+ DL, DAG);
+ SDValue SwappedHigh =
+ getPRMT(UnpackSrc.getValue(1), DAG.getConstant(0, DL, MVT::i32), 0x0123,
+ DL, DAG);
+ return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64,
+ {SwappedHigh, SwappedLow});
+ }
+ default:
+ llvm_unreachable("unsupported type for bswap");
+ }
+}
+
static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
switch (IID) {
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
@@ -2930,7 +2883,7 @@ static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
using NVPTX::PTXCvtMode::CvtMode;
auto [OpCode, RetTy, CvtModeFlag] =
- [&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> {
+ [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
switch (IntrinsicID) {
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
@@ -3181,6 +3134,86 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (auto [I, Op] : enumerate(Mask->ops())) {
+ // Mask elements must be constants.
+ if (Op.getNode()->getAsZExtVal() == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -3217,8 +3250,16 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
+ case ISD::MLOAD:
+ return LowerMLOAD(Op, DAG);
case ISD::SHL_PARTS:
return LowerShiftLeftParts(Op, DAG);
case ISD::SRA_PARTS:
@@ -3282,7 +3323,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return lowerCTLZCTPOP(Op, DAG);
case ISD::FREM:
return lowerFREM(Op, DAG);
-
+ case ISD::BSWAP:
+ return lowerBSWAP(Op, DAG);
default:
llvm_unreachable("Custom lowering not defined for operation");
}
@@ -3313,7 +3355,7 @@ SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
// Generate BrxEnd nodes
SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
IdV, Chain.getValue(1)};
- SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+ SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, MVT::Other, EndOps);
return BrxEnd;
}
@@ -3410,10 +3452,62 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
+static std::pair<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
+ const NVPTXSubtarget &STI) {
+ SDValue Chain = N->getOperand(0);
+ SDValue BasePtr = N->getOperand(1);
+ SDValue Mask = N->getOperand(3);
+ [[maybe_unused]] SDValue Passthru = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ResVT = N->getValueType(0);
+ assert(ResVT.isVector() && "Masked vector load must have vector type");
+ // While we only expect poison passthru vectors as an input to the backend,
+ // when the legalization framework splits a poison vector in half, it creates
+ // two undef vectors, so we can technically expect those too.
+ assert((Passthru.getOpcode() == ISD::POISON ||
+ Passthru.getOpcode() == ISD::UNDEF) &&
+ "Passthru operand expected to be poison or undef");
+
+ // Extract the mask and convert it to a uint32_t representing the used bytes
+ // of the entire vector load
+ uint32_t UsedBytesMask = 0;
+ uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+ assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+ uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+ uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+ for (SDValue Op : reverse(Mask->ops())) {
+ // We technically only want to do this shift for every
+ // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
+ // so this shift is a no-op.
+ UsedBytesMask <<= ElementSizeInBytes;
+
+ // Mask elements must be constants.
+ if (Op->getAsZExtVal() != 0)
+ UsedBytesMask |= ElementMask;
+ }
+
+ assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
+ "Unexpected masked load with elements masked all on or all off");
+
+ // Create a new load sd node to be handled normally by ReplaceLoadVector.
+ MemSDNode *NewLD = cast<MemSDNode>(
+ DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
+
+ // If our subtarget does not support the used bytes mask pragma, "drop" the
+ // mask by setting it to UINT32_MAX
+ if (!STI.hasUsedBytesMaskPragma())
+ UsedBytesMask = UINT32_MAX;
+
+ return {NewLD, UsedBytesMask};
+}
+
/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
static std::optional<std::pair<SDValue, SDValue>>
replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
- LoadSDNode *LD = cast<LoadSDNode>(N);
+ MemSDNode *LD = cast<MemSDNode>(N);
const EVT ResVT = LD->getValueType(0);
const EVT MemVT = LD->getMemoryVT();
@@ -3440,6 +3534,12 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
return std::nullopt;
}
+ // If we have a masked load, convert it to a normal load now
+ std::optional<uint32_t> UsedBytesMask = std::nullopt;
+ if (LD->getOpcode() == ISD::MLOAD)
+ std::tie(LD, UsedBytesMask) =
+ convertMLOADToLoadWithUsedBytesMask(LD, DAG, STI);
+
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
// loaded type to i16 and propagate the "real" type as the memory type.
@@ -3468,9 +3568,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
// Copy regular operands
SmallVector<SDValue, 8> OtherOps(LD->ops());
+ OtherOps.push_back(
+ DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
+
// The select routine does not have access to the LoadSDNode instance, so
// pass along the extension information
- OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
LD->getMemOperand());
@@ -3558,6 +3662,42 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
llvm_unreachable("Unexpected custom lowering for load");
}
+SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
+ // masked loads of these types and have to handle them here.
+ // v2f32 also needs to be handled here if the subtarget has f32x2
+ // instructions, making it legal.
+ //
+ // Note: misaligned masked loads should never reach this point
+ // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
+ // will validate alignment. Therefore, we do not need to special case handle
+ // them here.
+ EVT VT = Op.getValueType();
+ if (NVPTX::isPackedVectorTy(VT)) {
+ auto Result = convertMLOADToLoadWithUsedBytesMask(
+ cast<MemSDNode>(Op.getNode()), DAG, STI);
+ MemSDNode *LD = std::get<0>(Result);
+ uint32_t UsedBytesMask = std::get<1>(Result);
+
+ SDLoc DL(LD);
+
+ // Copy regular operands
+ SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+ OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
+
+ // We currently are not lowering extending loads, but pass the extension
+ // type anyway as later handling expects it.
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ SDValue NewLD =
+ DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
+ LD->getMemoryVT(), LD->getMemOperand());
+ return NewLD;
+ }
+ return SDValue();
+}
+
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
const NVPTXSubtarget &STI) {
MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -3944,9 +4084,10 @@ void NVPTXTargetLowering::LowerAsmOperandForConstraint(
// because we need the information that is only available in the "Value" type
// of destination
// pointer. In particular, the address space information.
-bool NVPTXTargetLowering::getTgtMemIntrinsic(
- IntrinsicInfo &Info, const CallInst &I,
- MachineFunction &MF, unsigned Intrinsic) const {
+bool NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
+ const CallBase &I,
+ MachineFunction &MF,
+ unsigned Intrinsic) const {
switch (Intrinsic) {
default:
return false;
@@ -5420,8 +5561,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8)
return SDValue();
- SmallVector<SDNode *> DeadCopyToRegs;
-
// Check whether all outputs are either used by an extractelt or are
// glue/chain nodes
if (!all_of(N->uses(), [&](SDUse &U) {
@@ -5458,7 +5597,7 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(LD);
// the new opcode after we double the number of operands
- NVPTXISD::NodeType Opcode;
+ unsigned Opcode;
SmallVector<SDValue> Operands(LD->ops());
unsigned OldNumOutputs; // non-glue, non-chain outputs
switch (LD->getOpcode()) {
@@ -5468,6 +5607,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
// here.
Opcode = NVPTXISD::LoadV2;
+ // append a "full" used bytes mask operand right before the extension type
+ // operand, signifying that all bytes are used.
+ Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
@@ -5476,9 +5618,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Opcode = NVPTXISD::LoadV4;
break;
case NVPTXISD::LoadV4:
- // V8 is only supported for f32. Don't forget, we're not changing the load
- // size here. This is already a 256-bit load.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // load size here. This is already a 256-bit load.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
OldNumOutputs = 4;
Opcode = NVPTXISD::LoadV8;
@@ -5541,7 +5683,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
auto *ST = cast<MemSDNode>(N);
// The new opcode after we double the number of operands.
- NVPTXISD::NodeType Opcode;
+ unsigned Opcode;
switch (N->getOpcode()) {
case ISD::STORE:
// Any packed type is legal, so the legalizer will not have lowered
@@ -5553,9 +5695,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
- // V8 is only supported for f32. Don't forget, we're not changing the store
- // size here. This is already a 256-bit store.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // store size here. This is already a 256-bit store.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
Opcode = NVPTXISD::StoreV8;
break;
@@ -5676,7 +5818,7 @@ static SDValue PerformFADDCombine(SDNode *N,
}
/// Get 3-input version of a 2-input min/max opcode
-static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
+static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
switch (MinMax2Opcode) {
case ISD::FMAXNUM:
case ISD::FMAXIMUMNUM:
@@ -5707,7 +5849,7 @@ static SDValue PerformFMinMaxCombine(SDNode *N,
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
unsigned MinMaxOp2 = N->getOpcode();
- NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
+ unsigned MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
// (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
@@ -6706,6 +6848,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
+ case ISD::MLOAD:
replaceLoadVector(N, DAG, Results, STI);
return;
case ISD::INTRINSIC_W_CHAIN: