diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
-rw-r--r-- | llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp | 184 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h | 2 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp | 61 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 30 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrGISel.td | 25 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 8 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td | 35 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVSubtarget.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 13 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 16 |
11 files changed, 327 insertions, 54 deletions
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp index 7f35107..38c1f9868 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp @@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) .clampScalar(0, s32, sXLen) .minScalarSameAs(1, 0); + auto &ExtActions = + getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}) + .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), + typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))); if (ST.is64Bit()) { - getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}) - .legalFor({{sXLen, s32}}) - .maxScalar(0, sXLen); - + ExtActions.legalFor({{sXLen, s32}}); getActionDefinitionsBuilder(G_SEXT_INREG) .customFor({sXLen}) .maxScalar(0, sXLen) .lower(); } else { - getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar(0, sXLen); - getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower(); } + ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST)) + .maxScalar(0, sXLen); // Merge/Unmerge for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) { @@ -235,7 +236,9 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) getActionDefinitionsBuilder(G_ICMP) .legalFor({{sXLen, sXLen}, {sXLen, p0}}) - .widenScalarToNextPow2(1) + .legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), + typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))) + .widenScalarOrEltToNextPow2OrMinSize(1, 8) .clampScalar(1, sXLen, sXLen) .clampScalar(0, sXLen, sXLen); @@ -418,6 +421,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) .clampScalar(0, sXLen, sXLen) .customFor({sXLen}); + auto &SplatActions = + getActionDefinitionsBuilder(G_SPLAT_VECTOR) + .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), + typeIs(1, sXLen))) + .customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1))); + // Handle case of s64 element vectors on RV32. If the subtarget does not have + // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget + // does have f64, then we don't know whether the type is an f64 or an i64, + // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it, + // depending on how the instructions it consumes are legalized. They are not + // legalized yet since legalization is in reverse postorder, so we cannot + // make the decision at this moment. + if (XLen == 32) { + if (ST.hasVInstructionsF64() && ST.hasStdExtD()) + SplatActions.legalIf(all( + typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64))); + else if (ST.hasVInstructionsI64()) + SplatActions.customIf(all( + typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64))); + } + + SplatActions.clampScalar(1, sXLen, sXLen); + getLegacyLegalizerInfo().computeTables(); } @@ -576,7 +602,145 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI, auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3)); MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val)); } + MI.eraseFromParent(); + return true; +} + +// Custom-lower extensions from mask vectors by using a vselect either with 1 +// for zero/any-extension or -1 for sign-extension: +// (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0) +// Note that any-extension is lowered identically to zero-extension. +bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI, + MachineIRBuilder &MIB) const { + + unsigned Opc = MI.getOpcode(); + assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT || + Opc == TargetOpcode::G_ANYEXT); + + MachineRegisterInfo &MRI = *MIB.getMRI(); + Register Dst = MI.getOperand(0).getReg(); + Register Src = MI.getOperand(1).getReg(); + + LLT DstTy = MRI.getType(Dst); + int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1; + LLT DstEltTy = DstTy.getElementType(); + auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0)); + auto SplatTrue = + MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal)); + MIB.buildSelect(Dst, Src, SplatTrue, SplatZero); + + MI.eraseFromParent(); + return true; +} + +/// Return the type of the mask type suitable for masking the provided +/// vector type. This is simply an i1 element type vector of the same +/// (possibly scalable) length. +static LLT getMaskTypeFor(LLT VecTy) { + assert(VecTy.isVector()); + ElementCount EC = VecTy.getElementCount(); + return LLT::vector(EC, LLT::scalar(1)); +} + +/// Creates an all ones mask suitable for masking a vector of type VecTy with +/// vector length VL. +static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL, + MachineIRBuilder &MIB, + MachineRegisterInfo &MRI) { + LLT MaskTy = getMaskTypeFor(VecTy); + return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL}); +} + +/// Gets the two common "VL" operands: an all-ones mask and the vector length. +/// VecTy is a scalable vector type. +static std::pair<MachineInstrBuilder, Register> +buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB, + MachineRegisterInfo &MRI) { + LLT VecTy = Dst.getLLTTy(MRI); + assert(VecTy.isScalableVector() && "Expecting scalable container type"); + Register VL(RISCV::X0); + MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI); + return {Mask, VL}; +} + +static MachineInstrBuilder +buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo, + Register Hi, Register VL, MachineIRBuilder &MIB, + MachineRegisterInfo &MRI) { + // TODO: If the Hi bits of the splat are undefined, then it's fine to just + // splat Lo even if it might be sign extended. I don't think we have + // introduced a case where we're build a s64 where the upper bits are undef + // yet. + + // Fall back to a stack store and stride x0 vector load. + // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in + // preprocessDAG in SDAG. + return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst}, + {Passthru, Lo, Hi, VL}); +} + +static MachineInstrBuilder +buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru, + const SrcOp &Scalar, Register VL, + MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { + assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!"); + auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar); + return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0), + Unmerge.getReg(1), VL, MIB, MRI); +} + +// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a +// legal equivalently-sized i8 type, so we can use that as a go-between. +// Splats of s1 types that have constant value can be legalized as VMSET_VL or +// VMCLR_VL. +bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI, + MachineIRBuilder &MIB) const { + assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR); + + MachineRegisterInfo &MRI = *MIB.getMRI(); + + Register Dst = MI.getOperand(0).getReg(); + Register SplatVal = MI.getOperand(1).getReg(); + + LLT VecTy = MRI.getType(Dst); + LLT XLenTy(STI.getXLenVT()); + + // Handle case of s64 element vectors on rv32 + if (XLenTy.getSizeInBits() == 32 && + VecTy.getElementType().getSizeInBits() == 64) { + auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI); + buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB, + MRI); + MI.eraseFromParent(); + return true; + } + + // All-zeros or all-ones splats are handled specially. + MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal); + if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) { + auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second; + MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL}); + MI.eraseFromParent(); + return true; + } + if (isNullOrNullSplat(SplatValMI, MRI)) { + auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second; + MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL}); + MI.eraseFromParent(); + return true; + } + // Handle non-constant mask splat (i.e. not sure if it's all zeros or all + // ones) by promoting it to an s8 splat. + LLT InterEltTy = LLT::scalar(8); + LLT InterTy = VecTy.changeElementType(InterEltTy); + auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal); + auto And = + MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1)); + auto LHS = MIB.buildSplatVector(InterTy, And); + auto ZeroSplat = + MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0)); + MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat); MI.eraseFromParent(); return true; } @@ -640,6 +804,12 @@ bool RISCVLegalizerInfo::legalizeCustom( return legalizeVAStart(MI, MIRBuilder); case TargetOpcode::G_VSCALE: return legalizeVScale(MI, MIRBuilder); + case TargetOpcode::G_ZEXT: + case TargetOpcode::G_SEXT: + case TargetOpcode::G_ANYEXT: + return legalizeExt(MI, MIRBuilder); + case TargetOpcode::G_SPLAT_VECTOR: + return legalizeSplatVector(MI, MIRBuilder); } llvm_unreachable("expected switch to return"); diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h index e2a98c8..5bb1e7a 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h +++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h @@ -43,6 +43,8 @@ private: bool legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const; bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const; + bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const; + bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const; }; } // end namespace llvm #endif diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp index 888bcc4..86e4434 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp @@ -290,16 +290,7 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { switch (Opc) { case TargetOpcode::G_ADD: - case TargetOpcode::G_SUB: { - if (MRI.getType(MI.getOperand(0).getReg()).isVector()) { - LLT Ty = MRI.getType(MI.getOperand(0).getReg()); - return getInstructionMapping( - DefaultMappingID, /*Cost=*/1, - getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue()), - NumOperands); - } - } - LLVM_FALLTHROUGH; + case TargetOpcode::G_SUB: case TargetOpcode::G_SHL: case TargetOpcode::G_ASHR: case TargetOpcode::G_LSHR: @@ -320,14 +311,6 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case TargetOpcode::G_PTR_ADD: case TargetOpcode::G_PTRTOINT: case TargetOpcode::G_INTTOPTR: - case TargetOpcode::G_TRUNC: - case TargetOpcode::G_ANYEXT: - case TargetOpcode::G_SEXT: - case TargetOpcode::G_ZEXT: - case TargetOpcode::G_SEXTLOAD: - case TargetOpcode::G_ZEXTLOAD: - return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping, - NumOperands); case TargetOpcode::G_FADD: case TargetOpcode::G_FSUB: case TargetOpcode::G_FMUL: @@ -338,25 +321,48 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case TargetOpcode::G_FMAXNUM: case TargetOpcode::G_FMINNUM: { LLT Ty = MRI.getType(MI.getOperand(0).getReg()); - return getInstructionMapping(DefaultMappingID, /*Cost=*/1, - getFPValueMapping(Ty.getSizeInBits()), - NumOperands); + TypeSize Size = Ty.getSizeInBits(); + + const ValueMapping *Mapping; + if (Ty.isVector()) + Mapping = getVRBValueMapping(Size.getKnownMinValue()); + else if (isPreISelGenericFloatingPointOpcode(Opc)) + Mapping = getFPValueMapping(Size.getFixedValue()); + else + Mapping = GPRValueMapping; + +#ifndef NDEBUG + // Make sure all the operands are using similar size and type. + for (unsigned Idx = 1; Idx != NumOperands; ++Idx) { + LLT OpTy = MRI.getType(MI.getOperand(Idx).getReg()); + assert(Ty.isVector() == OpTy.isVector() && + "Operand has incompatible type"); + // Don't check size for GPR. + if (OpTy.isVector() || isPreISelGenericFloatingPointOpcode(Opc)) + assert(Size == OpTy.getSizeInBits() && "Operand has incompatible size"); + } +#endif // End NDEBUG + + return getInstructionMapping(DefaultMappingID, 1, Mapping, NumOperands); } + case TargetOpcode::G_SEXTLOAD: + case TargetOpcode::G_ZEXTLOAD: + return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping, + NumOperands); case TargetOpcode::G_IMPLICIT_DEF: { Register Dst = MI.getOperand(0).getReg(); LLT DstTy = MRI.getType(Dst); - uint64_t DstMinSize = DstTy.getSizeInBits().getKnownMinValue(); + unsigned DstMinSize = DstTy.getSizeInBits().getKnownMinValue(); auto Mapping = GPRValueMapping; // FIXME: May need to do a better job determining when to use FPRB. // For example, the look through COPY case: // %0:_(s32) = G_IMPLICIT_DEF // %1:_(s32) = COPY %0 // $f10_d = COPY %1(s32) - if (anyUseOnlyUseFP(Dst, MRI, TRI)) - Mapping = getFPValueMapping(DstMinSize); - if (DstTy.isVector()) Mapping = getVRBValueMapping(DstMinSize); + else if (anyUseOnlyUseFP(Dst, MRI, TRI)) + Mapping = getFPValueMapping(DstMinSize); return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping, NumOperands); @@ -529,7 +535,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { if (!Ty.isValid()) continue; - if (isPreISelGenericFloatingPointOpcode(Opc)) + if (Ty.isVector()) + OpdsMapping[Idx] = + getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue()); + else if (isPreISelGenericFloatingPointOpcode(Opc)) OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits()); else OpdsMapping[Idx] = GPRValueMapping; diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 55ba494..f99dc0b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3287,24 +3287,24 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits, } bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) { - // Truncates are custom lowered during legalization. - auto IsTrunc = [this](SDValue N) { - if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + auto IsExtOrTrunc = [](SDValue N) { + switch (N->getOpcode()) { + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + // There's no passthru on these _VL nodes so any VL/mask is ok, since any + // inactive elements will be undef. + case RISCVISD::TRUNCATE_VECTOR_VL: + case RISCVISD::VSEXT_VL: + case RISCVISD::VZEXT_VL: + return true; + default: return false; - SDValue VL; - selectVLOp(N->getOperand(2), VL); - // Any vmset_vl is ok, since any bits past VL are undefined and we can - // assume they are set. - return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL && - isa<ConstantSDNode>(VL) && - cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel; + } }; - // We can have multiple nested truncates, so unravel them all if needed. - while (N->getOpcode() == ISD::SIGN_EXTEND || - N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) { - if (!N.hasOneUse() || - N.getValueType().getSizeInBits().getKnownMinValue() < 8) + // We can have multiple nested nodes, so unravel them all if needed. + while (IsExtOrTrunc(N)) { + if (!N.hasOneUse() || N.getScalarValueSizeInBits() < 8) return false; N = N->getOperand(0); } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index ee83f9d..279d8a4 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -21115,12 +21115,10 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) { RegisterVT.getVectorElementType() == MVT::i1) { RVVArgInfos.push_back({1, RegisterVT, true}); FirstVMaskAssigned = true; - } else { - RVVArgInfos.push_back({1, RegisterVT, false}); + --NumRegs; } - RVVArgInfos.insert(RVVArgInfos.end(), --NumRegs, - {1, RegisterVT, false}); + RVVArgInfos.insert(RVVArgInfos.end(), NumRegs, {1, RegisterVT, false}); } } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrGISel.td b/llvm/lib/Target/RISCV/RISCVInstrGISel.td index 54e22d6..ba40662 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrGISel.td +++ b/llvm/lib/Target/RISCV/RISCVInstrGISel.td @@ -32,3 +32,28 @@ def G_READ_VLENB : RISCVGenericInstruction { let hasSideEffects = false; } def : GINodeEquiv<G_READ_VLENB, riscv_read_vlenb>; + +// Pseudo equivalent to a RISCVISD::VMCLR_VL +def G_VMCLR_VL : RISCVGenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$vl); + let hasSideEffects = false; +} +def : GINodeEquiv<G_VMCLR_VL, riscv_vmclr_vl>; + +// Pseudo equivalent to a RISCVISD::VMSET_VL +def G_VMSET_VL : RISCVGenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type1:$vl); + let hasSideEffects = false; +} +def : GINodeEquiv<G_VMSET_VL, riscv_vmset_vl>; + +// Pseudo equivalent to a RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL. There is no +// record to mark as equivalent to using GINodeEquiv because it gets lowered +// before instruction selection. +def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction { + let OutOperandList = (outs type0:$dst); + let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl); + let hasSideEffects = false; +} diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index cc44092..73d52d5 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -387,6 +387,9 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisVT<3, XLenVT>]>; def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>; def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>; +def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C), + [(riscv_sext_vl node:$A, node:$B, node:$C), + (riscv_zext_vl node:$A, node:$B, node:$C)]>; def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL", SDTypeProfile<1, 3, [SDTCisVec<0>, @@ -535,6 +538,11 @@ def riscv_zext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C), return N->hasOneUse(); }]>; +def riscv_ext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C), + (riscv_ext_vl node:$A, node:$B, node:$C), [{ + return N->hasOneUse(); +}]>; + def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C), (riscv_fpextend_vl node:$A, node:$B, node:$C), [{ return N->hasOneUse(); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td index 51a7a0a1..c1facc79 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td @@ -630,6 +630,19 @@ foreach vtiToWti = AllWidenableIntVectors in { (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; def : Pat<(riscv_shl_vl + (wti.Vector (riscv_zext_vl_oneuse + (vti.Vector vti.RegClass:$rs2), + (vti.Mask V0), VLOpFrag)), + (wti.Vector (riscv_ext_vl_oneuse + (vti.Vector vti.RegClass:$rs1), + (vti.Mask V0), VLOpFrag)), + (wti.Vector wti.RegClass:$merge), + (vti.Mask V0), VLOpFrag), + (!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK") + wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + + def : Pat<(riscv_shl_vl (wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))), (wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))), (wti.Vector wti.RegClass:$merge), @@ -639,6 +652,17 @@ foreach vtiToWti = AllWidenableIntVectors in { (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; def : Pat<(riscv_shl_vl + (wti.Vector (riscv_zext_vl_oneuse + (vti.Vector vti.RegClass:$rs2), + (vti.Mask V0), VLOpFrag)), + (wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))), + (wti.Vector wti.RegClass:$merge), + (vti.Mask V0), VLOpFrag), + (!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK") + wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + + def : Pat<(riscv_shl_vl (wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))), (wti.Vector (SplatPat_uimm5 uimm5:$rs1)), (wti.Vector wti.RegClass:$merge), @@ -647,6 +671,17 @@ foreach vtiToWti = AllWidenableIntVectors in { wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(riscv_shl_vl + (wti.Vector (riscv_zext_vl_oneuse + (vti.Vector vti.RegClass:$rs2), + (vti.Mask V0), VLOpFrag)), + (wti.Vector (SplatPat_uimm5 uimm5:$rs1)), + (wti.Vector wti.RegClass:$merge), + (vti.Mask V0), VLOpFrag), + (!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK") + wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(riscv_vwsll_vl (vti.Vector vti.RegClass:$rs2), (vti.Vector vti.RegClass:$rs1), diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index ba108912..85f8f5f 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -254,6 +254,7 @@ public: const LegalizerInfo *getLegalizerInfo() const override; const RegisterBankInfo *getRegBankInfo() const override; + bool isTargetAndroid() const { return getTargetTriple().isAndroid(); } bool isTargetFuchsia() const { return getTargetTriple().isOSFuchsia(); } bool useConstantPoolForLargeInts() const; diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 38304ff..aeec063 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -245,6 +245,10 @@ RISCVTTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, return TTI::TCC_Free; } +bool RISCVTTIImpl::hasActiveVectorLength(unsigned, Type *DataTy, Align) const { + return ST->hasVInstructions(); +} + TargetTransformInfo::PopcntSupportKind RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) { assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2"); @@ -861,9 +865,14 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, } // TODO: add more intrinsic case Intrinsic::experimental_stepvector: { - unsigned Cost = 1; // vid auto LT = getTypeLegalizationCost(RetTy); - return Cost + (LT.first - 1); + // Legalisation of illegal types involves an `index' instruction plus + // (LT.first - 1) vector adds. + if (ST->hasVInstructions()) + return getRISCVInstructionCost(RISCV::VID_V, LT.second, CostKind) + + (LT.first - 1) * + getRISCVInstructionCost(RISCV::VADD_VX, LT.second, CostKind); + return 1 + (LT.first - 1); } case Intrinsic::vp_rint: { // RISC-V target uses at least 5 instructions to lower rounding intrinsics. diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index ac32aea..c0169ea 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -78,6 +78,22 @@ public: const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind); + /// \name EVL Support for predicated vectorization. + /// Whether the target supports the %evl parameter of VP intrinsic efficiently + /// in hardware, for the given opcode and type/alignment. (see LLVM Language + /// Reference - "Vector Predication Intrinsics", + /// https://llvm.org/docs/LangRef.html#vector-predication-intrinsics and + /// "IR-level VP intrinsics", + /// https://llvm.org/docs/Proposals/VectorPredication.html#ir-level-vp-intrinsics). + /// \param Opcode the opcode of the instruction checked for predicated version + /// support. + /// \param DataType the type of the instruction with the \p Opcode checked for + /// prediction support. + /// \param Alignment the alignment for memory access operation checked for + /// predicated version support. + bool hasActiveVectorLength(unsigned Opcode, Type *DataType, + Align Alignment) const; + TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth); bool shouldExpandReduction(const IntrinsicInst *II) const; |