aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp228
1 files changed, 222 insertions, 6 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a3ccbd8..637f194 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -284,6 +284,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
}
+ // fixed vector is stored in GPRs for P extension packed operations
+ if (Subtarget.enablePExtCodeGen()) {
+ if (Subtarget.is64Bit()) {
+ addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
+ } else {
+ addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
+ }
+ }
+
// Compute derived properties from the register classes.
computeRegisterProperties(STI.getRegisterInfo());
@@ -492,6 +504,34 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FCANONICALIZE};
+ if (Subtarget.enablePExtCodeGen()) {
+ setTargetDAGCombine(ISD::TRUNCATE);
+ setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
+ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
+ SmallVector<MVT, 2> VTs;
+ if (Subtarget.is64Bit()) {
+ VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
+ setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand);
+ setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
+ setTruncStoreAction(MVT::v8i16, MVT::v8i8, Expand);
+ setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
+ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
+ setOperationAction(ISD::LOAD, MVT::v2i16, Custom);
+ setOperationAction(ISD::LOAD, MVT::v4i8, Custom);
+ } else {
+ VTs.append({MVT::v2i16, MVT::v4i8});
+ }
+ setOperationAction(ISD::UADDSAT, VTs, Legal);
+ setOperationAction(ISD::SADDSAT, VTs, Legal);
+ setOperationAction(ISD::USUBSAT, VTs, Legal);
+ setOperationAction(ISD::SSUBSAT, VTs, Legal);
+ setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
+ setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
+ setOperationAction(ISD::BITCAST, VTs, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
+ }
+
if (Subtarget.hasStdExtZfbfmin()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
@@ -1776,6 +1816,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
MaxLoadsPerMemcmp = Subtarget.getMaxLoadsPerMemcmp(/*OptSize=*/false);
}
+TargetLoweringBase::LegalizeTypeAction
+RISCVTargetLowering::getPreferredVectorAction(MVT VT) const {
+ if (Subtarget.is64Bit() && Subtarget.enablePExtCodeGen())
+ if (VT == MVT::v2i16 || VT == MVT::v4i8)
+ return TypeWidenVector;
+
+ return TargetLoweringBase::getPreferredVectorAction(VT);
+}
+
EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL,
LLVMContext &Context,
EVT VT) const {
@@ -4391,6 +4440,37 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT XLenVT = Subtarget.getXLenVT();
SDLoc DL(Op);
+ // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
+ if (Subtarget.enablePExtCodeGen()) {
+ bool IsPExtVector =
+ (VT == MVT::v2i16 || VT == MVT::v4i8) ||
+ (Subtarget.is64Bit() &&
+ (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
+ if (IsPExtVector) {
+ if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
+ if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
+ int64_t SplatImm = C->getSExtValue();
+ bool IsValidImm = false;
+
+ // Check immediate range based on vector type
+ if (VT == MVT::v8i8 || VT == MVT::v4i8) {
+ // PLI_B uses 8-bit unsigned or unsigned immediate
+ IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm);
+ if (isUInt<8>(SplatImm))
+ SplatImm = (int8_t)SplatImm;
+ } else {
+ // PLI_H and PLI_W use 10-bit signed immediate
+ IsValidImm = isInt<10>(SplatImm);
+ }
+
+ if (IsValidImm) {
+ SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT);
+ return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
+ }
+ }
+ }
+ }
+ }
// Proper support for f16 requires Zvfh. bf16 always requires special
// handling. We need to cast the scalar to integer and create an integer
@@ -7546,6 +7626,19 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
}
+ if (Subtarget.enablePExtCodeGen()) {
+ bool Is32BitCast =
+ (VT == MVT::i32 && (Op0VT == MVT::v4i8 || Op0VT == MVT::v2i16)) ||
+ (Op0VT == MVT::i32 && (VT == MVT::v4i8 || VT == MVT::v2i16));
+ bool Is64BitCast =
+ (VT == MVT::i64 && (Op0VT == MVT::v8i8 || Op0VT == MVT::v4i16 ||
+ Op0VT == MVT::v2i32)) ||
+ (Op0VT == MVT::i64 &&
+ (VT == MVT::v8i8 || VT == MVT::v4i16 || VT == MVT::v2i32));
+ if (Is32BitCast || Is64BitCast)
+ return Op;
+ }
+
// Consider other scalar<->scalar casts as legal if the types are legal.
// Otherwise expand them.
if (!VT.isVector() && !Op0VT.isVector()) {
@@ -8218,6 +8311,17 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
auto *Store = cast<StoreSDNode>(Op);
SDValue StoredVal = Store->getValue();
EVT VT = StoredVal.getValueType();
+ if (Subtarget.enablePExtCodeGen()) {
+ if (VT == MVT::v2i16 || VT == MVT::v4i8) {
+ SDValue DL(Op);
+ SDValue Cast = DAG.getBitcast(MVT::i32, StoredVal);
+ SDValue NewStore =
+ DAG.getStore(Store->getChain(), DL, Cast, Store->getBasePtr(),
+ Store->getPointerInfo(), Store->getBaseAlign(),
+ Store->getMemOperand()->getFlags());
+ return NewStore;
+ }
+ }
if (VT == MVT::f64) {
assert(Subtarget.hasStdExtZdinx() && !Subtarget.hasStdExtZilsd() &&
!Subtarget.is64Bit() && "Unexpected custom legalisation");
@@ -10500,6 +10604,17 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
return DAG.getNode(RISCVISD::FMV_H_X, DL, EltVT, IntExtract);
}
+ if (Subtarget.enablePExtCodeGen() && VecVT.isFixedLengthVector()) {
+ if (VecVT != MVT::v4i16 && VecVT != MVT::v2i16 && VecVT != MVT::v8i8 &&
+ VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
+ return SDValue();
+ SDValue Extracted = DAG.getBitcast(XLenVT, Vec);
+ unsigned ElemWidth = EltVT.getSizeInBits();
+ SDValue Shamt = DAG.getNode(ISD::MUL, DL, XLenVT, Idx,
+ DAG.getConstant(ElemWidth, DL, XLenVT));
+ return DAG.getNode(ISD::SRL, DL, XLenVT, Extracted, Shamt);
+ }
+
// If this is a fixed vector, we need to convert it to a scalable vector.
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
@@ -14642,6 +14757,21 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
return;
}
+ if (Subtarget.is64Bit() && Subtarget.enablePExtCodeGen()) {
+ SDLoc DL(N);
+ SDValue ExtLoad =
+ DAG.getExtLoad(ISD::SEXTLOAD, DL, MVT::i64, Ld->getChain(),
+ Ld->getBasePtr(), MVT::i32, Ld->getMemOperand());
+ if (N->getValueType(0) == MVT::v2i16) {
+ Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad));
+ Results.push_back(ExtLoad.getValue(1));
+ } else if (N->getValueType(0) == MVT::v4i8) {
+ Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad));
+ Results.push_back(ExtLoad.getValue(1));
+ }
+ return;
+ }
+
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
@@ -14997,6 +15127,21 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes));
break;
}
+ case RISCVISD::PASUB:
+ case RISCVISD::PASUBU: {
+ MVT VT = N->getSimpleValueType(0);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ assert(VT == MVT::v2i16 || VT == MVT::v4i8);
+ MVT NewVT = MVT::v4i16;
+ if (VT == MVT::v4i8)
+ NewVT = MVT::v8i8;
+ SDValue Undef = DAG.getUNDEF(VT);
+ Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
+ Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
+ Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
+ return;
+ }
case ISD::EXTRACT_VECTOR_ELT: {
// Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
// type is illegal (currently only vXi64 RV32).
@@ -16104,11 +16249,84 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}
+// Handle P extension averaging subtraction pattern:
+// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
+// -> PASUB/PASUBU
+static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+ if (N0.getOpcode() != ISD::SRL)
+ return SDValue();
+
+ MVT VecVT = VT.getSimpleVT();
+ if (VecVT != MVT::v4i16 && VecVT != MVT::v2i16 && VecVT != MVT::v8i8 &&
+ VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
+ return SDValue();
+
+ // Check if shift amount is 1
+ SDValue ShAmt = N0.getOperand(1);
+ if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(ShAmt.getNode());
+ if (!BV)
+ return SDValue();
+ SDValue Splat = BV->getSplatValue();
+ if (!Splat)
+ return SDValue();
+ ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
+ if (!C)
+ return SDValue();
+ if (C->getZExtValue() != 1)
+ return SDValue();
+
+ // Check for SUB operation
+ SDValue Sub = N0.getOperand(0);
+ if (Sub.getOpcode() != ISD::SUB)
+ return SDValue();
+
+ SDValue LHS = Sub.getOperand(0);
+ SDValue RHS = Sub.getOperand(1);
+
+ // Check if both operands are sign/zero extends from the target
+ // type
+ bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
+ RHS.getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
+ RHS.getOpcode() == ISD::ZERO_EXTEND;
+
+ if (!IsSignExt && !IsZeroExt)
+ return SDValue();
+
+ SDValue A = LHS.getOperand(0);
+ SDValue B = RHS.getOperand(0);
+
+ // Check if the extends are from our target vector type
+ if (A.getValueType() != VT || B.getValueType() != VT)
+ return SDValue();
+
+ // Determine the instruction based on type and signedness
+ unsigned Opc;
+ if (IsSignExt)
+ Opc = RISCVISD::PASUB;
+ else if (IsZeroExt)
+ Opc = RISCVISD::PASUBU;
+ else
+ return SDValue();
+
+ // Create the machine node directly
+ return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
+}
+
static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
+ if (VT.isFixedLengthVector() && Subtarget.enablePExtCodeGen())
+ return combinePExtTruncate(N, DAG, Subtarget);
+
// Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero
// extending X. This is safe since we only need the LSB after the shift and
// shift amounts larger than 31 would produce poison. If we wait until
@@ -22203,8 +22421,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
MachineFunction &MF = *BB->getParent();
DebugLoc DL = MI.getDebugLoc();
- const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
- const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
+ const RISCVInstrInfo &TII = *MF.getSubtarget<RISCVSubtarget>().getInstrInfo();
Register LoReg = MI.getOperand(0).getReg();
Register HiReg = MI.getOperand(1).getReg();
Register SrcReg = MI.getOperand(2).getReg();
@@ -22213,7 +22430,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
TII.storeRegToStackSlot(*BB, MI, SrcReg, MI.getOperand(2).isKill(), FI, SrcRC,
- RI, Register());
+ Register());
MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
MachineMemOperand *MMOLo =
MF.getMachineMemOperand(MPI, MachineMemOperand::MOLoad, 4, Align(8));
@@ -22239,8 +22456,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
MachineFunction &MF = *BB->getParent();
DebugLoc DL = MI.getDebugLoc();
- const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
- const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo();
+ const RISCVInstrInfo &TII = *MF.getSubtarget<RISCVSubtarget>().getInstrInfo();
Register DstReg = MI.getOperand(0).getReg();
Register LoReg = MI.getOperand(1).getReg();
Register HiReg = MI.getOperand(2).getReg();
@@ -22263,7 +22479,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
.addFrameIndex(FI)
.addImm(4)
.addMemOperand(MMOHi);
- TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, RI, Register());
+ TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, Register());
MI.eraseFromParent(); // The pseudo instruction is gone now.
return BB;
}