diff options
author | Alex MacLean <amaclean@nvidia.com> | 2025-07-17 11:10:23 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-17 11:10:23 -0700 |
commit | f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee (patch) | |
tree | 2d122a62d34e80ac13449e6180ea5f1c2209271a /llvm/lib | |
parent | 3b11aaaf94fe6c7b4ccfd031f952265f706c1b68 (diff) | |
download | llvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.zip llvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.tar.gz llvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.tar.bz2 |
[NVPTX] Add PRMT constant folding and cleanup usage of PRMT node (#148906)
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 244 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 23 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 18 |
3 files changed, 203 insertions, 82 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index d017c65..7aa06f9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, MVT::v32i32, MVT::v64i32, MVT::v128i32}, Custom); - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); - // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom); + // Enable custom lowering for the following: + // * MVT::i128 - clusterlaunchcontrol + // * MVT::i32 - prmt + // * MVT::Other - internal.addrspace.wrap + setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other}, + Custom); } const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { @@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { return DAG.getBuildVector(Node->getValueType(0), dl, Ops); } +static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL, + SelectionDAG &DAG, + unsigned Mode = NVPTX::PTXPrmtMode::NONE) { + return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32, + {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)}); +} + +static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL, + SelectionDAG &DAG, + unsigned Mode = NVPTX::PTXPrmtMode::NONE) { + return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode); +} + SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { // Handle bitcasting from v2i8 without hitting the default promotion // strategy which goes through stack memory. @@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32); R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32); } - return DAG.getNode( - NVPTXISD::PRMT, DL, MVT::v4i8, - {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32), - DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)}); + return getPRMT(L, R, SelectionValue, DL, DAG); }; auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340); auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340); auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410); - return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210); + return DAG.getBitcast(VT, PRMT3210); } // Get value or the Nth operand as an APInt(32). Undef values treated as 0. @@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32, DAG.getZExtOrTrunc(Index, DL, MVT::i32), DAG.getConstant(0x7770, DL, MVT::i32)); - SDValue PRMT = DAG.getNode( - NVPTXISD::PRMT, DL, MVT::i32, - {DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32), - Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)}); - return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0)); + SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector), + DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG); + SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0)); + SDNodeFlags Flags; + Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8); + Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8); + Ext->setFlags(Flags); + return Ext; } // Constant index will be matched by tablegen. @@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, } SDLoc DL(Op); - return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2, - DAG.getConstant(Selector, DL, MVT::i32), - DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)); + SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1), + DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG); + return DAG.getBitcast(Op.getValueType(), PRMT); } /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift @@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op, {TryCancelResponse0, TryCancelResponse1}); } +static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) { + const unsigned Mode = [&]() { + switch (Op->getConstantOperandVal(0)) { + case Intrinsic::nvvm_prmt: + return NVPTX::PTXPrmtMode::NONE; + case Intrinsic::nvvm_prmt_b4e: + return NVPTX::PTXPrmtMode::B4E; + case Intrinsic::nvvm_prmt_ecl: + return NVPTX::PTXPrmtMode::ECL; + case Intrinsic::nvvm_prmt_ecr: + return NVPTX::PTXPrmtMode::ECR; + case Intrinsic::nvvm_prmt_f4e: + return NVPTX::PTXPrmtMode::F4E; + case Intrinsic::nvvm_prmt_rc16: + return NVPTX::PTXPrmtMode::RC16; + case Intrinsic::nvvm_prmt_rc8: + return NVPTX::PTXPrmtMode::RC8; + default: + llvm_unreachable("unsupported/unhandled intrinsic"); + } + }(); + SDLoc DL(Op); + SDValue A = Op->getOperand(1); + SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2) + : DAG.getConstant(0, DL, MVT::i32); + SDValue Selector = (Op->op_end() - 1)->get(); + return getPRMT(A, B, Selector, DL, DAG, Mode); +} static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) { switch (Op->getConstantOperandVal(0)) { default: return Op; + case Intrinsic::nvvm_prmt: + case Intrinsic::nvvm_prmt_b4e: + case Intrinsic::nvvm_prmt_ecl: + case Intrinsic::nvvm_prmt_ecr: + case Intrinsic::nvvm_prmt_f4e: + case Intrinsic::nvvm_prmt_rc16: + case Intrinsic::nvvm_prmt_rc8: + return lowerPrmtIntrinsic(Op, DAG); case Intrinsic::nvvm_internal_addrspace_wrap: return Op.getOperand(1); case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled: @@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); auto &DAG = DCI.DAG; - auto PRMT = DAG.getNode( - NVPTXISD::PRMT, DL, MVT::v4i8, - {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32), - DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)}); - return DAG.getNode(ISD::BITCAST, DL, VT, PRMT); + auto PRMT = + getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1), + (Op1Bytes << 8) | Op0Bytes, DL, DAG); + return DAG.getBitcast(VT, PRMT); } static SDValue combineADDRSPACECAST(SDNode *N, @@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N, return SDValue(); } +// Given a constant selector value and a prmt mode, return the selector value +// normalized to the generic prmt mode. See the PTX ISA documentation for more +// details: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt +static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) { + if (Mode == NVPTX::PTXPrmtMode::NONE) + return Selector; + + const unsigned V = Selector.trunc(2).getZExtValue(); + + const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2, + unsigned S3) { + return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12)); + }; + + switch (Mode) { + case NVPTX::PTXPrmtMode::F4E: + return GetSelector(V, V + 1, V + 2, V + 3); + case NVPTX::PTXPrmtMode::B4E: + return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7); + case NVPTX::PTXPrmtMode::RC8: + return GetSelector(V, V, V, V); + case NVPTX::PTXPrmtMode::ECL: + return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U); + case NVPTX::PTXPrmtMode::ECR: + return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V); + case NVPTX::PTXPrmtMode::RC16: { + unsigned V1 = (V & 1) << 1; + return GetSelector(V1, V1 + 1, V1, V1 + 1); + } + default: + llvm_unreachable("Invalid PRMT mode"); + } +} + +static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) { + // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} + APInt BitField = B.concat(A); + APInt SelectorVal = getPRMTSelector(Selector, Mode); + APInt Result(32, 0); + for (unsigned I : llvm::seq(4U)) { + APInt Sel = SelectorVal.extractBits(4, I * 4); + unsigned Idx = Sel.getLoBits(3).getZExtValue(); + unsigned Sign = Sel.getHiBits(1).getZExtValue(); + APInt Byte = BitField.extractBits(8, Idx * 8); + if (Sign) + Byte = Byte.ashr(8); + Result.insertBits(Byte, I * 8); + } + return Result; +} + +static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + CodeGenOptLevel OptLevel) { + if (OptLevel == CodeGenOptLevel::None) + return SDValue(); + + // Constant fold PRMT + if (isa<ConstantSDNode>(N->getOperand(0)) && + isa<ConstantSDNode>(N->getOperand(1)) && + isa<ConstantSDNode>(N->getOperand(2))) + return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0), + N->getConstantOperandAPInt(1), + N->getConstantOperandAPInt(2), + N->getConstantOperandVal(3)), + SDLoc(N), N->getValueType(0)); + + return SDValue(); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); switch (N->getOpcode()) { - default: break; - case ISD::ADD: - return PerformADDCombine(N, DCI, OptLevel); - case ISD::FADD: - return PerformFADDCombine(N, DCI, OptLevel); - case ISD::MUL: - return PerformMULCombine(N, DCI, OptLevel); - case ISD::SHL: - return PerformSHLCombine(N, DCI, OptLevel); - case ISD::AND: - return PerformANDCombine(N, DCI); - case ISD::UREM: - case ISD::SREM: - return PerformREMCombine(N, DCI, OptLevel); - case ISD::SETCC: - return PerformSETCCCombine(N, DCI, STI.getSmVersion()); - case ISD::LOAD: - case NVPTXISD::LoadParamV2: - case NVPTXISD::LoadV2: - case NVPTXISD::LoadV4: - return combineUnpackingMovIntoLoad(N, DCI); - case NVPTXISD::StoreParam: - case NVPTXISD::StoreParamV2: - case NVPTXISD::StoreParamV4: - return PerformStoreParamCombine(N, DCI); - case ISD::STORE: - case NVPTXISD::StoreV2: - case NVPTXISD::StoreV4: - return PerformStoreCombine(N, DCI); - case ISD::EXTRACT_VECTOR_ELT: - return PerformEXTRACTCombine(N, DCI); - case ISD::VSELECT: - return PerformVSELECTCombine(N, DCI); - case ISD::BUILD_VECTOR: - return PerformBUILD_VECTORCombine(N, DCI); - case ISD::ADDRSPACECAST: - return combineADDRSPACECAST(N, DCI); + default: + break; + case ISD::ADD: + return PerformADDCombine(N, DCI, OptLevel); + case ISD::ADDRSPACECAST: + return combineADDRSPACECAST(N, DCI); + case ISD::AND: + return PerformANDCombine(N, DCI); + case ISD::BUILD_VECTOR: + return PerformBUILD_VECTORCombine(N, DCI); + case ISD::EXTRACT_VECTOR_ELT: + return PerformEXTRACTCombine(N, DCI); + case ISD::FADD: + return PerformFADDCombine(N, DCI, OptLevel); + case ISD::LOAD: + case NVPTXISD::LoadParamV2: + case NVPTXISD::LoadV2: + case NVPTXISD::LoadV4: + return combineUnpackingMovIntoLoad(N, DCI); + case ISD::MUL: + return PerformMULCombine(N, DCI, OptLevel); + case NVPTXISD::PRMT: + return combinePRMT(N, DCI, OptLevel); + case ISD::SETCC: + return PerformSETCCCombine(N, DCI, STI.getSmVersion()); + case ISD::SHL: + return PerformSHLCombine(N, DCI, OptLevel); + case ISD::SREM: + case ISD::UREM: + return PerformREMCombine(N, DCI, OptLevel); + case NVPTXISD::StoreParam: + case NVPTXISD::StoreParamV2: + case NVPTXISD::StoreParamV4: + return PerformStoreParamCombine(N, DCI); + case ISD::STORE: + case NVPTXISD::StoreV2: + case NVPTXISD::StoreV4: + return PerformStoreCombine(N, DCI); + case ISD::VSELECT: + return PerformVSELECTCombine(N, DCI); } return SDValue(); } @@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2)); unsigned Mode = Op.getConstantOperandVal(3); - if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector) + if (!Selector) return; KnownBits AKnown = DAG.computeKnownBits(A, Depth); @@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} KnownBits BitField = BKnown.concat(AKnown); - APInt SelectorVal = Selector->getAPIntValue(); + APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode); for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) { APInt Sel = SelectorVal.extractBits(4, I * 4); unsigned Idx = Sel.getLoBits(3).getZExtValue(); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 4eef6c9..a5bb83d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1453,18 +1453,33 @@ let hasSideEffects = false in { (ins PrmtMode:$mode), "prmt.b32$mode", [(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>; + def PRMT_B32rir + : BasicFlagsNVPTXInst<(outs B32:$d), + (ins B32:$a, i32imm:$b, B32:$c), + (ins PrmtMode:$mode), + "prmt.b32$mode", + [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>; def PRMT_B32rii : BasicFlagsNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b, Hexu32imm:$c), (ins PrmtMode:$mode), "prmt.b32$mode", [(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>; - def PRMT_B32rir + def PRMT_B32irr : BasicFlagsNVPTXInst<(outs B32:$d), - (ins B32:$a, i32imm:$b, B32:$c), - (ins PrmtMode:$mode), + (ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode), + "prmt.b32$mode", + [(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>; + def PRMT_B32iri + : BasicFlagsNVPTXInst<(outs B32:$d), + (ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode), + "prmt.b32$mode", + [(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>; + def PRMT_B32iir + : BasicFlagsNVPTXInst<(outs B32:$d), + (ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode), "prmt.b32$mode", - [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>; + [(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>; } diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index bad4c3c..70150bd 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1047,24 +1047,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass, // MISC // -class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode> - : Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c), - (PRMT_B32rrr $a, $b, $c, prmt_mode)>; - -class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode> - : Pat<(prmt_intrinsic i32:$a, i32:$c), - (PRMT_B32rir $a, (i32 0), $c, prmt_mode)>; - -def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>; -def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>; -def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>; - -def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>; -def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>; -def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>; -def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>; - - def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32", [(int_nvvm_nanosleep imm:$i)]>, Requires<[hasPTX<63>, hasSM<70>]>; |