diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r-- | llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h | 2 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXTargetStreamer.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 277 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 5 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 670 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 10 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 409 |
8 files changed, 410 insertions, 970 deletions
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp index 614b321..ce9cd12 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp @@ -15,8 +15,6 @@ using namespace llvm; -void NVPTXMCAsmInfo::anchor() {} - NVPTXMCAsmInfo::NVPTXMCAsmInfo(const Triple &TheTriple, const MCTargetOptions &Options) { if (TheTriple.getArch() == Triple::nvptx64) { diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h index 77c4dae..f071406 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h @@ -19,8 +19,6 @@ namespace llvm { class Triple; class NVPTXMCAsmInfo : public MCAsmInfo { - virtual void anchor(); - public: explicit NVPTXMCAsmInfo(const Triple &TheTriple, const MCTargetOptions &Options); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXTargetStreamer.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXTargetStreamer.cpp index 9f91143..329e3b5 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXTargetStreamer.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXTargetStreamer.cpp @@ -97,10 +97,7 @@ void NVPTXTargetStreamer::changeSection(const MCSection *CurSection, if (isDwarfSection(FI, Section)) { // Emit DWARF .file directives in the outermost scope. outputDwarfFileDirectives(); - OS << "\t.section"; - Section->printSwitchToSection(*getStreamer().getContext().getAsmInfo(), - getStreamer().getContext().getTargetTriple(), - OS, SubSection); + OS << "\t.section\t" << Section->getName() << '\n'; // DWARF sections are enclosed into braces - emit the open one. OS << "\t{\n"; HasSections = true; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 65e7c56..95abcde 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -56,9 +56,7 @@ INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false) NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, CodeGenOptLevel OptLevel) - : SelectionDAGISel(tm, OptLevel), TM(tm) { - doMulWide = (OptLevel > CodeGenOptLevel::None); -} + : SelectionDAGISel(tm, OptLevel), TM(tm) {} bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) { Subtarget = &MF.getSubtarget<NVPTXSubtarget>(); @@ -145,18 +143,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { if (tryStoreVector(N)) return; break; - case NVPTXISD::LoadParam: - case NVPTXISD::LoadParamV2: - case NVPTXISD::LoadParamV4: - if (tryLoadParam(N)) - return; - break; - case NVPTXISD::StoreParam: - case NVPTXISD::StoreParamV2: - case NVPTXISD::StoreParamV4: - if (tryStoreParam(N)) - return; - break; case ISD::INTRINSIC_W_CHAIN: if (tryIntrinsicChain(N)) return; @@ -1462,267 +1448,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { return true; } -bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { - SDValue Chain = Node->getOperand(0); - SDValue Offset = Node->getOperand(2); - SDValue Glue = Node->getOperand(3); - SDLoc DL(Node); - MemSDNode *Mem = cast<MemSDNode>(Node); - - unsigned VecSize; - switch (Node->getOpcode()) { - default: - return false; - case NVPTXISD::LoadParam: - VecSize = 1; - break; - case NVPTXISD::LoadParamV2: - VecSize = 2; - break; - case NVPTXISD::LoadParamV4: - VecSize = 4; - break; - } - - EVT EltVT = Node->getValueType(0); - EVT MemVT = Mem->getMemoryVT(); - - std::optional<unsigned> Opcode; - - switch (VecSize) { - default: - return false; - case 1: - Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, - NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16, - NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64); - break; - case 2: - Opcode = - pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8, - NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32, - NVPTX::LoadParamMemV2I64); - break; - case 4: - Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, - NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16, - NVPTX::LoadParamMemV4I32, {/* no v4i64 */}); - break; - } - if (!Opcode) - return false; - - SDVTList VTs; - if (VecSize == 1) { - VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue); - } else if (VecSize == 2) { - VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue); - } else { - EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue }; - VTs = CurDAG->getVTList(EVTs); - } - - unsigned OffsetVal = Offset->getAsZExtVal(); - - SmallVector<SDValue, 2> Ops( - {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue}); - - ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops)); - return true; -} - -// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri) -#define getOpcV2H(ty, opKind0, opKind1) \ - NVPTX::StoreParamV2##ty##_##opKind0##opKind1 - -#define getOpcV2H1(ty, opKind0, isImm1) \ - (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r) - -#define getOpcodeForVectorStParamV2(ty, isimm) \ - (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1]) - -#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \ - NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3 - -#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \ - (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \ - : getOpcV4H(ty, opKind0, opKind1, opKind2, r) - -#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \ - (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \ - : getOpcV4H3(ty, opKind0, opKind1, r, isImm3) - -#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \ - (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \ - : getOpcV4H2(ty, opKind0, r, isImm2, isImm3) - -#define getOpcodeForVectorStParamV4(ty, isimm) \ - (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \ - : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3]) - -#define getOpcodeForVectorStParam(n, ty, isimm) \ - (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \ - : getOpcodeForVectorStParamV4(ty, isimm) - -static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops, - unsigned NumElts, - MVT::SimpleValueType MemTy, - SelectionDAG *CurDAG, SDLoc DL) { - // Determine which inputs are registers and immediates make new operators - // with constant values - SmallVector<bool, 4> IsImm(NumElts, false); - for (unsigned i = 0; i < NumElts; i++) { - IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i])); - if (IsImm[i]) { - SDValue Imm = Ops[i]; - if (MemTy == MVT::f32 || MemTy == MVT::f64) { - const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm); - const ConstantFP *CF = ConstImm->getConstantFPValue(); - Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0)); - } else { - const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm); - const ConstantInt *CI = ConstImm->getConstantIntValue(); - Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0)); - } - Ops[i] = Imm; - } - } - - // Get opcode for MemTy, size, and register/immediate operand ordering - switch (MemTy) { - case MVT::i8: - return getOpcodeForVectorStParam(NumElts, I8, IsImm); - case MVT::i16: - return getOpcodeForVectorStParam(NumElts, I16, IsImm); - case MVT::i32: - return getOpcodeForVectorStParam(NumElts, I32, IsImm); - case MVT::i64: - assert(NumElts == 2 && "MVT too large for NumElts > 2"); - return getOpcodeForVectorStParamV2(I64, IsImm); - case MVT::f32: - return getOpcodeForVectorStParam(NumElts, F32, IsImm); - case MVT::f64: - assert(NumElts == 2 && "MVT too large for NumElts > 2"); - return getOpcodeForVectorStParamV2(F64, IsImm); - - // These cases don't support immediates, just use the all register version - // and generate moves. - case MVT::i1: - return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr - : NVPTX::StoreParamV4I8_rrrr; - case MVT::f16: - case MVT::bf16: - return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr - : NVPTX::StoreParamV4I16_rrrr; - case MVT::v2f16: - case MVT::v2bf16: - case MVT::v2i16: - case MVT::v4i8: - return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr - : NVPTX::StoreParamV4I32_rrrr; - default: - llvm_unreachable("Cannot select st.param for unknown MemTy"); - } -} - -bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { - SDLoc DL(N); - SDValue Chain = N->getOperand(0); - SDValue Param = N->getOperand(1); - unsigned ParamVal = Param->getAsZExtVal(); - SDValue Offset = N->getOperand(2); - unsigned OffsetVal = Offset->getAsZExtVal(); - MemSDNode *Mem = cast<MemSDNode>(N); - SDValue Glue = N->getOperand(N->getNumOperands() - 1); - - // How many elements do we have? - unsigned NumElts; - switch (N->getOpcode()) { - default: - llvm_unreachable("Unexpected opcode"); - case NVPTXISD::StoreParam: - NumElts = 1; - break; - case NVPTXISD::StoreParamV2: - NumElts = 2; - break; - case NVPTXISD::StoreParamV4: - NumElts = 4; - break; - } - - // Build vector of operands - SmallVector<SDValue, 8> Ops; - for (unsigned i = 0; i < NumElts; ++i) - Ops.push_back(N->getOperand(i + 3)); - Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32), - CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue}); - - // Determine target opcode - // If we have an i1, use an 8-bit store. The lowering code in - // NVPTXISelLowering will have already emitted an upcast. - std::optional<unsigned> Opcode; - switch (NumElts) { - default: - llvm_unreachable("Unexpected NumElts"); - case 1: { - MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy; - SDValue Imm = Ops[0]; - if (MemTy != MVT::f16 && MemTy != MVT::bf16 && - (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) { - // Convert immediate to target constant - if (MemTy == MVT::f32 || MemTy == MVT::f64) { - const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm); - const ConstantFP *CF = ConstImm->getConstantFPValue(); - Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0)); - } else { - const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm); - const ConstantInt *CI = ConstImm->getConstantIntValue(); - Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0)); - } - Ops[0] = Imm; - // Use immediate version of store param - Opcode = - pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i, - NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i); - } else - Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy, - NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r, - NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r); - if (Opcode == NVPTX::StoreParamI8_r) { - // Fine tune the opcode depending on the size of the operand. - // This helps to avoid creating redundant COPY instructions in - // InstrEmitter::AddRegisterOperand(). - switch (Ops[0].getSimpleValueType().SimpleTy) { - default: - break; - case MVT::i32: - Opcode = NVPTX::StoreParamI8TruncI32_r; - break; - case MVT::i64: - Opcode = NVPTX::StoreParamI8TruncI64_r; - break; - } - } - break; - } - case 2: - case 4: { - MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy; - Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL); - break; - } - } - - SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue); - SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops); - MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand(); - CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef}); - - ReplaceNode(N, Ret); - return true; -} - /// SelectBFE - Look for instruction sequences that can be made more efficient /// by using the 'bfe' (bit-field extract) PTX instruction bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index b99b4ef..9e0f88e5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -40,9 +40,6 @@ private: class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { const NVPTXTargetMachine &TM; - // If true, generate mul.wide from sext and mul - bool doMulWide; - NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const; bool usePrecSqrtF32(const SDNode *N) const; bool useF32FTZ() const; @@ -78,8 +75,6 @@ private: bool tryLDG(MemSDNode *N); bool tryStore(SDNode *N); bool tryStoreVector(SDNode *N); - bool tryLoadParam(SDNode *N); - bool tryStoreParam(SDNode *N); bool tryFence(SDNode *N); void SelectAddrSpaceCast(SDNode *N); bool tryBFE(SDNode *N); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index ddcecc00..4fd3623 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -843,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE}); + ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -1075,12 +1075,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::DeclareArrayParam) MAKE_CASE(NVPTXISD::DeclareScalarParam) MAKE_CASE(NVPTXISD::CALL) - MAKE_CASE(NVPTXISD::LoadParam) - MAKE_CASE(NVPTXISD::LoadParamV2) - MAKE_CASE(NVPTXISD::LoadParamV4) - MAKE_CASE(NVPTXISD::StoreParam) - MAKE_CASE(NVPTXISD::StoreParamV2) - MAKE_CASE(NVPTXISD::StoreParamV4) MAKE_CASE(NVPTXISD::MoveParam) MAKE_CASE(NVPTXISD::UNPACK_VECTOR) MAKE_CASE(NVPTXISD::BUILD_VECTOR) @@ -1318,105 +1312,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, return DL.getABITypeAlign(Ty); } -static bool adjustElementType(EVT &ElementType) { - switch (ElementType.getSimpleVT().SimpleTy) { - default: - return false; - case MVT::f16: - case MVT::bf16: - ElementType = MVT::i16; - return true; - case MVT::f32: - case MVT::v2f16: - case MVT::v2bf16: - ElementType = MVT::i32; - return true; - case MVT::f64: - ElementType = MVT::i64; - return true; - } -} - -// Use byte-store when the param address of the argument value is unaligned. -// This may happen when the return value is a field of a packed structure. -// -// This is called in LowerCall() when passing the param values. -static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain, - uint64_t Offset, EVT ElementType, - SDValue StVal, SDValue &InGlue, - unsigned ArgID, const SDLoc &dl) { - // Bit logic only works on integer types - if (adjustElementType(ElementType)) - StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal); - - // Store each byte - SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue); - for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { - // Shift the byte to the last byte position - SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal, - DAG.getConstant(i * 8, dl, MVT::i32)); - SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32), - DAG.getConstant(Offset + i, dl, MVT::i32), - ShiftVal, InGlue}; - // Trunc store only the last byte by using - // st.param.b8 - // The register type can be larger than b8. - Chain = DAG.getMemIntrinsicNode( - NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8, - MachinePointerInfo(), Align(1), MachineMemOperand::MOStore); - InGlue = Chain.getValue(1); - } - return Chain; -} - -// Use byte-load when the param adress of the returned value is unaligned. -// This may happen when the returned value is a field of a packed structure. -static SDValue -LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset, - EVT ElementType, SDValue &InGlue, - SmallVectorImpl<SDValue> &TempProxyRegOps, - const SDLoc &dl) { - // Bit logic only works on integer types - EVT MergedType = ElementType; - adjustElementType(MergedType); - - // Load each byte and construct the whole value. Initial value to 0 - SDValue RetVal = DAG.getConstant(0, dl, MergedType); - // LoadParamMemI8 loads into i16 register only - SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue); - for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) { - SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(Offset + i, dl, MVT::i32), - InGlue}; - // This will be selected to LoadParamMemI8 - SDValue LdVal = - DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands, - MVT::i8, MachinePointerInfo(), Align(1)); - SDValue TmpLdVal = LdVal.getValue(0); - Chain = LdVal.getValue(1); - InGlue = LdVal.getValue(2); - - TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl, - TmpLdVal.getSimpleValueType(), TmpLdVal); - TempProxyRegOps.push_back(TmpLdVal); - - SDValue CMask = DAG.getConstant(255, dl, MergedType); - SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32); - // Need to extend the i16 register to the whole width. - TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal); - // Mask off the high bits. Leave only the lower 8bits. - // Do this because we are using loadparam.b8. - TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask); - // Shift and merge - TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift); - RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal); - } - if (ElementType != MergedType) - RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal); - - return RetVal; -} - static bool shouldConvertToIndirectCall(const CallBase *CB, const GlobalAddressSDNode *Func) { if (!Func) @@ -1483,10 +1378,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SelectionDAG &DAG = CLI.DAG; SDLoc dl = CLI.DL; - SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; - SDValue Chain = CLI.Chain; + const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; SDValue Callee = CLI.Callee; - bool &isTailCall = CLI.IsTailCall; ArgListTy &Args = CLI.getArgs(); Type *RetTy = CLI.RetTy; const CallBase *CB = CLI.CB; @@ -1496,6 +1389,36 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, return DAG.getConstant(I, dl, MVT::i32); }; + const unsigned UniqueCallSite = GlobalUniqueCallSite++; + const SDValue CallChain = CLI.Chain; + const SDValue StartChain = + DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl); + SDValue DeclareGlue = StartChain.getValue(1); + + SmallVector<SDValue, 16> CallPrereqs{StartChain}; + + const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) { + // PTX ABI requires integral types to be at least 32 bits in size. FP16 is + // loaded/stored using i16, so it's handled here as well. + const unsigned SizeBits = promoteScalarArgumentSize(Size * 8); + SDValue Declare = + DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue}, + {StartChain, Symbol, GetI32(SizeBits), DeclareGlue}); + CallPrereqs.push_back(Declare); + DeclareGlue = Declare.getValue(1); + return Declare; + }; + + const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align, + unsigned Size) { + SDValue Declare = DAG.getNode( + NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue}, + {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue}); + CallPrereqs.push_back(Declare); + DeclareGlue = Declare.getValue(1); + return Declare; + }; + // Variadic arguments. // // Normally, for each argument, we declare a param scalar or a param @@ -1511,15 +1434,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // // After all vararg is processed, 'VAOffset' holds the size of the // vararg byte array. + assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) && + "Non-VarArg function with extra arguments"); - SDValue VADeclareParam; // vararg byte array const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic - unsigned VAOffset = 0; // current offset in the param array + unsigned VAOffset = 0; // current offset in the param array - const unsigned UniqueCallSite = GlobalUniqueCallSite++; - SDValue TempChain = Chain; - Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl); - SDValue InGlue = Chain.getValue(1); + const SDValue VADeclareParam = + CLI.Args.size() > FirstVAArg + ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32), + Align(STI.getMaxRequiredAlignment()), 0) + : SDValue(); // Args.size() and Outs.size() need not match. // Outs.size() will be larger @@ -1580,43 +1505,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) && "type size mismatch"); - const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> { - if (IsVAArg) { - if (ArgI == FirstVAArg) { - VADeclareParam = DAG.getNode( - NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()), - GetI32(0), InGlue}); - return VADeclareParam; - } - return std::nullopt; - } - if (IsByVal || shouldPassAsArray(Arg.Ty)) { - // declare .param .align <align> .b8 .param<n>[<size>]; - return DAG.getNode(NVPTXISD::DeclareArrayParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(ArgAlign.value()), - GetI32(TypeSize), InGlue}); - } + const SDValue ArgDeclare = [&]() { + if (IsVAArg) + return VADeclareParam; + + if (IsByVal || shouldPassAsArray(Arg.Ty)) + return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize); + assert(ArgOuts.size() == 1 && "We must pass only one value as non-array"); - // declare .param .b<size> .param<n>; - - // PTX ABI requires integral types to be at least 32 bits in - // size. FP16 is loaded/stored using i16, so it's handled - // here as well. - const unsigned PromotedSize = - (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) - ? promoteScalarArgumentSize(TypeSize * 8) - : TypeSize * 8; - - return DAG.getNode(NVPTXISD::DeclareScalarParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, ParamSymbol, GetI32(PromotedSize), InGlue}); + assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) && + "Only int and float types are supported as non-array arguments"); + + return MakeDeclareScalarParam(ParamSymbol, TypeSize); }(); - if (ArgDeclare) { - Chain = ArgDeclare->getValue(0); - InGlue = ArgDeclare->getValue(1); - } // PTX Interoperability Guide 3.3(A): [Integer] Values shorter // than 32-bits are sign extended or zero extended, depending on @@ -1626,36 +1527,25 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32; const auto GetStoredValue = [&](const unsigned I, EVT EltVT, - const Align PartAlign) { - SDValue StVal; + const MaybeAlign PartAlign) { if (IsByVal) { SDValue Ptr = ArgOutVals[0]; auto MPI = refinePtrAS(Ptr, DAG, DL, *this); SDValue SrcAddr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I])); - StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign); - } else { - StVal = ArgOutVals[I]; - - auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType()); - if (PromotedVT != StVal.getValueType()) { - StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT, - StVal); - } + return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign); } + SDValue StVal = ArgOutVals[I]; + assert(promoteScalarIntegerPTX(StVal.getValueType()) == + StVal.getValueType() && + "OutVal type should always be legal"); - if (ExtendIntegerParam) { - assert(VTs.size() == 1 && "Scalar can't have multiple parts."); - // zext/sext to i32 - StVal = - DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal); - } else if (EltVT.getSizeInBits() < 16) { - // Use 16-bit registers for small stores as it's the - // smallest general purpose register size supported by NVPTX. - StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); - } - return StVal; + const EVT VTI = promoteScalarIntegerPTX(VTs[I]); + const EVT StoreVT = + ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); + + return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl); }; const auto VectorInfo = @@ -1664,23 +1554,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, unsigned J = 0; for (const unsigned NumElts : VectorInfo) { const int CurOffset = Offsets[J]; - EVT EltVT = promoteScalarIntegerPTX(VTs[J]); - const Align PartAlign = commonAlignment(ArgAlign, CurOffset); - - // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a - // scalar store. In such cases, fall back to byte stores. - if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) { - - SDValue StVal = GetStoredValue(J, EltVT, PartAlign); - Chain = LowerUnalignedStoreParam(DAG, Chain, - CurOffset + (IsByVal ? VAOffset : 0), - EltVT, StVal, InGlue, ArgI, dl); - - // LowerUnalignedStoreParam took care of inserting the necessary nodes - // into the SDAG, so just move on to the next element. - J++; - continue; - } + const EVT EltVT = promoteScalarIntegerPTX(VTs[J]); if (IsVAArg && !IsByVal) // Align each part of the variadic argument to their type. @@ -1688,44 +1562,45 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, assert((IsVAArg || VAOffset == 0) && "VAOffset must be 0 for non-VA args"); - SmallVector<SDValue, 6> StoreOperands{ - Chain, GetI32(IsVAArg ? FirstVAArg : ArgI), - GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))}; - // Record the values to store. - for (const unsigned K : llvm::seq(NumElts)) - StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign)); - StoreOperands.push_back(InGlue); + const unsigned Offset = + (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset)); + SDValue Ptr = + DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset)); - NVPTXISD::NodeType Op; - switch (NumElts) { - case 1: - Op = NVPTXISD::StoreParam; - break; - case 2: - Op = NVPTXISD::StoreParamV2; - break; - case 4: - Op = NVPTXISD::StoreParamV4; - break; - default: - llvm_unreachable("Invalid vector info."); + const MaybeAlign CurrentAlign = ExtendIntegerParam + ? MaybeAlign(std::nullopt) + : commonAlignment(ArgAlign, Offset); + + SDValue Val; + if (NumElts == 1) { + Val = GetStoredValue(J, EltVT, CurrentAlign); + } else { + SmallVector<SDValue, 8> StoreVals; + for (const unsigned K : llvm::seq(NumElts)) { + SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign); + if (ValJ.getValueType().isVector()) + DAG.ExtractVectorElements(ValJ, StoreVals); + else + StoreVals.push_back(ValJ); + } + + EVT VT = EVT::getVectorVT( + *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size()); + Val = DAG.getBuildVector(VT, dl, StoreVals); } - // Adjust type of the store op if we've extended the scalar - // return value. - EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; - Chain = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, - TheStoreType, MachinePointerInfo(), PartAlign, - MachineMemOperand::MOStore); - InGlue = Chain.getValue(1); + SDValue StoreParam = + DAG.getStore(ArgDeclare, dl, Val, Ptr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); + CallPrereqs.push_back(StoreParam); // TODO: We may need to support vector types that can be passed // as scalars in variadic arguments. if (IsVAArg && !IsByVal) { assert(NumElts == 1 && "Vectorization is expected to be disabled for variadics."); + const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT; VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext())); } @@ -1736,33 +1611,21 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, VAOffset += TypeSize; } - GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); - // Handle Result if (!Ins.empty()) { - const SDValue RetDeclare = [&]() { - const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); - const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy); - if (shouldPassAsArray(RetTy)) { - const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); - return DAG.getNode(NVPTXISD::DeclareArrayParam, dl, - {MVT::Other, MVT::Glue}, - {Chain, RetSymbol, GetI32(RetAlign.value()), - GetI32(ResultSize / 8), InGlue}); - } - const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize); - return DAG.getNode( - NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue}, - {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue}); - }(); - Chain = RetDeclare.getValue(0); - InGlue = RetDeclare.getValue(1); + const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); + const unsigned ResultSize = DL.getTypeAllocSize(RetTy); + if (shouldPassAsArray(RetTy)) { + const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL); + MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize); + } else { + MakeDeclareScalarParam(RetSymbol, ResultSize); + } } - const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs); // Set the size of the vararg param byte array if the callee is a variadic // function and the variadic part is not empty. - if (HasVAArgs) { + if (VADeclareParam) { SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0), VADeclareParam.getOperand(1), VADeclareParam.getOperand(2), GetI32(VAOffset), @@ -1771,6 +1634,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, VADeclareParam->getVTList(), DeclareParamOps); } + const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); // If the type of the callsite does not match that of the function, convert // the callsite to an indirect call. const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func); @@ -1800,15 +1664,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // instruction. // The prototype is embedded in a string and put as the operand for a // CallPrototype SDNode which will print out to the value of the string. + const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs); std::string Proto = getPrototype(DL, RetTy, Args, CLI.Outs, HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB, UniqueCallSite); const char *ProtoStr = nvTM->getStrPool().save(Proto).data(); - Chain = DAG.getNode( - NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue}, - {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue}); - InGlue = Chain.getValue(1); + const SDValue PrototypeDeclare = DAG.getNode( + NVPTXISD::CallPrototype, dl, MVT::Other, + {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)}); + CallPrereqs.push_back(PrototypeDeclare); } if (ConvertToIndirectCall) { @@ -1826,24 +1691,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const unsigned NumArgs = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size()); /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, - /// NumParams, Callee, Proto, InGlue) - Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue}, - {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall), - GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, - GetI32(Proto), InGlue}); - InGlue = Chain.getValue(1); - + /// NumParams, Callee, Proto) + const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs); + const SDValue Call = DAG.getNode( + NVPTXISD::CALL, dl, MVT::Other, + {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall), + GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)}); + + SmallVector<SDValue, 16> LoadChains{Call}; SmallVector<SDValue, 16> ProxyRegOps; - // An item of the vector is filled if the element does not need a ProxyReg - // operation on it and should be added to InVals as is. ProxyRegOps and - // ProxyRegTruncates contain empty/none items at the same index. - SmallVector<SDValue, 16> RetElts; - // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()` - // to use the values of `LoadParam`s and to be replaced later then - // `CALLSEQ_END` is added. - SmallVector<SDValue, 16> TempProxyRegOps; - - // Generate loads from param memory/moves from registers for result if (!Ins.empty()) { SmallVector<EVT, 16> VTs; SmallVector<uint64_t, 16> Offsets; @@ -1860,104 +1716,65 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); unsigned I = 0; - for (const unsigned VectorizedSize : VectorInfo) { - EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]); - EVT EltType = Ins[I].VT; - const Align EltAlign = commonAlignment(RetAlign, Offsets[I]); - - if (TheLoadType != VTs[I]) - EltType = TheLoadType; - - if (ExtendIntegerRetVal) { - TheLoadType = MVT::i32; - EltType = MVT::i32; - } else if (TheLoadType.getSizeInBits() < 16) { - EltType = MVT::i16; - } + for (const unsigned NumElts : VectorInfo) { + const MaybeAlign CurrentAlign = + ExtendIntegerRetVal ? MaybeAlign(std::nullopt) + : commonAlignment(RetAlign, Offsets[I]); - // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a - // scalar load. In such cases, fall back to byte loads. - if (VectorizedSize == 1 && RetTy->isAggregateType() && - EltAlign < DAG.getEVTAlign(TheLoadType)) { - SDValue Ret = LowerUnalignedLoadRetParam( - DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl); - ProxyRegOps.push_back(SDValue()); - RetElts.resize(I); - RetElts.push_back(Ret); - - I++; - continue; - } + const EVT VTI = promoteScalarIntegerPTX(VTs[I]); + const EVT LoadVT = + ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI); - SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType); - LoadVTs.append({MVT::Other, MVT::Glue}); + const unsigned PackingAmt = + LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1; - NVPTXISD::NodeType Op; - switch (VectorizedSize) { - case 1: - Op = NVPTXISD::LoadParam; - break; - case 2: - Op = NVPTXISD::LoadParamV2; - break; - case 4: - Op = NVPTXISD::LoadParamV4; - break; - default: - llvm_unreachable("Invalid vector info."); - } + const EVT VecVT = NumElts == 1 ? LoadVT + : EVT::getVectorVT(*DAG.getContext(), + LoadVT.getScalarType(), + NumElts * PackingAmt); - SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue}; - SDValue RetVal = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType, - MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad); + const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32); + SDValue Ptr = + DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I])); - for (const unsigned J : llvm::seq(VectorizedSize)) { - ProxyRegOps.push_back(RetVal.getValue(J)); - } + SDValue R = + DAG.getLoad(VecVT, dl, Call, Ptr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign); - Chain = RetVal.getValue(VectorizedSize); - InGlue = RetVal.getValue(VectorizedSize + 1); + LoadChains.push_back(R.getValue(1)); - I += VectorizedSize; + if (NumElts == 1) + ProxyRegOps.push_back(R); + else + for (const unsigned J : llvm::seq(NumElts)) { + SDValue Elt = DAG.getNode( + LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR + : ISD::EXTRACT_VECTOR_ELT, + dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl)); + ProxyRegOps.push_back(Elt); + } + I += NumElts; } } - Chain = - DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl); - InGlue = Chain.getValue(1); + const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains); + const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite, + UniqueCallSite + 1, SDValue(), dl); // Append ProxyReg instructions to the chain to make sure that `callseq_end` // will not get lost. Otherwise, during libcalls expansion, the nodes can become // dangling. - for (const unsigned I : llvm::seq(ProxyRegOps.size())) { - if (I < RetElts.size() && RetElts[I]) { - InVals.push_back(RetElts[I]); - continue; - } - - SDValue Ret = - DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(), - {Chain, ProxyRegOps[I]}); - - const EVT ExpectedVT = Ins[I].VT; - if (!Ret.getValueType().bitsEq(ExpectedVT)) { - Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret); - } + for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) { + SDValue Proxy = + DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg}); + SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl); InVals.push_back(Ret); } - for (SDValue &T : TempProxyRegOps) { - SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(), - {Chain, T.getOperand(0)}); - DAG.ReplaceAllUsesWith(T, Repl); - DAG.RemoveDeadNode(T.getNode()); - } - - // set isTailCall to false for now, until we figure out how to express + // set IsTailCall to false for now, until we figure out how to express // tail call optimization in PTX - isTailCall = false; - return Chain; + CLI.IsTailCall = false; + return CallEnd; } SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, @@ -5117,10 +4934,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { Operands.push_back(DCI.DAG.getIntPtrConstant( cast<LoadSDNode>(LD)->getExtensionType(), DL)); break; - case NVPTXISD::LoadParamV2: - OldNumOutputs = 2; - Opcode = NVPTXISD::LoadParamV4; - break; case NVPTXISD::LoadV2: OldNumOutputs = 2; Opcode = NVPTXISD::LoadV4; @@ -5201,12 +5014,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV2; break; - case NVPTXISD::StoreParam: - Opcode = NVPTXISD::StoreParamV2; - break; - case NVPTXISD::StoreParamV2: - Opcode = NVPTXISD::StoreParamV4; - break; case NVPTXISD::StoreV2: MemVT = ST->getMemoryVT(); Opcode = NVPTXISD::StoreV4; @@ -5218,7 +5025,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N, return SDValue(); Opcode = NVPTXISD::StoreV8; break; - case NVPTXISD::StoreParamV4: case NVPTXISD::StoreV8: // PTX doesn't support the next doubling of operands return SDValue(); @@ -5263,30 +5069,11 @@ static SDValue combinePackingMovIntoStore(SDNode *N, MemVT, ST->getMemOperand()); } -static SDValue PerformStoreCombineHelper(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - unsigned Front, unsigned Back) { - if (all_of(N->ops().drop_front(Front).drop_back(Back), - [](const SDUse &U) { return U.get()->isUndef(); })) - // Operand 0 is the previous value in the chain. Cannot return EntryToken - // as the previous value will become unused and eliminated later. - return N->getOperand(0); - - return combinePackingMovIntoStore(N, DCI, Front, Back); -} - static SDValue PerformStoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return combinePackingMovIntoStore(N, DCI, 1, 2); } -static SDValue PerformStoreParamCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { - // Operands from the 3rd to the 2nd last one are the values to be stored. - // {Chain, ArgID, Offset, Val, Glue} - return PerformStoreCombineHelper(N, DCI, 3, 1); -} - /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD. /// static SDValue PerformADDCombine(SDNode *N, @@ -5432,6 +5219,42 @@ static SDValue PerformREMCombine(SDNode *N, return SDValue(); } +// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y) +static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + CodeGenOptLevel OptLevel) { + if (OptLevel == CodeGenOptLevel::None) + return SDValue(); + + SDValue Op = N->getOperand(0); + if (!Op.hasOneUse()) + return SDValue(); + EVT ToVT = N->getValueType(0); + EVT FromVT = Op.getValueType(); + if (!((ToVT == MVT::i32 && FromVT == MVT::i16) || + (ToVT == MVT::i64 && FromVT == MVT::i32))) + return SDValue(); + if (!(Op.getOpcode() == ISD::MUL || + (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1))))) + return SDValue(); + + SDLoc DL(N); + unsigned ExtOpcode = N->getOpcode(); + unsigned Opcode = 0; + if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_SIGNED; + else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap()) + Opcode = NVPTXISD::MUL_WIDE_UNSIGNED; + else + return SDValue(); + SDValue RHS = Op.getOperand(1); + if (Op.getOpcode() == ISD::SHL) { + const auto ShiftAmt = Op.getConstantOperandVal(1); + const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt; + RHS = DCI.DAG.getConstant(MulVal, DL, ToVT); + } + return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS); +} + enum OperandSignedness { Signed = 0, Unsigned, @@ -5942,6 +5765,86 @@ static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, N->getConstantOperandAPInt(2), N->getConstantOperandVal(3)), SDLoc(N), N->getValueType(0)); + return SDValue(); +} + +// During call lowering we wrap the return values in a ProxyReg node which +// depend on the chain value produced by the completed call. This ensures that +// the full call is emitted in cases where libcalls are used to legalize +// operations. To improve the functioning of other DAG combines we pull all +// operations we can through one of these nodes, ensuring that the ProxyReg +// directly wraps a load. That is: +// +// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0))) +// +static SDValue sinkProxyReg(SDValue R, SDValue Chain, + TargetLowering::DAGCombinerInfo &DCI) { + switch (R.getOpcode()) { + case ISD::TRUNCATE: + case ISD::ANY_EXTEND: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + case ISD::BITCAST: { + if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI)) + return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V); + return SDValue(); + } + case ISD::SHL: + case ISD::SRL: + case ISD::SRA: + case ISD::OR: { + if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI)) + if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI)) + return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B); + return SDValue(); + } + case ISD::Constant: + return R; + case ISD::LOAD: + case NVPTXISD::LoadV2: + case NVPTXISD::LoadV4: { + return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(), + {Chain, R}); + } + case ISD::BUILD_VECTOR: { + if (DCI.isBeforeLegalize()) + return SDValue(); + + SmallVector<SDValue, 16> Ops; + for (auto &Op : R->ops()) { + SDValue V = sinkProxyReg(Op, Chain, DCI); + if (!V) + return SDValue(); + Ops.push_back(V); + } + return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops); + } + case ISD::EXTRACT_VECTOR_ELT: { + if (DCI.isBeforeLegalize()) + return SDValue(); + + if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI)) + return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R), + R.getValueType(), V, R.getOperand(1)); + return SDValue(); + } + default: + return SDValue(); + } +} + +static SDValue combineProxyReg(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + + SDValue Chain = N->getOperand(0); + SDValue Reg = N->getOperand(1); + + // If the ProxyReg is not wrapping a load, try to pull the operations through + // the ProxyReg. + if (Reg.getOpcode() != ISD::LOAD) { + if (SDValue V = sinkProxyReg(Reg, Chain, DCI)) + return V; + } return SDValue(); } @@ -5958,6 +5861,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return combineADDRSPACECAST(N, DCI); case ISD::AND: return PerformANDCombine(N, DCI); + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + return combineMulWide(N, DCI, OptLevel); case ISD::BUILD_VECTOR: return PerformBUILD_VECTORCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: @@ -5965,7 +5871,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::FADD: return PerformFADDCombine(N, DCI, OptLevel); case ISD::LOAD: - case NVPTXISD::LoadParamV2: case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: return combineUnpackingMovIntoLoad(N, DCI); @@ -5973,6 +5878,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformMULCombine(N, DCI, OptLevel); case NVPTXISD::PRMT: return combinePRMT(N, DCI, OptLevel); + case NVPTXISD::ProxyReg: + return combineProxyReg(N, DCI); case ISD::SETCC: return PerformSETCCCombine(N, DCI, STI.getSmVersion()); case ISD::SHL: @@ -5980,10 +5887,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::SREM: case ISD::UREM: return PerformREMCombine(N, DCI, OptLevel); - case NVPTXISD::StoreParam: - case NVPTXISD::StoreParamV2: - case NVPTXISD::StoreParamV4: - return PerformStoreParamCombine(N, DCI); case ISD::STORE: case NVPTXISD::StoreV2: case NVPTXISD::StoreV4: @@ -6332,6 +6235,22 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, Results.push_back(NewValue.getValue(3)); } +static void replaceProxyReg(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI, + SmallVectorImpl<SDValue> &Results) { + SDValue Chain = N->getOperand(0); + SDValue Reg = N->getOperand(1); + + MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType()); + + SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT); + SDValue NewProxy = + DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg}); + SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0)); + + Results.push_back(Res); +} + void NVPTXTargetLowering::ReplaceNodeResults( SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { switch (N->getOpcode()) { @@ -6349,6 +6268,9 @@ void NVPTXTargetLowering::ReplaceNodeResults( case ISD::CopyFromReg: ReplaceCopyFromReg_128(N, DAG, Results); return; + case NVPTXISD::ProxyReg: + replaceProxyReg(N, DAG, *this, Results); + return; } } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 228e2aa..cf72a1e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -38,7 +38,7 @@ enum NodeType : unsigned { /// This node represents a PTX call instruction. It's operands are as follows: /// /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, - /// NumParams, Callee, Proto, InGlue) + /// NumParams, Callee, Proto) CALL, MoveParam, @@ -84,13 +84,7 @@ enum NodeType : unsigned { StoreV2, StoreV4, StoreV8, - LoadParam, - LoadParamV2, - LoadParamV4, - StoreParam, - StoreParamV2, - StoreParamV4, - LAST_MEMORY_OPCODE = StoreParamV4, + LAST_MEMORY_OPCODE = StoreV8, }; } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 442b900..6000b40 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -125,8 +125,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">; def doNoF32FTZ : Predicate<"!useF32FTZ()">; def doRsqrtOpt : Predicate<"doRsqrtOpt()">; -def doMulWide : Predicate<"doMulWide">; - def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">; def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">; def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">; @@ -836,36 +834,28 @@ def MULWIDES64 : BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">; def MULWIDES64Imm : BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">; -def MULWIDES64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">; def MULWIDEU64 : BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">; def MULWIDEU64Imm : BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">; -def MULWIDEU64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">; def MULWIDES32 : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">; def MULWIDES32Imm : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">; -def MULWIDES32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">; def MULWIDEU32 : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">; def MULWIDEU32Imm : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">; -def MULWIDEU32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">; -def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>; -def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>; -def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>; +def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>; +def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; +def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; // Matchers for signed, unsigned mul.wide ISD nodes. -let Predicates = [doMulWide] in { +let Predicates = [hasOptEnabled] in { def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>; def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>; def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>; @@ -877,85 +867,6 @@ let Predicates = [doMulWide] in { def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>; } -// Predicates used for converting some patterns to mul.wide. -def SInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(32); -}]>; - -def UInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(32); -}]>; - -def SInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(16); -}]>; - -def UInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(16); -}]>; - -def IntConst_0_30 : PatLeaf<(imm), [{ - // Check if 0 <= v < 31; only then will the result of (x << v) be an int32. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(31); -}]>; - -def IntConst_0_14 : PatLeaf<(imm), [{ - // Check if 0 <= v < 15; only then will the result of (x << v) be an int16. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(15); -}]>; - -def SHL2MUL32 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(32, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32); -}]>; - -def SHL2MUL16 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(16, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16); -}]>; - -// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide. -let Predicates = [doMulWide] in { - def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDES64Imm $a, (SHL2MUL32 $b))>; - def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDEU64Imm $a, (SHL2MUL32 $b))>; - - def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDES32Imm $a, (SHL2MUL16 $b))>; - def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDEU32Imm $a, (SHL2MUL16 $b))>; - - // Convert "sign/zero-extend then multiply" to mul.wide. - def : Pat<(mul (sext i32:$a), (sext i32:$b)), - (MULWIDES64 $a, $b)>; - def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)), - (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>; - - def : Pat<(mul (zext i32:$a), (zext i32:$b)), - (MULWIDEU64 $a, $b)>; - def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)), - (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>; - - def : Pat<(mul (sext i16:$a), (sext i16:$b)), - (MULWIDES32 $a, $b)>; - def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)), - (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>; - - def : Pat<(mul (zext i16:$a), (zext i16:$b)), - (MULWIDEU32 $a, $b)>; - def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)), - (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>; -} - // // Integer multiply-add // @@ -991,6 +902,39 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>; defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>; } +multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> { + def rrr: + BasicNVPTXInst<(outs BigT.RC:$dst), + (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c), + "mad.wide." # PtxSuffix, + [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>; + def rri: + BasicNVPTXInst<(outs BigT.RC:$dst), + (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c), + "mad.wide." # PtxSuffix, + [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>; + def rir: + BasicNVPTXInst<(outs BigT.RC:$dst), + (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c), + "mad.wide." # PtxSuffix, + [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>; + def rii: + BasicNVPTXInst<(outs BigT.RC:$dst), + (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c), + "mad.wide." # PtxSuffix, + [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>; +} + +def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>; +def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>; + +let Predicates = [hasOptEnabled] in { +defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>; +defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>; +defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>; +defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>; +} + foreach t = [I16RT, I32RT, I64RT] in { def NEG_S # t.Size : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), @@ -1516,20 +1460,19 @@ def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtN // Byte extraction via shift/trunc/sext -def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), - (CVT_s8_s32 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), +def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), (CVT_s8_s32 $s, CvtNONE)>; +def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), (CVT_s8_s64 $s, CvtNONE)>; + +def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), (BFE_S64rii $s, imm:$o, 8)>; + +def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), (CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), - (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), + (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; + def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))), (CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), - (BFE_S64rii $s, imm:$o, 8)>; -def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), - (CVT_s8_s64 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), - (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; //----------------------------------- // Comparison instructions (setp, set) @@ -1713,56 +1656,39 @@ def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>; //----------------------------------- // Comparison and Selection //----------------------------------- +// TODO: These patterns seem very specific and brittle. We should try to find +// a more general solution. def cond_signed : PatLeaf<(cond), [{ return isSignedIntSetCC(N->get()); }]>; -def cond_not_signed : PatLeaf<(cond), [{ - return !isSignedIntSetCC(N->get()); -}]>; +// A 16-bit signed comparison of sign-extended byte extracts can be converted +// to 32-bit comparison if we change the PRMT to sign-extend the extracted +// bytes. +def : Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), + (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), + cond_signed:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE), + (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE), + (cond2cc $cc))>; + +// A 16-bit comparison of truncated byte extracts can be be converted to 32-bit +// comparison because we know that the truncate is just trancating off zeros +// and that the most-significant byte is also zeros so the meaning of signed and +// unsigned comparisons will not be changed. +def : Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), + (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), + cond:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), + (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), + (cond2cc $cc))>; -// comparisons of i8 extracted with PRMT as i32 -// It's faster to do comparison directly on i32 extracted by PRMT, -// instead of the long conversion and sign extending. -def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)), - (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), - (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_not_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; def SDTDeclareArrayParam : SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>; def SDTDeclareScalarParam : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; -def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>; -def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>; -def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>; -def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; -def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>; -def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisSameAs<0, 1>]>; def SDTProxyReg : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>]>; @@ -1774,104 +1700,20 @@ def declare_array_param : def declare_scalar_param : SDNode<"NVPTXISD::DeclareScalarParam", SDTDeclareScalarParam, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; - -def LoadParam : - SDNode<"NVPTXISD::LoadParam", SDTLoadParamProfile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def LoadParamV2 : - SDNode<"NVPTXISD::LoadParamV2", SDTLoadParamV2Profile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def LoadParamV4 : - SDNode<"NVPTXISD::LoadParamV4", SDTLoadParamV4Profile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def StoreParam : - SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def StoreParamV2 : - SDNode<"NVPTXISD::StoreParamV2", SDTStoreParamV2Profile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def StoreParamV4 : - SDNode<"NVPTXISD::StoreParamV4", SDTStoreParamV4Profile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def MoveParam : SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>; def proxy_reg : SDNode<"NVPTXISD::ProxyReg", SDTProxyReg, [SDNPHasChain]>; /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, - /// NumParams, Callee, Proto, InGlue) + /// NumParams, Callee, Proto) def SDTCallProfile : SDTypeProfile<0, 6, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<5, i32>]>; -def call : - SDNode<"NVPTXISD::CALL", SDTCallProfile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; - -let mayLoad = true in { - class LoadParamMemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst), (ins Offseti32imm:$b), - !strconcat("ld.param", opstr, " \t$dst, [retval0$b];"), - []>; - - class LoadParamV2MemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst, regclass:$dst2), (ins Offseti32imm:$b), - !strconcat("ld.param.v2", opstr, - " \t{{$dst, $dst2}}, [retval0$b];"), []>; - - class LoadParamV4MemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst, regclass:$dst2, regclass:$dst3, - regclass:$dst4), - (ins Offseti32imm:$b), - !strconcat("ld.param.v4", opstr, - " \t{{$dst, $dst2, $dst3, $dst4}}, [retval0$b];"), - []>; -} - -let mayStore = true in { - - multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> { - foreach op = [IMMType, regclass] in - if !or(support_imm, !isa<NVPTXRegClass>(op)) then - def _ # !if(!isa<NVPTXRegClass>(op), "r", "i") - : NVPTXInst<(outs), - (ins op:$val, i32imm:$a, Offseti32imm:$b), - "st.param" # opstr # " \t[param$a$b], $val;", - []>; - } - - multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> { - foreach op1 = [IMMType, regclass] in - foreach op2 = [IMMType, regclass] in - def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i") - # !if(!isa<NVPTXRegClass>(op2), "r", "i") - : NVPTXInst<(outs), - (ins op1:$val1, op2:$val2, - i32imm:$a, Offseti32imm:$b), - "st.param.v2" # opstr # " \t[param$a$b], {{$val1, $val2}};", - []>; - } - - multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> { - foreach op1 = [IMMType, regclass] in - foreach op2 = [IMMType, regclass] in - foreach op3 = [IMMType, regclass] in - foreach op4 = [IMMType, regclass] in - def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i") - # !if(!isa<NVPTXRegClass>(op2), "r", "i") - # !if(!isa<NVPTXRegClass>(op3), "r", "i") - # !if(!isa<NVPTXRegClass>(op4), "r", "i") - - : NVPTXInst<(outs), - (ins op1:$val1, op2:$val2, op3:$val3, op4:$val4, - i32imm:$a, Offseti32imm:$b), - "st.param.v4" # opstr # - " \t[param$a$b], {{$val1, $val2, $val3, $val4}};", - []>; - } -} +def call : SDNode<"NVPTXISD::CALL", SDTCallProfile, [SDNPHasChain, SDNPSideEffect]>; /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, -/// NumParams, Callee, Proto, InGlue) +/// NumParams, Callee, Proto) def CallOperand : Operand<i32> { let PrintMethod = "printCallOperand"; } @@ -1908,43 +1750,6 @@ foreach is_convergent = [0, 1] in { (call_uni_inst $addr, imm:$rets, imm:$params)>; } -def LoadParamMemI64 : LoadParamMemInst<B64, ".b64">; -def LoadParamMemI32 : LoadParamMemInst<B32, ".b32">; -def LoadParamMemI16 : LoadParamMemInst<B16, ".b16">; -def LoadParamMemI8 : LoadParamMemInst<B16, ".b8">; -def LoadParamMemV2I64 : LoadParamV2MemInst<B64, ".b64">; -def LoadParamMemV2I32 : LoadParamV2MemInst<B32, ".b32">; -def LoadParamMemV2I16 : LoadParamV2MemInst<B16, ".b16">; -def LoadParamMemV2I8 : LoadParamV2MemInst<B16, ".b8">; -def LoadParamMemV4I32 : LoadParamV4MemInst<B32, ".b32">; -def LoadParamMemV4I16 : LoadParamV4MemInst<B16, ".b16">; -def LoadParamMemV4I8 : LoadParamV4MemInst<B16, ".b8">; - -defm StoreParamI64 : StoreParamInst<B64, i64imm, ".b64">; -defm StoreParamI32 : StoreParamInst<B32, i32imm, ".b32">; -defm StoreParamI16 : StoreParamInst<B16, i16imm, ".b16">; -defm StoreParamI8 : StoreParamInst<B16, i8imm, ".b8">; - -defm StoreParamI8TruncI32 : StoreParamInst<B32, i8imm, ".b8", /* support_imm */ false>; -defm StoreParamI8TruncI64 : StoreParamInst<B64, i8imm, ".b8", /* support_imm */ false>; - -defm StoreParamV2I64 : StoreParamV2Inst<B64, i64imm, ".b64">; -defm StoreParamV2I32 : StoreParamV2Inst<B32, i32imm, ".b32">; -defm StoreParamV2I16 : StoreParamV2Inst<B16, i16imm, ".b16">; -defm StoreParamV2I8 : StoreParamV2Inst<B16, i8imm, ".b8">; - -defm StoreParamV4I32 : StoreParamV4Inst<B32, i32imm, ".b32">; -defm StoreParamV4I16 : StoreParamV4Inst<B16, i16imm, ".b16">; -defm StoreParamV4I8 : StoreParamV4Inst<B16, i8imm, ".b8">; - -defm StoreParamF32 : StoreParamInst<B32, f32imm, ".b32">; -defm StoreParamF64 : StoreParamInst<B64, f64imm, ".b64">; - -defm StoreParamV2F32 : StoreParamV2Inst<B32, f32imm, ".b32">; -defm StoreParamV2F64 : StoreParamV2Inst<B64, f64imm, ".b64">; - -defm StoreParamV4F32 : StoreParamV4Inst<B32, f32imm, ".b32">; - def DECLARE_PARAM_array : NVPTXInst<(outs), (ins i32imm:$a, i32imm:$align, i32imm:$size), ".param .align $align .b8 \t$a[$size];", []>; @@ -1957,6 +1762,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size), def : Pat<(declare_scalar_param externalsym:$a, imm:$size), (DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>; +// Call prototype wrapper, this is a dummy instruction that just prints it's +// operand which is string defining the prototype. +def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def CallPrototype : + SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; } +def CALL_PROTOTYPE : + NVPTXInst<(outs), (ins ProtoIdent:$ident), + "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; + + foreach t = [I32RT, I64RT] in { defvar inst_name = "MOV" # t.Size # "_PARAM"; def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>; @@ -1976,6 +1793,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>; defm ProxyRegB32 : ProxyRegInst<"b32", B32>; defm ProxyRegB64 : ProxyRegInst<"b64", B64>; + +// Callseq start and end + +// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because +// they define the scope in which the declared params may be used. Therefore +// we add these flags to ensure ld.param and st.param are not sunk or hoisted +// out of that scope. + +def callseq_start : SDNode<"ISD::CALLSEQ_START", + SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; +def callseq_end : SDNode<"ISD::CALLSEQ_END", + SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; + +def Callseq_Start : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\{ // callseq $amt1, $amt2", + [(callseq_start timm:$amt1, timm:$amt2)]>; +def Callseq_End : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\} // callseq $amt1", + [(callseq_end timm:$amt1, timm:$amt2)]>; + // // Load / Store Handling // @@ -2519,26 +2362,6 @@ def : Pat<(brcond i32:$a, bb:$target), def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), (CBranchOther $a, bb:$target)>; -// Call -def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>, - SDTCisVT<1, i32>]>; -def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; - -def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart, - [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>; -def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd, - [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, - SDNPSideEffect]>; - -def Callseq_Start : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\{ // callseq $amt1, $amt2", - [(callseq_start timm:$amt1, timm:$amt2)]>; -def Callseq_End : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\} // callseq $amt1", - [(callseq_end timm:$amt1, timm:$amt2)]>; - // trap instruction def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>; // Emit an `exit` as well to convey to ptxas that `trap` exits the CFG. @@ -2547,18 +2370,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[ // brkpt instruction def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>; -// Call prototype wrapper -def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; -def CallPrototype : - SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def ProtoIdent : Operand<i32> { - let PrintMethod = "printProtoIdent"; -} -def CALL_PROTOTYPE : - NVPTXInst<(outs), (ins ProtoIdent:$ident), - "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; - def SDTDynAllocaOp : SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>; |