diff options
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoP.td | 71 |
2 files changed, 55 insertions, 41 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 08f3ac4..5c4b1f3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8859,14 +8859,27 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, case ISD::SRA: if (Op.getSimpleValueType().isFixedLengthVector()) { if (Subtarget.hasStdExtP()) { - // We have patterns for scalar/immediate shift amount, so no lowering - // needed. - if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR) - return Op; - // There's no vector-vector version of shift instruction in P extension // so we need to unroll to scalar computation and pack them back. - return DAG.UnrollVectorOp(Op.getNode()); + if (Op.getOperand(1)->getOpcode() != ISD::SPLAT_VECTOR) + return DAG.UnrollVectorOp(Op.getNode()); + + unsigned Opc; + switch (Op.getOpcode()) { + default: + llvm_unreachable("Unexpected opcode"); + case ISD::SHL: + Opc = RISCVISD::PSHL; + break; + case ISD::SRL: + Opc = RISCVISD::PSRL; + break; + case ISD::SRA: + Opc = RISCVISD::PSRA; + break; + } + return DAG.getNode(Opc, SDLoc(Op), Op.getValueType(), Op.getOperand(0), + Op.getOperand(1).getOperand(0)); } return lowerToScalableOp(Op, DAG); } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 3066732..2e8e4c9 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1658,6 +1658,13 @@ def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>; def riscv_mulhru : RVSDNode<"MULHRU", SDTIntBinOp>; def riscv_mulhrsu : RVSDNode<"MULHRSU", SDTIntBinOp>; +def STD_RISCVPackedShift : SDTypeProfile<1, 2, [SDTCisVec<0>, + SDTCisSameAs<0, 1>, + SDTCisVT<2, XLenVT>]>; +def riscv_pshl : RVSDNode<"PSHL", STD_RISCVPackedShift>; +def riscv_psrl : RVSDNode<"PSRL", STD_RISCVPackedShift>; +def riscv_psra : RVSDNode<"PSRA", STD_RISCVPackedShift>; + // Bitwise merge: res = (~op0 & op1) | (op0 & op2) def SDT_RISCVMERGE : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0, 1>, @@ -1766,23 +1773,23 @@ let Predicates = [HasStdExtP] in { def: Pat<(XLenVecI16VT (riscv_mulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>; // 8-bit logical shift left/right patterns - def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))), + def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, uimm3:$shamt)), (PSLLI_B GPR:$rs1, uimm3:$shamt)>; - def: Pat<(XLenVecI8VT (srl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))), + def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, uimm3:$shamt)), (PSRLI_B GPR:$rs1, uimm3:$shamt)>; // 16-bit logical shift left/right patterns - def: Pat<(XLenVecI16VT (shl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))), + def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, uimm4:$shamt)), (PSLLI_H GPR:$rs1, uimm4:$shamt)>; - def: Pat<(XLenVecI16VT (srl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))), + def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, uimm4:$shamt)), (PSRLI_H GPR:$rs1, uimm4:$shamt)>; // 8-bit arithmetic shift right patterns - def: Pat<(XLenVecI8VT (sra GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))), + def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, uimm3:$shamt)), (PSRAI_B GPR:$rs1, uimm3:$shamt)>; // 16-bit arithmetic shift right patterns - def: Pat<(XLenVecI16VT (sra GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))), + def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, uimm4:$shamt)), (PSRAI_H GPR:$rs1, uimm4:$shamt)>; // 16-bit signed saturation shift left patterns @@ -1790,29 +1797,23 @@ let Predicates = [HasStdExtP] in { (PSSLAI_H GPR:$rs1, uimm4:$shamt)>; // 8-bit logical shift left/right - def: Pat<(XLenVecI8VT (shl GPR:$rs1, - (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, GPR:$rs2)), (PSLL_BS GPR:$rs1, GPR:$rs2)>; - def: Pat<(XLenVecI8VT (srl GPR:$rs1, - (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, GPR:$rs2)), (PSRL_BS GPR:$rs1, GPR:$rs2)>; // 8-bit arithmetic shift left/right - def: Pat<(XLenVecI8VT (sra GPR:$rs1, - (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, GPR:$rs2)), (PSRA_BS GPR:$rs1, GPR:$rs2)>; // 16-bit logical shift left/right - def: Pat<(XLenVecI16VT (shl GPR:$rs1, - (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, GPR:$rs2)), (PSLL_HS GPR:$rs1, GPR:$rs2)>; - def: Pat<(XLenVecI16VT (srl GPR:$rs1, - (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, GPR:$rs2)), (PSRL_HS GPR:$rs1, GPR:$rs2)>; // 16-bit arithmetic shift left/right - def: Pat<(XLenVecI16VT (sra GPR:$rs1, - (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, GPR:$rs2)), (PSRA_HS GPR:$rs1, GPR:$rs2)>; // 8-bit PLI SD node pattern @@ -1973,14 +1974,28 @@ let Predicates = [HasStdExtP, IsRV64] in { def: Pat<(v2i32 (mul GPR:$rs1, GPR:$rs2)), (PACK (MUL_W00 GPR:$rs1, GPR:$rs2), (MUL_W11 GPR:$rs1, GPR:$rs2))>; + // 32-bit logical shift left/right patterns + def: Pat<(v2i32 (riscv_pshl GPR:$rs1, uimm5:$shamt)), + (PSLLI_W GPR:$rs1, uimm5:$shamt)>; + def: Pat<(v2i32 (riscv_psrl GPR:$rs1, uimm5:$shamt)), + (PSRLI_W GPR:$rs1, uimm5:$shamt)>; + + // 32-bit arithmetic shift left/right patterns + def: Pat<(v2i32 (riscv_psra GPR:$rs1, uimm5:$shamt)), + (PSRAI_W GPR:$rs1, uimm5:$shamt)>; + + // 32-bit signed saturation shift left patterns + def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))), + (PSSLAI_W GPR:$rs1, uimm5:$shamt)>; + // 32-bit logical shift left/right - def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(v2i32 (riscv_pshl GPR:$rs1, GPR:$rs2)), (PSLL_WS GPR:$rs1, GPR:$rs2)>; - def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(v2i32 (riscv_psrl GPR:$rs1, GPR:$rs2)), (PSRL_WS GPR:$rs1, GPR:$rs2)>; // 32-bit arithmetic shift left/right - def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))), + def: Pat<(v2i32 (riscv_psra GPR:$rs1, GPR:$rs2)), (PSRA_WS GPR:$rs1, GPR:$rs2)>; // splat pattern @@ -2007,20 +2022,6 @@ let Predicates = [HasStdExtP, IsRV64] in { def: Pat<(v2i32 (smax GPR:$rs1, GPR:$rs2)), (PMAX_W GPR:$rs1, GPR:$rs2)>; def: Pat<(v2i32 (umax GPR:$rs1, GPR:$rs2)), (PMAXU_W GPR:$rs1, GPR:$rs2)>; - // 32-bit logical shift left/right patterns - def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))), - (PSLLI_W GPR:$rs1, uimm5:$shamt)>; - def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))), - (PSRLI_W GPR:$rs1, uimm5:$shamt)>; - - // 32-bit arithmetic shift left/right patterns - def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))), - (PSRAI_W GPR:$rs1, uimm5:$shamt)>; - - // 32-bit signed saturation shift left patterns - def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))), - (PSSLAI_W GPR:$rs1, uimm5:$shamt)>; - // 32-bit vselect patterns def: Pat<(v2i32 (vselect (v2i32 GPR:$mask), GPR:$true_v, GPR:$false_v)), (MERGE GPR:$mask, GPR:$false_v, GPR:$true_v)>; |
