1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6/// Load tile configuration from a 64-byte memory location specified by mem_addr.
7/// The tile configuration format is specified below, and includes the tile type pallette,
8/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
9/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
10/// Any invalid configurations will result in #GP fault.
11///
12/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
13#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18 ldtilecfg(mem_addr);
19}
20
21/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
22/// The tile configuration format is specified below, and includes the tile type pallette,
23/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
24///
25/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
26#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31 sttilecfg(mem_addr);
32}
33
34/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
35///
36/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
37#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43 static_assert_uimm_bits!(DST, 3);
44 tileloadd64(DST as i8, base, stride);
45}
46
47/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
48///
49/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
50#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55 tilerelease();
56}
57
58/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
59///
60/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
61#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67 static_assert_uimm_bits!(DST, 3);
68 tilestored64(DST as i8, base, stride);
69}
70
71/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
72/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
73/// likely not be reused in the near future and the data caching can be optimized accordingly.
74///
75/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
76#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82 static_assert_uimm_bits!(DST, 3);
83 tileloaddt164(DST as i8, base, stride);
84}
85
86/// Zero the tile specified by tdest.
87///
88/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
89#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95 static_assert_uimm_bits!(DST, 3);
96 tilezero(DST as i8);
97}
98
99/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
100/// accumulating the intermediate single-precision (32-bit) floating-point elements
101/// with elements in dst, and store the 32-bit result back to tile dst.
102///
103/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
104#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110 static_assert_uimm_bits!(DST, 3);
111 static_assert_uimm_bits!(A, 3);
112 static_assert_uimm_bits!(B, 3);
113 tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116/// Compute dot-product of bytes in tiles with a source/destination accumulator.
117/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
118/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
119/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
120///
121/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
122#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128 static_assert_uimm_bits!(DST, 3);
129 static_assert_uimm_bits!(A, 3);
130 static_assert_uimm_bits!(B, 3);
131 tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134/// Compute dot-product of bytes in tiles with a source/destination accumulator.
135/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
136/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
137/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
138///
139/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
140#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146 static_assert_uimm_bits!(DST, 3);
147 static_assert_uimm_bits!(A, 3);
148 static_assert_uimm_bits!(B, 3);
149 tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152/// Compute dot-product of bytes in tiles with a source/destination accumulator.
153/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
154/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
155/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
156///
157/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
158#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164 static_assert_uimm_bits!(DST, 3);
165 static_assert_uimm_bits!(A, 3);
166 static_assert_uimm_bits!(B, 3);
167 tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170/// Compute dot-product of bytes in tiles with a source/destination accumulator.
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
172/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
173/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
174///
175/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
176#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182 static_assert_uimm_bits!(DST, 3);
183 static_assert_uimm_bits!(A, 3);
184 static_assert_uimm_bits!(B, 3);
185 tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
189/// accumulating the intermediate single-precision (32-bit) floating-point elements
190/// with elements in dst, and store the 32-bit result back to tile dst.
191///
192/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
193#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199 static_assert_uimm_bits!(DST, 3);
200 static_assert_uimm_bits!(A, 3);
201 static_assert_uimm_bits!(B, 3);
202 tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
206/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
207/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
208/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
209/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
210/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
211/// and then accumulated into the corresponding row and column of dst.
212///
213/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
214#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220 static_assert_uimm_bits!(DST, 3);
221 static_assert_uimm_bits!(A, 3);
222 static_assert_uimm_bits!(B, 3);
223 tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
227/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
228/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
229/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
230/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
231/// the a element is multiplied with the imaginary part of the corresponding b elements.
232/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
233///
234/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
235#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241 static_assert_uimm_bits!(DST, 3);
242 static_assert_uimm_bits!(A, 3);
243 static_assert_uimm_bits!(B, 3);
244 tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248/// floating-point elements in tile b, accumulating the intermediate single-precision
249/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250/// back to tile dst.
251#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255 all(test, any(target_os = "linux", target_env = "msvc")),
256 assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260 static_assert_uimm_bits!(DST, 3);
261 static_assert_uimm_bits!(A, 3);
262 static_assert_uimm_bits!(B, 3);
263 tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269/// back to tile dst.
270#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274 all(test, any(target_os = "linux", target_env = "msvc")),
275 assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279 static_assert_uimm_bits!(DST, 3);
280 static_assert_uimm_bits!(A, 3);
281 static_assert_uimm_bits!(B, 3);
282 tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288/// back to tile dst.
289#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293 all(test, any(target_os = "linux", target_env = "msvc")),
294 assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298 static_assert_uimm_bits!(DST, 3);
299 static_assert_uimm_bits!(A, 3);
300 static_assert_uimm_bits!(B, 3);
301 tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305/// floating-point elements in tile b, accumulating the intermediate single-precision
306/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307/// back to tile dst.
308#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312 all(test, any(target_os = "linux", target_env = "msvc")),
313 assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317 static_assert_uimm_bits!(DST, 3);
318 static_assert_uimm_bits!(A, 3);
319 static_assert_uimm_bits!(B, 3);
320 tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323/// Load tile rows from memory specified by base address and stride into destination tile dst
324/// using the tile configuration previously configured via _tile_loadconfig.
325/// Additionally, this intrinsic indicates the source memory location is likely to become
326/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327/// before it is written, assuming it is ever written in the future.
328#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332 all(test, any(target_os = "linux", target_env = "msvc")),
333 assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337 static_assert_uimm_bits!(DST, 3);
338 tileloaddrs64(DST as i8, base, stride);
339}
340
341/// Load tile rows from memory specified by base address and stride into destination tile dst
342/// using the tile configuration previously configured via _tile_loadconfig.
343/// Provides a hint to the implementation that the data would be reused but does not need
344/// to be resident in the nearest cache levels.
345/// Additionally, this intrinsic indicates the source memory location is likely to become
346/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347/// before it is written, assuming it is ever written in the future.
348#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352 all(test, any(target_os = "linux", target_env = "msvc")),
353 assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357 static_assert_uimm_bits!(DST, 3);
358 tileloaddrst164(DST as i8, base, stride);
359}
360
361/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363/// results into a packed single precision tile.
364/// For each possible combination of (row of a, column of b), it performs
365/// - convert to TF32
366/// - multiply the corresponding elements of a and b
367/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368/// rounding mode.
369/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370/// handled and *not* treated as zero.
371#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375 all(test, any(target_os = "linux", target_env = "msvc")),
376 assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380 static_assert_uimm_bits!(DST, 3);
381 static_assert_uimm_bits!(A, 3);
382 static_assert_uimm_bits!(B, 3);
383 tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387/// elements to packed single-precision (32-bit) floating-point elements.
388#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392 all(test, any(target_os = "linux", target_env = "msvc")),
393 assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397 static_assert_uimm_bits!(TILE, 3);
398 tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
402/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
403/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
404#[inline]
405#[rustc_legacy_const_generics(0)]
406#[target_feature(enable = "amx-avx512,avx10.2")]
407#[cfg_attr(
408 all(test, any(target_os = "linux", target_env = "msvc")),
409 assert_instr(tcvtrowps2phh, TILE = 0)
410)]
411#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
412pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
413 static_assert_uimm_bits!(TILE, 3);
414 tcvtrowps2phh(TILE as i8, row).as_m512h()
415}
416
417/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
420#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424 all(test, any(target_os = "linux", target_env = "msvc")),
425 assert_instr(tcvtrowps2phl, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
429 static_assert_uimm_bits!(TILE, 3);
430 tcvtrowps2phl(TILE as i8, row).as_m512h()
431}
432
433/// Moves one row of tile data into a zmm vector register
434#[inline]
435#[rustc_legacy_const_generics(0)]
436#[target_feature(enable = "amx-avx512,avx10.2")]
437#[cfg_attr(
438 all(test, any(target_os = "linux", target_env = "msvc")),
439 assert_instr(tilemovrow, TILE = 0)
440)]
441#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
442pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
443 static_assert_uimm_bits!(TILE, 3);
444 tilemovrow(TILE as i8, row).as_m512i()
445}
446
447#[allow(improper_ctypes)]
448unsafe extern "C" {
449 #[link_name = "llvm.x86.ldtilecfg"]
450 unsafefn ldtilecfg(mem_addr: *const u8);
451 #[link_name = "llvm.x86.sttilecfg"]
452 unsafefn sttilecfg(mem_addr: *mut u8);
453 #[link_name = "llvm.x86.tileloadd64"]
454 unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize);
455 #[link_name = "llvm.x86.tileloaddt164"]
456 unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize);
457 #[link_name = "llvm.x86.tilerelease"]
458 unsafefn tilerelease();
459 #[link_name = "llvm.x86.tilestored64"]
460 unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize);
461 #[link_name = "llvm.x86.tilezero"]
462 unsafefn tilezero(dst: i8);
463 #[link_name = "llvm.x86.tdpbf16ps"]
464 unsafefn tdpbf16ps(dst: i8, a: i8, b: i8);
465 #[link_name = "llvm.x86.tdpbuud"]
466 unsafefn tdpbuud(dst: i8, a: i8, b: i8);
467 #[link_name = "llvm.x86.tdpbusd"]
468 unsafefn tdpbusd(dst: i8, a: i8, b: i8);
469 #[link_name = "llvm.x86.tdpbsud"]
470 unsafefn tdpbsud(dst: i8, a: i8, b: i8);
471 #[link_name = "llvm.x86.tdpbssd"]
472 unsafefn tdpbssd(dst: i8, a: i8, b: i8);
473 #[link_name = "llvm.x86.tdpfp16ps"]
474 unsafefn tdpfp16ps(dst: i8, a: i8, b: i8);
475 #[link_name = "llvm.x86.tcmmimfp16ps"]
476 unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8);
477 #[link_name = "llvm.x86.tcmmrlfp16ps"]
478 unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
479 #[link_name = "llvm.x86.tdpbf8ps"]
480 unsafefn tdpbf8ps(dst: i8, a: i8, b: i8);
481 #[link_name = "llvm.x86.tdpbhf8ps"]
482 unsafefn tdpbhf8ps(dst: i8, a: i8, b: i8);
483 #[link_name = "llvm.x86.tdphbf8ps"]
484 unsafefn tdphbf8ps(dst: i8, a: i8, b: i8);
485 #[link_name = "llvm.x86.tdphf8ps"]
486 unsafefn tdphf8ps(dst: i8, a: i8, b: i8);
487 #[link_name = "llvm.x86.tileloaddrs64"]
488 unsafefn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
489 #[link_name = "llvm.x86.tileloaddrst164"]
490 unsafefn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
491 #[link_name = "llvm.x86.tmmultf32ps"]
492 unsafefn tmmultf32ps(dst: i8, a: i8, b: i8);
493 #[link_name = "llvm.x86.tcvtrowd2ps"]
494 unsafefn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
495 #[link_name = "llvm.x86.tcvtrowps2phh"]
496 unsafefn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
497 #[link_name = "llvm.x86.tcvtrowps2phl"]
498 unsafefn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
499 #[link_name = "llvm.x86.tilemovrow"]
500 unsafefn tilemovrow(tile: i8, row: u32) -> i32x16;
501}
502
503#[cfg(test)]
504mod tests {
505 use crate::core_arch::x86::_mm_cvtness_sbh;
506 use crate::core_arch::x86_64::*;
507 use core::{array, mem::transmute};
508 use stdarch_test::simd_test;
509 #[cfg(target_os = "linux")]
510 use syscalls::{Sysno, syscall};
511
512 #[allow(non_camel_case_types)]
513 #[repr(C, packed)]
514 #[derive(Copy, Clone, Default, Debug, PartialEq)]
515 struct __tilecfg {
516 /// 0 `or` 1
517 palette: u8,
518 start_row: u8,
519 /// reserved, must be zero
520 reserved_a0: [u8; 14],
521 /// number of bytes of one row in each tile
522 colsb: [u16; 8],
523 /// reserved, must be zero
524 reserved_b0: [u16; 8],
525 /// number of rows in each tile
526 rows: [u8; 8],
527 /// reserved, must be zero
528 reserved_c0: [u8; 8],
529 }
530
531 impl __tilecfg {
532 fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
533 Self {
534 palette,
535 start_row,
536 reserved_a0: [0u8; 14],
537 colsb,
538 reserved_b0: [0u16; 8],
539 rows,
540 reserved_c0: [0u8; 8],
541 }
542 }
543
544 const fn as_ptr(&self) -> *const u8 {
545 self as *const Self as *const u8
546 }
547
548 fn as_mut_ptr(&mut self) -> *mut u8 {
549 self as *mut Self as *mut u8
550 }
551 }
552
553 #[cfg(not(target_os = "linux"))]
554 #[target_feature(enable = "amx-tile")]
555 fn _init_amx() {}
556
557 #[cfg(target_os = "linux")]
558 #[target_feature(enable = "amx-tile")]
559 #[inline]
560 unsafe fn _init_amx() {
561 let mut ret: usize;
562 let mut xfeatures: usize = 0;
563 ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
564 .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
565 if ret != 0 {
566 panic!("Failed to get XFEATURES");
567 } else {
568 match 0b11 & (xfeatures >> 17) {
569 0 => panic!("AMX is not available"),
570 1 => {
571 ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
572 .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
573 if ret != 0 {
574 panic!("Failed to enable AMX");
575 }
576 }
577 3 => {}
578 _ => unreachable!(),
579 }
580 }
581 }
582
583 #[simd_test(enable = "amx-tile")]
584 fn test_tile_loadconfig() {
585 unsafe {
586 let config = __tilecfg::default();
587 _tile_loadconfig(config.as_ptr());
588 _tile_release();
589 }
590 }
591
592 #[simd_test(enable = "amx-tile")]
593 fn test_tile_storeconfig() {
594 unsafe {
595 let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
596 _tile_loadconfig(config.as_ptr());
597 let mut _config = __tilecfg::default();
598 _tile_storeconfig(_config.as_mut_ptr());
599 _tile_release();
600 assert_eq!(config, _config);
601 }
602 }
603
604 #[simd_test(enable = "amx-tile")]
605 fn test_tile_zero() {
606 unsafe {
607 _init_amx();
608 let mut config = __tilecfg::default();
609 config.palette = 1;
610 config.colsb[0] = 64;
611 config.rows[0] = 16;
612 _tile_loadconfig(config.as_ptr());
613 _tile_zero::<0>();
614 let mut out = [[1_i8; 64]; 16];
615 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
616 _tile_release();
617 assert_eq!(out, [[0; 64]; 16]);
618 }
619 }
620
621 #[simd_test(enable = "amx-tile")]
622 fn test_tile_stored() {
623 unsafe {
624 _init_amx();
625 let mut config = __tilecfg::default();
626 config.palette = 1;
627 config.colsb[0] = 64;
628 config.rows[0] = 16;
629 _tile_loadconfig(config.as_ptr());
630 _tile_zero::<0>();
631 let mut out = [[1_i8; 64]; 16];
632 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
633 _tile_release();
634 assert_eq!(out, [[0; 64]; 16]);
635 }
636 }
637
638 #[simd_test(enable = "amx-tile")]
639 fn test_tile_loadd() {
640 unsafe {
641 _init_amx();
642 let mut config = __tilecfg::default();
643 config.palette = 1;
644 config.colsb[0] = 64;
645 config.rows[0] = 16;
646 _tile_loadconfig(config.as_ptr());
647 _tile_zero::<0>();
648 let mat = [1_i8; 1024];
649 _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
650 let mut out = [[0_i8; 64]; 16];
651 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
652 _tile_release();
653 assert_eq!(out, [[1; 64]; 16]);
654 }
655 }
656
657 #[simd_test(enable = "amx-tile")]
658 fn test_tile_stream_loadd() {
659 unsafe {
660 _init_amx();
661 let mut config = __tilecfg::default();
662 config.palette = 1;
663 config.colsb[0] = 64;
664 config.rows[0] = 16;
665 _tile_loadconfig(config.as_ptr());
666 _tile_zero::<0>();
667 let mat = [1_i8; 1024];
668 _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
669 let mut out = [[0_i8; 64]; 16];
670 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
671 _tile_release();
672 assert_eq!(out, [[1; 64]; 16]);
673 }
674 }
675
676 #[simd_test(enable = "amx-tile")]
677 fn test_tile_release() {
678 unsafe {
679 _tile_release();
680 }
681 }
682
683 #[simd_test(enable = "amx-bf16,avx512f")]
684 fn test_tile_dpbf16ps() {
685 unsafe {
686 _init_amx();
687 let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
688 let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
689 let ones: [u8; 1024] = transmute([bf16_1; 512]);
690 let twos: [u8; 1024] = transmute([bf16_2; 512]);
691 let mut res = [[0f32; 16]; 16];
692 let mut config = __tilecfg::default();
693 config.palette = 1;
694 (0..=2).for_each(|i| {
695 config.colsb[i] = 64;
696 config.rows[i] = 16;
697 });
698 _tile_loadconfig(config.as_ptr());
699 _tile_zero::<0>();
700 _tile_loadd::<1>(&ones as *const u8, 64);
701 _tile_loadd::<2>(&twos as *const u8, 64);
702 _tile_dpbf16ps::<0, 1, 2>();
703 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
704 _tile_release();
705 assert_eq!(res, [[64f32; 16]; 16]);
706 }
707 }
708
709 #[simd_test(enable = "amx-int8")]
710 fn test_tile_dpbssd() {
711 unsafe {
712 _init_amx();
713 let ones = [-1_i8; 1024];
714 let twos = [-2_i8; 1024];
715 let mut res = [[0_i32; 16]; 16];
716 let mut config = __tilecfg::default();
717 config.palette = 1;
718 (0..=2).for_each(|i| {
719 config.colsb[i] = 64;
720 config.rows[i] = 16;
721 });
722 _tile_loadconfig(config.as_ptr());
723 _tile_zero::<0>();
724 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
725 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
726 _tile_dpbssd::<0, 1, 2>();
727 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
728 _tile_release();
729 assert_eq!(res, [[128_i32; 16]; 16]);
730 }
731 }
732
733 #[simd_test(enable = "amx-int8")]
734 fn test_tile_dpbsud() {
735 unsafe {
736 _init_amx();
737 let ones = [-1_i8; 1024];
738 let twos = [2_u8; 1024];
739 let mut res = [[0_i32; 16]; 16];
740 let mut config = __tilecfg::default();
741 config.palette = 1;
742 (0..=2).for_each(|i| {
743 config.colsb[i] = 64;
744 config.rows[i] = 16;
745 });
746 _tile_loadconfig(config.as_ptr());
747 _tile_zero::<0>();
748 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
749 _tile_loadd::<2>(&twos as *const u8, 64);
750 _tile_dpbsud::<0, 1, 2>();
751 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
752 _tile_release();
753 assert_eq!(res, [[-128_i32; 16]; 16]);
754 }
755 }
756
757 #[simd_test(enable = "amx-int8")]
758 fn test_tile_dpbusd() {
759 unsafe {
760 _init_amx();
761 let ones = [1_u8; 1024];
762 let twos = [-2_i8; 1024];
763 let mut res = [[0_i32; 16]; 16];
764 let mut config = __tilecfg::default();
765 config.palette = 1;
766 (0..=2).for_each(|i| {
767 config.colsb[i] = 64;
768 config.rows[i] = 16;
769 });
770 _tile_loadconfig(config.as_ptr());
771 _tile_zero::<0>();
772 _tile_loadd::<1>(&ones as *const u8, 64);
773 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
774 _tile_dpbusd::<0, 1, 2>();
775 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
776 _tile_release();
777 assert_eq!(res, [[-128_i32; 16]; 16]);
778 }
779 }
780
781 #[simd_test(enable = "amx-int8")]
782 fn test_tile_dpbuud() {
783 unsafe {
784 _init_amx();
785 let ones = [1_u8; 1024];
786 let twos = [2_u8; 1024];
787 let mut res = [[0_i32; 16]; 16];
788 let mut config = __tilecfg::default();
789 config.palette = 1;
790 (0..=2).for_each(|i| {
791 config.colsb[i] = 64;
792 config.rows[i] = 16;
793 });
794 _tile_loadconfig(config.as_ptr());
795 _tile_zero::<0>();
796 _tile_loadd::<1>(&ones as *const u8, 64);
797 _tile_loadd::<2>(&twos as *const u8, 64);
798 _tile_dpbuud::<0, 1, 2>();
799 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
800 _tile_release();
801 assert_eq!(res, [[128_i32; 16]; 16]);
802 }
803 }
804
805 #[simd_test(enable = "amx-fp16")]
806 fn test_tile_dpfp16ps() {
807 unsafe {
808 _init_amx();
809 let ones = [1f16; 512];
810 let twos = [2f16; 512];
811 let mut res = [[0f32; 16]; 16];
812 let mut config = __tilecfg::default();
813 config.palette = 1;
814 (0..=2).for_each(|i| {
815 config.colsb[i] = 64;
816 config.rows[i] = 16;
817 });
818 _tile_loadconfig(config.as_ptr());
819 _tile_zero::<0>();
820 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
821 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
822 _tile_dpfp16ps::<0, 1, 2>();
823 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
824 _tile_release();
825 assert_eq!(res, [[64f32; 16]; 16]);
826 }
827 }
828
829 #[simd_test(enable = "amx-complex")]
830 fn test_tile_cmmimfp16ps() {
831 unsafe {
832 _init_amx();
833 let ones = [1f16; 512];
834 let twos = [2f16; 512];
835 let mut res = [[0f32; 16]; 16];
836 let mut config = __tilecfg::default();
837 config.palette = 1;
838 (0..=2).for_each(|i| {
839 config.colsb[i] = 64;
840 config.rows[i] = 16;
841 });
842 _tile_loadconfig(config.as_ptr());
843 _tile_zero::<0>();
844 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
845 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
846 _tile_cmmimfp16ps::<0, 1, 2>();
847 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
848 _tile_release();
849 assert_eq!(res, [[64f32; 16]; 16]);
850 }
851 }
852
853 #[simd_test(enable = "amx-complex")]
854 fn test_tile_cmmrlfp16ps() {
855 unsafe {
856 _init_amx();
857 let ones = [1f16; 512];
858 let twos = [2f16; 512];
859 let mut res = [[0f32; 16]; 16];
860 let mut config = __tilecfg::default();
861 config.palette = 1;
862 (0..=2).for_each(|i| {
863 config.colsb[i] = 64;
864 config.rows[i] = 16;
865 });
866 _tile_loadconfig(config.as_ptr());
867 _tile_zero::<0>();
868 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
869 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
870 _tile_cmmrlfp16ps::<0, 1, 2>();
871 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
872 _tile_release();
873 assert_eq!(res, [[0f32; 16]; 16]);
874 }
875 }
876
877 const BF8_ONE: u8 = 0x3c;
878 const BF8_TWO: u8 = 0x40;
879 const HF8_ONE: u8 = 0x38;
880 const HF8_TWO: u8 = 0x40;
881
882 #[simd_test(enable = "amx-fp8")]
883 fn test_tile_dpbf8ps() {
884 unsafe {
885 _init_amx();
886 let ones = [BF8_ONE; 1024];
887 let twos = [BF8_TWO; 1024];
888 let mut res = [[0.0_f32; 16]; 16];
889 let mut config = __tilecfg::default();
890 config.palette = 1;
891 (0..=2).for_each(|i| {
892 config.colsb[i] = 64;
893 config.rows[i] = 16;
894 });
895 _tile_loadconfig(config.as_ptr());
896 _tile_zero::<0>();
897 _tile_loadd::<1>(&ones as *const u8, 64);
898 _tile_loadd::<2>(&twos as *const u8, 64);
899 _tile_dpbf8ps::<0, 1, 2>();
900 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
901 _tile_release();
902 assert_eq!(res, [[128.0_f32; 16]; 16]);
903 }
904 }
905
906 #[simd_test(enable = "amx-fp8")]
907 fn test_tile_dpbhf8ps() {
908 unsafe {
909 _init_amx();
910 let ones = [BF8_ONE; 1024];
911 let twos = [HF8_TWO; 1024];
912 let mut res = [[0.0_f32; 16]; 16];
913 let mut config = __tilecfg::default();
914 config.palette = 1;
915 (0..=2).for_each(|i| {
916 config.colsb[i] = 64;
917 config.rows[i] = 16;
918 });
919 _tile_loadconfig(config.as_ptr());
920 _tile_zero::<0>();
921 _tile_loadd::<1>(&ones as *const u8, 64);
922 _tile_loadd::<2>(&twos as *const u8, 64);
923 _tile_dpbhf8ps::<0, 1, 2>();
924 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
925 _tile_release();
926 assert_eq!(res, [[128.0_f32; 16]; 16]);
927 }
928 }
929
930 #[simd_test(enable = "amx-fp8")]
931 fn test_tile_dphbf8ps() {
932 unsafe {
933 _init_amx();
934 let ones = [HF8_ONE; 1024];
935 let twos = [BF8_TWO; 1024];
936 let mut res = [[0.0_f32; 16]; 16];
937 let mut config = __tilecfg::default();
938 config.palette = 1;
939 (0..=2).for_each(|i| {
940 config.colsb[i] = 64;
941 config.rows[i] = 16;
942 });
943 _tile_loadconfig(config.as_ptr());
944 _tile_zero::<0>();
945 _tile_loadd::<1>(&ones as *const u8, 64);
946 _tile_loadd::<2>(&twos as *const u8, 64);
947 _tile_dphbf8ps::<0, 1, 2>();
948 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
949 _tile_release();
950 assert_eq!(res, [[128.0_f32; 16]; 16]);
951 }
952 }
953
954 #[simd_test(enable = "amx-fp8")]
955 fn test_tile_dphf8ps() {
956 unsafe {
957 _init_amx();
958 let ones = [HF8_ONE; 1024];
959 let twos = [HF8_TWO; 1024];
960 let mut res = [[0.0_f32; 16]; 16];
961 let mut config = __tilecfg::default();
962 config.palette = 1;
963 (0..=2).for_each(|i| {
964 config.colsb[i] = 64;
965 config.rows[i] = 16;
966 });
967 _tile_loadconfig(config.as_ptr());
968 _tile_zero::<0>();
969 _tile_loadd::<1>(&ones as *const u8, 64);
970 _tile_loadd::<2>(&twos as *const u8, 64);
971 _tile_dphf8ps::<0, 1, 2>();
972 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
973 _tile_release();
974 assert_eq!(res, [[128.0_f32; 16]; 16]);
975 }
976 }
977
978 #[simd_test(enable = "amx-movrs")]
979 fn test_tile_loaddrs() {
980 unsafe {
981 _init_amx();
982 let mut config = __tilecfg::default();
983 config.palette = 1;
984 config.colsb[0] = 64;
985 config.rows[0] = 16;
986 _tile_loadconfig(config.as_ptr());
987 _tile_zero::<0>();
988 let mat = [1_i8; 1024];
989 _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
990 let mut out = [[0_i8; 64]; 16];
991 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
992 _tile_release();
993 assert_eq!(out, [[1; 64]; 16]);
994 }
995 }
996
997 #[simd_test(enable = "amx-movrs")]
998 fn test_tile_stream_loaddrs() {
999 unsafe {
1000 _init_amx();
1001 let mut config = __tilecfg::default();
1002 config.palette = 1;
1003 config.colsb[0] = 64;
1004 config.rows[0] = 16;
1005 _tile_loadconfig(config.as_ptr());
1006 _tile_zero::<0>();
1007 let mat = [1_i8; 1024];
1008 _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1009 let mut out = [[0_i8; 64]; 16];
1010 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1011 _tile_release();
1012 assert_eq!(out, [[1; 64]; 16]);
1013 }
1014 }
1015
1016 #[simd_test(enable = "amx-avx512,avx10.2")]
1017 fn test_tile_movrow() {
1018 unsafe {
1019 _init_amx();
1020 let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1021
1022 let mut config = __tilecfg::default();
1023 config.palette = 1;
1024 config.colsb[0] = 64;
1025 config.rows[0] = 16;
1026 _tile_loadconfig(config.as_ptr());
1027 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1028 for i in 0..16 {
1029 let row = _tile_movrow::<0>(i);
1030 assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1031 }
1032 }
1033 }
1034
1035 #[simd_test(enable = "amx-avx512,avx10.2")]
1036 fn test_tile_cvtrowd2ps() {
1037 unsafe {
1038 _init_amx();
1039 let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1040
1041 let mut config = __tilecfg::default();
1042 config.palette = 1;
1043 config.colsb[0] = 64;
1044 config.rows[0] = 16;
1045 _tile_loadconfig(config.as_ptr());
1046 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1047 for i in 0..16 {
1048 let row = _tile_cvtrowd2ps::<0>(i);
1049 assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1050 }
1051 }
1052 }
1053
1054 #[simd_test(enable = "amx-avx512,avx10.2")]
1055 fn test_tile_cvtrowps2phh() {
1056 unsafe {
1057 _init_amx();
1058 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1059
1060 let mut config = __tilecfg::default();
1061 config.palette = 1;
1062 config.colsb[0] = 64;
1063 config.rows[0] = 16;
1064 _tile_loadconfig(config.as_ptr());
1065 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1066 for i in 0..16 {
1067 let row = _tile_cvtrowps2phh::<0>(i);
1068 assert_eq!(
1069 *row.as_f16x32().as_array(),
1070 array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1071 );
1072 }
1073 }
1074 }
1075
1076 #[simd_test(enable = "amx-avx512,avx10.2")]
1077 fn test_tile_cvtrowps2phl() {
1078 unsafe {
1079 _init_amx();
1080 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1081
1082 let mut config = __tilecfg::default();
1083 config.palette = 1;
1084 config.colsb[0] = 64;
1085 config.rows[0] = 16;
1086 _tile_loadconfig(config.as_ptr());
1087 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1088 for i in 0..16 {
1089 let row = _tile_cvtrowps2phl::<0>(i);
1090 assert_eq!(
1091 *row.as_f16x32().as_array(),
1092 array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1093 );
1094 }
1095 }
1096 }
1097
1098 #[simd_test(enable = "amx-tf32")]
1099 fn test_tile_mmultf32ps() {
1100 unsafe {
1101 _init_amx();
1102 let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1103 let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1104 let mut res = [[0.0; 16]; 16];
1105
1106 let mut config = __tilecfg::default();
1107 config.palette = 1;
1108 (0..=2).for_each(|i| {
1109 config.colsb[i] = 64;
1110 config.rows[i] = 16;
1111 });
1112 _tile_loadconfig(config.as_ptr());
1113 _tile_zero::<0>();
1114 _tile_loadd::<1>(a.as_ptr().cast(), 64);
1115 _tile_loadd::<2>(b.as_ptr().cast(), 64);
1116 _tile_mmultf32ps::<0, 1, 2>();
1117 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1118 _tile_release();
1119
1120 let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1121 assert_eq!(res, expected);
1122 }
1123 }
1124}
1125