diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64.td | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64Features.td | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.td | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 30 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/SVEInstrFormats.td | 36 |
6 files changed, 66 insertions, 16 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64.td b/llvm/lib/Target/AArch64/AArch64.td index 4634653..a4529a5 100644 --- a/llvm/lib/Target/AArch64/AArch64.td +++ b/llvm/lib/Target/AArch64/AArch64.td @@ -74,7 +74,7 @@ def SVEUnsupported : AArch64Unsupported { } def SME2p3Unsupported : AArch64Unsupported { - let F = [HasSVE2p3_or_SME2p3]; + let F = [HasSVE2p3_or_SME2p3, HasSVE_B16MM]; } def SME2p2Unsupported : AArch64Unsupported { diff --git a/llvm/lib/Target/AArch64/AArch64Features.td b/llvm/lib/Target/AArch64/AArch64Features.td index cd7a209..11e476f 100644 --- a/llvm/lib/Target/AArch64/AArch64Features.td +++ b/llvm/lib/Target/AArch64/AArch64Features.td @@ -613,6 +613,12 @@ def FeatureSVE2p3 : ExtensionWithMArch<"sve2p3", "SVE2p3", "FEAT_SVE2p3", def FeatureSME2p3 : ExtensionWithMArch<"sme2p3", "SME2p3", "FEAT_SME2p3", "Enable Armv9.7-A Scalable Matrix Extension 2.3 instructions", [FeatureSME2p2]>; +def FeatureSVE_B16MM : ExtensionWithMArch<"sve-b16mm", "SVE_B16MM", "FEAT_SVE_B16MM", + "Enable Armv9.7-A SVE non-widening BFloat16 matrix multiply-accumulate", [FeatureSVE]>; + +def FeatureF16MM : ExtensionWithMArch<"f16mm", "F16MM", "FEAT_F16MM", + "Enable Armv9.7-A non-widening half-precision matrix multiply-accumulate", [FeatureFullFP16]>; + //===----------------------------------------------------------------------===// // Other Features //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 395d124..7d23d42 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -246,6 +246,12 @@ def HasCMH : Predicate<"Subtarget->hasCMH()">, AssemblerPredicateWithAll<(all_of FeatureCMH), "cmh">; def HasLSCP : Predicate<"Subtarget->hasLSCP()">, AssemblerPredicateWithAll<(all_of FeatureLSCP), "lscp">; +def HasSVE2p2 : Predicate<"Subtarget->hasSVE2p2()">, + AssemblerPredicateWithAll<(all_of FeatureSVE2p2), "sve2p2">; +def HasSVE_B16MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE_B16MM()">, + AssemblerPredicateWithAll<(all_of FeatureSVE_B16MM), "sve-b16mm">; +def HasF16MM : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasF16MM()">, + AssemblerPredicateWithAll<(all_of FeatureF16MM), "f16mm">; // A subset of SVE(2) instructions are legal in Streaming SVE execution mode, // they should be enabled if either has been specified. diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index d8b4903..545493f 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2569,7 +2569,7 @@ let Predicates = [HasBF16, HasSVE_or_SME] in { } // End HasBF16, HasSVE_or_SME let Predicates = [HasBF16, HasSVE] in { - defm BFMMLA_ZZZ_HtoS : sve_fp_matrix_mla<0b01, "bfmmla", ZPR32, ZPR16, int_aarch64_sve_bfmmla, nxv4f32, nxv8bf16>; + defm BFMMLA_ZZZ_HtoS : sve_fp_matrix_mla<0b011, "bfmmla", ZPR32, ZPR16, int_aarch64_sve_bfmmla, nxv4f32, nxv8bf16>; } // End HasBF16, HasSVE let Predicates = [HasBF16, HasSVE_or_SME] in { @@ -3680,15 +3680,15 @@ let Predicates = [HasSVE_or_SME, HasMatMulInt8] in { } // End HasSVE_or_SME, HasMatMulInt8 let Predicates = [HasSVE, HasMatMulFP32] in { - defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0b10, "fmmla", ZPR32, ZPR32, int_aarch64_sve_fmmla, nxv4f32, nxv4f32>; + defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0b101, "fmmla", ZPR32, ZPR32, int_aarch64_sve_fmmla, nxv4f32, nxv4f32>; } // End HasSVE, HasMatMulFP32 let Predicates = [HasSVE_F16F32MM] in { - def FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b00, "fmmla", ZPR32, ZPR16>; + def FMLLA_ZZZ_HtoS : sve_fp_matrix_mla<0b001, "fmmla", ZPR32, ZPR16>; } // End HasSVE_F16F32MM let Predicates = [HasSVE, HasMatMulFP64] in { - defm FMMLA_ZZZ_D : sve_fp_matrix_mla<0b11, "fmmla", ZPR64, ZPR64, int_aarch64_sve_fmmla, nxv2f64, nxv2f64>; + defm FMMLA_ZZZ_D : sve_fp_matrix_mla<0b111, "fmmla", ZPR64, ZPR64, int_aarch64_sve_fmmla, nxv2f64, nxv2f64>; defm LD1RO_B_IMM : sve_mem_ldor_si<0b00, "ld1rob", Z_b, ZPR8, nxv16i8, nxv16i1, AArch64ld1ro_z>; defm LD1RO_H_IMM : sve_mem_ldor_si<0b01, "ld1roh", Z_h, ZPR16, nxv8i16, nxv8i1, AArch64ld1ro_z>; defm LD1RO_W_IMM : sve_mem_ldor_si<0b10, "ld1row", Z_s, ZPR32, nxv4i32, nxv4i1, AArch64ld1ro_z>; @@ -4631,9 +4631,31 @@ let Predicates = [HasSVE2p3_or_SME2p3] in { defm SABAL_ZZZ : sve2_int_two_way_absdiff_accum_long<0b0, "sabal">; defm UABAL_ZZZ : sve2_int_two_way_absdiff_accum_long<0b1, "uabal">; + // SVE2 integer dot product + def SDOT_ZZZ_BtoH : sve_intx_dot<0b01, 0b00000, 0b0, "sdot", ZPR16, ZPR8>; + def UDOT_ZZZ_BtoH : sve_intx_dot<0b01, 0b00000, 0b1, "udot", ZPR16, ZPR8>; + + // SVE2 integer indexed dot product + def SDOT_ZZZI_BtoH : sve_intx_dot_by_indexed_elem_x<0b0, "sdot">; + def UDOT_ZZZI_BtoH : sve_intx_dot_by_indexed_elem_x<0b1, "udot">; + } // End HasSME2p3orSVE2p3 //===----------------------------------------------------------------------===// +// SVE_B16MM Instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE_B16MM] in { + def BFMMLA_ZZZ_H : sve_fp_matrix_mla<0b110, "bfmmla", ZPR16, ZPR16>; +} + +//===----------------------------------------------------------------------===// +// F16MM Instructions +//===----------------------------------------------------------------------===// +let Predicates = [HasSVE2p2, HasF16MM] in { + def FMMLA_ZZZ_H : sve_fp_matrix_mla<0b100, "fmmla", ZPR16, ZPR16>; +} + +//===----------------------------------------------------------------------===// // SME2.2 or SVE2.2 instructions - Legal in streaming mode iff target has SME2p2 //===----------------------------------------------------------------------===// let Predicates = [HasNonStreamingSVE2p2_or_SME2p2] in { diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp index babfdc3..c2f7323 100644 --- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -3892,6 +3892,8 @@ static const struct Extension { {"gcie", {AArch64::FeatureGCIE}}, {"sme2p3", {AArch64::FeatureSME2p3}}, {"sve2p3", {AArch64::FeatureSVE2p3}}, + {"sve-b16mm", {AArch64::FeatureSVE_B16MM}}, + {"f16mm", {AArch64::FeatureF16MM}}, }; static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) { diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 337c6b6..290a869 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -3787,7 +3787,7 @@ multiclass sve2p1_two_way_dot_vv<string mnemonic, bit u, SDPatternOperator intri // SVE Integer Dot Product Group - Indexed Group //===----------------------------------------------------------------------===// -class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, +class sve_intx_dot_by_indexed_elem<bit U, string asm, ZPRRegOp zprty1, ZPRRegOp zprty2, ZPRRegOp zprty3, Operand itype> : I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty3:$Zm, itype:$iop), @@ -3795,8 +3795,7 @@ class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, "", []>, Sched<[]> { bits<5> Zda; bits<5> Zn; - let Inst{31-23} = 0b010001001; - let Inst{22} = sz; + let Inst{31-24} = 0b01000100; let Inst{21} = 0b1; let Inst{15-11} = 0; let Inst{10} = U; @@ -3810,16 +3809,18 @@ class sve_intx_dot_by_indexed_elem<bit sz, bit U, string asm, multiclass sve_intx_dot_by_indexed_elem<bit opc, string asm, SDPatternOperator op> { - def _BtoS : sve_intx_dot_by_indexed_elem<0b0, opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b_timm> { + def _BtoS : sve_intx_dot_by_indexed_elem<opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b_timm> { bits<2> iop; bits<3> Zm; + let Inst{23-22} = 0b10; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _HtoD : sve_intx_dot_by_indexed_elem<0b1, opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b_timm> { + def _HtoD : sve_intx_dot_by_indexed_elem<opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b_timm> { bits<1> iop; bits<4> Zm; - let Inst{20} = iop; + let Inst{23-22} = 0b11; + let Inst{20} = iop; let Inst{19-16} = Zm; } @@ -3827,6 +3828,16 @@ multiclass sve_intx_dot_by_indexed_elem<bit opc, string asm, def : SVE_4_Op_Imm_Pat<nxv2i64, op, nxv2i64, nxv8i16, nxv8i16, i32, VectorIndexD32b_timm, !cast<Instruction>(NAME # _HtoD)>; } +class sve_intx_dot_by_indexed_elem_x<bit opc, string asm> +: sve_intx_dot_by_indexed_elem<opc, asm, ZPR16, ZPR8, ZPR3b8, VectorIndexH32b_timm> { + bits<3> iop; + bits<3> Zm; + let Inst{23} = 0b0; + let Inst{22} = iop{2}; + let Inst{20-19} = iop{1-0}; + let Inst{18-16} = Zm; +} + //===----------------------------------------------------------------------===// // SVE2 Complex Integer Dot Product Group //===----------------------------------------------------------------------===// @@ -9616,17 +9627,18 @@ multiclass sve_int_dot_mixed_indexed<bit U, string asm, SDPatternOperator op> { // SVE Floating Point Matrix Multiply Accumulate Group //===----------------------------------------------------------------------===// -class sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty> +class sve_fp_matrix_mla<bits<3> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty> : I<(outs zda_ty:$Zda), (ins zda_ty:$_Zda, reg_ty:$Zn, reg_ty:$Zm), asm, "\t$Zda, $Zn, $Zm", "", []>, Sched<[]> { bits<5> Zda; bits<5> Zn; bits<5> Zm; let Inst{31-24} = 0b01100100; - let Inst{23-22} = opc; + let Inst{23-22} = opc{2-1}; let Inst{21} = 1; let Inst{20-16} = Zm; - let Inst{15-10} = 0b111001; + let Inst{15-11} = 0b11100; + let Inst{10} = opc{0}; let Inst{9-5} = Zn; let Inst{4-0} = Zda; @@ -9636,10 +9648,12 @@ class sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_t let mayRaiseFPException = 1; } -multiclass sve_fp_matrix_mla<bits<2> opc, string asm, ZPRRegOp zda_ty, ZPRRegOp reg_ty, SDPatternOperator op, ValueType zda_vt, ValueType reg_vt> { +multiclass sve_fp_matrix_mla<bits<3> opc, string asm, ZPRRegOp zda_ty, + ZPRRegOp reg_ty, SDPatternOperator op, + ValueType zda_vt, ValueType reg_vt> { def NAME : sve_fp_matrix_mla<opc, asm, zda_ty, reg_ty>; - def : SVE_3_Op_Pat<zda_vt, op , zda_vt, reg_vt, reg_vt, !cast<Instruction>(NAME)>; + def : SVE_3_Op_Pat<zda_vt, op, zda_vt, reg_vt, reg_vt, !cast<Instruction>(NAME)>; } //===----------------------------------------------------------------------===// |
