Skip to main content

core/stdarch/crates/core_arch/src/x86_64/
amx.rs

1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6/// Load tile configuration from a 64-byte memory location specified by mem_addr.
7/// The tile configuration format is specified below, and includes the tile type pallette,
8/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
9/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
10/// Any invalid configurations will result in #GP fault.
11///
12/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
13#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18    ldtilecfg(mem_addr);
19}
20
21/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
22/// The tile configuration format is specified below, and includes the tile type pallette,
23/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
24///
25/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
26#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31    sttilecfg(mem_addr);
32}
33
34/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
35///
36/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
37#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43    static_assert_uimm_bits!(DST, 3);
44    tileloadd64(DST as i8, base, stride);
45}
46
47/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
48///
49/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
50#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55    tilerelease();
56}
57
58/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
59///
60/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
61#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67    static_assert_uimm_bits!(DST, 3);
68    tilestored64(DST as i8, base, stride);
69}
70
71/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
72/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
73/// likely not be reused in the near future and the data caching can be optimized accordingly.
74///
75/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
76#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82    static_assert_uimm_bits!(DST, 3);
83    tileloaddt164(DST as i8, base, stride);
84}
85
86/// Zero the tile specified by tdest.
87///
88/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
89#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95    static_assert_uimm_bits!(DST, 3);
96    tilezero(DST as i8);
97}
98
99/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
100/// accumulating the intermediate single-precision (32-bit) floating-point elements
101/// with elements in dst, and store the 32-bit result back to tile dst.
102///
103/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
104#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110    static_assert_uimm_bits!(DST, 3);
111    static_assert_uimm_bits!(A, 3);
112    static_assert_uimm_bits!(B, 3);
113    tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116/// Compute dot-product of bytes in tiles with a source/destination accumulator.
117/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
118/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
119/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
120///
121/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
122#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128    static_assert_uimm_bits!(DST, 3);
129    static_assert_uimm_bits!(A, 3);
130    static_assert_uimm_bits!(B, 3);
131    tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134/// Compute dot-product of bytes in tiles with a source/destination accumulator.
135/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
136/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
137/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
138///
139/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
140#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146    static_assert_uimm_bits!(DST, 3);
147    static_assert_uimm_bits!(A, 3);
148    static_assert_uimm_bits!(B, 3);
149    tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152/// Compute dot-product of bytes in tiles with a source/destination accumulator.
153/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
154/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
155/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
156///
157/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
158#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164    static_assert_uimm_bits!(DST, 3);
165    static_assert_uimm_bits!(A, 3);
166    static_assert_uimm_bits!(B, 3);
167    tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170/// Compute dot-product of bytes in tiles with a source/destination accumulator.
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
172/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
173/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
174///
175/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
176#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182    static_assert_uimm_bits!(DST, 3);
183    static_assert_uimm_bits!(A, 3);
184    static_assert_uimm_bits!(B, 3);
185    tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
189/// accumulating the intermediate single-precision (32-bit) floating-point elements
190///  with elements in dst, and store the 32-bit result back to tile dst.
191///
192/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
193#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199    static_assert_uimm_bits!(DST, 3);
200    static_assert_uimm_bits!(A, 3);
201    static_assert_uimm_bits!(B, 3);
202    tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
206/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
207/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
208/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
209/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
210/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
211/// and then accumulated into the corresponding row and column of dst.
212///
213/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
214#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220    static_assert_uimm_bits!(DST, 3);
221    static_assert_uimm_bits!(A, 3);
222    static_assert_uimm_bits!(B, 3);
223    tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
227/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
228/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
229/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
230/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
231/// the a element is multiplied with the imaginary part of the corresponding b elements.
232/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
233///
234/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
235#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241    static_assert_uimm_bits!(DST, 3);
242    static_assert_uimm_bits!(A, 3);
243    static_assert_uimm_bits!(B, 3);
244    tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248/// floating-point elements in tile b, accumulating the intermediate single-precision
249/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250/// back to tile dst.
251#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255    all(test, not(target_vendor = "apple")),
256    assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260    static_assert_uimm_bits!(DST, 3);
261    static_assert_uimm_bits!(A, 3);
262    static_assert_uimm_bits!(B, 3);
263    tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269/// back to tile dst.
270#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274    all(test, not(target_vendor = "apple")),
275    assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279    static_assert_uimm_bits!(DST, 3);
280    static_assert_uimm_bits!(A, 3);
281    static_assert_uimm_bits!(B, 3);
282    tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288/// back to tile dst.
289#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293    all(test, not(target_vendor = "apple")),
294    assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298    static_assert_uimm_bits!(DST, 3);
299    static_assert_uimm_bits!(A, 3);
300    static_assert_uimm_bits!(B, 3);
301    tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305/// floating-point elements in tile b, accumulating the intermediate single-precision
306/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307/// back to tile dst.
308#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312    all(test, not(target_vendor = "apple")),
313    assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317    static_assert_uimm_bits!(DST, 3);
318    static_assert_uimm_bits!(A, 3);
319    static_assert_uimm_bits!(B, 3);
320    tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323/// Load tile rows from memory specified by base address and stride into destination tile dst
324/// using the tile configuration previously configured via _tile_loadconfig.
325/// Additionally, this intrinsic indicates the source memory location is likely to become
326/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327/// before it is written, assuming it is ever written in the future.
328#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332    all(test, not(target_vendor = "apple")),
333    assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337    static_assert_uimm_bits!(DST, 3);
338    tileloaddrs64(DST as i8, base, stride);
339}
340
341/// Load tile rows from memory specified by base address and stride into destination tile dst
342/// using the tile configuration previously configured via _tile_loadconfig.
343/// Provides a hint to the implementation that the data would be reused but does not need
344/// to be resident in the nearest cache levels.
345/// Additionally, this intrinsic indicates the source memory location is likely to become
346/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347/// before it is written, assuming it is ever written in the future.
348#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352    all(test, not(target_vendor = "apple")),
353    assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357    static_assert_uimm_bits!(DST, 3);
358    tileloaddrst164(DST as i8, base, stride);
359}
360
361/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363///  results into a packed single precision tile.
364/// For each possible combination of (row of a, column of b), it performs
365///  - convert to TF32
366///  - multiply the corresponding elements of a and b
367///  - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368/// rounding mode.
369/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370/// handled and *not* treated as zero.
371#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375    all(test, not(target_vendor = "apple")),
376    assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380    static_assert_uimm_bits!(DST, 3);
381    static_assert_uimm_bits!(A, 3);
382    static_assert_uimm_bits!(B, 3);
383    tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387/// elements to packed single-precision (32-bit) floating-point elements.
388#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392    all(test, not(target_vendor = "apple")),
393    assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397    static_assert_uimm_bits!(TILE, 3);
398    tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
402/// elements to packed single-precision (32-bit) floating-point elements.
403#[inline]
404#[rustc_legacy_const_generics(0, 1)]
405#[target_feature(enable = "amx-avx512,avx10.2")]
406#[cfg_attr(
407    all(test, not(target_vendor = "apple")),
408    assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
409)]
410#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
411pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
412    static_assert_uimm_bits!(TILE, 3);
413    static_assert_uimm_bits!(ROW, 6);
414    tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
415}
416
417/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
420#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424    all(test, not(target_vendor = "apple")),
425    assert_instr(tcvtrowps2phh, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
429    static_assert_uimm_bits!(TILE, 3);
430    tcvtrowps2phh(TILE as i8, row).as_m512h()
431}
432
433/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
434/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
435/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
436#[inline]
437#[rustc_legacy_const_generics(0, 1)]
438#[target_feature(enable = "amx-avx512,avx10.2")]
439#[cfg_attr(
440    all(test, not(target_vendor = "apple")),
441    assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
442)]
443#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
444pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h {
445    static_assert_uimm_bits!(TILE, 3);
446    static_assert_uimm_bits!(ROW, 6);
447    tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
448}
449
450/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
451/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
452/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
453#[inline]
454#[rustc_legacy_const_generics(0)]
455#[target_feature(enable = "amx-avx512,avx10.2")]
456#[cfg_attr(
457    all(test, not(target_vendor = "apple")),
458    assert_instr(tcvtrowps2phl, TILE = 0)
459)]
460#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
461pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
462    static_assert_uimm_bits!(TILE, 3);
463    tcvtrowps2phl(TILE as i8, row).as_m512h()
464}
465
466/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
467/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
468/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
469#[inline]
470#[rustc_legacy_const_generics(0, 1)]
471#[target_feature(enable = "amx-avx512,avx10.2")]
472#[cfg_attr(
473    all(test, not(target_vendor = "apple")),
474    assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
475)]
476#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
477pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h {
478    static_assert_uimm_bits!(TILE, 3);
479    static_assert_uimm_bits!(ROW, 6);
480    tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
481}
482
483/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
484/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
485/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
486#[inline]
487#[rustc_legacy_const_generics(0)]
488#[target_feature(enable = "amx-avx512,avx10.2")]
489#[cfg_attr(
490    all(test, not(target_vendor = "apple")),
491    assert_instr(tcvtrowps2bf16h, TILE = 0)
492)]
493#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
494pub unsafe fn _tile_cvtrowps2bf16h<const TILE: i32>(row: u32) -> __m512bh {
495    static_assert_uimm_bits!(TILE, 3);
496    tcvtrowps2bf16h(TILE as i8, row).as_m512bh()
497}
498
499/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
500/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
501/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
502#[inline]
503#[rustc_legacy_const_generics(0, 1)]
504#[target_feature(enable = "amx-avx512,avx10.2")]
505#[cfg_attr(
506    all(test, not(target_vendor = "apple")),
507    assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0)
508)]
509#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
510pub unsafe fn _tile_cvtrowps2bf16hi<const TILE: i32, const ROW: i32>() -> __m512bh {
511    static_assert_uimm_bits!(TILE, 3);
512    static_assert_uimm_bits!(ROW, 6);
513    tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
514}
515
516/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
517/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
518/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
519#[inline]
520#[rustc_legacy_const_generics(0)]
521#[target_feature(enable = "amx-avx512,avx10.2")]
522#[cfg_attr(
523    all(test, not(target_vendor = "apple")),
524    assert_instr(tcvtrowps2bf16l, TILE = 0)
525)]
526#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
527pub unsafe fn _tile_cvtrowps2bf16l<const TILE: i32>(row: u32) -> __m512bh {
528    static_assert_uimm_bits!(TILE, 3);
529    tcvtrowps2bf16l(TILE as i8, row).as_m512bh()
530}
531
532/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
533/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
534/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
535#[inline]
536#[rustc_legacy_const_generics(0, 1)]
537#[target_feature(enable = "amx-avx512,avx10.2")]
538#[cfg_attr(
539    all(test, not(target_vendor = "apple")),
540    assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0)
541)]
542#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
543pub unsafe fn _tile_cvtrowps2bf16li<const TILE: i32, const ROW: i32>() -> __m512bh {
544    static_assert_uimm_bits!(TILE, 3);
545    static_assert_uimm_bits!(ROW, 6);
546    tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
547}
548
549/// Moves one row of tile data into a zmm vector register
550#[inline]
551#[rustc_legacy_const_generics(0)]
552#[target_feature(enable = "amx-avx512,avx10.2")]
553#[cfg_attr(
554    all(test, not(target_vendor = "apple")),
555    assert_instr(tilemovrow, TILE = 0)
556)]
557#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
558pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
559    static_assert_uimm_bits!(TILE, 3);
560    tilemovrow(TILE as i8, row).as_m512i()
561}
562
563/// Moves one row of tile data into a zmm vector register
564#[inline]
565#[rustc_legacy_const_generics(0, 1)]
566#[target_feature(enable = "amx-avx512,avx10.2")]
567#[cfg_attr(
568    all(test, not(target_vendor = "apple")),
569    assert_instr(tilemovrow, TILE = 0, ROW = 0)
570)]
571#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
572pub unsafe fn _tile_movrowi<const TILE: i32, const ROW: i32>() -> __m512i {
573    static_assert_uimm_bits!(TILE, 3);
574    static_assert_uimm_bits!(ROW, 6);
575    tilemovrowi(TILE as i8, ROW as u32).as_m512i()
576}
577
578#[allow(improper_ctypes)]
579unsafe extern "C" {
580    #[link_name = "llvm.x86.ldtilecfg"]
581    fn ldtilecfg(mem_addr: *const u8);
582    #[link_name = "llvm.x86.sttilecfg"]
583    fn sttilecfg(mem_addr: *mut u8);
584    #[link_name = "llvm.x86.tileloadd64"]
585    fn tileloadd64(dst: i8, base: *const u8, stride: usize);
586    #[link_name = "llvm.x86.tileloaddt164"]
587    fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
588    #[link_name = "llvm.x86.tilerelease"]
589    fn tilerelease();
590    #[link_name = "llvm.x86.tilestored64"]
591    fn tilestored64(dst: i8, base: *mut u8, stride: usize);
592    #[link_name = "llvm.x86.tilezero"]
593    fn tilezero(dst: i8);
594    #[link_name = "llvm.x86.tdpbf16ps"]
595    fn tdpbf16ps(dst: i8, a: i8, b: i8);
596    #[link_name = "llvm.x86.tdpbuud"]
597    fn tdpbuud(dst: i8, a: i8, b: i8);
598    #[link_name = "llvm.x86.tdpbusd"]
599    fn tdpbusd(dst: i8, a: i8, b: i8);
600    #[link_name = "llvm.x86.tdpbsud"]
601    fn tdpbsud(dst: i8, a: i8, b: i8);
602    #[link_name = "llvm.x86.tdpbssd"]
603    fn tdpbssd(dst: i8, a: i8, b: i8);
604    #[link_name = "llvm.x86.tdpfp16ps"]
605    fn tdpfp16ps(dst: i8, a: i8, b: i8);
606    #[link_name = "llvm.x86.tcmmimfp16ps"]
607    fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
608    #[link_name = "llvm.x86.tcmmrlfp16ps"]
609    fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
610    #[link_name = "llvm.x86.tdpbf8ps"]
611    fn tdpbf8ps(dst: i8, a: i8, b: i8);
612    #[link_name = "llvm.x86.tdpbhf8ps"]
613    fn tdpbhf8ps(dst: i8, a: i8, b: i8);
614    #[link_name = "llvm.x86.tdphbf8ps"]
615    fn tdphbf8ps(dst: i8, a: i8, b: i8);
616    #[link_name = "llvm.x86.tdphf8ps"]
617    fn tdphf8ps(dst: i8, a: i8, b: i8);
618    #[link_name = "llvm.x86.tileloaddrs64"]
619    fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
620    #[link_name = "llvm.x86.tileloaddrst164"]
621    fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
622    #[link_name = "llvm.x86.tmmultf32ps"]
623    fn tmmultf32ps(dst: i8, a: i8, b: i8);
624    #[link_name = "llvm.x86.tcvtrowd2ps"]
625    fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
626    #[link_name = "llvm.x86.tcvtrowd2psi"]
627    fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
628    #[link_name = "llvm.x86.tcvtrowps2phh"]
629    fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
630    #[link_name = "llvm.x86.tcvtrowps2phhi"]
631    fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
632    #[link_name = "llvm.x86.tcvtrowps2phl"]
633    fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
634    #[link_name = "llvm.x86.tcvtrowps2phli"]
635    fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
636    #[link_name = "llvm.x86.tcvtrowps2bf16h"]
637    fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
638    #[link_name = "llvm.x86.tcvtrowps2bf16hi"]
639    fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
640    #[link_name = "llvm.x86.tcvtrowps2bf16l"]
641    fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
642    #[link_name = "llvm.x86.tcvtrowps2bf16li"]
643    fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
644    #[link_name = "llvm.x86.tilemovrow"]
645    fn tilemovrow(tile: i8, row: u32) -> i32x16;
646    #[link_name = "llvm.x86.tilemovrowi"]
647    fn tilemovrowi(tile: i8, row: u32) -> i32x16;
648}
649
650#[cfg(test)]
651mod tests {
652    use crate::core_arch::x86::_mm_cvtness_sbh;
653    use crate::core_arch::x86_64::*;
654    use core::{array, mem::transmute};
655    use stdarch_test::simd_test;
656    #[cfg(target_os = "linux")]
657    use syscalls::{Sysno, syscall};
658
659    #[allow(non_camel_case_types)]
660    #[repr(C, packed)]
661    #[derive(Copy, Clone, Default, Debug, PartialEq)]
662    struct __tilecfg {
663        /// 0 `or` 1
664        palette: u8,
665        start_row: u8,
666        /// reserved, must be zero
667        reserved_a0: [u8; 14],
668        /// number of bytes of one row in each tile
669        colsb: [u16; 8],
670        /// reserved, must be zero
671        reserved_b0: [u16; 8],
672        /// number of rows in each tile
673        rows: [u8; 8],
674        /// reserved, must be zero
675        reserved_c0: [u8; 8],
676    }
677
678    impl __tilecfg {
679        fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
680            Self {
681                palette,
682                start_row,
683                reserved_a0: [0u8; 14],
684                colsb,
685                reserved_b0: [0u16; 8],
686                rows,
687                reserved_c0: [0u8; 8],
688            }
689        }
690
691        const fn as_ptr(&self) -> *const u8 {
692            self as *const Self as *const u8
693        }
694
695        fn as_mut_ptr(&mut self) -> *mut u8 {
696            self as *mut Self as *mut u8
697        }
698    }
699
700    #[cfg(not(target_os = "linux"))]
701    #[target_feature(enable = "amx-tile")]
702    fn _init_amx() {}
703
704    #[cfg(target_os = "linux")]
705    #[target_feature(enable = "amx-tile")]
706    #[inline]
707    unsafe fn _init_amx() {
708        let mut ret: usize;
709        let mut xfeatures: usize = 0;
710        ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
711            .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
712        if ret != 0 {
713            panic!("Failed to get XFEATURES");
714        } else {
715            match 0b11 & (xfeatures >> 17) {
716                0 => panic!("AMX is not available"),
717                1 => {
718                    ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
719                        .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
720                    if ret != 0 {
721                        panic!("Failed to enable AMX");
722                    }
723                }
724                3 => {}
725                _ => unreachable!(),
726            }
727        }
728    }
729
730    #[simd_test(enable = "amx-tile")]
731    fn test_tile_loadconfig() {
732        unsafe {
733            let config = __tilecfg::default();
734            _tile_loadconfig(config.as_ptr());
735            _tile_release();
736        }
737    }
738
739    #[simd_test(enable = "amx-tile")]
740    fn test_tile_storeconfig() {
741        unsafe {
742            let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
743            _tile_loadconfig(config.as_ptr());
744            let mut _config = __tilecfg::default();
745            _tile_storeconfig(_config.as_mut_ptr());
746            _tile_release();
747            assert_eq!(config, _config);
748        }
749    }
750
751    #[simd_test(enable = "amx-tile")]
752    fn test_tile_zero() {
753        unsafe {
754            _init_amx();
755            let mut config = __tilecfg::default();
756            config.palette = 1;
757            config.colsb[0] = 64;
758            config.rows[0] = 16;
759            _tile_loadconfig(config.as_ptr());
760            _tile_zero::<0>();
761            let mut out = [[1_i8; 64]; 16];
762            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
763            _tile_release();
764            assert_eq!(out, [[0; 64]; 16]);
765        }
766    }
767
768    #[simd_test(enable = "amx-tile")]
769    fn test_tile_stored() {
770        unsafe {
771            _init_amx();
772            let mut config = __tilecfg::default();
773            config.palette = 1;
774            config.colsb[0] = 64;
775            config.rows[0] = 16;
776            _tile_loadconfig(config.as_ptr());
777            _tile_zero::<0>();
778            let mut out = [[1_i8; 64]; 16];
779            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
780            _tile_release();
781            assert_eq!(out, [[0; 64]; 16]);
782        }
783    }
784
785    #[simd_test(enable = "amx-tile")]
786    fn test_tile_loadd() {
787        unsafe {
788            _init_amx();
789            let mut config = __tilecfg::default();
790            config.palette = 1;
791            config.colsb[0] = 64;
792            config.rows[0] = 16;
793            _tile_loadconfig(config.as_ptr());
794            _tile_zero::<0>();
795            let mat = [1_i8; 1024];
796            _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
797            let mut out = [[0_i8; 64]; 16];
798            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
799            _tile_release();
800            assert_eq!(out, [[1; 64]; 16]);
801        }
802    }
803
804    #[simd_test(enable = "amx-tile")]
805    fn test_tile_stream_loadd() {
806        unsafe {
807            _init_amx();
808            let mut config = __tilecfg::default();
809            config.palette = 1;
810            config.colsb[0] = 64;
811            config.rows[0] = 16;
812            _tile_loadconfig(config.as_ptr());
813            _tile_zero::<0>();
814            let mat = [1_i8; 1024];
815            _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
816            let mut out = [[0_i8; 64]; 16];
817            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
818            _tile_release();
819            assert_eq!(out, [[1; 64]; 16]);
820        }
821    }
822
823    #[simd_test(enable = "amx-tile")]
824    fn test_tile_release() {
825        unsafe {
826            _tile_release();
827        }
828    }
829
830    #[simd_test(enable = "amx-bf16,avx512f")]
831    fn test_tile_dpbf16ps() {
832        unsafe {
833            _init_amx();
834            let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
835            let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
836            let ones: [u8; 1024] = transmute([bf16_1; 512]);
837            let twos: [u8; 1024] = transmute([bf16_2; 512]);
838            let mut res = [[0f32; 16]; 16];
839            let mut config = __tilecfg::default();
840            config.palette = 1;
841            (0..=2).for_each(|i| {
842                config.colsb[i] = 64;
843                config.rows[i] = 16;
844            });
845            _tile_loadconfig(config.as_ptr());
846            _tile_zero::<0>();
847            _tile_loadd::<1>(&ones as *const u8, 64);
848            _tile_loadd::<2>(&twos as *const u8, 64);
849            _tile_dpbf16ps::<0, 1, 2>();
850            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
851            _tile_release();
852            assert_eq!(res, [[64f32; 16]; 16]);
853        }
854    }
855
856    #[simd_test(enable = "amx-int8")]
857    fn test_tile_dpbssd() {
858        unsafe {
859            _init_amx();
860            let ones = [-1_i8; 1024];
861            let twos = [-2_i8; 1024];
862            let mut res = [[0_i32; 16]; 16];
863            let mut config = __tilecfg::default();
864            config.palette = 1;
865            (0..=2).for_each(|i| {
866                config.colsb[i] = 64;
867                config.rows[i] = 16;
868            });
869            _tile_loadconfig(config.as_ptr());
870            _tile_zero::<0>();
871            _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
872            _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
873            _tile_dpbssd::<0, 1, 2>();
874            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
875            _tile_release();
876            assert_eq!(res, [[128_i32; 16]; 16]);
877        }
878    }
879
880    #[simd_test(enable = "amx-int8")]
881    fn test_tile_dpbsud() {
882        unsafe {
883            _init_amx();
884            let ones = [-1_i8; 1024];
885            let twos = [2_u8; 1024];
886            let mut res = [[0_i32; 16]; 16];
887            let mut config = __tilecfg::default();
888            config.palette = 1;
889            (0..=2).for_each(|i| {
890                config.colsb[i] = 64;
891                config.rows[i] = 16;
892            });
893            _tile_loadconfig(config.as_ptr());
894            _tile_zero::<0>();
895            _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
896            _tile_loadd::<2>(&twos as *const u8, 64);
897            _tile_dpbsud::<0, 1, 2>();
898            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
899            _tile_release();
900            assert_eq!(res, [[-128_i32; 16]; 16]);
901        }
902    }
903
904    #[simd_test(enable = "amx-int8")]
905    fn test_tile_dpbusd() {
906        unsafe {
907            _init_amx();
908            let ones = [1_u8; 1024];
909            let twos = [-2_i8; 1024];
910            let mut res = [[0_i32; 16]; 16];
911            let mut config = __tilecfg::default();
912            config.palette = 1;
913            (0..=2).for_each(|i| {
914                config.colsb[i] = 64;
915                config.rows[i] = 16;
916            });
917            _tile_loadconfig(config.as_ptr());
918            _tile_zero::<0>();
919            _tile_loadd::<1>(&ones as *const u8, 64);
920            _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
921            _tile_dpbusd::<0, 1, 2>();
922            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
923            _tile_release();
924            assert_eq!(res, [[-128_i32; 16]; 16]);
925        }
926    }
927
928    #[simd_test(enable = "amx-int8")]
929    fn test_tile_dpbuud() {
930        unsafe {
931            _init_amx();
932            let ones = [1_u8; 1024];
933            let twos = [2_u8; 1024];
934            let mut res = [[0_i32; 16]; 16];
935            let mut config = __tilecfg::default();
936            config.palette = 1;
937            (0..=2).for_each(|i| {
938                config.colsb[i] = 64;
939                config.rows[i] = 16;
940            });
941            _tile_loadconfig(config.as_ptr());
942            _tile_zero::<0>();
943            _tile_loadd::<1>(&ones as *const u8, 64);
944            _tile_loadd::<2>(&twos as *const u8, 64);
945            _tile_dpbuud::<0, 1, 2>();
946            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
947            _tile_release();
948            assert_eq!(res, [[128_i32; 16]; 16]);
949        }
950    }
951
952    #[simd_test(enable = "amx-fp16")]
953    fn test_tile_dpfp16ps() {
954        unsafe {
955            _init_amx();
956            let ones = [1f16; 512];
957            let twos = [2f16; 512];
958            let mut res = [[0f32; 16]; 16];
959            let mut config = __tilecfg::default();
960            config.palette = 1;
961            (0..=2).for_each(|i| {
962                config.colsb[i] = 64;
963                config.rows[i] = 16;
964            });
965            _tile_loadconfig(config.as_ptr());
966            _tile_zero::<0>();
967            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
968            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
969            _tile_dpfp16ps::<0, 1, 2>();
970            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
971            _tile_release();
972            assert_eq!(res, [[64f32; 16]; 16]);
973        }
974    }
975
976    #[simd_test(enable = "amx-complex")]
977    fn test_tile_cmmimfp16ps() {
978        unsafe {
979            _init_amx();
980            let ones = [1f16; 512];
981            let twos = [2f16; 512];
982            let mut res = [[0f32; 16]; 16];
983            let mut config = __tilecfg::default();
984            config.palette = 1;
985            (0..=2).for_each(|i| {
986                config.colsb[i] = 64;
987                config.rows[i] = 16;
988            });
989            _tile_loadconfig(config.as_ptr());
990            _tile_zero::<0>();
991            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
992            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
993            _tile_cmmimfp16ps::<0, 1, 2>();
994            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
995            _tile_release();
996            assert_eq!(res, [[64f32; 16]; 16]);
997        }
998    }
999
1000    #[simd_test(enable = "amx-complex")]
1001    fn test_tile_cmmrlfp16ps() {
1002        unsafe {
1003            _init_amx();
1004            let ones = [1f16; 512];
1005            let twos = [2f16; 512];
1006            let mut res = [[0f32; 16]; 16];
1007            let mut config = __tilecfg::default();
1008            config.palette = 1;
1009            (0..=2).for_each(|i| {
1010                config.colsb[i] = 64;
1011                config.rows[i] = 16;
1012            });
1013            _tile_loadconfig(config.as_ptr());
1014            _tile_zero::<0>();
1015            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
1016            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
1017            _tile_cmmrlfp16ps::<0, 1, 2>();
1018            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
1019            _tile_release();
1020            assert_eq!(res, [[0f32; 16]; 16]);
1021        }
1022    }
1023
1024    const BF8_ONE: u8 = 0x3c;
1025    const BF8_TWO: u8 = 0x40;
1026    const HF8_ONE: u8 = 0x38;
1027    const HF8_TWO: u8 = 0x40;
1028
1029    #[simd_test(enable = "amx-fp8")]
1030    fn test_tile_dpbf8ps() {
1031        unsafe {
1032            _init_amx();
1033            let ones = [BF8_ONE; 1024];
1034            let twos = [BF8_TWO; 1024];
1035            let mut res = [[0.0_f32; 16]; 16];
1036            let mut config = __tilecfg::default();
1037            config.palette = 1;
1038            (0..=2).for_each(|i| {
1039                config.colsb[i] = 64;
1040                config.rows[i] = 16;
1041            });
1042            _tile_loadconfig(config.as_ptr());
1043            _tile_zero::<0>();
1044            _tile_loadd::<1>(&ones as *const u8, 64);
1045            _tile_loadd::<2>(&twos as *const u8, 64);
1046            _tile_dpbf8ps::<0, 1, 2>();
1047            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1048            _tile_release();
1049            assert_eq!(res, [[128.0_f32; 16]; 16]);
1050        }
1051    }
1052
1053    #[simd_test(enable = "amx-fp8")]
1054    fn test_tile_dpbhf8ps() {
1055        unsafe {
1056            _init_amx();
1057            let ones = [BF8_ONE; 1024];
1058            let twos = [HF8_TWO; 1024];
1059            let mut res = [[0.0_f32; 16]; 16];
1060            let mut config = __tilecfg::default();
1061            config.palette = 1;
1062            (0..=2).for_each(|i| {
1063                config.colsb[i] = 64;
1064                config.rows[i] = 16;
1065            });
1066            _tile_loadconfig(config.as_ptr());
1067            _tile_zero::<0>();
1068            _tile_loadd::<1>(&ones as *const u8, 64);
1069            _tile_loadd::<2>(&twos as *const u8, 64);
1070            _tile_dpbhf8ps::<0, 1, 2>();
1071            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1072            _tile_release();
1073            assert_eq!(res, [[128.0_f32; 16]; 16]);
1074        }
1075    }
1076
1077    #[simd_test(enable = "amx-fp8")]
1078    fn test_tile_dphbf8ps() {
1079        unsafe {
1080            _init_amx();
1081            let ones = [HF8_ONE; 1024];
1082            let twos = [BF8_TWO; 1024];
1083            let mut res = [[0.0_f32; 16]; 16];
1084            let mut config = __tilecfg::default();
1085            config.palette = 1;
1086            (0..=2).for_each(|i| {
1087                config.colsb[i] = 64;
1088                config.rows[i] = 16;
1089            });
1090            _tile_loadconfig(config.as_ptr());
1091            _tile_zero::<0>();
1092            _tile_loadd::<1>(&ones as *const u8, 64);
1093            _tile_loadd::<2>(&twos as *const u8, 64);
1094            _tile_dphbf8ps::<0, 1, 2>();
1095            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1096            _tile_release();
1097            assert_eq!(res, [[128.0_f32; 16]; 16]);
1098        }
1099    }
1100
1101    #[simd_test(enable = "amx-fp8")]
1102    fn test_tile_dphf8ps() {
1103        unsafe {
1104            _init_amx();
1105            let ones = [HF8_ONE; 1024];
1106            let twos = [HF8_TWO; 1024];
1107            let mut res = [[0.0_f32; 16]; 16];
1108            let mut config = __tilecfg::default();
1109            config.palette = 1;
1110            (0..=2).for_each(|i| {
1111                config.colsb[i] = 64;
1112                config.rows[i] = 16;
1113            });
1114            _tile_loadconfig(config.as_ptr());
1115            _tile_zero::<0>();
1116            _tile_loadd::<1>(&ones as *const u8, 64);
1117            _tile_loadd::<2>(&twos as *const u8, 64);
1118            _tile_dphf8ps::<0, 1, 2>();
1119            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1120            _tile_release();
1121            assert_eq!(res, [[128.0_f32; 16]; 16]);
1122        }
1123    }
1124
1125    #[simd_test(enable = "amx-movrs")]
1126    fn test_tile_loaddrs() {
1127        unsafe {
1128            _init_amx();
1129            let mut config = __tilecfg::default();
1130            config.palette = 1;
1131            config.colsb[0] = 64;
1132            config.rows[0] = 16;
1133            _tile_loadconfig(config.as_ptr());
1134            _tile_zero::<0>();
1135            let mat = [1_i8; 1024];
1136            _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1137            let mut out = [[0_i8; 64]; 16];
1138            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1139            _tile_release();
1140            assert_eq!(out, [[1; 64]; 16]);
1141        }
1142    }
1143
1144    #[simd_test(enable = "amx-movrs")]
1145    fn test_tile_stream_loaddrs() {
1146        unsafe {
1147            _init_amx();
1148            let mut config = __tilecfg::default();
1149            config.palette = 1;
1150            config.colsb[0] = 64;
1151            config.rows[0] = 16;
1152            _tile_loadconfig(config.as_ptr());
1153            _tile_zero::<0>();
1154            let mat = [1_i8; 1024];
1155            _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1156            let mut out = [[0_i8; 64]; 16];
1157            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1158            _tile_release();
1159            assert_eq!(out, [[1; 64]; 16]);
1160        }
1161    }
1162
1163    #[simd_test(enable = "amx-avx512,avx10.2")]
1164    fn test_tile_movrow() {
1165        unsafe {
1166            _init_amx();
1167            let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1168
1169            let mut config = __tilecfg::default();
1170            config.palette = 1;
1171            config.colsb[0] = 64;
1172            config.rows[0] = 16;
1173            _tile_loadconfig(config.as_ptr());
1174            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1175            for i in 0..16 {
1176                let row = _tile_movrow::<0>(i);
1177                assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1178            }
1179        }
1180    }
1181
1182    macro_rules! wrap_imm4 {
1183        ($name:ident :: <$TILE:literal>, $row:expr) => {
1184            match $row {
1185                0 => $name::<$TILE, 0>(),
1186                1 => $name::<$TILE, 1>(),
1187                2 => $name::<$TILE, 2>(),
1188                3 => $name::<$TILE, 3>(),
1189                4 => $name::<$TILE, 4>(),
1190                5 => $name::<$TILE, 5>(),
1191                6 => $name::<$TILE, 6>(),
1192                7 => $name::<$TILE, 7>(),
1193                8 => $name::<$TILE, 8>(),
1194                9 => $name::<$TILE, 9>(),
1195                10 => $name::<$TILE, 10>(),
1196                11 => $name::<$TILE, 11>(),
1197                12 => $name::<$TILE, 12>(),
1198                13 => $name::<$TILE, 13>(),
1199                14 => $name::<$TILE, 14>(),
1200                15 => $name::<$TILE, 15>(),
1201                _ => panic!("row index out of range"),
1202            }
1203        };
1204    }
1205
1206    #[simd_test(enable = "amx-avx512,avx10.2")]
1207    fn test_tile_movrowi() {
1208        unsafe {
1209            _init_amx();
1210            let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1211
1212            let mut config = __tilecfg::default();
1213            config.palette = 1;
1214            config.colsb[0] = 64;
1215            config.rows[0] = 16;
1216            _tile_loadconfig(config.as_ptr());
1217            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1218
1219            for i in 0..16 {
1220                let row = wrap_imm4!(_tile_movrowi::<0>, i);
1221                assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1222            }
1223        }
1224    }
1225
1226    #[simd_test(enable = "amx-avx512,avx10.2")]
1227    fn test_tile_cvtrowd2ps() {
1228        unsafe {
1229            _init_amx();
1230            let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1231
1232            let mut config = __tilecfg::default();
1233            config.palette = 1;
1234            config.colsb[0] = 64;
1235            config.rows[0] = 16;
1236            _tile_loadconfig(config.as_ptr());
1237            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1238            for i in 0..16 {
1239                let row = _tile_cvtrowd2ps::<0>(i);
1240                assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1241            }
1242        }
1243    }
1244
1245    #[simd_test(enable = "amx-avx512,avx10.2")]
1246    fn test_tile_cvtrowd2psi() {
1247        unsafe {
1248            _init_amx();
1249            let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1250
1251            let mut config = __tilecfg::default();
1252            config.palette = 1;
1253            config.colsb[0] = 64;
1254            config.rows[0] = 16;
1255            _tile_loadconfig(config.as_ptr());
1256            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1257
1258            for i in 0..16 {
1259                let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i);
1260                assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1261            }
1262        }
1263    }
1264
1265    #[simd_test(enable = "amx-avx512,avx10.2")]
1266    fn test_tile_cvtrowps2phh() {
1267        unsafe {
1268            _init_amx();
1269            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1270
1271            let mut config = __tilecfg::default();
1272            config.palette = 1;
1273            config.colsb[0] = 64;
1274            config.rows[0] = 16;
1275            _tile_loadconfig(config.as_ptr());
1276            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1277            for i in 0..16 {
1278                let row = _tile_cvtrowps2phh::<0>(i);
1279                assert_eq!(
1280                    *row.as_f16x32().as_array(),
1281                    array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1282                );
1283            }
1284        }
1285    }
1286
1287    #[simd_test(enable = "amx-avx512,avx10.2")]
1288    fn test_tile_cvtrowps2phhi() {
1289        unsafe {
1290            _init_amx();
1291            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1292
1293            let mut config = __tilecfg::default();
1294            config.palette = 1;
1295            config.colsb[0] = 64;
1296            config.rows[0] = 16;
1297            _tile_loadconfig(config.as_ptr());
1298            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1299            for i in 0..16 {
1300                let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i);
1301                assert_eq!(
1302                    *row.as_f16x32().as_array(),
1303                    array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1304                );
1305            }
1306        }
1307    }
1308
1309    #[simd_test(enable = "amx-avx512,avx10.2")]
1310    fn test_tile_cvtrowps2phl() {
1311        unsafe {
1312            _init_amx();
1313            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1314
1315            let mut config = __tilecfg::default();
1316            config.palette = 1;
1317            config.colsb[0] = 64;
1318            config.rows[0] = 16;
1319            _tile_loadconfig(config.as_ptr());
1320            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1321            for i in 0..16 {
1322                let row = _tile_cvtrowps2phl::<0>(i);
1323                assert_eq!(
1324                    *row.as_f16x32().as_array(),
1325                    array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1326                );
1327            }
1328        }
1329    }
1330
1331    #[simd_test(enable = "amx-avx512,avx10.2")]
1332    fn test_tile_cvtrowps2phli() {
1333        unsafe {
1334            _init_amx();
1335            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1336
1337            let mut config = __tilecfg::default();
1338            config.palette = 1;
1339            config.colsb[0] = 64;
1340            config.rows[0] = 16;
1341            _tile_loadconfig(config.as_ptr());
1342            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1343            for i in 0..16 {
1344                let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i);
1345                assert_eq!(
1346                    *row.as_f16x32().as_array(),
1347                    array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1348                );
1349            }
1350        }
1351    }
1352
1353    #[simd_test(enable = "amx-avx512,avx10.2")]
1354    fn test_tile_cvtrowps2bf16h() {
1355        unsafe {
1356            _init_amx();
1357            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1358
1359            let mut config = __tilecfg::default();
1360            config.palette = 1;
1361            config.colsb[0] = 64;
1362            config.rows[0] = 16;
1363            _tile_loadconfig(config.as_ptr());
1364            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1365            for i in 0..16 {
1366                let row = _tile_cvtrowps2bf16h::<0>(i);
1367                assert_eq!(
1368                    *row.as_u16x32().as_array(),
1369                    array::from_fn(|j| if j & 1 == 0 {
1370                        0
1371                    } else {
1372                        _mm_cvtness_sbh(i as _).to_bits()
1373                    })
1374                );
1375            }
1376        }
1377    }
1378
1379    #[simd_test(enable = "amx-avx512,avx10.2")]
1380    fn test_tile_cvtrowps2bf16hi() {
1381        unsafe {
1382            _init_amx();
1383            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1384
1385            let mut config = __tilecfg::default();
1386            config.palette = 1;
1387            config.colsb[0] = 64;
1388            config.rows[0] = 16;
1389            _tile_loadconfig(config.as_ptr());
1390            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1391            for i in 0..16 {
1392                let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i);
1393                assert_eq!(
1394                    *row.as_u16x32().as_array(),
1395                    array::from_fn(|j| if j & 1 == 0 {
1396                        0
1397                    } else {
1398                        _mm_cvtness_sbh(i as _).to_bits()
1399                    })
1400                );
1401            }
1402        }
1403    }
1404
1405    #[simd_test(enable = "amx-avx512,avx10.2")]
1406    fn test_tile_cvtrowps2bf16l() {
1407        unsafe {
1408            _init_amx();
1409            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1410
1411            let mut config = __tilecfg::default();
1412            config.palette = 1;
1413            config.colsb[0] = 64;
1414            config.rows[0] = 16;
1415            _tile_loadconfig(config.as_ptr());
1416            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1417            for i in 0..16 {
1418                let row = _tile_cvtrowps2bf16l::<0>(i);
1419                assert_eq!(
1420                    *row.as_u16x32().as_array(),
1421                    array::from_fn(|j| if j & 1 == 0 {
1422                        _mm_cvtness_sbh(i as _).to_bits()
1423                    } else {
1424                        0
1425                    })
1426                );
1427            }
1428        }
1429    }
1430
1431    #[simd_test(enable = "amx-avx512,avx10.2")]
1432    fn test_tile_cvtrowps2bf16li() {
1433        unsafe {
1434            _init_amx();
1435            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1436
1437            let mut config = __tilecfg::default();
1438            config.palette = 1;
1439            config.colsb[0] = 64;
1440            config.rows[0] = 16;
1441            _tile_loadconfig(config.as_ptr());
1442            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1443            for i in 0..16 {
1444                let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i);
1445                assert_eq!(
1446                    *row.as_u16x32().as_array(),
1447                    array::from_fn(|j| if j & 1 == 0 {
1448                        _mm_cvtness_sbh(i as _).to_bits()
1449                    } else {
1450                        0
1451                    })
1452                );
1453            }
1454        }
1455    }
1456
1457    #[simd_test(enable = "amx-tf32")]
1458    fn test_tile_mmultf32ps() {
1459        unsafe {
1460            _init_amx();
1461            let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1462            let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1463            let mut res = [[0.0; 16]; 16];
1464
1465            let mut config = __tilecfg::default();
1466            config.palette = 1;
1467            (0..=2).for_each(|i| {
1468                config.colsb[i] = 64;
1469                config.rows[i] = 16;
1470            });
1471            _tile_loadconfig(config.as_ptr());
1472            _tile_zero::<0>();
1473            _tile_loadd::<1>(a.as_ptr().cast(), 64);
1474            _tile_loadd::<2>(b.as_ptr().cast(), 64);
1475            _tile_mmultf32ps::<0, 1, 2>();
1476            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1477            _tile_release();
1478
1479            let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1480            assert_eq!(res, expected);
1481        }
1482    }
1483}