diff options
Diffstat (limited to 'target/arm/tcg/sme_helper.c')
-rw-r--r-- | target/arm/tcg/sme_helper.c | 153 |
1 files changed, 107 insertions, 46 deletions
diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c index 5a6dd76..de0c6e5 100644 --- a/target/arm/tcg/sme_helper.c +++ b/target/arm/tcg/sme_helper.c @@ -22,8 +22,8 @@ #include "internals.h" #include "tcg/tcg-gvec-desc.h" #include "exec/helper-proto.h" -#include "exec/cpu_ldst.h" -#include "exec/exec-all.h" +#include "accel/tcg/cpu-ldst.h" +#include "accel/tcg/helper-retaddr.h" #include "qemu/int128.h" #include "fpu/softfloat.h" #include "vec_internal.h" @@ -517,6 +517,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, clr_fn(za, 0, reg_off); } + set_helper_retaddr(ra); + while (reg_off <= reg_last) { uint64_t pg = vg[reg_off >> 6]; do { @@ -529,6 +531,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, } while (reg_off <= reg_last && (reg_off & 63)); } + clear_helper_retaddr(); + /* * Use the slow path to manage the cross-page misalignment. * But we know this is RAM and cannot trap. @@ -543,6 +547,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, reg_last = info.reg_off_last[1]; host = info.page[1].host; + set_helper_retaddr(ra); + do { uint64_t pg = vg[reg_off >> 6]; do { @@ -554,6 +560,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, reg_off += esize; } while (reg_off & 63); } while (reg_off <= reg_last); + + clear_helper_retaddr(); } } @@ -701,6 +709,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg, reg_last = info.reg_off_last[0]; host = info.page[0].host; + set_helper_retaddr(ra); + while (reg_off <= reg_last) { uint64_t pg = vg[reg_off >> 6]; do { @@ -711,6 +721,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg, } while (reg_off <= reg_last && (reg_off & 63)); } + clear_helper_retaddr(); + /* * Use the slow path to manage the cross-page misalignment. * But we know this is RAM and cannot trap. @@ -725,6 +737,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg, reg_last = info.reg_off_last[1]; host = info.page[1].host; + set_helper_retaddr(ra); + do { uint64_t pg = vg[reg_off >> 6]; do { @@ -734,6 +748,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg, reg_off += 1 << esz; } while (reg_off & 63); } while (reg_off <= reg_last); + + clear_helper_retaddr(); } } @@ -888,7 +904,7 @@ void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn, } void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, - void *vpm, void *vst, uint32_t desc) + void *vpm, float_status *fpst_in, uint32_t desc) { intptr_t row, col, oprsz = simd_maxsz(desc); uint32_t neg = simd_data(desc) << 31; @@ -900,7 +916,7 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, * update the cumulative fp exception status. It also produces * default nans. */ - fpst = *(float_status *)vst; + fpst = *fpst_in; set_default_nan_mode(true, &fpst); for (row = 0; row < oprsz; ) { @@ -930,13 +946,13 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, } void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn, - void *vpm, void *vst, uint32_t desc) + void *vpm, float_status *fpst_in, uint32_t desc) { intptr_t row, col, oprsz = simd_oprsz(desc) / 8; uint64_t neg = (uint64_t)simd_data(desc) << 63; uint64_t *za = vza, *zn = vzn, *zm = vzm; uint8_t *pn = vpn, *pm = vpm; - float_status fpst = *(float_status *)vst; + float_status fpst = *fpst_in; set_default_nan_mode(true, &fpst); @@ -976,12 +992,23 @@ static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg) } static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, - float_status *s_std, float_status *s_odd) + float_status *s_f16, float_status *s_std, + float_status *s_odd) { - float64 e1r = float16_to_float64(e1 & 0xffff, true, s_std); - float64 e1c = float16_to_float64(e1 >> 16, true, s_std); - float64 e2r = float16_to_float64(e2 & 0xffff, true, s_std); - float64 e2c = float16_to_float64(e2 >> 16, true, s_std); + /* + * We need three different float_status for different parts of this + * operation: + * - the input conversion of the float16 values must use the + * f16-specific float_status, so that the FPCR.FZ16 control is applied + * - operations on float32 including the final accumulation must use + * the normal float_status, so that FPCR.FZ is applied + * - we have pre-set-up copy of s_std which is set to round-to-odd, + * for the multiply (see below) + */ + float64 e1r = float16_to_float64(e1 & 0xffff, true, s_f16); + float64 e1c = float16_to_float64(e1 >> 16, true, s_f16); + float64 e2r = float16_to_float64(e2 & 0xffff, true, s_f16); + float64 e2c = float16_to_float64(e2 >> 16, true, s_f16); float64 t64; float32 t32; @@ -1003,20 +1030,23 @@ static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, } void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn, - void *vpm, void *vst, uint32_t desc) + void *vpm, CPUARMState *env, uint32_t desc) { intptr_t row, col, oprsz = simd_maxsz(desc); uint32_t neg = simd_data(desc) * 0x80008000u; uint16_t *pn = vpn, *pm = vpm; - float_status fpst_odd, fpst_std; + float_status fpst_odd, fpst_std, fpst_f16; /* - * Make a copy of float_status because this operation does not - * update the cumulative fp exception status. It also produces - * default nans. Make a second copy with round-to-odd -- see above. + * Make copies of the fp status fields we use, because this operation + * does not update the cumulative fp exception status. It also + * produces default NaNs. We also need a second copy of fp_status with + * round-to-odd -- see above. */ - fpst_std = *(float_status *)vst; + fpst_f16 = env->vfp.fp_status[FPST_A64_F16]; + fpst_std = env->vfp.fp_status[FPST_A64]; set_default_nan_mode(true, &fpst_std); + set_default_nan_mode(true, &fpst_f16); fpst_odd = fpst_std; set_float_rounding_mode(float_round_to_odd, &fpst_odd); @@ -1036,7 +1066,8 @@ void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn, uint32_t m = *(uint32_t *)(vzm + H1_4(col)); m = f16mop_adj_pair(m, pcol, 0); - *a = f16_dotadd(*a, n, m, &fpst_std, &fpst_odd); + *a = f16_dotadd(*a, n, m, + &fpst_f16, &fpst_std, &fpst_odd); } col += 4; pcol >>= 4; @@ -1048,38 +1079,68 @@ void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn, } } -void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn, - void *vpm, uint32_t desc) +void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, + void *vpn, void *vpm, CPUARMState *env, uint32_t desc) { intptr_t row, col, oprsz = simd_maxsz(desc); uint32_t neg = simd_data(desc) * 0x80008000u; uint16_t *pn = vpn, *pm = vpm; + float_status fpst, fpst_odd; - for (row = 0; row < oprsz; ) { - uint16_t prow = pn[H2(row >> 4)]; - do { - void *vza_row = vza + tile_vslice_offset(row); - uint32_t n = *(uint32_t *)(vzn + H1_4(row)); + if (is_ebf(env, &fpst, &fpst_odd)) { + for (row = 0; row < oprsz; ) { + uint16_t prow = pn[H2(row >> 4)]; + do { + void *vza_row = vza + tile_vslice_offset(row); + uint32_t n = *(uint32_t *)(vzn + H1_4(row)); - n = f16mop_adj_pair(n, prow, neg); + n = f16mop_adj_pair(n, prow, neg); - for (col = 0; col < oprsz; ) { - uint16_t pcol = pm[H2(col >> 4)]; - do { - if (prow & pcol & 0b0101) { - uint32_t *a = vza_row + H1_4(col); - uint32_t m = *(uint32_t *)(vzm + H1_4(col)); + for (col = 0; col < oprsz; ) { + uint16_t pcol = pm[H2(col >> 4)]; + do { + if (prow & pcol & 0b0101) { + uint32_t *a = vza_row + H1_4(col); + uint32_t m = *(uint32_t *)(vzm + H1_4(col)); - m = f16mop_adj_pair(m, pcol, 0); - *a = bfdotadd(*a, n, m); - } - col += 4; - pcol >>= 4; - } while (col & 15); - } - row += 4; - prow >>= 4; - } while (row & 15); + m = f16mop_adj_pair(m, pcol, 0); + *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd); + } + col += 4; + pcol >>= 4; + } while (col & 15); + } + row += 4; + prow >>= 4; + } while (row & 15); + } + } else { + for (row = 0; row < oprsz; ) { + uint16_t prow = pn[H2(row >> 4)]; + do { + void *vza_row = vza + tile_vslice_offset(row); + uint32_t n = *(uint32_t *)(vzn + H1_4(row)); + + n = f16mop_adj_pair(n, prow, neg); + + for (col = 0; col < oprsz; ) { + uint16_t pcol = pm[H2(col >> 4)]; + do { + if (prow & pcol & 0b0101) { + uint32_t *a = vza_row + H1_4(col); + uint32_t m = *(uint32_t *)(vzm + H1_4(col)); + + m = f16mop_adj_pair(m, pcol, 0); + *a = bfdotadd(*a, n, m, &fpst); + } + col += 4; + pcol >>= 4; + } while (col & 15); + } + row += 4; + prow >>= 4; + } while (row & 15); + } } } @@ -1146,10 +1207,10 @@ static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \ uint64_t sum = 0; \ /* Apply P to N as a mask, making the inactive elements 0. */ \ n &= expand_pred_h(p); \ - sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ - sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ - sum += (NTYPE)(n >> 32) * (MTYPE)(m >> 32); \ - sum += (NTYPE)(n >> 48) * (MTYPE)(m >> 48); \ + sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ + sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ + sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32); \ + sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48); \ return neg ? a - sum : a + sum; \ } |