diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 203 |
1 files changed, 196 insertions, 7 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 437022f..9a6afa1 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -516,6 +516,44 @@ void RISCVDAGToDAGISel::selectVSETVLI(SDNode *Node) { CurDAG->getMachineNode(Opcode, DL, XLenVT, VLOperand, VTypeIOp)); } +void RISCVDAGToDAGISel::selectXSfmmVSET(SDNode *Node) { + if (!Subtarget->hasVendorXSfmmbase()) + return; + + assert(Node->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Unexpected opcode"); + + SDLoc DL(Node); + MVT XLenVT = Subtarget->getXLenVT(); + + unsigned IntNo = Node->getConstantOperandVal(0); + + assert((IntNo == Intrinsic::riscv_sf_vsettnt || + IntNo == Intrinsic::riscv_sf_vsettm || + IntNo == Intrinsic::riscv_sf_vsettk) && + "Unexpected XSfmm vset intrinsic"); + + unsigned SEW = RISCVVType::decodeVSEW(Node->getConstantOperandVal(2)); + unsigned Widen = RISCVVType::decodeTWiden(Node->getConstantOperandVal(3)); + unsigned PseudoOpCode = + IntNo == Intrinsic::riscv_sf_vsettnt ? RISCV::PseudoSF_VSETTNT + : IntNo == Intrinsic::riscv_sf_vsettm ? RISCV::PseudoSF_VSETTM + : RISCV::PseudoSF_VSETTK; + + if (IntNo == Intrinsic::riscv_sf_vsettnt) { + unsigned VTypeI = RISCVVType::encodeXSfmmVType(SEW, Widen, 0); + SDValue VTypeIOp = CurDAG->getTargetConstant(VTypeI, DL, XLenVT); + + ReplaceNode(Node, CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT, + Node->getOperand(1), VTypeIOp)); + } else { + SDValue Log2SEW = CurDAG->getTargetConstant(Log2_32(SEW), DL, XLenVT); + SDValue TWiden = CurDAG->getTargetConstant(Widen, DL, XLenVT); + ReplaceNode(Node, + CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT, + Node->getOperand(1), Log2SEW, TWiden)); + } +} + bool RISCVDAGToDAGISel::tryShrinkShlLogicImm(SDNode *Node) { MVT VT = Node->getSimpleValueType(0); unsigned Opcode = Node->getOpcode(); @@ -847,6 +885,11 @@ bool RISCVDAGToDAGISel::tryIndexedLoad(SDNode *Node) { return true; } +static Register getTileReg(uint64_t TileNum) { + assert(TileNum <= 15 && "Invalid tile number"); + return RISCV::T0 + TileNum; +} + void RISCVDAGToDAGISel::selectSF_VC_X_SE(SDNode *Node) { if (!Subtarget->hasVInstructions()) return; @@ -2035,6 +2078,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { case Intrinsic::riscv_vsetvli: case Intrinsic::riscv_vsetvlimax: return selectVSETVLI(Node); + case Intrinsic::riscv_sf_vsettnt: + case Intrinsic::riscv_sf_vsettm: + case Intrinsic::riscv_sf_vsettk: + return selectXSfmmVSET(Node); } break; } @@ -2458,6 +2505,142 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { case Intrinsic::riscv_sf_vc_i_se: selectSF_VC_X_SE(Node); return; + case Intrinsic::riscv_sf_vlte8: + case Intrinsic::riscv_sf_vlte16: + case Intrinsic::riscv_sf_vlte32: + case Intrinsic::riscv_sf_vlte64: { + unsigned Log2SEW; + unsigned PseudoInst; + switch (IntNo) { + case Intrinsic::riscv_sf_vlte8: + PseudoInst = RISCV::PseudoSF_VLTE8; + Log2SEW = 3; + break; + case Intrinsic::riscv_sf_vlte16: + PseudoInst = RISCV::PseudoSF_VLTE16; + Log2SEW = 4; + break; + case Intrinsic::riscv_sf_vlte32: + PseudoInst = RISCV::PseudoSF_VLTE32; + Log2SEW = 5; + break; + case Intrinsic::riscv_sf_vlte64: + PseudoInst = RISCV::PseudoSF_VLTE64; + Log2SEW = 6; + break; + } + + SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT); + SDValue TWidenOp = CurDAG->getTargetConstant(1, DL, XLenVT); + SDValue Operands[] = {Node->getOperand(2), + Node->getOperand(3), + Node->getOperand(4), + SEWOp, + TWidenOp, + Node->getOperand(0)}; + + MachineSDNode *TileLoad = + CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands); + if (auto *MemOp = dyn_cast<MemSDNode>(Node)) + CurDAG->setNodeMemRefs(TileLoad, {MemOp->getMemOperand()}); + + ReplaceNode(Node, TileLoad); + return; + } + case Intrinsic::riscv_sf_mm_s_s: + case Intrinsic::riscv_sf_mm_s_u: + case Intrinsic::riscv_sf_mm_u_s: + case Intrinsic::riscv_sf_mm_u_u: + case Intrinsic::riscv_sf_mm_e5m2_e5m2: + case Intrinsic::riscv_sf_mm_e5m2_e4m3: + case Intrinsic::riscv_sf_mm_e4m3_e5m2: + case Intrinsic::riscv_sf_mm_e4m3_e4m3: + case Intrinsic::riscv_sf_mm_f_f: { + bool HasFRM = false; + unsigned PseudoInst; + switch (IntNo) { + case Intrinsic::riscv_sf_mm_s_s: + PseudoInst = RISCV::PseudoSF_MM_S_S; + break; + case Intrinsic::riscv_sf_mm_s_u: + PseudoInst = RISCV::PseudoSF_MM_S_U; + break; + case Intrinsic::riscv_sf_mm_u_s: + PseudoInst = RISCV::PseudoSF_MM_U_S; + break; + case Intrinsic::riscv_sf_mm_u_u: + PseudoInst = RISCV::PseudoSF_MM_U_U; + break; + case Intrinsic::riscv_sf_mm_e5m2_e5m2: + PseudoInst = RISCV::PseudoSF_MM_E5M2_E5M2; + HasFRM = true; + break; + case Intrinsic::riscv_sf_mm_e5m2_e4m3: + PseudoInst = RISCV::PseudoSF_MM_E5M2_E4M3; + HasFRM = true; + break; + case Intrinsic::riscv_sf_mm_e4m3_e5m2: + PseudoInst = RISCV::PseudoSF_MM_E4M3_E5M2; + HasFRM = true; + break; + case Intrinsic::riscv_sf_mm_e4m3_e4m3: + PseudoInst = RISCV::PseudoSF_MM_E4M3_E4M3; + HasFRM = true; + break; + case Intrinsic::riscv_sf_mm_f_f: + if (Node->getOperand(3).getValueType().getScalarType() == MVT::bf16) + PseudoInst = RISCV::PseudoSF_MM_F_F_ALT; + else + PseudoInst = RISCV::PseudoSF_MM_F_F; + HasFRM = true; + break; + } + uint64_t TileNum = Node->getConstantOperandVal(2); + SDValue Op1 = Node->getOperand(3); + SDValue Op2 = Node->getOperand(4); + MVT VT = Op1->getSimpleValueType(0); + unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits()); + SDValue TmOp = Node->getOperand(5); + SDValue TnOp = Node->getOperand(6); + SDValue TkOp = Node->getOperand(7); + SDValue TWidenOp = Node->getOperand(8); + SDValue Chain = Node->getOperand(0); + + // sf.mm.f.f with sew=32, twiden=2 is invalid + if (IntNo == Intrinsic::riscv_sf_mm_f_f && Log2SEW == 5 && + TWidenOp->getAsZExtVal() == 2) + reportFatalUsageError("sf.mm.f.f doesn't support (sew=32, twiden=2)"); + + SmallVector<SDValue, 10> Operands( + {CurDAG->getRegister(getTileReg(TileNum), XLenVT), Op1, Op2}); + if (HasFRM) + Operands.push_back( + CurDAG->getTargetConstant(RISCVFPRndMode::DYN, DL, XLenVT)); + Operands.append({TmOp, TnOp, TkOp, + CurDAG->getTargetConstant(Log2SEW, DL, XLenVT), TWidenOp, + Chain}); + + auto *NewNode = + CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands); + + ReplaceNode(Node, NewNode); + return; + } + case Intrinsic::riscv_sf_vtzero_t: { + uint64_t TileNum = Node->getConstantOperandVal(2); + SDValue Tm = Node->getOperand(3); + SDValue Tn = Node->getOperand(4); + SDValue Log2SEW = Node->getOperand(5); + SDValue TWiden = Node->getOperand(6); + SDValue Chain = Node->getOperand(0); + auto *NewNode = CurDAG->getMachineNode( + RISCV::PseudoSF_VTZERO_T, DL, Node->getVTList(), + {CurDAG->getRegister(getTileReg(TileNum), XLenVT), Tm, Tn, Log2SEW, + TWiden, Chain}); + + ReplaceNode(Node, NewNode); + return; + } } break; } @@ -3353,14 +3536,20 @@ bool RISCVDAGToDAGISel::selectSETCC(SDValue N, ISD::CondCode ExpectedCCVal, 0); return true; } - // If the RHS is [-2047,2048], we can use addi with -RHS to produce 0 if the - // LHS is equal to the RHS and non-zero otherwise. + // If the RHS is [-2047,2048], we can use addi/addiw with -RHS to produce 0 + // if the LHS is equal to the RHS and non-zero otherwise. if (isInt<12>(CVal) || CVal == 2048) { - Val = SDValue( - CurDAG->getMachineNode( - RISCV::ADDI, DL, N->getValueType(0), LHS, - CurDAG->getSignedTargetConstant(-CVal, DL, N->getValueType(0))), - 0); + unsigned Opc = RISCV::ADDI; + if (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG && + cast<VTSDNode>(LHS.getOperand(1))->getVT() == MVT::i32) { + Opc = RISCV::ADDIW; + LHS = LHS.getOperand(0); + } + + Val = SDValue(CurDAG->getMachineNode(Opc, DL, N->getValueType(0), LHS, + CurDAG->getSignedTargetConstant( + -CVal, DL, N->getValueType(0))), + 0); return true; } if (isPowerOf2_64(CVal) && Subtarget->hasStdExtZbs()) { |