diff options
author | Levy Hsu <admin@levyhsu.com> | 2024-09-25 14:32:35 +1100 |
---|---|---|
committer | Levy Hsu <admin@levyhsu.com> | 2024-10-10 01:54:32 +0000 |
commit | 8718727509b2d038d00afa3bd5ef8e0df216a287 (patch) | |
tree | 727f0972a10bf47ff00ed79eec5dab5886185a83 /gcc | |
parent | 00ede02bc8bb73da8f4bf1d7542142cd923b0c54 (diff) | |
download | gcc-8718727509b2d038d00afa3bd5ef8e0df216a287.zip gcc-8718727509b2d038d00afa3bd5ef8e0df216a287.tar.gz gcc-8718727509b2d038d00afa3bd5ef8e0df216a287.tar.bz2 |
x86: Implement Fast-Math Float Truncation to BF16 via PSRLD Instruction
gcc/ChangeLog:
* config/i386/i386.md: Rewrite insn truncsfbf2.
gcc/testsuite/ChangeLog:
* gcc.target/i386/truncsfbf-1.c: New test.
* gcc.target/i386/truncsfbf-2.c: New test.
Diffstat (limited to 'gcc')
-rw-r--r-- | gcc/config/i386/i386.md | 16 | ||||
-rw-r--r-- | gcc/testsuite/gcc.target/i386/truncsfbf-1.c | 9 | ||||
-rw-r--r-- | gcc/testsuite/gcc.target/i386/truncsfbf-2.c | 65 |
3 files changed, 83 insertions, 7 deletions
diff --git a/gcc/config/i386/i386.md b/gcc/config/i386/i386.md index fb9befc..e4d1c56 100644 --- a/gcc/config/i386/i386.md +++ b/gcc/config/i386/i386.md @@ -5673,16 +5673,18 @@ (set_attr "mode" "HF")]) (define_insn "truncsfbf2" - [(set (match_operand:BF 0 "register_operand" "=x, v") + [(set (match_operand:BF 0 "register_operand" "=x,x,v,Yv") (float_truncate:BF - (match_operand:SF 1 "register_operand" "x,v")))] - "((TARGET_AVX512BF16 && TARGET_AVX512VL) || TARGET_AVXNECONVERT) - && !HONOR_NANS (BFmode) && flag_unsafe_math_optimizations" + (match_operand:SF 1 "register_operand" "0,x,v,Yv")))] + "TARGET_SSE2 && flag_unsafe_math_optimizations && !HONOR_NANS (BFmode)" "@ + psrld\t{$16, %0|%0, 16} %{vex%} vcvtneps2bf16\t{%1, %0|%0, %1} - vcvtneps2bf16\t{%1, %0|%0, %1}" - [(set_attr "isa" "avxneconvert,avx512bf16vl") - (set_attr "prefix" "vex,evex")]) + vcvtneps2bf16\t{%1, %0|%0, %1} + vpsrld\t{$16, %1, %0|%0, %1, 16}" + [(set_attr "isa" "noavx,avxneconvert,avx512bf16vl,avx") + (set_attr "prefix" "orig,vex,evex,vex") + (set_attr "type" "sseishft1,ssecvt,ssecvt,sseishft1")]) ;; Signed conversion to DImode. diff --git a/gcc/testsuite/gcc.target/i386/truncsfbf-1.c b/gcc/testsuite/gcc.target/i386/truncsfbf-1.c new file mode 100644 index 0000000..dd3ff8a --- /dev/null +++ b/gcc/testsuite/gcc.target/i386/truncsfbf-1.c @@ -0,0 +1,9 @@ +/* { dg-do compile } */ +/* { dg-options "-msse2 -O2 -ffast-math" } */ +/* { dg-final { scan-assembler-times "psrld" 1 } } */ + +__bf16 +foo (float a) +{ + return a; +} diff --git a/gcc/testsuite/gcc.target/i386/truncsfbf-2.c b/gcc/testsuite/gcc.target/i386/truncsfbf-2.c new file mode 100644 index 0000000..f4952f8 --- /dev/null +++ b/gcc/testsuite/gcc.target/i386/truncsfbf-2.c @@ -0,0 +1,65 @@ +/* { dg-do run } */ +/* { dg-options "-msse2 -O2 -ffast-math" } */ + +#include <stdlib.h> +#include <stdint.h> +#include <string.h> +#include <math.h> + +__bf16 +foo (float a) +{ + return a; +} + +static __bf16 +CALC (float *a) +{ + uint32_t bits; + memcpy (&bits, a, sizeof (bits)); + bits >>= 16; + uint16_t bfloat16_bits = (uint16_t) bits; + __bf16 bf16; + memcpy (&bf16, &bfloat16_bits, sizeof (bf16)); + return bf16; +} + +int +main (void) +{ + float test_values[] = { 0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f, 1000.0f, -1000.0f, + 3.1415926f, -3.1415926f, 1e-8f, -1e-8f, + 1.0e+38f, -1.0e+38f, 1.0e-38f, -1.0e-38f }; + size_t num_values = sizeof (test_values) / sizeof (test_values[0]); + + for (size_t i = 0; i < num_values; ++i) + { + float original = test_values[i]; + __bf16 hw_bf16 = foo (original); + __bf16 sw_bf16 = CALC (&original); + + /* Verify psrld $16, %0 == %0 >> 16 */ + if (memcmp (&hw_bf16, &sw_bf16, sizeof (__bf16)) != 0) + abort (); + + /* Reconstruct the float value from the __bf16 bits */ + uint16_t bf16_bits; + memcpy (&bf16_bits, &hw_bf16, sizeof (bf16_bits)); + uint32_t reconstructed_bits = ((uint32_t) bf16_bits) << 16; + float converted; + memcpy (&converted, &reconstructed_bits, sizeof (converted)); + + float diff = fabsf (original - converted); + + /* Expected Maximum Precision Loss */ + uint32_t orig_bits; + memcpy (&orig_bits, &original, sizeof (orig_bits)); + int exponent = ((orig_bits >> 23) & 0xFF) - 127; + float expected_loss = (exponent == -127) + ? ldexpf (1.0f, -126 - 7) + : ldexpf (1.0f, exponent - 7); + if (diff > expected_loss) + abort (); + } + return 0; +} |