diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 377 |
1 files changed, 260 insertions, 117 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 2f1a7ad..b88978a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -15,6 +15,7 @@ #include "MCTargetDesc/NVPTXBaseInfo.h" #include "NVPTX.h" #include "NVPTXISelDAGToDAG.h" +#include "NVPTXSelectionDAGInfo.h" #include "NVPTXSubtarget.h" #include "NVPTXTargetMachine.h" #include "NVPTXTargetObjectFile.h" @@ -305,7 +306,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, uint64_t StartingOffset = 0) { SmallVector<EVT, 16> TempVTs; SmallVector<uint64_t, 16> TempOffsets; - ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset); + ComputeValueVTs(TLI, DL, Ty, TempVTs, /*MemVTs=*/nullptr, &TempOffsets, + StartingOffset); for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) { MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT); @@ -512,7 +514,7 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, // NVPTXTargetLowering Constructor. NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI) - : TargetLowering(TM), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) { + : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) { // always lower memset, memcpy, and memmove intrinsics to load/store // instructions, rather // then generating calls to memset, mempcy or memmove. @@ -711,8 +713,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, Custom); } - setOperationAction(ISD::BSWAP, MVT::i16, Expand); - setOperationAction(ISD::BR_JT, MVT::Other, Custom); setOperationAction(ISD::BRIND, MVT::Other, Expand); @@ -766,10 +766,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // Register custom handling for illegal type loads/stores. We'll try to custom // lower almost all illegal types and logic in the lowering will discard cases // we can't handle. - setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom); + setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::i256, MVT::f128}, + Custom); for (MVT VT : MVT::fixedlen_vector_valuetypes()) if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256) - setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom); + setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT, + Custom); // Custom legalization for LDU intrinsics. // TODO: The logic to lower these is not very robust and we should rewrite it. @@ -1104,97 +1106,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // * MVT::Other - internal.addrspace.wrap setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom); -} -const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { - -#define MAKE_CASE(V) \ - case V: \ - return #V; - - switch ((NVPTXISD::NodeType)Opcode) { - case NVPTXISD::FIRST_NUMBER: - break; - - MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128) - MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128) - MAKE_CASE(NVPTXISD::RET_GLUE) - MAKE_CASE(NVPTXISD::DeclareArrayParam) - MAKE_CASE(NVPTXISD::DeclareScalarParam) - MAKE_CASE(NVPTXISD::CALL) - MAKE_CASE(NVPTXISD::MoveParam) - MAKE_CASE(NVPTXISD::UNPACK_VECTOR) - MAKE_CASE(NVPTXISD::BUILD_VECTOR) - MAKE_CASE(NVPTXISD::CallPrototype) - MAKE_CASE(NVPTXISD::ProxyReg) - MAKE_CASE(NVPTXISD::LoadV2) - MAKE_CASE(NVPTXISD::LoadV4) - MAKE_CASE(NVPTXISD::LoadV8) - MAKE_CASE(NVPTXISD::LDUV2) - MAKE_CASE(NVPTXISD::LDUV4) - MAKE_CASE(NVPTXISD::StoreV2) - MAKE_CASE(NVPTXISD::StoreV4) - MAKE_CASE(NVPTXISD::StoreV8) - MAKE_CASE(NVPTXISD::FSHL_CLAMP) - MAKE_CASE(NVPTXISD::FSHR_CLAMP) - MAKE_CASE(NVPTXISD::BFI) - MAKE_CASE(NVPTXISD::PRMT) - MAKE_CASE(NVPTXISD::FCOPYSIGN) - MAKE_CASE(NVPTXISD::FMAXNUM3) - MAKE_CASE(NVPTXISD::FMINNUM3) - MAKE_CASE(NVPTXISD::FMAXIMUM3) - MAKE_CASE(NVPTXISD::FMINIMUM3) - MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC) - MAKE_CASE(NVPTXISD::STACKRESTORE) - MAKE_CASE(NVPTXISD::STACKSAVE) - MAKE_CASE(NVPTXISD::SETP_F16X2) - MAKE_CASE(NVPTXISD::SETP_BF16X2) - MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED) - MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED) - MAKE_CASE(NVPTXISD::BrxEnd) - MAKE_CASE(NVPTXISD::BrxItem) - MAKE_CASE(NVPTXISD::BrxStart) - MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED) - MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X) - MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y) - MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT) - MAKE_CASE( - NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT) - MAKE_CASE( - NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1) - MAKE_CASE(NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2) - MAKE_CASE( - NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT) - MAKE_CASE( - NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT) - MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF) - MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF) - MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF) - MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF) - MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF) - } - return nullptr; - -#undef MAKE_CASE + // Custom lowering for bswap + setOperationAction(ISD::BSWAP, {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16}, + Custom); } TargetLoweringBase::LegalizeTypeAction @@ -2031,7 +1946,7 @@ static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) { } /// Get 3-input scalar reduction opcode -static std::optional<NVPTXISD::NodeType> +static std::optional<unsigned> getScalar3OpcodeForReduction(unsigned ReductionOpcode) { switch (ReductionOpcode) { case ISD::VECREDUCE_FMAX: @@ -2659,6 +2574,44 @@ static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) { return Tcgen05StNode; } +static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) { + SDLoc DL(Op); + SDValue Src = Op.getOperand(0); + EVT VT = Op.getValueType(); + + switch (VT.getSimpleVT().SimpleTy) { + case MVT::i16: { + SDValue Extended = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src); + SDValue Swapped = + getPRMT(Extended, DAG.getConstant(0, DL, MVT::i32), 0x7701, DL, DAG); + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Swapped); + } + case MVT::i32: { + return getPRMT(Src, DAG.getConstant(0, DL, MVT::i32), 0x0123, DL, DAG); + } + case MVT::v2i16: { + SDValue Converted = DAG.getBitcast(MVT::i32, Src); + SDValue Swapped = + getPRMT(Converted, DAG.getConstant(0, DL, MVT::i32), 0x2301, DL, DAG); + return DAG.getNode(ISD::BITCAST, DL, MVT::v2i16, Swapped); + } + case MVT::i64: { + SDValue UnpackSrc = + DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, Src); + SDValue SwappedLow = + getPRMT(UnpackSrc.getValue(0), DAG.getConstant(0, DL, MVT::i32), 0x0123, + DL, DAG); + SDValue SwappedHigh = + getPRMT(UnpackSrc.getValue(1), DAG.getConstant(0, DL, MVT::i32), 0x0123, + DL, DAG); + return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64, + {SwappedHigh, SwappedLow}); + } + default: + llvm_unreachable("unsupported type for bswap"); + } +} + static unsigned getTcgen05MMADisableOutputLane(unsigned IID) { switch (IID) { case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1: @@ -2930,7 +2883,7 @@ static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) { using NVPTX::PTXCvtMode::CvtMode; auto [OpCode, RetTy, CvtModeFlag] = - [&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> { + [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> { switch (IntrinsicID) { case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite: return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, @@ -3181,6 +3134,86 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) { return Or; } +static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) { + SDNode *N = Op.getNode(); + + SDValue Chain = N->getOperand(0); + SDValue Val = N->getOperand(1); + SDValue BasePtr = N->getOperand(2); + SDValue Offset = N->getOperand(3); + SDValue Mask = N->getOperand(4); + + SDLoc DL(N); + EVT ValVT = Val.getValueType(); + MemSDNode *MemSD = cast<MemSDNode>(N); + assert(ValVT.isVector() && "Masked vector store must have vector type"); + assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) && + "Unexpected alignment for masked store"); + + unsigned Opcode = 0; + switch (ValVT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected masked vector store type"); + case MVT::v4i64: + case MVT::v4f64: { + Opcode = NVPTXISD::StoreV4; + break; + } + case MVT::v8i32: + case MVT::v8f32: { + Opcode = NVPTXISD::StoreV8; + break; + } + } + + SmallVector<SDValue, 8> Ops; + + // Construct the new SDNode. First operand is the chain. + Ops.push_back(Chain); + + // The next N operands are the values to store. Encode the mask into the + // values using the sentinel register 0 to represent a masked-off element. + assert(Mask.getValueType().isVector() && + Mask.getValueType().getVectorElementType() == MVT::i1 && + "Mask must be a vector of i1"); + assert(Mask.getOpcode() == ISD::BUILD_VECTOR && + "Mask expected to be a BUILD_VECTOR"); + assert(Mask.getValueType().getVectorNumElements() == + ValVT.getVectorNumElements() && + "Mask size must be the same as the vector size"); + for (auto [I, Op] : enumerate(Mask->ops())) { + // Mask elements must be constants. + if (Op.getNode()->getAsZExtVal() == 0) { + // Append a sentinel register 0 to the Ops vector to represent a masked + // off element, this will be handled in tablegen + Ops.push_back(DAG.getRegister(MCRegister::NoRegister, + ValVT.getVectorElementType())); + } else { + // Extract the element from the vector to store + SDValue ExtVal = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(), + Val, DAG.getIntPtrConstant(I, DL)); + Ops.push_back(ExtVal); + } + } + + // Next, the pointer operand. + Ops.push_back(BasePtr); + + // Finally, the offset operand. We expect this to always be undef, and it will + // be ignored in lowering, but to mirror the handling of the other vector + // store instructions we include it in the new SDNode. + assert(Offset.getOpcode() == ISD::UNDEF && + "Offset operand expected to be undef"); + Ops.push_back(Offset); + + SDValue NewSt = + DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops, + MemSD->getMemoryVT(), MemSD->getMemOperand()); + + return NewSt; +} + SDValue NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -3217,8 +3250,16 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerVECREDUCE(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); + case ISD::MSTORE: { + assert(STI.has256BitVectorLoadStore( + cast<MemSDNode>(Op.getNode())->getAddressSpace()) && + "Masked store vector not supported on subtarget."); + return lowerMSTORE(Op, DAG); + } case ISD::LOAD: return LowerLOAD(Op, DAG); + case ISD::MLOAD: + return LowerMLOAD(Op, DAG); case ISD::SHL_PARTS: return LowerShiftLeftParts(Op, DAG); case ISD::SRA_PARTS: @@ -3282,7 +3323,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return lowerCTLZCTPOP(Op, DAG); case ISD::FREM: return lowerFREM(Op, DAG); - + case ISD::BSWAP: + return lowerBSWAP(Op, DAG); default: llvm_unreachable("Custom lowering not defined for operation"); } @@ -3313,7 +3355,7 @@ SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const { // Generate BrxEnd nodes SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index, IdV, Chain.getValue(1)}; - SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps); + SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, MVT::Other, EndOps); return BrxEnd; } @@ -3410,10 +3452,62 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { MachinePointerInfo(SV)); } +static std::pair<MemSDNode *, uint32_t> +convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG, + const NVPTXSubtarget &STI) { + SDValue Chain = N->getOperand(0); + SDValue BasePtr = N->getOperand(1); + SDValue Mask = N->getOperand(3); + [[maybe_unused]] SDValue Passthru = N->getOperand(4); + + SDLoc DL(N); + EVT ResVT = N->getValueType(0); + assert(ResVT.isVector() && "Masked vector load must have vector type"); + // While we only expect poison passthru vectors as an input to the backend, + // when the legalization framework splits a poison vector in half, it creates + // two undef vectors, so we can technically expect those too. + assert((Passthru.getOpcode() == ISD::POISON || + Passthru.getOpcode() == ISD::UNDEF) && + "Passthru operand expected to be poison or undef"); + + // Extract the mask and convert it to a uint32_t representing the used bytes + // of the entire vector load + uint32_t UsedBytesMask = 0; + uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits(); + assert(ElementSizeInBits % 8 == 0 && "Unexpected element size"); + uint32_t ElementSizeInBytes = ElementSizeInBits / 8; + uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u; + + for (SDValue Op : reverse(Mask->ops())) { + // We technically only want to do this shift for every + // iteration *but* the first, but in the first iteration UsedBytesMask is 0, + // so this shift is a no-op. + UsedBytesMask <<= ElementSizeInBytes; + + // Mask elements must be constants. + if (Op->getAsZExtVal() != 0) + UsedBytesMask |= ElementMask; + } + + assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX && + "Unexpected masked load with elements masked all on or all off"); + + // Create a new load sd node to be handled normally by ReplaceLoadVector. + MemSDNode *NewLD = cast<MemSDNode>( + DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode()); + + // If our subtarget does not support the used bytes mask pragma, "drop" the + // mask by setting it to UINT32_MAX + if (!STI.hasUsedBytesMaskPragma()) + UsedBytesMask = UINT32_MAX; + + return {NewLD, UsedBytesMask}; +} + /// replaceLoadVector - Convert vector loads into multi-output scalar loads. static std::optional<std::pair<SDValue, SDValue>> replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { - LoadSDNode *LD = cast<LoadSDNode>(N); + MemSDNode *LD = cast<MemSDNode>(N); const EVT ResVT = LD->getValueType(0); const EVT MemVT = LD->getMemoryVT(); @@ -3440,6 +3534,12 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { return std::nullopt; } + // If we have a masked load, convert it to a normal load now + std::optional<uint32_t> UsedBytesMask = std::nullopt; + if (LD->getOpcode() == ISD::MLOAD) + std::tie(LD, UsedBytesMask) = + convertMLOADToLoadWithUsedBytesMask(LD, DAG, STI); + // Since LoadV2 is a target node, we cannot rely on DAG type legalization. // Therefore, we must ensure the type is legal. For i1 and i8, we set the // loaded type to i16 and propagate the "real" type as the memory type. @@ -3468,9 +3568,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { // Copy regular operands SmallVector<SDValue, 8> OtherOps(LD->ops()); + OtherOps.push_back( + DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32)); + // The select routine does not have access to the LoadSDNode instance, so // pass along the extension information - OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); + OtherOps.push_back( + DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL)); SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT, LD->getMemOperand()); @@ -3558,6 +3662,42 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { llvm_unreachable("Unexpected custom lowering for load"); } +SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { + // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle + // masked loads of these types and have to handle them here. + // v2f32 also needs to be handled here if the subtarget has f32x2 + // instructions, making it legal. + // + // Note: misaligned masked loads should never reach this point + // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp + // will validate alignment. Therefore, we do not need to special case handle + // them here. + EVT VT = Op.getValueType(); + if (NVPTX::isPackedVectorTy(VT)) { + auto Result = convertMLOADToLoadWithUsedBytesMask( + cast<MemSDNode>(Op.getNode()), DAG, STI); + MemSDNode *LD = std::get<0>(Result); + uint32_t UsedBytesMask = std::get<1>(Result); + + SDLoc DL(LD); + + // Copy regular operands + SmallVector<SDValue, 8> OtherOps(LD->ops()); + + OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32)); + + // We currently are not lowering extending loads, but pass the extension + // type anyway as later handling expects it. + OtherOps.push_back( + DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL)); + SDValue NewLD = + DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps, + LD->getMemoryVT(), LD->getMemOperand()); + return NewLD; + } + return SDValue(); +} + static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG, const NVPTXSubtarget &STI) { MemSDNode *N = cast<MemSDNode>(Op.getNode()); @@ -3944,9 +4084,10 @@ void NVPTXTargetLowering::LowerAsmOperandForConstraint( // because we need the information that is only available in the "Value" type // of destination // pointer. In particular, the address space information. -bool NVPTXTargetLowering::getTgtMemIntrinsic( - IntrinsicInfo &Info, const CallInst &I, - MachineFunction &MF, unsigned Intrinsic) const { +bool NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, + const CallBase &I, + MachineFunction &MF, + unsigned Intrinsic) const { switch (Intrinsic) { default: return false; @@ -5420,8 +5561,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { if (!NVPTX::isPackedVectorTy(ElementVT) || ElementVT == MVT::v4i8) return SDValue(); - SmallVector<SDNode *> DeadCopyToRegs; - // Check whether all outputs are either used by an extractelt or are // glue/chain nodes if (!all_of(N->uses(), [&](SDUse &U) { @@ -5458,7 +5597,7 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(LD); // the new opcode after we double the number of operands - NVPTXISD::NodeType Opcode; + unsigned Opcode; SmallVector<SDValue> Operands(LD->ops()); unsigned OldNumOutputs; // non-glue, non-chain outputs switch (LD->getOpcode()) { @@ -5468,6 +5607,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it // here. Opcode = NVPTXISD::LoadV2; + // append a "full" used bytes mask operand right before the extension type + // operand, signifying that all bytes are used. + Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32)); Operands.push_back(DCI.DAG.getIntPtrConstant( cast<LoadSDNode>(LD)->getExtensionType(), DL)); break; @@ -5476,9 +5618,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { Opcode = NVPTXISD::LoadV4; break; case NVPTXISD::LoadV4: - // V8 is only supported for f32. Don't forget, we're not changing the load - // size here. This is already a 256-bit load. - if (ElementVT != MVT::v2f32) + // V8 is only supported for f32/i32. Don't forget, we're not changing the + // load size here. This is already a 256-bit load. + if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32) return SDValue(); OldNumOutputs = 4; Opcode = NVPTXISD::LoadV8; @@ -5541,7 +5683,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N, auto *ST = cast<MemSDNode>(N); // The new opcode after we double the number of operands. - NVPTXISD::NodeType Opcode; + unsigned Opcode; switch (N->getOpcode()) { case ISD::STORE: // Any packed type is legal, so the legalizer will not have lowered @@ -5553,9 +5695,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N, Opcode = NVPTXISD::StoreV4; break; case NVPTXISD::StoreV4: - // V8 is only supported for f32. Don't forget, we're not changing the store - // size here. This is already a 256-bit store. - if (ElementVT != MVT::v2f32) + // V8 is only supported for f32/i32. Don't forget, we're not changing the + // store size here. This is already a 256-bit store. + if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32) return SDValue(); Opcode = NVPTXISD::StoreV8; break; @@ -5676,7 +5818,7 @@ static SDValue PerformFADDCombine(SDNode *N, } /// Get 3-input version of a 2-input min/max opcode -static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) { +static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) { switch (MinMax2Opcode) { case ISD::FMAXNUM: case ISD::FMAXIMUMNUM: @@ -5707,7 +5849,7 @@ static SDValue PerformFMinMaxCombine(SDNode *N, SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); unsigned MinMaxOp2 = N->getOpcode(); - NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2); + unsigned MinMaxOp3 = getMinMax3Opcode(MinMaxOp2); if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) { // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c) @@ -6706,6 +6848,7 @@ void NVPTXTargetLowering::ReplaceNodeResults( ReplaceBITCAST(N, DAG, Results); return; case ISD::LOAD: + case ISD::MLOAD: replaceLoadVector(N, DAG, Results, STI); return; case ISD::INTRINSIC_W_CHAIN: |
