aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorCraig Topper <craig.topper@sifive.com>2026-02-11 22:37:47 -0800
committerGitHub <noreply@github.com>2026-02-11 22:37:47 -0800
commitdb588931c5bca3f09965689b883a56abf322bbd4 (patch)
tree44a7793b552bd6271d3f6fee8078186eecb5fa7a /llvm/lib
parent4323b3604cf6a26afbefcdb79edbd4a9a2e0a000 (diff)
downloadllvm-db588931c5bca3f09965689b883a56abf322bbd4.tar.gz
llvm-db588931c5bca3f09965689b883a56abf322bbd4.tar.bz2
llvm-db588931c5bca3f09965689b883a56abf322bbd4.zip
[RISCV] Use NSRL/NSRA for legalizing i64 shifts with P extension on RV32. (#181040)
If the shift amount might be in the range [0, 31], we can use NSRL/NSRA to shift the i64 value to compute the lower 32 bits of the result. If the shift amount is >= 32, the high half of the result is all zeros or sign bits. Otherwise it is a srl/sra of the high bits. I've handled the constant case in ReplaceNodeResults but deferred the non-constant case to lowerShiftRightParts. This function is not called for constants. This gives the opportunity for DAGCombine to optimize the SRL_PARTS/SRA_PARTS if the shift amount can be proven to be >= 32 or < 32. Sequences were also discussed on the P extension mailing list here https://lists.riscv.org/g/tech-p-ext/message/861 Assisted-by: claude
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp34
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp65
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoP.td9
3 files changed, 103 insertions, 5 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index db65d6ac1a5d..b7dfd7e24338 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1791,6 +1791,40 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
+ case RISCVISD::NSRL:
+ case RISCVISD::NSRA: {
+ assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 &&
+ "Unexpected opcode");
+
+ bool IsSRA = Node->getOpcode() == RISCVISD::NSRA;
+ SDValue Lo = Node->getOperand(0);
+ SDValue Hi = Node->getOperand(1);
+ SDValue ShAmt = Node->getOperand(2);
+
+ SDValue Ops[] = {
+ CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Lo,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Hi,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};
+ SDValue Pair = SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE,
+ DL, MVT::Untyped, Ops),
+ 0);
+
+ MachineSDNode *Res;
+ if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
+ unsigned Opc = IsSRA ? RISCV::NSRAI : RISCV::NSRLI;
+ Res = CurDAG->getMachineNode(
+ Opc, DL, MVT::i32, Pair,
+ CurDAG->getTargetConstant(*ShAmtC->getConstantIntValue(), DL,
+ MVT::i32));
+ } else {
+ // NSRL/NSRA only read 6 bits of the shift amount.
+ selectShiftMask(ShAmt, 64, ShAmt);
+ unsigned Opc = IsSRA ? RISCV::NSRA : RISCV::NSRL;
+ Res = CurDAG->getMachineNode(Opc, DL, MVT::i32, Pair, ShAmt);
+ }
+ ReplaceNode(Node, Res);
+ return;
+ }
case ISD::LOAD: {
if (tryIndexedLoad(Node))
return;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0fef6d8b66f..944354a0fd2b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -358,10 +358,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
MVT::i32, Custom);
setOperationAction({ISD::UADDO, ISD::USUBO}, MVT::i32, Custom);
setOperationAction({ISD::SADDO, ISD::SSUBO}, MVT::i32, Custom);
- } else {
- // Custom legalize i64 ADD/SUB for RV32+P.
- if (Subtarget.hasStdExtP())
- setOperationAction({ISD::ADD, ISD::SUB}, MVT::i64, Custom);
+ } else if (Subtarget.hasStdExtP()) {
+ // Custom legalize i64 ADD/SUB/SRL/SRA for RV32+P.
+ setOperationAction({ISD::ADD, ISD::SUB}, MVT::i64, Custom);
+ setOperationAction({ISD::SRL, ISD::SRA}, MVT::i64, Custom);
}
if (!Subtarget.hasStdExtZmmul()) {
setOperationAction({ISD::MUL, ISD::MULHS, ISD::MULHU}, XLenVT, Expand);
@@ -10294,6 +10294,37 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
SDValue Shamt = Op.getOperand(2);
EVT VT = Lo.getValueType();
+ // With P extension on RV32, use NSRL/NSRA for the low part.
+ if (Subtarget.hasStdExtP() && !Subtarget.is64Bit()) {
+ SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT,
+ Lo, Hi, Shamt);
+ // Mask shift amount to avoid UB when Shamt >= 32.
+ SDValue ShamtMasked =
+ DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(31, DL, VT));
+ SDValue HiRes =
+ DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
+
+ // Create a mask that is -1 when Shamt >= 32, 0 otherwise.
+ // FIXME: We should use a select and let LowerSelect make the
+ // optimizations.
+ SDValue ShAmtExt =
+ DAG.getNode(ISD::SHL, DL, VT, Shamt, DAG.getConstant(26, DL, VT));
+ SDValue Mask =
+ DAG.getNode(ISD::SRA, DL, VT, ShAmtExt, DAG.getConstant(31, DL, VT));
+
+ if (IsSRA) {
+ // sra hi, hi, (mask & 31) - shifts by 31 when shamt >= 32
+ SDValue MaskAmt =
+ DAG.getNode(ISD::AND, DL, VT, Mask, DAG.getConstant(31, DL, VT));
+ HiRes = DAG.getNode(ISD::SRA, DL, VT, HiRes, MaskAmt);
+ } else {
+ // andn hi, hi, mask - clears hi when shamt >= 32
+ HiRes = DAG.getNode(ISD::AND, DL, VT, HiRes, DAG.getNOT(DL, Mask, VT));
+ }
+
+ return DAG.getMergeValues({LoRes, HiRes}, DL);
+ }
+
// SRA expansion:
// if Shamt-XLEN < 0: // Shamt < XLEN
// Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - ShAmt))
@@ -15275,7 +15306,31 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
break;
}
- assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
+ if (VT == MVT::i64) {
+ assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() &&
+ "Unexpected custom legalisation");
+ assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
+ "Unexpected custom legalisation");
+
+ // Only handle constant shifts < 32. Non-constant shifts are handled by
+ // lowerShiftRightParts, and shifts >= 32 use default legalization.
+ auto *ShAmtC = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!ShAmtC || ShAmtC->getZExtValue() >= 32)
+ break;
+
+ auto [Lo, Hi] = DAG.SplitScalar(N->getOperand(0), DL, MVT::i32, MVT::i32);
+
+ bool IsSRA = N->getOpcode() == ISD::SRA;
+ SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL,
+ MVT::i32, Lo, Hi, N->getOperand(1));
+ SDValue HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi,
+ N->getOperand(1));
+ SDValue Res = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, LoRes, HiRes);
+ Results.push_back(Res);
+ return;
+ }
+
+ assert(VT == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
if (N->getOperand(1).getOpcode() != ISD::Constant) {
// If we can use a BSET instruction, allow default promotion to apply.
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 9ea45e77bccf..e39e400f7d36 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1495,6 +1495,15 @@ def riscv_subd : RVSDNode<"SUBD", SDT_RISCVIntBinOpD>;
def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>;
+// Narrowing shift: res = nsrl(lo, hi, shamt) is equivalent to
+// res = truncate (srl (build_pair lo, hi), shamt), XLenVT
+def SDT_RISCVNarrowingShift : SDTypeProfile<1, 3, [SDTCisVT<0, i32>,
+ SDTCisVT<1, i32>,
+ SDTCisVT<2, i32>,
+ SDTCisVT<3, i32>]>;
+def riscv_nsrl : RVSDNode<"NSRL", SDT_RISCVNarrowingShift>;
+def riscv_nsra : RVSDNode<"NSRA", SDT_RISCVNarrowingShift>;
+
// Averaging subtraction, (a - b) >> 2
def riscv_asub : RVSDNode<"ASUB", SDTIntBinOp>;
def riscv_asubu : RVSDNode<"ASUBU", SDTIntBinOp>;