diff options
-rw-r--r-- | riscv/insns/vfwmaccbf16_vf.h | 5 | ||||
-rw-r--r-- | riscv/insns/vfwmaccbf16_vv.h | 5 | ||||
-rw-r--r-- | riscv/riscv.mk.in | 5 | ||||
-rw-r--r-- | riscv/v_ext_macros.h | 54 |
4 files changed, 69 insertions, 0 deletions
diff --git a/riscv/insns/vfwmaccbf16_vf.h b/riscv/insns/vfwmaccbf16_vf.h new file mode 100644 index 0000000..2c77b3b --- /dev/null +++ b/riscv/insns/vfwmaccbf16_vf.h @@ -0,0 +1,5 @@ +// vfwmaccbf16.vf vd, vs2, rs1 +VI_VFP_BF16_VF_LOOP_WIDE +({ + vd = f32_mulAdd(rs1, vs2, vd); +}) diff --git a/riscv/insns/vfwmaccbf16_vv.h b/riscv/insns/vfwmaccbf16_vv.h new file mode 100644 index 0000000..bd8f305 --- /dev/null +++ b/riscv/insns/vfwmaccbf16_vv.h @@ -0,0 +1,5 @@ +// vfwmaccbf16.vv vd, vs2, vs1 +VI_VFP_BF16_VV_LOOP_WIDE +({ + vd = f32_mulAdd(vs1, vs2, vd); +}) diff --git a/riscv/riscv.mk.in b/riscv/riscv.mk.in index a83bec2..1cfe627 100644 --- a/riscv/riscv.mk.in +++ b/riscv/riscv.mk.in @@ -1367,9 +1367,14 @@ riscv_insn_ext_zvfbfmin = \ vfncvtbf16_f_f_w \ vfwcvtbf16_f_f_v \ +riscv_insn_ext_zvfbfwma = \ + vfwmaccbf16_vv \ + vfwmaccbf16_vf \ + riscv_insn_ext_bf16 = \ $(riscv_insn_ext_zfbfmin) \ $(riscv_insn_ext_zvfbfmin) \ + $(riscv_insn_ext_zvfbfwma) \ riscv_insn_list = \ $(riscv_insn_ext_a) \ diff --git a/riscv/v_ext_macros.h b/riscv/v_ext_macros.h index 376c330..41256c7 100644 --- a/riscv/v_ext_macros.h +++ b/riscv/v_ext_macros.h @@ -1488,11 +1488,27 @@ reg_t index[P.VU.vlmax]; \ reg_t UNUSED rs2_num = insn.rs2(); \ softfloat_roundingMode = STATE.frm->read(); +#define VI_VFP_BF16_COMMON \ + require_fp; \ + require((P.VU.vsew == e16 && p->extension_enabled(EXT_ZVFBFWMA))); \ + require_vector(true); \ + require(STATE.frm->read() < 0x5); \ + reg_t UNUSED vl = P.VU.vl->read(); \ + reg_t UNUSED rd_num = insn.rd(); \ + reg_t UNUSED rs1_num = insn.rs1(); \ + reg_t UNUSED rs2_num = insn.rs2(); \ + softfloat_roundingMode = STATE.frm->read(); + #define VI_VFP_LOOP_BASE \ VI_VFP_COMMON \ for (reg_t i = P.VU.vstart->read(); i < vl; ++i) { \ VI_LOOP_ELEMENT_SKIP(); +#define VI_VFP_BF16_LOOP_BASE \ + VI_VFP_BF16_COMMON \ + for (reg_t i = P.VU.vstart->read(); i < vl; ++i) { \ + VI_LOOP_ELEMENT_SKIP(); + #define VI_VFP_LOOP_CMP_BASE \ VI_VFP_COMMON \ for (reg_t i = P.VU.vstart->read(); i < vl; ++i) { \ @@ -1818,6 +1834,25 @@ reg_t index[P.VU.vlmax]; \ DEBUG_RVV_FP_VV; \ VI_VFP_LOOP_END +#define VI_VFP_BF16_VF_LOOP_WIDE(BODY) \ + VI_CHECK_DSS(false); \ + VI_VFP_BF16_LOOP_BASE \ + switch (P.VU.vsew) { \ + case e16: { \ + float32_t &vd = P.VU.elt<float32_t>(rd_num, i, true); \ + float32_t vs2 = bf16_to_f32(P.VU.elt<bfloat16_t>(rs2_num, i)); \ + float32_t rs1 = bf16_to_f32(FRS1_BF); \ + BODY; \ + set_fp_exceptions; \ + break; \ + } \ + default: \ + require(0); \ + break; \ + }; \ + DEBUG_RVV_FP_VV; \ + VI_VFP_LOOP_END + #define VI_VFP_VV_LOOP_WIDE(BODY16, BODY32) \ VI_CHECK_DSS(true); \ VI_VFP_LOOP_BASE \ @@ -1845,6 +1880,25 @@ reg_t index[P.VU.vlmax]; \ DEBUG_RVV_FP_VV; \ VI_VFP_LOOP_END +#define VI_VFP_BF16_VV_LOOP_WIDE(BODY) \ + VI_CHECK_DSS(true); \ + VI_VFP_BF16_LOOP_BASE \ + switch (P.VU.vsew) { \ + case e16: { \ + float32_t &vd = P.VU.elt<float32_t>(rd_num, i, true); \ + float32_t vs2 = bf16_to_f32(P.VU.elt<bfloat16_t>(rs2_num, i)); \ + float32_t vs1 = bf16_to_f32(P.VU.elt<bfloat16_t>(rs1_num, i)); \ + BODY; \ + set_fp_exceptions; \ + break; \ + } \ + default: \ + require(0); \ + break; \ + }; \ + DEBUG_RVV_FP_VV; \ + VI_VFP_LOOP_END + #define VI_VFP_WF_LOOP_WIDE(BODY16, BODY32) \ VI_CHECK_DDS(false); \ VI_VFP_LOOP_BASE \ |