diff options
| author | Craig Topper <craig.topper@sifive.com> | 2026-02-11 22:37:47 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-11 22:37:47 -0800 |
| commit | db588931c5bca3f09965689b883a56abf322bbd4 (patch) | |
| tree | 44a7793b552bd6271d3f6fee8078186eecb5fa7a /llvm/lib | |
| parent | 4323b3604cf6a26afbefcdb79edbd4a9a2e0a000 (diff) | |
| download | llvm-db588931c5bca3f09965689b883a56abf322bbd4.tar.gz llvm-db588931c5bca3f09965689b883a56abf322bbd4.tar.bz2 llvm-db588931c5bca3f09965689b883a56abf322bbd4.zip | |
[RISCV] Use NSRL/NSRA for legalizing i64 shifts with P extension on RV32. (#181040)
If the shift amount might be in the range [0, 31], we can use
NSRL/NSRA to shift the i64 value to compute the lower 32 bits of
the result.
If the shift amount is >= 32, the high half of the result is all
zeros or sign bits. Otherwise it is a srl/sra of the high bits.
I've handled the constant case in ReplaceNodeResults but deferred
the non-constant case to lowerShiftRightParts. This function is
not called for constants. This gives the opportunity for DAGCombine to
optimize the SRL_PARTS/SRA_PARTS if the shift amount can be proven
to be >= 32 or < 32.
Sequences were also discussed on the P extension mailing list here
https://lists.riscv.org/g/tech-p-ext/message/861
Assisted-by: claude
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 34 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 65 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoP.td | 9 |
3 files changed, 103 insertions, 5 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index db65d6ac1a5d..b7dfd7e24338 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -1791,6 +1791,40 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { CurDAG->RemoveDeadNode(Node); return; } + case RISCVISD::NSRL: + case RISCVISD::NSRA: { + assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 && + "Unexpected opcode"); + + bool IsSRA = Node->getOpcode() == RISCVISD::NSRA; + SDValue Lo = Node->getOperand(0); + SDValue Hi = Node->getOperand(1); + SDValue ShAmt = Node->getOperand(2); + + SDValue Ops[] = { + CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Lo, + CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Hi, + CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)}; + SDValue Pair = SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, + DL, MVT::Untyped, Ops), + 0); + + MachineSDNode *Res; + if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) { + unsigned Opc = IsSRA ? RISCV::NSRAI : RISCV::NSRLI; + Res = CurDAG->getMachineNode( + Opc, DL, MVT::i32, Pair, + CurDAG->getTargetConstant(*ShAmtC->getConstantIntValue(), DL, + MVT::i32)); + } else { + // NSRL/NSRA only read 6 bits of the shift amount. + selectShiftMask(ShAmt, 64, ShAmt); + unsigned Opc = IsSRA ? RISCV::NSRA : RISCV::NSRL; + Res = CurDAG->getMachineNode(Opc, DL, MVT::i32, Pair, ShAmt); + } + ReplaceNode(Node, Res); + return; + } case ISD::LOAD: { if (tryIndexedLoad(Node)) return; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f0fef6d8b66f..944354a0fd2b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -358,10 +358,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, MVT::i32, Custom); setOperationAction({ISD::UADDO, ISD::USUBO}, MVT::i32, Custom); setOperationAction({ISD::SADDO, ISD::SSUBO}, MVT::i32, Custom); - } else { - // Custom legalize i64 ADD/SUB for RV32+P. - if (Subtarget.hasStdExtP()) - setOperationAction({ISD::ADD, ISD::SUB}, MVT::i64, Custom); + } else if (Subtarget.hasStdExtP()) { + // Custom legalize i64 ADD/SUB/SRL/SRA for RV32+P. + setOperationAction({ISD::ADD, ISD::SUB}, MVT::i64, Custom); + setOperationAction({ISD::SRL, ISD::SRA}, MVT::i64, Custom); } if (!Subtarget.hasStdExtZmmul()) { setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand); @@ -10294,6 +10294,37 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, SDValue Shamt = Op.getOperand(2); EVT VT = Lo.getValueType(); + // With P extension on RV32, use NSRL/NSRA for the low part. + if (Subtarget.hasStdExtP() && !Subtarget.is64Bit()) { + SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT, + Lo, Hi, Shamt); + // Mask shift amount to avoid UB when Shamt >= 32. + SDValue ShamtMasked = + DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(31, DL, VT)); + SDValue HiRes = + DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked); + + // Create a mask that is -1 when Shamt >= 32, 0 otherwise. + // FIXME: We should use a select and let LowerSelect make the + // optimizations. + SDValue ShAmtExt = + DAG.getNode(ISD::SHL, DL, VT, Shamt, DAG.getConstant(26, DL, VT)); + SDValue Mask = + DAG.getNode(ISD::SRA, DL, VT, ShAmtExt, DAG.getConstant(31, DL, VT)); + + if (IsSRA) { + // sra hi, hi, (mask & 31) - shifts by 31 when shamt >= 32 + SDValue MaskAmt = + DAG.getNode(ISD::AND, DL, VT, Mask, DAG.getConstant(31, DL, VT)); + HiRes = DAG.getNode(ISD::SRA, DL, VT, HiRes, MaskAmt); + } else { + // andn hi, hi, mask - clears hi when shamt >= 32 + HiRes = DAG.getNode(ISD::AND, DL, VT, HiRes, DAG.getNOT(DL, Mask, VT)); + } + + return DAG.getMergeValues({LoRes, HiRes}, DL); + } + // SRA expansion: // if Shamt-XLEN < 0: // Shamt < XLEN // Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt)) @@ -15275,7 +15306,31 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, break; } - assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && + if (VT == MVT::i64) { + assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() && + "Unexpected custom legalisation"); + assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) && + "Unexpected custom legalisation"); + + // Only handle constant shifts < 32. Non-constant shifts are handled by + // lowerShiftRightParts, and shifts >= 32 use default legalization. + auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!ShAmtC || ShAmtC->getZExtValue() >= 32) + break; + + auto [Lo, Hi] = DAG.SplitScalar(N->getOperand(0), DL, MVT::i32, MVT::i32); + + bool IsSRA = N->getOpcode() == ISD::SRA; + SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, + MVT::i32, Lo, Hi, N->getOperand(1)); + SDValue HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi, + N->getOperand(1)); + SDValue Res = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, LoRes, HiRes); + Results.push_back(Res); + return; + } + + assert(VT == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); if (N->getOperand(1).getOpcode() != ISD::Constant) { // If we can use a BSET instruction, allow default promotion to apply. diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 9ea45e77bccf..e39e400f7d36 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1495,6 +1495,15 @@ def riscv_subd : RVSDNode<"SUBD", SDT_RISCVIntBinOpD>; def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>; +// Narrowing shift: res = nsrl(lo, hi, shamt) is equivalent to +// res = truncate (srl (build_pair lo, hi), shamt), XLenVT +def SDT_RISCVNarrowingShift : SDTypeProfile<1, 3, [SDTCisVT<0, i32>, + SDTCisVT<1, i32>, + SDTCisVT<2, i32>, + SDTCisVT<3, i32>]>; +def riscv_nsrl : RVSDNode<"NSRL", SDT_RISCVNarrowingShift>; +def riscv_nsra : RVSDNode<"NSRA", SDT_RISCVNarrowingShift>; + // Averaging subtraction, (a - b) >> 2 def riscv_asub : RVSDNode<"ASUB", SDTIntBinOp>; def riscv_asubu : RVSDNode<"ASUBU", SDTIntBinOp>; |
