aboutsummaryrefslogtreecommitdiff
path: root/gcc
diff options
context:
space:
mode:
authorLevy Hsu <admin@levyhsu.com>2024-09-25 14:32:35 +1100
committerLevy Hsu <admin@levyhsu.com>2024-10-10 01:54:32 +0000
commit8718727509b2d038d00afa3bd5ef8e0df216a287 (patch)
tree727f0972a10bf47ff00ed79eec5dab5886185a83 /gcc
parent00ede02bc8bb73da8f4bf1d7542142cd923b0c54 (diff)
downloadgcc-8718727509b2d038d00afa3bd5ef8e0df216a287.zip
gcc-8718727509b2d038d00afa3bd5ef8e0df216a287.tar.gz
gcc-8718727509b2d038d00afa3bd5ef8e0df216a287.tar.bz2
x86: Implement Fast-Math Float Truncation to BF16 via PSRLD Instruction
gcc/ChangeLog: * config/i386/i386.md: Rewrite insn truncsfbf2. gcc/testsuite/ChangeLog: * gcc.target/i386/truncsfbf-1.c: New test. * gcc.target/i386/truncsfbf-2.c: New test.
Diffstat (limited to 'gcc')
-rw-r--r--gcc/config/i386/i386.md16
-rw-r--r--gcc/testsuite/gcc.target/i386/truncsfbf-1.c9
-rw-r--r--gcc/testsuite/gcc.target/i386/truncsfbf-2.c65
3 files changed, 83 insertions, 7 deletions
diff --git a/gcc/config/i386/i386.md b/gcc/config/i386/i386.md
index fb9befc..e4d1c56 100644
--- a/gcc/config/i386/i386.md
+++ b/gcc/config/i386/i386.md
@@ -5673,16 +5673,18 @@
(set_attr "mode" "HF")])
(define_insn "truncsfbf2"
- [(set (match_operand:BF 0 "register_operand" "=x, v")
+ [(set (match_operand:BF 0 "register_operand" "=x,x,v,Yv")
(float_truncate:BF
- (match_operand:SF 1 "register_operand" "x,v")))]
- "((TARGET_AVX512BF16 && TARGET_AVX512VL) || TARGET_AVXNECONVERT)
- && !HONOR_NANS (BFmode) && flag_unsafe_math_optimizations"
+ (match_operand:SF 1 "register_operand" "0,x,v,Yv")))]
+ "TARGET_SSE2 && flag_unsafe_math_optimizations && !HONOR_NANS (BFmode)"
"@
+ psrld\t{$16, %0|%0, 16}
%{vex%} vcvtneps2bf16\t{%1, %0|%0, %1}
- vcvtneps2bf16\t{%1, %0|%0, %1}"
- [(set_attr "isa" "avxneconvert,avx512bf16vl")
- (set_attr "prefix" "vex,evex")])
+ vcvtneps2bf16\t{%1, %0|%0, %1}
+ vpsrld\t{$16, %1, %0|%0, %1, 16}"
+ [(set_attr "isa" "noavx,avxneconvert,avx512bf16vl,avx")
+ (set_attr "prefix" "orig,vex,evex,vex")
+ (set_attr "type" "sseishft1,ssecvt,ssecvt,sseishft1")])
;; Signed conversion to DImode.
diff --git a/gcc/testsuite/gcc.target/i386/truncsfbf-1.c b/gcc/testsuite/gcc.target/i386/truncsfbf-1.c
new file mode 100644
index 0000000..dd3ff8a
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/truncsfbf-1.c
@@ -0,0 +1,9 @@
+/* { dg-do compile } */
+/* { dg-options "-msse2 -O2 -ffast-math" } */
+/* { dg-final { scan-assembler-times "psrld" 1 } } */
+
+__bf16
+foo (float a)
+{
+ return a;
+}
diff --git a/gcc/testsuite/gcc.target/i386/truncsfbf-2.c b/gcc/testsuite/gcc.target/i386/truncsfbf-2.c
new file mode 100644
index 0000000..f4952f8
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/truncsfbf-2.c
@@ -0,0 +1,65 @@
+/* { dg-do run } */
+/* { dg-options "-msse2 -O2 -ffast-math" } */
+
+#include <stdlib.h>
+#include <stdint.h>
+#include <string.h>
+#include <math.h>
+
+__bf16
+foo (float a)
+{
+ return a;
+}
+
+static __bf16
+CALC (float *a)
+{
+ uint32_t bits;
+ memcpy (&bits, a, sizeof (bits));
+ bits >>= 16;
+ uint16_t bfloat16_bits = (uint16_t) bits;
+ __bf16 bf16;
+ memcpy (&bf16, &bfloat16_bits, sizeof (bf16));
+ return bf16;
+}
+
+int
+main (void)
+{
+ float test_values[] = { 0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f, 1000.0f, -1000.0f,
+ 3.1415926f, -3.1415926f, 1e-8f, -1e-8f,
+ 1.0e+38f, -1.0e+38f, 1.0e-38f, -1.0e-38f };
+ size_t num_values = sizeof (test_values) / sizeof (test_values[0]);
+
+ for (size_t i = 0; i < num_values; ++i)
+ {
+ float original = test_values[i];
+ __bf16 hw_bf16 = foo (original);
+ __bf16 sw_bf16 = CALC (&original);
+
+ /* Verify psrld $16, %0 == %0 >> 16 */
+ if (memcmp (&hw_bf16, &sw_bf16, sizeof (__bf16)) != 0)
+ abort ();
+
+ /* Reconstruct the float value from the __bf16 bits */
+ uint16_t bf16_bits;
+ memcpy (&bf16_bits, &hw_bf16, sizeof (bf16_bits));
+ uint32_t reconstructed_bits = ((uint32_t) bf16_bits) << 16;
+ float converted;
+ memcpy (&converted, &reconstructed_bits, sizeof (converted));
+
+ float diff = fabsf (original - converted);
+
+ /* Expected Maximum Precision Loss */
+ uint32_t orig_bits;
+ memcpy (&orig_bits, &original, sizeof (orig_bits));
+ int exponent = ((orig_bits >> 23) & 0xFF) - 127;
+ float expected_loss = (exponent == -127)
+ ? ldexpf (1.0f, -126 - 7)
+ : ldexpf (1.0f, exponent - 7);
+ if (diff > expected_loss)
+ abort ();
+ }
+ return 0;
+}