aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Headers/amxcomplextransposeintrin.h
blob: 11abaf98e93719d2e71e90d4399054a70b1c9504 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
/*===----- amxcomplextransposeintrin.h - AMX-COMPLEX and AMX-TRANSPOSE ------===
 *
 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 * See https://llvm.org/LICENSE.txt for license information.
 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 *
 *===------------------------------------------------------------------------===
 */

#ifndef __IMMINTRIN_H
#error                                                                         \
    "Never use <amxcomplextransposeintrin.h> directly; include <immintrin.h> instead."
#endif // __IMMINTRIN_H

#ifndef __AMX_COMPLEXTRANSPOSEINTRIN_H
#define __AMX_COMPLEXTRANSPOSEINTRIN_H
#ifdef __x86_64__

#define __DEFAULT_FN_ATTRS                                                     \
  __attribute__((__always_inline__, __nodebug__,                               \
                 __target__("amx-complex,amx-transpose")))

/// Perform matrix multiplication of two tiles containing complex elements and
///    accumulate the results into a packed single precision tile. Each dword
///    element in input tiles \a a and \a b is interpreted as a complex number
///    with FP16 real part and FP16 imaginary part.
/// Calculates the imaginary part of the result. For each possible combination
///    of (transposed column of \a a, column of \a b), it performs a set of
///    multiplication and accumulations on all corresponding complex numbers
///    (one from \a a and one from \a b). The imaginary part of the \a a element
///    is multiplied with the real part of the corresponding \a b element, and
///    the real part of the \a a element is multiplied with the imaginary part
///    of the corresponding \a b elements. The two accumulated results are
///    added, and then accumulated into the corresponding row and column of
///    \a dst.
///
/// \headerfile <x86intrin.h>
///
/// \code
/// void _tile_tcmmimfp16ps(__tile dst, __tile a, __tile b);
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
///	tmp := dst.row[m]
///	FOR k := 0 TO a.rows - 1
///		FOR n := 0 TO (dst.colsb / 4) - 1
///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1])
///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0])
///		ENDFOR
///	ENDFOR
///	write_row_and_zero(dst, m, tmp, dst.colsb)
/// ENDFOR
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param a
///    The 1st source tile. Max size is 1024 Bytes.
/// \param b
///    The 2nd source tile. Max size is 1024 Bytes.
#define _tile_tcmmimfp16ps(dst, a, b)                                          \
  __builtin_ia32_ttcmmimfp16ps((dst), (a), (b))

/// Perform matrix multiplication of two tiles containing complex elements and
///    accumulate the results into a packed single precision tile. Each dword
///    element in input tiles \a a and \a b is interpreted as a complex number
///    with FP16 real part and FP16 imaginary part.
/// Calculates the real part of the result. For each possible combination
///    of (rtransposed colum of \a a, column of \a b), it performs a set of
///    multiplication and accumulations on all corresponding complex numbers
///    (one from \a a and one from \a b). The real part of the \a a element is
///    multiplied with the real part of the corresponding \a b element, and the
///    negated imaginary part of the \a a element is multiplied with the
///    imaginary part of the corresponding \a b elements. The two accumulated
///    results are added, and then accumulated into the corresponding row and
///    column of \a dst.
///
/// \headerfile <x86intrin.h>
///
/// \code
/// void _tile_tcmmrlfp16ps(__tile dst, __tile a, __tile b);
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
///	tmp := dst.row[m]
///	FOR k := 0 TO a.rows - 1
///		FOR n := 0 TO (dst.colsb / 4) - 1
///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+0])
///			tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+1])
///		ENDFOR
///	ENDFOR
///	write_row_and_zero(dst, m, tmp, dst.colsb)
/// ENDFOR
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param a
///    The 1st source tile. Max size is 1024 Bytes.
/// \param b
///    The 2nd source tile. Max size is 1024 Bytes.
#define _tile_tcmmrlfp16ps(dst, a, b)                                          \
  __builtin_ia32_ttcmmrlfp16ps((dst), (a), (b))

/// Perform matrix conjugate transpose and multiplication of two tiles
///    containing complex elements and accumulate the results into a packed
///    single precision tile. Each dword element in input tiles \a a and \a b
///    is interpreted as a complex number with FP16 real part and FP16 imaginary
///    part.
/// Calculates the imaginary part of the result. For each possible combination
///    of (transposed column of \a a, column of \a b), it performs a set of
///    multiplication and accumulations on all corresponding complex numbers
///    (one from \a a and one from \a b). The negated imaginary part of the \a a
///    element is multiplied with the real part of the corresponding \a b
///    element, and the real part of the \a a element is multiplied with the
///    imaginary part of the corresponding \a b elements. The two accumulated
///    results are added, and then accumulated into the corresponding row and
///    column of \a dst.
///
/// \headerfile <x86intrin.h>
///
/// \code
/// void _tile_conjtcmmimfp16ps(__tile dst, __tile a, __tile b);
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
///	tmp := dst.row[m]
///	FOR k := 0 TO a.rows - 1
///		FOR n := 0 TO (dst.colsb / 4) - 1
///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1])
///			tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0])
///		ENDFOR
///	ENDFOR
///	write_row_and_zero(dst, m, tmp, dst.colsb)
/// ENDFOR
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TCONJTCMMIMFP16PS instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param a
///    The 1st source tile. Max size is 1024 Bytes.
/// \param b
///    The 2nd source tile. Max size is 1024 Bytes.
#define _tile_conjtcmmimfp16ps(dst, a, b)                                      \
  __builtin_ia32_tconjtcmmimfp16ps((dst), (a), (b))

/// Perform conjugate transpose of an FP16-pair of complex elements from \a a
///    and writes the result to \a dst.
///
/// \headerfile <x86intrin.h>
///
/// \code
/// void _tile_conjtfp16(__tile dst, __tile a);
/// \endcode
///
/// \code{.operation}
/// FOR i := 0 TO dst.rows - 1
///	FOR j := 0 TO (dst.colsb / 4) - 1
///		tmp.fp16[2*j+0] := a.row[j].fp16[2*i+0]
///		tmp.fp16[2*j+1] := -a.row[j].fp16[2*i+1]
///	ENDFOR
///	write_row_and_zero(dst, i, tmp, dst.colsb)
/// ENDFOR
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TCONJTFP16 instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param a
///    The source tile. Max size is 1024 Bytes.
#define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16((dst), (a))

static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmimfp16ps_internal(
    unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    _tile1024i src1, _tile1024i src2) {
  return __builtin_ia32_ttcmmimfp16ps_internal(m, n, k, dst, src1, src2);
}

static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmrlfp16ps_internal(
    unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    _tile1024i src1, _tile1024i src2) {
  return __builtin_ia32_ttcmmrlfp16ps_internal(m, n, k, dst, src1, src2);
}

static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtcmmimfp16ps_internal(
    unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    _tile1024i src1, _tile1024i src2) {
  return __builtin_ia32_tconjtcmmimfp16ps_internal(m, n, k, dst, src1, src2);
}

static __inline__ _tile1024i __DEFAULT_FN_ATTRS
_tile_conjtfp16_internal(unsigned short m, unsigned short n, _tile1024i src) {
  return __builtin_ia32_tconjtfp16_internal(m, n, src);
}

/// Perform matrix multiplication of two tiles containing complex elements and
///    accumulate the results into a packed single precision tile. Each dword
///    element in input tiles src0 and src1 is interpreted as a complex number
///    with FP16 real part and FP16 imaginary part.
///    This function calculates the imaginary part of the result.
///
/// \headerfile <immintrin.h>
///
/// This intrinsic corresponds to the <c> TTCMMIMFP16PS </c> instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param src0
///    The 1st source tile. Max size is 1024 Bytes.
/// \param src1
///    The 2nd source tile. Max size is 1024 Bytes.
__DEFAULT_FN_ATTRS
static void __tile_tcmmimfp16ps(__tile1024i *dst, __tile1024i src0,
                                __tile1024i src1) {
  dst->tile = _tile_tcmmimfp16ps_internal(src0.row, src1.col, src0.col,
                                          dst->tile, src0.tile, src1.tile);
}

/// Perform matrix multiplication of two tiles containing complex elements and
///    accumulate the results into a packed single precision tile. Each dword
///    element in input tiles src0 and src1 is interpreted as a complex number
///    with FP16 real part and FP16 imaginary part.
///    This function calculates the real part of the result.
///
/// \headerfile <immintrin.h>
///
/// This intrinsic corresponds to the <c> TTCMMRLFP16PS </c> instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param src0
///    The 1st source tile. Max size is 1024 Bytes.
/// \param src1
///    The 2nd source tile. Max size is 1024 Bytes.
__DEFAULT_FN_ATTRS
static void __tile_tcmmrlfp16ps(__tile1024i *dst, __tile1024i src0,
                                __tile1024i src1) {
  dst->tile = _tile_tcmmrlfp16ps_internal(src0.row, src1.col, src0.col,
                                          dst->tile, src0.tile, src1.tile);
}

/// Perform matrix conjugate transpose and multiplication of two tiles
///    containing complex elements and accumulate the results into a packed
///    single precision tile. Each dword element in input tiles src0 and src1
///    is interpreted as a complex number with FP16 real part and FP16 imaginary
///    part.
///    This function calculates the imaginary part of the result.
///
/// \headerfile <immintrin.h>
///
/// This intrinsic corresponds to the <c> TCONJTCMMIMFP16PS </c> instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param src0
///    The 1st source tile. Max size is 1024 Bytes.
/// \param src1
///    The 2nd source tile. Max size is 1024 Bytes.
__DEFAULT_FN_ATTRS
static void __tile_conjtcmmimfp16ps(__tile1024i *dst, __tile1024i src0,
                                    __tile1024i src1) {
  dst->tile = _tile_conjtcmmimfp16ps_internal(src0.row, src1.col, src0.col,
                                              dst->tile, src0.tile, src1.tile);
}

/// Perform conjugate transpose of an FP16-pair of complex elements from src and
///    writes the result to dst.
///
/// \headerfile <immintrin.h>
///
/// This intrinsic corresponds to the <c> TCONJTFP16 </c> instruction.
///
/// \param dst
///    The destination tile. Max size is 1024 Bytes.
/// \param src
///    The source tile. Max size is 1024 Bytes.
__DEFAULT_FN_ATTRS
static void __tile_conjtfp16(__tile1024i *dst, __tile1024i src) {
  dst->tile = _tile_conjtfp16_internal(src.row, src.col, src.tile);
}

#undef __DEFAULT_FN_ATTRS

#endif // __x86_64__
#endif // __AMX_COMPLEXTRANSPOSEINTRIN_H