diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 153 |
1 files changed, 140 insertions, 13 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 7aa06f9..f2c2f46 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -731,6 +731,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTruncStoreAction(MVT::f32, MVT::bf16, Expand); setTruncStoreAction(MVT::f64, MVT::bf16, Expand); setTruncStoreAction(MVT::f64, MVT::f32, Expand); + setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand); + setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand); // PTX does not support load / store predicate registers setOperationAction(ISD::LOAD, MVT::i1, Custom); @@ -2066,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL, SelectionDAG &DAG, unsigned Mode = NVPTX::PTXPrmtMode::NONE) { + assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 && + Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands"); return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32, {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)}); } @@ -4004,7 +4008,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -4027,6 +4034,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(4); + return true; + } + + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: @@ -5060,12 +5091,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return !U.getUser()->use_empty(); } - // Handle CopyToReg nodes that will become dead after our replacement - if (U.getUser()->getOpcode() == ISD::CopyToReg) { - DeadCopyToRegs.push_back(U.getUser()); - return true; - } - // Otherwise, this use prevents us from splitting a value. return false; })) @@ -5132,10 +5157,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs)) Results.push_back(NewLoad.getValue(NewNumOutputs + I)); - // Remove dead CopyToReg nodes by folding them into the chain they reference - for (SDNode *CTR : DeadCopyToRegs) - DCI.CombineTo(CTR, CTR->getOperand(0)); - return DCI.DAG.getMergeValues(Results, DL); } @@ -5853,6 +5874,8 @@ static SDValue combineADDRSPACECAST(SDNode *N, // details: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) { + assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands"); + if (Mode == NVPTX::PTXPrmtMode::NONE) return Selector; @@ -5884,6 +5907,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) { } static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) { + assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 && + Selector.getBitWidth() == 32 && "PRMT must have i32 operands"); // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} APInt BitField = B.concat(A); APInt SelectorVal = getPRMTSelector(Selector, Mode); @@ -6518,10 +6543,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, KnownBits BKnown = DAG.computeKnownBits(B, Depth); // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} + assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 && + "PRMT must have i32 operands"); + assert(Known.getBitWidth() == 32 && "PRMT must have i32 result"); KnownBits BitField = BKnown.concat(AKnown); APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode); - for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) { + for (unsigned I : llvm::seq(4)) { APInt Sel = SelectorVal.extractBits(4, I * 4); unsigned Idx = Sel.getLoBits(3).getZExtValue(); unsigned Sign = Sel.getHiBits(1).getZExtValue(); @@ -6544,4 +6572,103 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode( default: break; } -}
\ No newline at end of file +} + +static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal, + const APInt &DemandedBits) { + APInt DemandedLHS = APInt(32, 0); + APInt DemandedRHS = APInt(32, 0); + + for (unsigned I : llvm::seq(4)) { + if (DemandedBits.extractBits(8, I * 8).isZero()) + continue; + + APInt Sel = SelectorVal.extractBits(4, I * 4); + unsigned Idx = Sel.getLoBits(3).getZExtValue(); + unsigned Sign = Sel.getHiBits(1).getZExtValue(); + + APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS; + unsigned ByteStart = (Idx % 4) * 8; + if (Sign) + Src.setBit(ByteStart + 7); + else + Src.setBits(ByteStart, ByteStart + 8); + } + + return {DemandedLHS, DemandedRHS}; +} + +// Replace undef with 0 as this is easier for other optimizations such as +// known bits. +static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) { + if (!Op) + return SDValue(); + if (Op.isUndef()) + return DAG.getConstant(0, SDLoc(), MVT::i32); + return Op; +} + +static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT, + const APInt &DemandedBits, + SelectionDAG &DAG, + const TargetLowering &TLI, + unsigned Depth) { + assert(PRMT.getOpcode() == NVPTXISD::PRMT); + SDValue Op0 = PRMT.getOperand(0); + SDValue Op1 = PRMT.getOperand(1); + auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2)); + if (!SelectorConst) + return SDValue(); + + unsigned Mode = PRMT.getConstantOperandVal(3); + const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode); + + // Try to simplify the PRMT to one of the inputs if the used bytes are all + // from the same input in the correct order. + const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8; + const unsigned SelBits = (4 - LeadingBytes) * 4; + if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits)) + return Op0; + if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits)) + return Op1; + + auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits); + + // Attempt to avoid multi-use ops if we don't need anything from them. + SDValue DemandedOp0 = + TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1); + SDValue DemandedOp1 = + TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1); + + DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG); + DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG); + if ((DemandedOp0 && DemandedOp0 != Op0) || + (DemandedOp1 && DemandedOp1 != Op1)) { + Op0 = DemandedOp0 ? DemandedOp0 : Op0; + Op1 = DemandedOp1 ? DemandedOp1 : Op1; + return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG); + } + + return SDValue(); +} + +bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const { + Known.resetAll(); + + switch (Op.getOpcode()) { + case NVPTXISD::PRMT: + if (SDValue Result = simplifyDemandedBitsForPRMT(Op, DemandedBits, TLO.DAG, + *this, Depth)) { + TLO.CombineTo(Op, Result); + return true; + } + break; + default: + break; + } + + computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth); + return false; +} |