diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVFeatures.td | 3 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 78 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td | 83 | ||||
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 3 |
4 files changed, 163 insertions, 4 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 5b72334..0b964c4 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -956,6 +956,9 @@ def FeatureStdExtSsdbltrp def FeatureStdExtSmepmp : RISCVExtension<1, 0, "Enhanced Physical Memory Protection">; +def FeatureStdExtSmpmpmt + : RISCVExperimentalExtension<0, 6, "PMP-based Memory Types Extension">; + def FeatureStdExtSmrnmi : RISCVExtension<1, 0, "Resumable Non-Maskable Interrupts">; def HasStdExtSmrnmi : Predicate<"Subtarget->hasStdExtSmrnmi()">, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1977d33..a3ccbd8 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -87,6 +87,11 @@ static cl::opt<bool> "be combined with a shift"), cl::init(true)); +// TODO: Support more ops +static const unsigned ZvfbfaVPOps[] = {ISD::VP_FNEG, ISD::VP_FABS, + ISD::VP_FCOPYSIGN}; +static const unsigned ZvfbfaOps[] = {ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN}; + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -1208,6 +1213,61 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } }; + // Sets common actions for zvfbfa, some of instructions are supported + // natively so that we don't need to promote them. + const auto SetZvfbfaActions = [&](MVT VT) { + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); + setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, + Custom); + setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); + setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom); + setOperationAction({ISD::LROUND, ISD::LLROUND}, VT, Custom); + setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT, + Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction({ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, VT, Custom); + setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::CONCAT_VECTORS, + ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR, + ISD::VECTOR_DEINTERLEAVE, ISD::VECTOR_INTERLEAVE, + ISD::VECTOR_REVERSE, ISD::VECTOR_SPLICE, + ISD::VECTOR_COMPRESS}, + VT, Custom); + setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom); + setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom); + + setOperationAction(ISD::FCOPYSIGN, VT, Legal); + setOperationAction(ZvfbfaVPOps, VT, Custom); + + MVT EltVT = VT.getVectorElementType(); + if (isTypeLegal(EltVT)) + setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT, + ISD::EXTRACT_VECTOR_ELT}, + VT, Custom); + else + setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT}, + EltVT, Custom); + setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE, + ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD, + ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD, + ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER, + ISD::VP_SCATTER}, + VT, Custom); + setOperationAction(ISD::VP_LOAD_FF, VT, Custom); + + // Expand FP operations that need libcalls. + setOperationAction(FloatingPointLibCallOps, VT, Expand); + + // Custom split nxv32[b]f16 since nxv32[b]f32 is not legal. + if (getLMUL(VT) == RISCVVType::LMUL_8) { + setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom); + setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom); + } else { + MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); + setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT); + setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT); + } + }; + if (Subtarget.hasVInstructionsF16()) { for (MVT VT : F16VecVTs) { if (!isTypeLegal(VT)) @@ -1222,7 +1282,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } - if (Subtarget.hasVInstructionsBF16Minimal()) { + if (Subtarget.hasVInstructionsBF16()) { + for (MVT VT : BF16VecVTs) { + if (!isTypeLegal(VT)) + continue; + SetZvfbfaActions(VT); + } + } else if (Subtarget.hasVInstructionsBF16Minimal()) { for (MVT VT : BF16VecVTs) { if (!isTypeLegal(VT)) continue; @@ -1501,6 +1567,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // available. setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom); } + if (Subtarget.hasStdExtZvfbfa()) { + setOperationAction(ZvfbfaOps, VT, Custom); + setOperationAction(ZvfbfaVPOps, VT, Custom); + } setOperationAction( {ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT, Custom); @@ -7245,7 +7315,11 @@ static bool isPromotedOpNeedingSplit(SDValue Op, return (Op.getValueType() == MVT::nxv32f16 && (Subtarget.hasVInstructionsF16Minimal() && !Subtarget.hasVInstructionsF16())) || - Op.getValueType() == MVT::nxv32bf16; + (Op.getValueType() == MVT::nxv32bf16 && + Subtarget.hasVInstructionsBF16Minimal() && + (!Subtarget.hasVInstructionsBF16() || + (!llvm::is_contained(ZvfbfaOps, Op.getOpcode()) && + !llvm::is_contained(ZvfbfaVPOps, Op.getOpcode())))); } static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index b9c5b75..ffb2ac0 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -701,5 +701,86 @@ let Predicates = [HasStdExtZvfbfa] in { FRM_DYN, fvti.AVL, fvti.Log2SEW, TA_MA)>; } -} + + foreach vti = AllBF16Vectors in { + // 13.12. Vector Floating-Point Sign-Injection Instructions + def : Pat<(fabs (vti.Vector vti.RegClass:$rs)), + (!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; + // Handle fneg with VFSGNJN using the same input for both operands. + def : Pat<(fneg (vti.Vector vti.RegClass:$rs)), + (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>; + + def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), + (vti.Vector vti.RegClass:$rs2))), + (!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; + def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), + (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))), + (!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; + + def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), + (vti.Vector (fneg vti.RegClass:$rs2)))), + (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; + def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1), + (vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))), + (!cast<Instruction>("PseudoVFSGNJN_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; + + // 13.12. Vector Floating-Point Sign-Injection Instructions + def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK") + (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, + vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TA_MA)>; + // Handle fneg with VFSGNJN using the same input for both operands. + def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK") + (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs, + vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TA_MA)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (vti.Vector vti.RegClass:$rs2), + vti.RegClass:$passthru, + (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") + vti.RegClass:$passthru, vti.RegClass:$rs1, + vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TAIL_AGNOSTIC)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (riscv_fneg_vl vti.RegClass:$rs2, + (vti.Mask true_mask), + VLOpFrag), + srcvalue, + (vti.Mask true_mask), + VLOpFrag), + (!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW) + (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, + vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>; + + def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1), + (SplatFPOp vti.ScalarRegClass:$rs2), + vti.RegClass:$passthru, + (vti.Mask VMV0:$vm), + VLOpFrag), + (!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK") + vti.RegClass:$passthru, vti.RegClass:$rs1, + vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, + TAIL_AGNOSTIC)>; + } + } } // Predicates = [HasStdExtZvfbfa] diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 332433b..3d8eb40 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1683,7 +1683,8 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, !TypeSize::isKnownLE(DL.getTypeSizeInBits(Src), SrcLT.second.getSizeInBits()) || !TypeSize::isKnownLE(DL.getTypeSizeInBits(Dst), - DstLT.second.getSizeInBits())) + DstLT.second.getSizeInBits()) || + SrcLT.first > 1 || DstLT.first > 1) return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); // The split cost is handled by the base getCastInstrCost |
