aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp153
1 files changed, 140 insertions, 13 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9..f2c2f46 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -731,6 +731,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
+ setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
+ setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);
// PTX does not support load / store predicate registers
setOperationAction(ISD::LOAD, MVT::i1, Custom);
@@ -2066,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
SelectionDAG &DAG,
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
+ assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
+ Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
{A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
}
@@ -4004,7 +4008,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
- case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
@@ -4027,6 +4034,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(4);
+ return true;
+ }
+
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v4i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(16);
+ return true;
+ }
+
case Intrinsic::nvvm_atomic_add_gen_f_cta:
case Intrinsic::nvvm_atomic_add_gen_f_sys:
case Intrinsic::nvvm_atomic_add_gen_i_cta:
@@ -5060,12 +5091,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return !U.getUser()->use_empty();
}
- // Handle CopyToReg nodes that will become dead after our replacement
- if (U.getUser()->getOpcode() == ISD::CopyToReg) {
- DeadCopyToRegs.push_back(U.getUser());
- return true;
- }
-
// Otherwise, this use prevents us from splitting a value.
return false;
}))
@@ -5132,10 +5157,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs))
Results.push_back(NewLoad.getValue(NewNumOutputs + I));
- // Remove dead CopyToReg nodes by folding them into the chain they reference
- for (SDNode *CTR : DeadCopyToRegs)
- DCI.CombineTo(CTR, CTR->getOperand(0));
-
return DCI.DAG.getMergeValues(Results, DL);
}
@@ -5853,6 +5874,8 @@ static SDValue combineADDRSPACECAST(SDNode *N,
// details:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
+ assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
+
if (Mode == NVPTX::PTXPrmtMode::NONE)
return Selector;
@@ -5884,6 +5907,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
}
static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
+ assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
+ Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
APInt BitField = B.concat(A);
APInt SelectorVal = getPRMTSelector(Selector, Mode);
@@ -6518,10 +6543,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
KnownBits BKnown = DAG.computeKnownBits(B, Depth);
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
+ assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
+ "PRMT must have i32 operands");
+ assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
KnownBits BitField = BKnown.concat(AKnown);
APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
- for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
+ for (unsigned I : llvm::seq(4)) {
APInt Sel = SelectorVal.extractBits(4, I * 4);
unsigned Idx = Sel.getLoBits(3).getZExtValue();
unsigned Sign = Sel.getHiBits(1).getZExtValue();
@@ -6544,4 +6572,103 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
default:
break;
}
-} \ No newline at end of file
+}
+
+static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
+ const APInt &DemandedBits) {
+ APInt DemandedLHS = APInt(32, 0);
+ APInt DemandedRHS = APInt(32, 0);
+
+ for (unsigned I : llvm::seq(4)) {
+ if (DemandedBits.extractBits(8, I * 8).isZero())
+ continue;
+
+ APInt Sel = SelectorVal.extractBits(4, I * 4);
+ unsigned Idx = Sel.getLoBits(3).getZExtValue();
+ unsigned Sign = Sel.getHiBits(1).getZExtValue();
+
+ APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
+ unsigned ByteStart = (Idx % 4) * 8;
+ if (Sign)
+ Src.setBit(ByteStart + 7);
+ else
+ Src.setBits(ByteStart, ByteStart + 8);
+ }
+
+ return {DemandedLHS, DemandedRHS};
+}
+
+// Replace undef with 0 as this is easier for other optimizations such as
+// known bits.
+static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
+ if (!Op)
+ return SDValue();
+ if (Op.isUndef())
+ return DAG.getConstant(0, SDLoc(), MVT::i32);
+ return Op;
+}
+
+static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
+ const APInt &DemandedBits,
+ SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ unsigned Depth) {
+ assert(PRMT.getOpcode() == NVPTXISD::PRMT);
+ SDValue Op0 = PRMT.getOperand(0);
+ SDValue Op1 = PRMT.getOperand(1);
+ auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
+ if (!SelectorConst)
+ return SDValue();
+
+ unsigned Mode = PRMT.getConstantOperandVal(3);
+ const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
+
+ // Try to simplify the PRMT to one of the inputs if the used bytes are all
+ // from the same input in the correct order.
+ const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
+ const unsigned SelBits = (4 - LeadingBytes) * 4;
+ if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
+ return Op0;
+ if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
+ return Op1;
+
+ auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
+
+ // Attempt to avoid multi-use ops if we don't need anything from them.
+ SDValue DemandedOp0 =
+ TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
+ SDValue DemandedOp1 =
+ TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);
+
+ DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
+ DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
+ if ((DemandedOp0 && DemandedOp0 != Op0) ||
+ (DemandedOp1 && DemandedOp1 != Op1)) {
+ Op0 = DemandedOp0 ? DemandedOp0 : Op0;
+ Op1 = DemandedOp1 ? DemandedOp1 : Op1;
+ return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
+ }
+
+ return SDValue();
+}
+
+bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
+ SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+ KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
+ Known.resetAll();
+
+ switch (Op.getOpcode()) {
+ case NVPTXISD::PRMT:
+ if (SDValue Result = simplifyDemandedBitsForPRMT(Op, DemandedBits, TLO.DAG,
+ *this, Depth)) {
+ TLO.CombineTo(Op, Result);
+ return true;
+ }
+ break;
+ default:
+ break;
+ }
+
+ computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
+ return false;
+}