aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp166
1 files changed, 106 insertions, 60 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f7275eb..540c2e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -691,7 +691,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_FP_TO_UINT, ISD::VP_SETCC, ISD::VP_SIGN_EXTEND,
ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE, ISD::VP_SMIN,
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
- ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE};
+ ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
+ ISD::VP_SADDSAT, ISD::VP_UADDSAT, ISD::VP_SSUBSAT,
+ ISD::VP_USUBSAT};
static const unsigned FloatingPointVPOps[] = {
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
@@ -830,7 +832,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
Custom);
- setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
VT, Legal);
@@ -956,6 +957,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// between vXf16 and vXf64 must be lowered as sequences which convert via
// vXf32.
setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
+ setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
Custom);
@@ -3240,45 +3242,49 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
// Note that this method will also match potentially unappealing index
// sequences, like <i32 0, i32 50939494>, however it is left to the caller to
// determine whether this is worth generating code for.
-static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
- unsigned NumElts = Op.getNumOperands();
+static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
+ unsigned EltSizeInBits) {
assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
+ if (!cast<BuildVectorSDNode>(Op)->isConstant())
+ return std::nullopt;
bool IsInteger = Op.getValueType().isInteger();
std::optional<unsigned> SeqStepDenom;
std::optional<int64_t> SeqStepNum, SeqAddend;
std::optional<std::pair<uint64_t, unsigned>> PrevElt;
- unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
- for (unsigned Idx = 0; Idx < NumElts; Idx++) {
- // Assume undef elements match the sequence; we just have to be careful
- // when interpolating across them.
- if (Op.getOperand(Idx).isUndef())
+ assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
+
+ // First extract the ops into a list of constant integer values. This may not
+ // be possible for floats if they're not all representable as integers.
+ SmallVector<std::optional<uint64_t>> Elts(Op.getNumOperands());
+ const unsigned OpSize = Op.getScalarValueSizeInBits();
+ for (auto [Idx, Elt] : enumerate(Op->op_values())) {
+ if (Elt.isUndef()) {
+ Elts[Idx] = std::nullopt;
continue;
-
- uint64_t Val;
+ }
if (IsInteger) {
- // The BUILD_VECTOR must be all constants.
- if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
- return std::nullopt;
- Val = Op.getConstantOperandVal(Idx) &
- maskTrailingOnes<uint64_t>(EltSizeInBits);
+ Elts[Idx] = Elt->getAsZExtVal() & maskTrailingOnes<uint64_t>(OpSize);
} else {
- // The BUILD_VECTOR must be all constants.
- if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
- return std::nullopt;
- if (auto ExactInteger = getExactInteger(
- cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
- EltSizeInBits))
- Val = *ExactInteger;
- else
+ auto ExactInteger =
+ getExactInteger(cast<ConstantFPSDNode>(Elt)->getValueAPF(), OpSize);
+ if (!ExactInteger)
return std::nullopt;
+ Elts[Idx] = *ExactInteger;
}
+ }
+
+ for (auto [Idx, Elt] : enumerate(Elts)) {
+ // Assume undef elements match the sequence; we just have to be careful
+ // when interpolating across them.
+ if (!Elt)
+ continue;
if (PrevElt) {
// Calculate the step since the last non-undef element, and ensure
// it's consistent across the entire sequence.
unsigned IdxDiff = Idx - PrevElt->second;
- int64_t ValDiff = SignExtend64(Val - PrevElt->first, EltSizeInBits);
+ int64_t ValDiff = SignExtend64(*Elt - PrevElt->first, EltSizeInBits);
// A zero-value value difference means that we're somewhere in the middle
// of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a
@@ -3308,8 +3314,8 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
}
// Record this non-undef element for later.
- if (!PrevElt || PrevElt->first != Val)
- PrevElt = std::make_pair(Val, Idx);
+ if (!PrevElt || PrevElt->first != *Elt)
+ PrevElt = std::make_pair(*Elt, Idx);
}
// We need to have logged a step for this to count as a legal index sequence.
@@ -3318,21 +3324,12 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
// Loop back through the sequence and validate elements we might have skipped
// while waiting for a valid step. While doing this, log any sequence addend.
- for (unsigned Idx = 0; Idx < NumElts; Idx++) {
- if (Op.getOperand(Idx).isUndef())
+ for (auto [Idx, Elt] : enumerate(Elts)) {
+ if (!Elt)
continue;
- uint64_t Val;
- if (IsInteger) {
- Val = Op.getConstantOperandVal(Idx) &
- maskTrailingOnes<uint64_t>(EltSizeInBits);
- } else {
- Val = *getExactInteger(
- cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
- EltSizeInBits);
- }
uint64_t ExpectedVal =
(int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
- int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits);
+ int64_t Addend = SignExtend64(*Elt - ExpectedVal, EltSizeInBits);
if (!SeqAddend)
SeqAddend = Addend;
else if (Addend != SeqAddend)
@@ -3598,7 +3595,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
// Try and match index sequences, which we can lower to the vid instruction
// with optional modifications. An all-undef vector is matched by
// getSplatValue, above.
- if (auto SimpleVID = isSimpleVIDSequence(Op)) {
+ if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
int64_t StepNumerator = SimpleVID->StepNumerator;
unsigned StepDenominator = SimpleVID->StepDenominator;
int64_t Addend = SimpleVID->Addend;
@@ -3853,11 +3850,10 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
// If we're compiling for an exact VLEN value, we can split our work per
// register in the register group.
- const unsigned MinVLen = Subtarget.getRealMinVLen();
- const unsigned MaxVLen = Subtarget.getRealMaxVLen();
- if (MinVLen == MaxVLen && VT.getSizeInBits().getKnownMinValue() > MinVLen) {
+ if (const auto VLen = Subtarget.getRealVLen();
+ VLen && VT.getSizeInBits().getKnownMinValue() > *VLen) {
MVT ElemVT = VT.getVectorElementType();
- unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
+ unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
@@ -4768,9 +4764,8 @@ static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN,
// If we don't know exact data layout, not much we can do. If this
// is already m1 or smaller, no point in splitting further.
- const unsigned MinVLen = Subtarget.getRealMinVLen();
- const unsigned MaxVLen = Subtarget.getRealMaxVLen();
- if (MinVLen != MaxVLen || VT.getSizeInBits().getFixedValue() <= MinVLen)
+ const auto VLen = Subtarget.getRealVLen();
+ if (!VLen || VT.getSizeInBits().getFixedValue() <= *VLen)
return SDValue();
// Avoid picking up bitrotate patterns which we have a linear-in-lmul
@@ -4781,7 +4776,7 @@ static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN,
return SDValue();
MVT ElemVT = VT.getVectorElementType();
- unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
+ unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
unsigned VRegsPerSrc = NumElts / ElemsPerVReg;
SmallVector<std::pair<int, SmallVector<int>>>
@@ -5759,6 +5754,10 @@ static unsigned getRISCVVLOp(SDValue Op) {
VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP
VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP
VP_CASE(BITREVERSE) // VP_BITREVERSE
+ VP_CASE(SADDSAT) // VP_SADDSAT
+ VP_CASE(UADDSAT) // VP_UADDSAT
+ VP_CASE(SSUBSAT) // VP_SSUBSAT
+ VP_CASE(USUBSAT) // VP_USUBSAT
VP_CASE(BSWAP) // VP_BSWAP
VP_CASE(CTLZ) // VP_CTLZ
VP_CASE(CTTZ) // VP_CTTZ
@@ -6798,6 +6797,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::VP_UDIV:
case ISD::VP_SREM:
case ISD::VP_UREM:
+ case ISD::VP_UADDSAT:
+ case ISD::VP_USUBSAT:
+ case ISD::VP_SADDSAT:
+ case ISD::VP_SSUBSAT:
return lowerVPOp(Op, DAG);
case ISD::VP_AND:
case ISD::VP_OR:
@@ -7384,6 +7387,26 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
if (SDValue V = combineSelectToBinOp(Op.getNode(), DAG, Subtarget))
return V;
+ // (select c, c1, c2) -> (add (czero_nez c2 - c1, c), c1)
+ // (select c, c1, c2) -> (add (czero_eqz c1 - c2, c), c2)
+ if (isa<ConstantSDNode>(TrueV) && isa<ConstantSDNode>(FalseV)) {
+ const APInt &TrueVal = TrueV->getAsAPIntVal();
+ const APInt &FalseVal = FalseV->getAsAPIntVal();
+ const int TrueValCost = RISCVMatInt::getIntMatCost(
+ TrueVal, Subtarget.getXLen(), Subtarget, /*CompressionCost=*/true);
+ const int FalseValCost = RISCVMatInt::getIntMatCost(
+ FalseVal, Subtarget.getXLen(), Subtarget, /*CompressionCost=*/true);
+ bool IsCZERO_NEZ = TrueValCost <= FalseValCost;
+ SDValue LHSVal = DAG.getConstant(
+ IsCZERO_NEZ ? FalseVal - TrueVal : TrueVal - FalseVal, DL, VT);
+ SDValue RHSVal =
+ DAG.getConstant(IsCZERO_NEZ ? TrueVal : FalseVal, DL, VT);
+ SDValue CMOV =
+ DAG.getNode(IsCZERO_NEZ ? RISCVISD::CZERO_NEZ : RISCVISD::CZERO_EQZ,
+ DL, VT, LHSVal, CondV);
+ return DAG.getNode(ISD::ADD, DL, VT, CMOV, RHSVal);
+ }
+
// (select c, t, f) -> (or (czero_eqz t, c), (czero_nez f, c))
// Unless we have the short forward branch optimization.
if (!Subtarget.hasConditionalMoveFusion())
@@ -8313,15 +8336,13 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
// constant index, we can always perform the extract in m1 (or
// smaller) as we can determine the register corresponding to
// the index in the register group.
- const unsigned MinVLen = Subtarget.getRealMinVLen();
- const unsigned MaxVLen = Subtarget.getRealMaxVLen();
+ const auto VLen = Subtarget.getRealVLen();
if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx);
- IdxC && MinVLen == MaxVLen &&
- VecVT.getSizeInBits().getKnownMinValue() > MinVLen) {
+ IdxC && VLen && VecVT.getSizeInBits().getKnownMinValue() > *VLen) {
MVT M1VT = getLMUL1VT(ContainerVT);
unsigned OrigIdx = IdxC->getZExtValue();
EVT ElemVT = VecVT.getVectorElementType();
- unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
+ unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
unsigned RemIdx = OrigIdx % ElemsPerVReg;
unsigned SubRegIdx = OrigIdx / ElemsPerVReg;
unsigned ExtractIdx =
@@ -9782,15 +9803,14 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
if (OrigIdx == 0)
return Op;
- const unsigned MinVLen = Subtarget.getRealMinVLen();
- const unsigned MaxVLen = Subtarget.getRealMaxVLen();
+ const auto VLen = Subtarget.getRealVLen();
// If the subvector vector is a fixed-length type and we don't know VLEN
// exactly, we cannot use subregister manipulation to simplify the codegen; we
// don't know which register of a LMUL group contains the specific subvector
// as we only know the minimum register size. Therefore we must slide the
// vector group down the full amount.
- if (SubVecVT.isFixedLengthVector() && MinVLen != MaxVLen) {
+ if (SubVecVT.isFixedLengthVector() && !VLen) {
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VecVT);
@@ -9837,8 +9857,8 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
// and decomposeSubvectorInsertExtractToSubRegs takes this into account. So if
// we have a fixed length subvector, we need to adjust the index by 1/vscale.
if (SubVecVT.isFixedLengthVector()) {
- assert(MinVLen == MaxVLen);
- unsigned Vscale = MinVLen / RISCV::RVVBitsPerBlock;
+ assert(VLen);
+ unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
auto Decompose =
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
VecVT, ContainerSubVecVT, OrigIdx / Vscale, TRI);
@@ -12872,6 +12892,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineSubOfBoolean(N, DAG))
return V;
+ EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
// fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
@@ -12879,7 +12900,6 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
isNullConstant(N1.getOperand(1))) {
ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
if (CCVal == ISD::SETLT) {
- EVT VT = N->getValueType(0);
SDLoc DL(N);
unsigned ShAmt = N0.getValueSizeInBits() - 1;
return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
@@ -12887,6 +12907,29 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
}
}
+ // sub (zext, zext) -> sext (sub (zext, zext))
+ // where the sum of the extend widths match, and the inner zexts
+ // add at least one bit. (For profitability on rvv, we use a
+ // power of two for both inner and outer extend.)
+ if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
+ N0.getOpcode() == N1.getOpcode() && N0.getOpcode() == ISD::ZERO_EXTEND &&
+ N0.hasOneUse() && N1.hasOneUse()) {
+ SDValue Src0 = N0.getOperand(0);
+ SDValue Src1 = N1.getOperand(0);
+ EVT SrcVT = Src0.getValueType();
+ if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
+ SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
+ SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
+ LLVMContext &C = *DAG.getContext();
+ EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+ EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+ Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+ Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+ return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
+ DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
+ }
+ }
+
// fold (sub x, (select lhs, rhs, cc, 0, y)) ->
// (select lhs, rhs, cc, x, (sub x, y))
return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
@@ -15978,7 +16021,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (Index.getOpcode() == ISD::BUILD_VECTOR &&
MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
- if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
+ // The sequence will be XLenVT, not the type of Index. Tell
+ // isSimpleVIDSequence this so we avoid overflow.
+ if (std::optional<VIDSequence> SimpleVID =
+ isSimpleVIDSequence(Index, Subtarget.getXLen());
SimpleVID && SimpleVID->StepDenominator == 1) {
const int64_t StepNumerator = SimpleVID->StepNumerator;
const int64_t Addend = SimpleVID->Addend;