diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 228 |
1 files changed, 222 insertions, 6 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index a3ccbd8..637f194 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -284,6 +284,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass); } + // fixed vector is stored in GPRs for P extension packed operations + if (Subtarget.enablePExtCodeGen()) { + if (Subtarget.is64Bit()) { + addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass); + addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass); + addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass); + } else { + addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass); + addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass); + } + } + // Compute derived properties from the register classes. computeRegisterProperties(STI.getRegisterInfo()); @@ -492,6 +504,34 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::FTRUNC, ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN, ISD::FCANONICALIZE}; + if (Subtarget.enablePExtCodeGen()) { + setTargetDAGCombine(ISD::TRUNCATE); + setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); + setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand); + SmallVector<MVT, 2> VTs; + if (Subtarget.is64Bit()) { + VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8}); + setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand); + setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand); + setTruncStoreAction(MVT::v8i16, MVT::v8i8, Expand); + setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); + setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand); + setOperationAction(ISD::LOAD, MVT::v2i16, Custom); + setOperationAction(ISD::LOAD, MVT::v4i8, Custom); + } else { + VTs.append({MVT::v2i16, MVT::v4i8}); + } + setOperationAction(ISD::UADDSAT, VTs, Legal); + setOperationAction(ISD::SADDSAT, VTs, Legal); + setOperationAction(ISD::USUBSAT, VTs, Legal); + setOperationAction(ISD::SSUBSAT, VTs, Legal); + setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal); + setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal); + setOperationAction(ISD::BUILD_VECTOR, VTs, Custom); + setOperationAction(ISD::BITCAST, VTs, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom); + } + if (Subtarget.hasStdExtZfbfmin()) { setOperationAction(ISD::BITCAST, MVT::i16, Custom); setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); @@ -1776,6 +1816,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, MaxLoadsPerMemcmp = Subtarget.getMaxLoadsPerMemcmp(/*OptSize=*/false); } +TargetLoweringBase::LegalizeTypeAction +RISCVTargetLowering::getPreferredVectorAction(MVT VT) const { + if (Subtarget.is64Bit() && Subtarget.enablePExtCodeGen()) + if (VT == MVT::v2i16 || VT == MVT::v4i8) + return TypeWidenVector; + + return TargetLoweringBase::getPreferredVectorAction(VT); +} + EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &Context, EVT VT) const { @@ -4391,6 +4440,37 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, MVT XLenVT = Subtarget.getXLenVT(); SDLoc DL(Op); + // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants + if (Subtarget.enablePExtCodeGen()) { + bool IsPExtVector = + (VT == MVT::v2i16 || VT == MVT::v4i8) || + (Subtarget.is64Bit() && + (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32)); + if (IsPExtVector) { + if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) { + if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) { + int64_t SplatImm = C->getSExtValue(); + bool IsValidImm = false; + + // Check immediate range based on vector type + if (VT == MVT::v8i8 || VT == MVT::v4i8) { + // PLI_B uses 8-bit unsigned or unsigned immediate + IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm); + if (isUInt<8>(SplatImm)) + SplatImm = (int8_t)SplatImm; + } else { + // PLI_H and PLI_W use 10-bit signed immediate + IsValidImm = isInt<10>(SplatImm); + } + + if (IsValidImm) { + SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT); + return DAG.getNode(RISCVISD::PLI, DL, VT, Imm); + } + } + } + } + } // Proper support for f16 requires Zvfh. bf16 always requires special // handling. We need to cast the scalar to integer and create an integer @@ -7546,6 +7626,19 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi); } + if (Subtarget.enablePExtCodeGen()) { + bool Is32BitCast = + (VT == MVT::i32 && (Op0VT == MVT::v4i8 || Op0VT == MVT::v2i16)) || + (Op0VT == MVT::i32 && (VT == MVT::v4i8 || VT == MVT::v2i16)); + bool Is64BitCast = + (VT == MVT::i64 && (Op0VT == MVT::v8i8 || Op0VT == MVT::v4i16 || + Op0VT == MVT::v2i32)) || + (Op0VT == MVT::i64 && + (VT == MVT::v8i8 || VT == MVT::v4i16 || VT == MVT::v2i32)); + if (Is32BitCast || Is64BitCast) + return Op; + } + // Consider other scalar<->scalar casts as legal if the types are legal. // Otherwise expand them. if (!VT.isVector() && !Op0VT.isVector()) { @@ -8218,6 +8311,17 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, auto *Store = cast<StoreSDNode>(Op); SDValue StoredVal = Store->getValue(); EVT VT = StoredVal.getValueType(); + if (Subtarget.enablePExtCodeGen()) { + if (VT == MVT::v2i16 || VT == MVT::v4i8) { + SDValue DL(Op); + SDValue Cast = DAG.getBitcast(MVT::i32, StoredVal); + SDValue NewStore = + DAG.getStore(Store->getChain(), DL, Cast, Store->getBasePtr(), + Store->getPointerInfo(), Store->getBaseAlign(), + Store->getMemOperand()->getFlags()); + return NewStore; + } + } if (VT == MVT::f64) { assert(Subtarget.hasStdExtZdinx() && !Subtarget.hasStdExtZilsd() && !Subtarget.is64Bit() && "Unexpected custom legalisation"); @@ -10500,6 +10604,17 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, return DAG.getNode(RISCVISD::FMV_H_X, DL, EltVT, IntExtract); } + if (Subtarget.enablePExtCodeGen() && VecVT.isFixedLengthVector()) { + if (VecVT != MVT::v4i16 && VecVT != MVT::v2i16 && VecVT != MVT::v8i8 && + VecVT != MVT::v4i8 && VecVT != MVT::v2i32) + return SDValue(); + SDValue Extracted = DAG.getBitcast(XLenVT, Vec); + unsigned ElemWidth = EltVT.getSizeInBits(); + SDValue Shamt = DAG.getNode(ISD::MUL, DL, XLenVT, Idx, + DAG.getConstant(ElemWidth, DL, XLenVT)); + return DAG.getNode(ISD::SRL, DL, XLenVT, Extracted, Shamt); + } + // If this is a fixed vector, we need to convert it to a scalable vector. MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { @@ -14642,6 +14757,21 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, return; } + if (Subtarget.is64Bit() && Subtarget.enablePExtCodeGen()) { + SDLoc DL(N); + SDValue ExtLoad = + DAG.getExtLoad(ISD::SEXTLOAD, DL, MVT::i64, Ld->getChain(), + Ld->getBasePtr(), MVT::i32, Ld->getMemOperand()); + if (N->getValueType(0) == MVT::v2i16) { + Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad)); + Results.push_back(ExtLoad.getValue(1)); + } else if (N->getValueType(0) == MVT::v4i8) { + Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad)); + Results.push_back(ExtLoad.getValue(1)); + } + return; + } + assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); @@ -14997,6 +15127,21 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes)); break; } + case RISCVISD::PASUB: + case RISCVISD::PASUBU: { + MVT VT = N->getSimpleValueType(0); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + assert(VT == MVT::v2i16 || VT == MVT::v4i8); + MVT NewVT = MVT::v4i16; + if (VT == MVT::v4i8) + NewVT = MVT::v8i8; + SDValue Undef = DAG.getUNDEF(VT); + Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef}); + Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef}); + Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1})); + return; + } case ISD::EXTRACT_VECTOR_ELT: { // Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element // type is illegal (currently only vXi64 RV32). @@ -16104,11 +16249,84 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::TRUNCATE, DL, VT, Min); } +// Handle P extension averaging subtraction pattern: +// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1))) +// -> PASUB/PASUBU +static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + if (N0.getOpcode() != ISD::SRL) + return SDValue(); + + MVT VecVT = VT.getSimpleVT(); + if (VecVT != MVT::v4i16 && VecVT != MVT::v2i16 && VecVT != MVT::v8i8 && + VecVT != MVT::v4i8 && VecVT != MVT::v2i32) + return SDValue(); + + // Check if shift amount is 1 + SDValue ShAmt = N0.getOperand(1); + if (ShAmt.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(ShAmt.getNode()); + if (!BV) + return SDValue(); + SDValue Splat = BV->getSplatValue(); + if (!Splat) + return SDValue(); + ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat); + if (!C) + return SDValue(); + if (C->getZExtValue() != 1) + return SDValue(); + + // Check for SUB operation + SDValue Sub = N0.getOperand(0); + if (Sub.getOpcode() != ISD::SUB) + return SDValue(); + + SDValue LHS = Sub.getOperand(0); + SDValue RHS = Sub.getOperand(1); + + // Check if both operands are sign/zero extends from the target + // type + bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND && + RHS.getOpcode() == ISD::SIGN_EXTEND; + bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND && + RHS.getOpcode() == ISD::ZERO_EXTEND; + + if (!IsSignExt && !IsZeroExt) + return SDValue(); + + SDValue A = LHS.getOperand(0); + SDValue B = RHS.getOperand(0); + + // Check if the extends are from our target vector type + if (A.getValueType() != VT || B.getValueType() != VT) + return SDValue(); + + // Determine the instruction based on type and signedness + unsigned Opc; + if (IsSignExt) + Opc = RISCVISD::PASUB; + else if (IsZeroExt) + Opc = RISCVISD::PASUBU; + else + return SDValue(); + + // Create the machine node directly + return DAG.getNode(Opc, SDLoc(N), VT, {A, B}); +} + static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + if (VT.isFixedLengthVector() && Subtarget.enablePExtCodeGen()) + return combinePExtTruncate(N, DAG, Subtarget); + // Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero // extending X. This is safe since we only need the LSB after the shift and // shift amounts larger than 31 would produce poison. If we wait until @@ -22203,8 +22421,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI, MachineFunction &MF = *BB->getParent(); DebugLoc DL = MI.getDebugLoc(); - const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); - const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo(); + const RISCVInstrInfo &TII = *MF.getSubtarget<RISCVSubtarget>().getInstrInfo(); Register LoReg = MI.getOperand(0).getReg(); Register HiReg = MI.getOperand(1).getReg(); Register SrcReg = MI.getOperand(2).getReg(); @@ -22213,7 +22430,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI, int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF); TII.storeRegToStackSlot(*BB, MI, SrcReg, MI.getOperand(2).isKill(), FI, SrcRC, - RI, Register()); + Register()); MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI); MachineMemOperand *MMOLo = MF.getMachineMemOperand(MPI, MachineMemOperand::MOLoad, 4, Align(8)); @@ -22239,8 +22456,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI, MachineFunction &MF = *BB->getParent(); DebugLoc DL = MI.getDebugLoc(); - const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); - const TargetRegisterInfo *RI = MF.getSubtarget().getRegisterInfo(); + const RISCVInstrInfo &TII = *MF.getSubtarget<RISCVSubtarget>().getInstrInfo(); Register DstReg = MI.getOperand(0).getReg(); Register LoReg = MI.getOperand(1).getReg(); Register HiReg = MI.getOperand(2).getReg(); @@ -22263,7 +22479,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI, .addFrameIndex(FI) .addImm(4) .addMemOperand(MMOHi); - TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, RI, Register()); + TII.loadRegFromStackSlot(*BB, MI, DstReg, FI, DstRC, Register()); MI.eraseFromParent(); // The pseudo instruction is gone now. return BB; } |
