aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp25
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoP.td71
2 files changed, 55 insertions, 41 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 08f3ac4..5c4b1f3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8859,14 +8859,27 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::SRA:
if (Op.getSimpleValueType().isFixedLengthVector()) {
if (Subtarget.hasStdExtP()) {
- // We have patterns for scalar/immediate shift amount, so no lowering
- // needed.
- if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR)
- return Op;
-
// There's no vector-vector version of shift instruction in P extension
// so we need to unroll to scalar computation and pack them back.
- return DAG.UnrollVectorOp(Op.getNode());
+ if (Op.getOperand(1)->getOpcode() != ISD::SPLAT_VECTOR)
+ return DAG.UnrollVectorOp(Op.getNode());
+
+ unsigned Opc;
+ switch (Op.getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case ISD::SHL:
+ Opc = RISCVISD::PSHL;
+ break;
+ case ISD::SRL:
+ Opc = RISCVISD::PSRL;
+ break;
+ case ISD::SRA:
+ Opc = RISCVISD::PSRA;
+ break;
+ }
+ return DAG.getNode(Opc, SDLoc(Op), Op.getValueType(), Op.getOperand(0),
+ Op.getOperand(1).getOperand(0));
}
return lowerToScalableOp(Op, DAG);
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 3066732..2e8e4c9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1658,6 +1658,13 @@ def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>;
def riscv_mulhru : RVSDNode<"MULHRU", SDTIntBinOp>;
def riscv_mulhrsu : RVSDNode<"MULHRSU", SDTIntBinOp>;
+def STD_RISCVPackedShift : SDTypeProfile<1, 2, [SDTCisVec<0>,
+ SDTCisSameAs<0, 1>,
+ SDTCisVT<2, XLenVT>]>;
+def riscv_pshl : RVSDNode<"PSHL", STD_RISCVPackedShift>;
+def riscv_psrl : RVSDNode<"PSRL", STD_RISCVPackedShift>;
+def riscv_psra : RVSDNode<"PSRA", STD_RISCVPackedShift>;
+
// Bitwise merge: res = (~op0 & op1) | (op0 & op2)
def SDT_RISCVMERGE : SDTypeProfile<1, 3, [SDTCisInt<0>,
SDTCisSameAs<0, 1>,
@@ -1766,23 +1773,23 @@ let Predicates = [HasStdExtP] in {
def: Pat<(XLenVecI16VT (riscv_mulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
// 8-bit logical shift left/right patterns
- def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+ def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, uimm3:$shamt)),
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
- def: Pat<(XLenVecI8VT (srl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+ def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, uimm3:$shamt)),
(PSRLI_B GPR:$rs1, uimm3:$shamt)>;
// 16-bit logical shift left/right patterns
- def: Pat<(XLenVecI16VT (shl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+ def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, uimm4:$shamt)),
(PSLLI_H GPR:$rs1, uimm4:$shamt)>;
- def: Pat<(XLenVecI16VT (srl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+ def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, uimm4:$shamt)),
(PSRLI_H GPR:$rs1, uimm4:$shamt)>;
// 8-bit arithmetic shift right patterns
- def: Pat<(XLenVecI8VT (sra GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+ def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, uimm3:$shamt)),
(PSRAI_B GPR:$rs1, uimm3:$shamt)>;
// 16-bit arithmetic shift right patterns
- def: Pat<(XLenVecI16VT (sra GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+ def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, uimm4:$shamt)),
(PSRAI_H GPR:$rs1, uimm4:$shamt)>;
// 16-bit signed saturation shift left patterns
@@ -1790,29 +1797,23 @@ let Predicates = [HasStdExtP] in {
(PSSLAI_H GPR:$rs1, uimm4:$shamt)>;
// 8-bit logical shift left/right
- def: Pat<(XLenVecI8VT (shl GPR:$rs1,
- (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, GPR:$rs2)),
(PSLL_BS GPR:$rs1, GPR:$rs2)>;
- def: Pat<(XLenVecI8VT (srl GPR:$rs1,
- (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, GPR:$rs2)),
(PSRL_BS GPR:$rs1, GPR:$rs2)>;
// 8-bit arithmetic shift left/right
- def: Pat<(XLenVecI8VT (sra GPR:$rs1,
- (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, GPR:$rs2)),
(PSRA_BS GPR:$rs1, GPR:$rs2)>;
// 16-bit logical shift left/right
- def: Pat<(XLenVecI16VT (shl GPR:$rs1,
- (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, GPR:$rs2)),
(PSLL_HS GPR:$rs1, GPR:$rs2)>;
- def: Pat<(XLenVecI16VT (srl GPR:$rs1,
- (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, GPR:$rs2)),
(PSRL_HS GPR:$rs1, GPR:$rs2)>;
// 16-bit arithmetic shift left/right
- def: Pat<(XLenVecI16VT (sra GPR:$rs1,
- (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, GPR:$rs2)),
(PSRA_HS GPR:$rs1, GPR:$rs2)>;
// 8-bit PLI SD node pattern
@@ -1973,14 +1974,28 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (mul GPR:$rs1, GPR:$rs2)),
(PACK (MUL_W00 GPR:$rs1, GPR:$rs2), (MUL_W11 GPR:$rs1, GPR:$rs2))>;
+ // 32-bit logical shift left/right patterns
+ def: Pat<(v2i32 (riscv_pshl GPR:$rs1, uimm5:$shamt)),
+ (PSLLI_W GPR:$rs1, uimm5:$shamt)>;
+ def: Pat<(v2i32 (riscv_psrl GPR:$rs1, uimm5:$shamt)),
+ (PSRLI_W GPR:$rs1, uimm5:$shamt)>;
+
+ // 32-bit arithmetic shift left/right patterns
+ def: Pat<(v2i32 (riscv_psra GPR:$rs1, uimm5:$shamt)),
+ (PSRAI_W GPR:$rs1, uimm5:$shamt)>;
+
+ // 32-bit signed saturation shift left patterns
+ def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
+ (PSSLAI_W GPR:$rs1, uimm5:$shamt)>;
+
// 32-bit logical shift left/right
- def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(v2i32 (riscv_pshl GPR:$rs1, GPR:$rs2)),
(PSLL_WS GPR:$rs1, GPR:$rs2)>;
- def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(v2i32 (riscv_psrl GPR:$rs1, GPR:$rs2)),
(PSRL_WS GPR:$rs1, GPR:$rs2)>;
// 32-bit arithmetic shift left/right
- def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+ def: Pat<(v2i32 (riscv_psra GPR:$rs1, GPR:$rs2)),
(PSRA_WS GPR:$rs1, GPR:$rs2)>;
// splat pattern
@@ -2007,20 +2022,6 @@ let Predicates = [HasStdExtP, IsRV64] in {
def: Pat<(v2i32 (smax GPR:$rs1, GPR:$rs2)), (PMAX_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (umax GPR:$rs1, GPR:$rs2)), (PMAXU_W GPR:$rs1, GPR:$rs2)>;
- // 32-bit logical shift left/right patterns
- def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
- (PSLLI_W GPR:$rs1, uimm5:$shamt)>;
- def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
- (PSRLI_W GPR:$rs1, uimm5:$shamt)>;
-
- // 32-bit arithmetic shift left/right patterns
- def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
- (PSRAI_W GPR:$rs1, uimm5:$shamt)>;
-
- // 32-bit signed saturation shift left patterns
- def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
- (PSSLAI_W GPR:$rs1, uimm5:$shamt)>;
-
// 32-bit vselect patterns
def: Pat<(v2i32 (vselect (v2i32 GPR:$mask), GPR:$true_v, GPR:$false_v)),
(MERGE GPR:$mask, GPR:$false_v, GPR:$true_v)>;