aboutsummaryrefslogtreecommitdiff
path: root/target/arm/tcg/sme_helper.c
diff options
context:
space:
mode:
Diffstat (limited to 'target/arm/tcg/sme_helper.c')
-rw-r--r--target/arm/tcg/sme_helper.c153
1 files changed, 107 insertions, 46 deletions
diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c
index 5a6dd76..de0c6e5 100644
--- a/target/arm/tcg/sme_helper.c
+++ b/target/arm/tcg/sme_helper.c
@@ -22,8 +22,8 @@
#include "internals.h"
#include "tcg/tcg-gvec-desc.h"
#include "exec/helper-proto.h"
-#include "exec/cpu_ldst.h"
-#include "exec/exec-all.h"
+#include "accel/tcg/cpu-ldst.h"
+#include "accel/tcg/helper-retaddr.h"
#include "qemu/int128.h"
#include "fpu/softfloat.h"
#include "vec_internal.h"
@@ -517,6 +517,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
clr_fn(za, 0, reg_off);
}
+ set_helper_retaddr(ra);
+
while (reg_off <= reg_last) {
uint64_t pg = vg[reg_off >> 6];
do {
@@ -529,6 +531,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
} while (reg_off <= reg_last && (reg_off & 63));
}
+ clear_helper_retaddr();
+
/*
* Use the slow path to manage the cross-page misalignment.
* But we know this is RAM and cannot trap.
@@ -543,6 +547,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
reg_last = info.reg_off_last[1];
host = info.page[1].host;
+ set_helper_retaddr(ra);
+
do {
uint64_t pg = vg[reg_off >> 6];
do {
@@ -554,6 +560,8 @@ void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
reg_off += esize;
} while (reg_off & 63);
} while (reg_off <= reg_last);
+
+ clear_helper_retaddr();
}
}
@@ -701,6 +709,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
reg_last = info.reg_off_last[0];
host = info.page[0].host;
+ set_helper_retaddr(ra);
+
while (reg_off <= reg_last) {
uint64_t pg = vg[reg_off >> 6];
do {
@@ -711,6 +721,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
} while (reg_off <= reg_last && (reg_off & 63));
}
+ clear_helper_retaddr();
+
/*
* Use the slow path to manage the cross-page misalignment.
* But we know this is RAM and cannot trap.
@@ -725,6 +737,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
reg_last = info.reg_off_last[1];
host = info.page[1].host;
+ set_helper_retaddr(ra);
+
do {
uint64_t pg = vg[reg_off >> 6];
do {
@@ -734,6 +748,8 @@ void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
reg_off += 1 << esz;
} while (reg_off & 63);
} while (reg_off <= reg_last);
+
+ clear_helper_retaddr();
}
}
@@ -888,7 +904,7 @@ void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
}
void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
- void *vpm, void *vst, uint32_t desc)
+ void *vpm, float_status *fpst_in, uint32_t desc)
{
intptr_t row, col, oprsz = simd_maxsz(desc);
uint32_t neg = simd_data(desc) << 31;
@@ -900,7 +916,7 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
* update the cumulative fp exception status. It also produces
* default nans.
*/
- fpst = *(float_status *)vst;
+ fpst = *fpst_in;
set_default_nan_mode(true, &fpst);
for (row = 0; row < oprsz; ) {
@@ -930,13 +946,13 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
}
void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
- void *vpm, void *vst, uint32_t desc)
+ void *vpm, float_status *fpst_in, uint32_t desc)
{
intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
uint64_t neg = (uint64_t)simd_data(desc) << 63;
uint64_t *za = vza, *zn = vzn, *zm = vzm;
uint8_t *pn = vpn, *pm = vpm;
- float_status fpst = *(float_status *)vst;
+ float_status fpst = *fpst_in;
set_default_nan_mode(true, &fpst);
@@ -976,12 +992,23 @@ static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
}
static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
- float_status *s_std, float_status *s_odd)
+ float_status *s_f16, float_status *s_std,
+ float_status *s_odd)
{
- float64 e1r = float16_to_float64(e1 & 0xffff, true, s_std);
- float64 e1c = float16_to_float64(e1 >> 16, true, s_std);
- float64 e2r = float16_to_float64(e2 & 0xffff, true, s_std);
- float64 e2c = float16_to_float64(e2 >> 16, true, s_std);
+ /*
+ * We need three different float_status for different parts of this
+ * operation:
+ * - the input conversion of the float16 values must use the
+ * f16-specific float_status, so that the FPCR.FZ16 control is applied
+ * - operations on float32 including the final accumulation must use
+ * the normal float_status, so that FPCR.FZ is applied
+ * - we have pre-set-up copy of s_std which is set to round-to-odd,
+ * for the multiply (see below)
+ */
+ float64 e1r = float16_to_float64(e1 & 0xffff, true, s_f16);
+ float64 e1c = float16_to_float64(e1 >> 16, true, s_f16);
+ float64 e2r = float16_to_float64(e2 & 0xffff, true, s_f16);
+ float64 e2c = float16_to_float64(e2 >> 16, true, s_f16);
float64 t64;
float32 t32;
@@ -1003,20 +1030,23 @@ static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
}
void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
- void *vpm, void *vst, uint32_t desc)
+ void *vpm, CPUARMState *env, uint32_t desc)
{
intptr_t row, col, oprsz = simd_maxsz(desc);
uint32_t neg = simd_data(desc) * 0x80008000u;
uint16_t *pn = vpn, *pm = vpm;
- float_status fpst_odd, fpst_std;
+ float_status fpst_odd, fpst_std, fpst_f16;
/*
- * Make a copy of float_status because this operation does not
- * update the cumulative fp exception status. It also produces
- * default nans. Make a second copy with round-to-odd -- see above.
+ * Make copies of the fp status fields we use, because this operation
+ * does not update the cumulative fp exception status. It also
+ * produces default NaNs. We also need a second copy of fp_status with
+ * round-to-odd -- see above.
*/
- fpst_std = *(float_status *)vst;
+ fpst_f16 = env->vfp.fp_status[FPST_A64_F16];
+ fpst_std = env->vfp.fp_status[FPST_A64];
set_default_nan_mode(true, &fpst_std);
+ set_default_nan_mode(true, &fpst_f16);
fpst_odd = fpst_std;
set_float_rounding_mode(float_round_to_odd, &fpst_odd);
@@ -1036,7 +1066,8 @@ void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
uint32_t m = *(uint32_t *)(vzm + H1_4(col));
m = f16mop_adj_pair(m, pcol, 0);
- *a = f16_dotadd(*a, n, m, &fpst_std, &fpst_odd);
+ *a = f16_dotadd(*a, n, m,
+ &fpst_f16, &fpst_std, &fpst_odd);
}
col += 4;
pcol >>= 4;
@@ -1048,38 +1079,68 @@ void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
}
}
-void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn,
- void *vpm, uint32_t desc)
+void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm,
+ void *vpn, void *vpm, CPUARMState *env, uint32_t desc)
{
intptr_t row, col, oprsz = simd_maxsz(desc);
uint32_t neg = simd_data(desc) * 0x80008000u;
uint16_t *pn = vpn, *pm = vpm;
+ float_status fpst, fpst_odd;
- for (row = 0; row < oprsz; ) {
- uint16_t prow = pn[H2(row >> 4)];
- do {
- void *vza_row = vza + tile_vslice_offset(row);
- uint32_t n = *(uint32_t *)(vzn + H1_4(row));
+ if (is_ebf(env, &fpst, &fpst_odd)) {
+ for (row = 0; row < oprsz; ) {
+ uint16_t prow = pn[H2(row >> 4)];
+ do {
+ void *vza_row = vza + tile_vslice_offset(row);
+ uint32_t n = *(uint32_t *)(vzn + H1_4(row));
- n = f16mop_adj_pair(n, prow, neg);
+ n = f16mop_adj_pair(n, prow, neg);
- for (col = 0; col < oprsz; ) {
- uint16_t pcol = pm[H2(col >> 4)];
- do {
- if (prow & pcol & 0b0101) {
- uint32_t *a = vza_row + H1_4(col);
- uint32_t m = *(uint32_t *)(vzm + H1_4(col));
+ for (col = 0; col < oprsz; ) {
+ uint16_t pcol = pm[H2(col >> 4)];
+ do {
+ if (prow & pcol & 0b0101) {
+ uint32_t *a = vza_row + H1_4(col);
+ uint32_t m = *(uint32_t *)(vzm + H1_4(col));
- m = f16mop_adj_pair(m, pcol, 0);
- *a = bfdotadd(*a, n, m);
- }
- col += 4;
- pcol >>= 4;
- } while (col & 15);
- }
- row += 4;
- prow >>= 4;
- } while (row & 15);
+ m = f16mop_adj_pair(m, pcol, 0);
+ *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
+ }
+ col += 4;
+ pcol >>= 4;
+ } while (col & 15);
+ }
+ row += 4;
+ prow >>= 4;
+ } while (row & 15);
+ }
+ } else {
+ for (row = 0; row < oprsz; ) {
+ uint16_t prow = pn[H2(row >> 4)];
+ do {
+ void *vza_row = vza + tile_vslice_offset(row);
+ uint32_t n = *(uint32_t *)(vzn + H1_4(row));
+
+ n = f16mop_adj_pair(n, prow, neg);
+
+ for (col = 0; col < oprsz; ) {
+ uint16_t pcol = pm[H2(col >> 4)];
+ do {
+ if (prow & pcol & 0b0101) {
+ uint32_t *a = vza_row + H1_4(col);
+ uint32_t m = *(uint32_t *)(vzm + H1_4(col));
+
+ m = f16mop_adj_pair(m, pcol, 0);
+ *a = bfdotadd(*a, n, m, &fpst);
+ }
+ col += 4;
+ pcol >>= 4;
+ } while (col & 15);
+ }
+ row += 4;
+ prow >>= 4;
+ } while (row & 15);
+ }
}
}
@@ -1146,10 +1207,10 @@ static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
uint64_t sum = 0; \
/* Apply P to N as a mask, making the inactive elements 0. */ \
n &= expand_pred_h(p); \
- sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \
- sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \
- sum += (NTYPE)(n >> 32) * (MTYPE)(m >> 32); \
- sum += (NTYPE)(n >> 48) * (MTYPE)(m >> 48); \
+ sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0); \
+ sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16); \
+ sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32); \
+ sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48); \
return neg ? a - sum : a + sum; \
}