diff options
Diffstat (limited to 'target/arm/tcg/vec_helper.c')
-rw-r--r-- | target/arm/tcg/vec_helper.c | 147 |
1 files changed, 106 insertions, 41 deletions
diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c index 616ec54..b0de74b 100644 --- a/target/arm/tcg/vec_helper.c +++ b/target/arm/tcg/vec_helper.c @@ -2790,39 +2790,58 @@ DO_MMLA_B(gvec_usmmla_b, do_usmmla_b) * BFloat16 Dot Product */ -float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2) +bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp) { /* FPCR is ignored for BFDOT and BFMMLA. */ - float_status bf_status = { + *statusp = (float_status){ .tininess_before_rounding = float_tininess_before_rounding, .float_rounding_mode = float_round_to_odd_inf, .flush_to_zero = true, .flush_inputs_to_zero = true, .default_nan_mode = true, }; + + return false; +} + +float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst) +{ float32 t1, t2; /* * Extract each BFloat16 from the element pair, and shift * them such that they become float32. */ - t1 = float32_mul(e1 << 16, e2 << 16, &bf_status); - t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, &bf_status); - t1 = float32_add(t1, t2, &bf_status); - t1 = float32_add(sum, t1, &bf_status); + t1 = float32_mul(e1 << 16, e2 << 16, fpst); + t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, fpst); + t1 = float32_add(t1, t2, fpst); + t1 = float32_add(sum, t1, fpst); return t1; } +float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2, + float_status *fpst, float_status *fpst_odd) +{ + g_assert_not_reached(); +} + void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va, CPUARMState *env, uint32_t desc) { intptr_t i, opr_sz = simd_oprsz(desc); float32 *d = vd, *a = va; uint32_t *n = vn, *m = vm; + float_status fpst, fpst_odd; - for (i = 0; i < opr_sz / 4; ++i) { - d[i] = bfdotadd(a[i], n[i], m[i]); + if (is_ebf(env, &fpst, &fpst_odd)) { + for (i = 0; i < opr_sz / 4; ++i) { + d[i] = bfdotadd_ebf(a[i], n[i], m[i], &fpst, &fpst_odd); + } + } else { + for (i = 0; i < opr_sz / 4; ++i) { + d[i] = bfdotadd(a[i], n[i], m[i], &fpst); + } } clear_tail(d, opr_sz, simd_maxsz(desc)); } @@ -2836,12 +2855,23 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm, intptr_t eltspersegment = MIN(16 / 4, elements); float32 *d = vd, *a = va; uint32_t *n = vn, *m = vm; + float_status fpst, fpst_odd; - for (i = 0; i < elements; i += eltspersegment) { - uint32_t m_idx = m[i + H4(index)]; + if (is_ebf(env, &fpst, &fpst_odd)) { + for (i = 0; i < elements; i += eltspersegment) { + uint32_t m_idx = m[i + H4(index)]; - for (j = i; j < i + eltspersegment; j++) { - d[j] = bfdotadd(a[j], n[j], m_idx); + for (j = i; j < i + eltspersegment; j++) { + d[j] = bfdotadd_ebf(a[j], n[j], m_idx, &fpst, &fpst_odd); + } + } + } else { + for (i = 0; i < elements; i += eltspersegment) { + uint32_t m_idx = m[i + H4(index)]; + + for (j = i; j < i + eltspersegment; j++) { + d[j] = bfdotadd(a[j], n[j], m_idx, &fpst); + } } } clear_tail(d, opr_sz, simd_maxsz(desc)); @@ -2853,37 +2883,72 @@ void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, intptr_t s, opr_sz = simd_oprsz(desc); float32 *d = vd, *a = va; uint32_t *n = vn, *m = vm; + float_status fpst, fpst_odd; - for (s = 0; s < opr_sz / 4; s += 4) { - float32 sum00, sum01, sum10, sum11; + if (is_ebf(env, &fpst, &fpst_odd)) { + for (s = 0; s < opr_sz / 4; s += 4) { + float32 sum00, sum01, sum10, sum11; - /* - * Process the entire segment at once, writing back the - * results only after we've consumed all of the inputs. - * - * Key to indices by column: - * i j i k j k - */ - sum00 = a[s + H4(0 + 0)]; - sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]); - sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]); - - sum01 = a[s + H4(0 + 1)]; - sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]); - sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]); - - sum10 = a[s + H4(2 + 0)]; - sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]); - sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]); - - sum11 = a[s + H4(2 + 1)]; - sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]); - sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]); - - d[s + H4(0 + 0)] = sum00; - d[s + H4(0 + 1)] = sum01; - d[s + H4(2 + 0)] = sum10; - d[s + H4(2 + 1)] = sum11; + /* + * Process the entire segment at once, writing back the + * results only after we've consumed all of the inputs. + * + * Key to indices by column: + * i j i k j k + */ + sum00 = a[s + H4(0 + 0)]; + sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd); + sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd); + + sum01 = a[s + H4(0 + 1)]; + sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd); + sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd); + + sum10 = a[s + H4(2 + 0)]; + sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd); + sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd); + + sum11 = a[s + H4(2 + 1)]; + sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd); + sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd); + + d[s + H4(0 + 0)] = sum00; + d[s + H4(0 + 1)] = sum01; + d[s + H4(2 + 0)] = sum10; + d[s + H4(2 + 1)] = sum11; + } + } else { + for (s = 0; s < opr_sz / 4; s += 4) { + float32 sum00, sum01, sum10, sum11; + + /* + * Process the entire segment at once, writing back the + * results only after we've consumed all of the inputs. + * + * Key to indices by column: + * i j i k j k + */ + sum00 = a[s + H4(0 + 0)]; + sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst); + sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst); + + sum01 = a[s + H4(0 + 1)]; + sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst); + sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst); + + sum10 = a[s + H4(2 + 0)]; + sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst); + sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst); + + sum11 = a[s + H4(2 + 1)]; + sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst); + sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst); + + d[s + H4(0 + 0)] = sum00; + d[s + H4(0 + 1)] = sum01; + d[s + H4(2 + 0)] = sum10; + d[s + H4(2 + 1)] = sum11; + } } clear_tail(d, opr_sz, simd_maxsz(desc)); } |