aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorAlex MacLean <amaclean@nvidia.com>2025-07-17 11:10:23 -0700
committerGitHub <noreply@github.com>2025-07-17 11:10:23 -0700
commitf480e1b8258eac3565b3ffaf3f8ed0f77eb87fee (patch)
tree2d122a62d34e80ac13449e6180ea5f1c2209271a /llvm/lib
parent3b11aaaf94fe6c7b4ccfd031f952265f706c1b68 (diff)
downloadllvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.zip
llvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.tar.gz
llvm-f480e1b8258eac3565b3ffaf3f8ed0f77eb87fee.tar.bz2
[NVPTX] Add PRMT constant folding and cleanup usage of PRMT node (#148906)
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp244
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td23
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td18
3 files changed, 203 insertions, 82 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d017c65..7aa06f9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
MVT::v32i32, MVT::v64i32, MVT::v128i32},
Custom);
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
- // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom);
+ // Enable custom lowering for the following:
+ // * MVT::i128 - clusterlaunchcontrol
+ // * MVT::i32 - prmt
+ // * MVT::Other - internal.addrspace.wrap
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
+ Custom);
}
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}
+static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
+ SelectionDAG &DAG,
+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
+ return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
+ {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
+}
+
+static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
+ SelectionDAG &DAG,
+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
+ return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
+}
+
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
// Handle bitcasting from v2i8 without hitting the default promotion
// strategy which goes through stack memory.
@@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
}
- return DAG.getNode(
- NVPTXISD::PRMT, DL, MVT::v4i8,
- {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
- DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
+ return getPRMT(L, R, SelectionValue, DL, DAG);
};
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
- return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
+ return DAG.getBitcast(VT, PRMT3210);
}
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
DAG.getConstant(0x7770, DL, MVT::i32));
- SDValue PRMT = DAG.getNode(
- NVPTXISD::PRMT, DL, MVT::i32,
- {DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
- Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
- return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
+ SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
+ DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
+ SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
+ SDNodeFlags Flags;
+ Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
+ Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
+ Ext->setFlags(Flags);
+ return Ext;
}
// Constant index will be matched by tablegen.
@@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
}
SDLoc DL(Op);
- return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
- DAG.getConstant(Selector, DL, MVT::i32),
- DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+ SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
+ DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
+ return DAG.getBitcast(Op.getValueType(), PRMT);
}
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}
+static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
+ const unsigned Mode = [&]() {
+ switch (Op->getConstantOperandVal(0)) {
+ case Intrinsic::nvvm_prmt:
+ return NVPTX::PTXPrmtMode::NONE;
+ case Intrinsic::nvvm_prmt_b4e:
+ return NVPTX::PTXPrmtMode::B4E;
+ case Intrinsic::nvvm_prmt_ecl:
+ return NVPTX::PTXPrmtMode::ECL;
+ case Intrinsic::nvvm_prmt_ecr:
+ return NVPTX::PTXPrmtMode::ECR;
+ case Intrinsic::nvvm_prmt_f4e:
+ return NVPTX::PTXPrmtMode::F4E;
+ case Intrinsic::nvvm_prmt_rc16:
+ return NVPTX::PTXPrmtMode::RC16;
+ case Intrinsic::nvvm_prmt_rc8:
+ return NVPTX::PTXPrmtMode::RC8;
+ default:
+ llvm_unreachable("unsupported/unhandled intrinsic");
+ }
+ }();
+ SDLoc DL(Op);
+ SDValue A = Op->getOperand(1);
+ SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
+ : DAG.getConstant(0, DL, MVT::i32);
+ SDValue Selector = (Op->op_end() - 1)->get();
+ return getPRMT(A, B, Selector, DL, DAG, Mode);
+}
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
switch (Op->getConstantOperandVal(0)) {
default:
return Op;
+ case Intrinsic::nvvm_prmt:
+ case Intrinsic::nvvm_prmt_b4e:
+ case Intrinsic::nvvm_prmt_ecl:
+ case Intrinsic::nvvm_prmt_ecr:
+ case Intrinsic::nvvm_prmt_f4e:
+ case Intrinsic::nvvm_prmt_rc16:
+ case Intrinsic::nvvm_prmt_rc8:
+ return lowerPrmtIntrinsic(Op, DAG);
case Intrinsic::nvvm_internal_addrspace_wrap:
return Op.getOperand(1);
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
auto &DAG = DCI.DAG;
- auto PRMT = DAG.getNode(
- NVPTXISD::PRMT, DL, MVT::v4i8,
- {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
- DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
- return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
+ auto PRMT =
+ getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
+ (Op1Bytes << 8) | Op0Bytes, DL, DAG);
+ return DAG.getBitcast(VT, PRMT);
}
static SDValue combineADDRSPACECAST(SDNode *N,
@@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}
+// Given a constant selector value and a prmt mode, return the selector value
+// normalized to the generic prmt mode. See the PTX ISA documentation for more
+// details:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
+static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
+ if (Mode == NVPTX::PTXPrmtMode::NONE)
+ return Selector;
+
+ const unsigned V = Selector.trunc(2).getZExtValue();
+
+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
+ unsigned S3) {
+ return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
+ };
+
+ switch (Mode) {
+ case NVPTX::PTXPrmtMode::F4E:
+ return GetSelector(V, V + 1, V + 2, V + 3);
+ case NVPTX::PTXPrmtMode::B4E:
+ return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
+ case NVPTX::PTXPrmtMode::RC8:
+ return GetSelector(V, V, V, V);
+ case NVPTX::PTXPrmtMode::ECL:
+ return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
+ case NVPTX::PTXPrmtMode::ECR:
+ return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
+ case NVPTX::PTXPrmtMode::RC16: {
+ unsigned V1 = (V & 1) << 1;
+ return GetSelector(V1, V1 + 1, V1, V1 + 1);
+ }
+ default:
+ llvm_unreachable("Invalid PRMT mode");
+ }
+}
+
+static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+ APInt BitField = B.concat(A);
+ APInt SelectorVal = getPRMTSelector(Selector, Mode);
+ APInt Result(32, 0);
+ for (unsigned I : llvm::seq(4U)) {
+ APInt Sel = SelectorVal.extractBits(4, I * 4);
+ unsigned Idx = Sel.getLoBits(3).getZExtValue();
+ unsigned Sign = Sel.getHiBits(1).getZExtValue();
+ APInt Byte = BitField.extractBits(8, Idx * 8);
+ if (Sign)
+ Byte = Byte.ashr(8);
+ Result.insertBits(Byte, I * 8);
+ }
+ return Result;
+}
+
+static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ // Constant fold PRMT
+ if (isa<ConstantSDNode>(N->getOperand(0)) &&
+ isa<ConstantSDNode>(N->getOperand(1)) &&
+ isa<ConstantSDNode>(N->getOperand(2)))
+ return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
+ N->getConstantOperandAPInt(1),
+ N->getConstantOperandAPInt(2),
+ N->getConstantOperandVal(3)),
+ SDLoc(N), N->getValueType(0));
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
switch (N->getOpcode()) {
- default: break;
- case ISD::ADD:
- return PerformADDCombine(N, DCI, OptLevel);
- case ISD::FADD:
- return PerformFADDCombine(N, DCI, OptLevel);
- case ISD::MUL:
- return PerformMULCombine(N, DCI, OptLevel);
- case ISD::SHL:
- return PerformSHLCombine(N, DCI, OptLevel);
- case ISD::AND:
- return PerformANDCombine(N, DCI);
- case ISD::UREM:
- case ISD::SREM:
- return PerformREMCombine(N, DCI, OptLevel);
- case ISD::SETCC:
- return PerformSETCCCombine(N, DCI, STI.getSmVersion());
- case ISD::LOAD:
- case NVPTXISD::LoadParamV2:
- case NVPTXISD::LoadV2:
- case NVPTXISD::LoadV4:
- return combineUnpackingMovIntoLoad(N, DCI);
- case NVPTXISD::StoreParam:
- case NVPTXISD::StoreParamV2:
- case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N, DCI);
- case ISD::STORE:
- case NVPTXISD::StoreV2:
- case NVPTXISD::StoreV4:
- return PerformStoreCombine(N, DCI);
- case ISD::EXTRACT_VECTOR_ELT:
- return PerformEXTRACTCombine(N, DCI);
- case ISD::VSELECT:
- return PerformVSELECTCombine(N, DCI);
- case ISD::BUILD_VECTOR:
- return PerformBUILD_VECTORCombine(N, DCI);
- case ISD::ADDRSPACECAST:
- return combineADDRSPACECAST(N, DCI);
+ default:
+ break;
+ case ISD::ADD:
+ return PerformADDCombine(N, DCI, OptLevel);
+ case ISD::ADDRSPACECAST:
+ return combineADDRSPACECAST(N, DCI);
+ case ISD::AND:
+ return PerformANDCombine(N, DCI);
+ case ISD::BUILD_VECTOR:
+ return PerformBUILD_VECTORCombine(N, DCI);
+ case ISD::EXTRACT_VECTOR_ELT:
+ return PerformEXTRACTCombine(N, DCI);
+ case ISD::FADD:
+ return PerformFADDCombine(N, DCI, OptLevel);
+ case ISD::LOAD:
+ case NVPTXISD::LoadParamV2:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ return combineUnpackingMovIntoLoad(N, DCI);
+ case ISD::MUL:
+ return PerformMULCombine(N, DCI, OptLevel);
+ case NVPTXISD::PRMT:
+ return combinePRMT(N, DCI, OptLevel);
+ case ISD::SETCC:
+ return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+ case ISD::SHL:
+ return PerformSHLCombine(N, DCI, OptLevel);
+ case ISD::SREM:
+ case ISD::UREM:
+ return PerformREMCombine(N, DCI, OptLevel);
+ case NVPTXISD::StoreParam:
+ case NVPTXISD::StoreParamV2:
+ case NVPTXISD::StoreParamV4:
+ return PerformStoreParamCombine(N, DCI);
+ case ISD::STORE:
+ case NVPTXISD::StoreV2:
+ case NVPTXISD::StoreV4:
+ return PerformStoreCombine(N, DCI);
+ case ISD::VSELECT:
+ return PerformVSELECTCombine(N, DCI);
}
return SDValue();
}
@@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
unsigned Mode = Op.getConstantOperandVal(3);
- if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
+ if (!Selector)
return;
KnownBits AKnown = DAG.computeKnownBits(A, Depth);
@@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
KnownBits BitField = BKnown.concat(AKnown);
- APInt SelectorVal = Selector->getAPIntValue();
+ APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
APInt Sel = SelectorVal.extractBits(4, I * 4);
unsigned Idx = Sel.getLoBits(3).getZExtValue();
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4eef6c9..a5bb83d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1453,18 +1453,33 @@ let hasSideEffects = false in {
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
+ def PRMT_B32rir
+ : BasicFlagsNVPTXInst<(outs B32:$d),
+ (ins B32:$a, i32imm:$b, B32:$c),
+ (ins PrmtMode:$mode),
+ "prmt.b32$mode",
+ [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
def PRMT_B32rii
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins B32:$a, i32imm:$b, Hexu32imm:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
- def PRMT_B32rir
+ def PRMT_B32irr
: BasicFlagsNVPTXInst<(outs B32:$d),
- (ins B32:$a, i32imm:$b, B32:$c),
- (ins PrmtMode:$mode),
+ (ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode),
+ "prmt.b32$mode",
+ [(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>;
+ def PRMT_B32iri
+ : BasicFlagsNVPTXInst<(outs B32:$d),
+ (ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode),
+ "prmt.b32$mode",
+ [(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>;
+ def PRMT_B32iir
+ : BasicFlagsNVPTXInst<(outs B32:$d),
+ (ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode),
"prmt.b32$mode",
- [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
+ [(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index bad4c3c..70150bd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1047,24 +1047,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
// MISC
//
-class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
- : Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
- (PRMT_B32rrr $a, $b, $c, prmt_mode)>;
-
-class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
- : Pat<(prmt_intrinsic i32:$a, i32:$c),
- (PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
-
-def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
-def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
-def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;
-
-def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
-def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
-def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
-def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
-
-
def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32",
[(int_nvvm_nanosleep imm:$i)]>,
Requires<[hasPTX<63>, hasSM<70>]>;