aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
authormingmingl <mingmingl@google.com>2025-02-04 11:11:14 -0800
committermingmingl <mingmingl@google.com>2025-02-04 11:11:14 -0800
commite91747a92d27ecf799427bf563f9f64f7c4d2447 (patch)
tree7aa5a8a9170deec293e152bdf2be804399dcd612 /llvm/lib/Target/RISCV/RISCVISelLowering.cpp
parent3a8d9337d816aef41c3ca1484be8b933a71a3c46 (diff)
parent53d6e59b594639417cdbfcfa2d18cea64acb4009 (diff)
downloadllvm-users/mingmingl-llvm/spr/sdpglobalvariable.zip
llvm-users/mingmingl-llvm/spr/sdpglobalvariable.tar.gz
llvm-users/mingmingl-llvm/spr/sdpglobalvariable.tar.bz2
Merge branch 'main' into users/mingmingl-llvm/spr/sdpglobalvariableusers/mingmingl-llvm/spr/sdpglobalvariable
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp81
1 files changed, 79 insertions, 2 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8e3caf5..7c3b583 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17759,6 +17759,83 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
return DAG.getZExtOrTrunc(Pop, DL, VT);
}
+static SDValue performSHLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const RISCVSubtarget &Subtarget) {
+ // (shl (zext x), y) -> (vwsll x, y)
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ return V;
+
+ // (shl (sext x), C) -> (vwmulsu x, 1u << C)
+ // (shl (zext x), C) -> (vwmulu x, 1u << C)
+
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ if (!LHS.hasOneUse())
+ return SDValue();
+ unsigned Opcode;
+ switch (LHS.getOpcode()) {
+ case ISD::SIGN_EXTEND:
+ case RISCVISD::VSEXT_VL:
+ Opcode = RISCVISD::VWMULSU_VL;
+ break;
+ case ISD::ZERO_EXTEND:
+ case RISCVISD::VZEXT_VL:
+ Opcode = RISCVISD::VWMULU_VL;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue RHS = N->getOperand(1);
+ APInt ShAmt;
+ uint64_t ShAmtInt;
+ if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
+ ShAmtInt = ShAmt.getZExtValue();
+ else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
+ RHS.getOperand(1).getOpcode() == ISD::Constant)
+ ShAmtInt = RHS.getConstantOperandVal(1);
+ else
+ return SDValue();
+
+ // Better foldings:
+ // (shl (sext x), 1) -> (vwadd x, x)
+ // (shl (zext x), 1) -> (vwaddu x, x)
+ if (ShAmtInt <= 1)
+ return SDValue();
+
+ SDValue NarrowOp = LHS.getOperand(0);
+ MVT NarrowVT = NarrowOp.getSimpleValueType();
+ uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
+ if (ShAmtInt >= NarrowBits)
+ return SDValue();
+ MVT VT = N->getSimpleValueType(0);
+ if (NarrowBits * 2 != VT.getScalarSizeInBits())
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+ SDValue Passthru, Mask, VL;
+ switch (N->getOpcode()) {
+ case ISD::SHL:
+ Passthru = DAG.getUNDEF(VT);
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ break;
+ case RISCVISD::SHL_VL:
+ Passthru = N->getOperand(2);
+ Mask = N->getOperand(3);
+ VL = N->getOperand(4);
+ break;
+ default:
+ llvm_unreachable("Expected SHL");
+ }
+ return DAG.getNode(Opcode, DL, VT, NarrowOp,
+ DAG.getConstant(1ULL << ShAmtInt, SDLoc(RHS), NarrowVT),
+ Passthru, Mask, VL);
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -18392,7 +18469,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::SHL_VL:
- if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ if (SDValue V = performSHLCombine(N, DCI, Subtarget))
return V;
[[fallthrough]];
case RISCVISD::SRA_VL:
@@ -18417,7 +18494,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SRL:
case ISD::SHL: {
if (N->getOpcode() == ISD::SHL) {
- if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ if (SDValue V = performSHLCombine(N, DCI, Subtarget))
return V;
}
SDValue ShAmt = N->getOperand(1);