1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6/// Load tile configuration from a 64-byte memory location specified by mem_addr.
7/// The tile configuration format is specified below, and includes the tile type pallette,
8/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
9/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
10/// Any invalid configurations will result in #GP fault.
11///
12/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
13#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18 ldtilecfg(mem_addr);
19}
20
21/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
22/// The tile configuration format is specified below, and includes the tile type pallette,
23/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
24///
25/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
26#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31 sttilecfg(mem_addr);
32}
33
34/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
35///
36/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
37#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43 static_assert_uimm_bits!(DST, 3);
44 tileloadd64(DST as i8, base, stride);
45}
46
47/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
48///
49/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
50#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55 tilerelease();
56}
57
58/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
59///
60/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
61#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67 static_assert_uimm_bits!(DST, 3);
68 tilestored64(DST as i8, base, stride);
69}
70
71/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
72/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
73/// likely not be reused in the near future and the data caching can be optimized accordingly.
74///
75/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
76#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82 static_assert_uimm_bits!(DST, 3);
83 tileloaddt164(DST as i8, base, stride);
84}
85
86/// Zero the tile specified by tdest.
87///
88/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
89#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95 static_assert_uimm_bits!(DST, 3);
96 tilezero(DST as i8);
97}
98
99/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
100/// accumulating the intermediate single-precision (32-bit) floating-point elements
101/// with elements in dst, and store the 32-bit result back to tile dst.
102///
103/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
104#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110 static_assert_uimm_bits!(DST, 3);
111 static_assert_uimm_bits!(A, 3);
112 static_assert_uimm_bits!(B, 3);
113 tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116/// Compute dot-product of bytes in tiles with a source/destination accumulator.
117/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
118/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
119/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
120///
121/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
122#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128 static_assert_uimm_bits!(DST, 3);
129 static_assert_uimm_bits!(A, 3);
130 static_assert_uimm_bits!(B, 3);
131 tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134/// Compute dot-product of bytes in tiles with a source/destination accumulator.
135/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
136/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
137/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
138///
139/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
140#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146 static_assert_uimm_bits!(DST, 3);
147 static_assert_uimm_bits!(A, 3);
148 static_assert_uimm_bits!(B, 3);
149 tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152/// Compute dot-product of bytes in tiles with a source/destination accumulator.
153/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
154/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
155/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
156///
157/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
158#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164 static_assert_uimm_bits!(DST, 3);
165 static_assert_uimm_bits!(A, 3);
166 static_assert_uimm_bits!(B, 3);
167 tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170/// Compute dot-product of bytes in tiles with a source/destination accumulator.
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
172/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
173/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
174///
175/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
176#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182 static_assert_uimm_bits!(DST, 3);
183 static_assert_uimm_bits!(A, 3);
184 static_assert_uimm_bits!(B, 3);
185 tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
189/// accumulating the intermediate single-precision (32-bit) floating-point elements
190/// with elements in dst, and store the 32-bit result back to tile dst.
191///
192/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
193#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199 static_assert_uimm_bits!(DST, 3);
200 static_assert_uimm_bits!(A, 3);
201 static_assert_uimm_bits!(B, 3);
202 tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
206/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
207/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
208/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
209/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
210/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
211/// and then accumulated into the corresponding row and column of dst.
212///
213/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
214#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220 static_assert_uimm_bits!(DST, 3);
221 static_assert_uimm_bits!(A, 3);
222 static_assert_uimm_bits!(B, 3);
223 tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
227/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
228/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
229/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
230/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
231/// the a element is multiplied with the imaginary part of the corresponding b elements.
232/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
233///
234/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
235#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241 static_assert_uimm_bits!(DST, 3);
242 static_assert_uimm_bits!(A, 3);
243 static_assert_uimm_bits!(B, 3);
244 tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248/// floating-point elements in tile b, accumulating the intermediate single-precision
249/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250/// back to tile dst.
251#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255 all(test, any(target_os = "linux", target_env = "msvc")),
256 assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260 static_assert_uimm_bits!(DST, 3);
261 static_assert_uimm_bits!(A, 3);
262 static_assert_uimm_bits!(B, 3);
263 tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269/// back to tile dst.
270#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274 all(test, any(target_os = "linux", target_env = "msvc")),
275 assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279 static_assert_uimm_bits!(DST, 3);
280 static_assert_uimm_bits!(A, 3);
281 static_assert_uimm_bits!(B, 3);
282 tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288/// back to tile dst.
289#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293 all(test, any(target_os = "linux", target_env = "msvc")),
294 assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298 static_assert_uimm_bits!(DST, 3);
299 static_assert_uimm_bits!(A, 3);
300 static_assert_uimm_bits!(B, 3);
301 tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305/// floating-point elements in tile b, accumulating the intermediate single-precision
306/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307/// back to tile dst.
308#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312 all(test, any(target_os = "linux", target_env = "msvc")),
313 assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317 static_assert_uimm_bits!(DST, 3);
318 static_assert_uimm_bits!(A, 3);
319 static_assert_uimm_bits!(B, 3);
320 tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323/// Load tile rows from memory specified by base address and stride into destination tile dst
324/// using the tile configuration previously configured via _tile_loadconfig.
325/// Additionally, this intrinsic indicates the source memory location is likely to become
326/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327/// before it is written, assuming it is ever written in the future.
328#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332 all(test, any(target_os = "linux", target_env = "msvc")),
333 assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337 static_assert_uimm_bits!(DST, 3);
338 tileloaddrs64(DST as i8, base, stride);
339}
340
341/// Load tile rows from memory specified by base address and stride into destination tile dst
342/// using the tile configuration previously configured via _tile_loadconfig.
343/// Provides a hint to the implementation that the data would be reused but does not need
344/// to be resident in the nearest cache levels.
345/// Additionally, this intrinsic indicates the source memory location is likely to become
346/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347/// before it is written, assuming it is ever written in the future.
348#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352 all(test, any(target_os = "linux", target_env = "msvc")),
353 assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357 static_assert_uimm_bits!(DST, 3);
358 tileloaddrst164(DST as i8, base, stride);
359}
360
361/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363/// results into a packed single precision tile.
364/// For each possible combination of (row of a, column of b), it performs
365/// - convert to TF32
366/// - multiply the corresponding elements of a and b
367/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368/// rounding mode.
369/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370/// handled and *not* treated as zero.
371#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375 all(test, any(target_os = "linux", target_env = "msvc")),
376 assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380 static_assert_uimm_bits!(DST, 3);
381 static_assert_uimm_bits!(A, 3);
382 static_assert_uimm_bits!(B, 3);
383 tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387/// elements to packed single-precision (32-bit) floating-point elements.
388#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392 all(test, any(target_os = "linux", target_env = "msvc")),
393 assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397 static_assert_uimm_bits!(TILE, 3);
398 tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
402/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
403/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
404#[inline]
405#[rustc_legacy_const_generics(0)]
406#[target_feature(enable = "amx-avx512,avx10.2")]
407#[cfg_attr(
408 all(test, any(target_os = "linux", target_env = "msvc")),
409 assert_instr(tcvtrowps2phh, TILE = 0)
410)]
411#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
412pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
413 static_assert_uimm_bits!(TILE, 3);
414 tcvtrowps2phh(TILE as i8, row).as_m512h()
415}
416
417/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
420#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424 all(test, any(target_os = "linux", target_env = "msvc")),
425 assert_instr(tcvtrowps2phl, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
429 static_assert_uimm_bits!(TILE, 3);
430 tcvtrowps2phl(TILE as i8, row).as_m512h()
431}
432
433/// Moves one row of tile data into a zmm vector register
434#[inline]
435#[rustc_legacy_const_generics(0)]
436#[target_feature(enable = "amx-avx512,avx10.2")]
437#[cfg_attr(
438 all(test, any(target_os = "linux", target_env = "msvc")),
439 assert_instr(tilemovrow, TILE = 0)
440)]
441#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
442pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
443 static_assert_uimm_bits!(TILE, 3);
444 tilemovrow(TILE as i8, row).as_m512i()
445}
446
447#[allow(improper_ctypes)]
448unsafe extern "C" {
449 #[link_name = "llvm.x86.ldtilecfg"]
450 unsafefn ldtilecfg(mem_addr: *const u8);
451 #[link_name = "llvm.x86.sttilecfg"]
452 unsafefn sttilecfg(mem_addr: *mut u8);
453 #[link_name = "llvm.x86.tileloadd64"]
454 unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize);
455 #[link_name = "llvm.x86.tileloaddt164"]
456 unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize);
457 #[link_name = "llvm.x86.tilerelease"]
458 unsafefn tilerelease();
459 #[link_name = "llvm.x86.tilestored64"]
460 unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize);
461 #[link_name = "llvm.x86.tilezero"]
462 unsafefn tilezero(dst: i8);
463 #[link_name = "llvm.x86.tdpbf16ps"]
464 unsafefn tdpbf16ps(dst: i8, a: i8, b: i8);
465 #[link_name = "llvm.x86.tdpbuud"]
466 unsafefn tdpbuud(dst: i8, a: i8, b: i8);
467 #[link_name = "llvm.x86.tdpbusd"]
468 unsafefn tdpbusd(dst: i8, a: i8, b: i8);
469 #[link_name = "llvm.x86.tdpbsud"]
470 unsafefn tdpbsud(dst: i8, a: i8, b: i8);
471 #[link_name = "llvm.x86.tdpbssd"]
472 unsafefn tdpbssd(dst: i8, a: i8, b: i8);
473 #[link_name = "llvm.x86.tdpfp16ps"]
474 unsafefn tdpfp16ps(dst: i8, a: i8, b: i8);
475 #[link_name = "llvm.x86.tcmmimfp16ps"]
476 unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8);
477 #[link_name = "llvm.x86.tcmmrlfp16ps"]
478 unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
479 #[link_name = "llvm.x86.tdpbf8ps"]
480 unsafefn tdpbf8ps(dst: i8, a: i8, b: i8);
481 #[link_name = "llvm.x86.tdpbhf8ps"]
482 unsafefn tdpbhf8ps(dst: i8, a: i8, b: i8);
483 #[link_name = "llvm.x86.tdphbf8ps"]
484 unsafefn tdphbf8ps(dst: i8, a: i8, b: i8);
485 #[link_name = "llvm.x86.tdphf8ps"]
486 unsafefn tdphf8ps(dst: i8, a: i8, b: i8);
487 #[link_name = "llvm.x86.tileloaddrs64"]
488 unsafefn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
489 #[link_name = "llvm.x86.tileloaddrst164"]
490 unsafefn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
491 #[link_name = "llvm.x86.tmmultf32ps"]
492 unsafefn tmmultf32ps(dst: i8, a: i8, b: i8);
493 #[link_name = "llvm.x86.tcvtrowd2ps"]
494 unsafefn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
495 #[link_name = "llvm.x86.tcvtrowps2phh"]
496 unsafefn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
497 #[link_name = "llvm.x86.tcvtrowps2phl"]
498 unsafefn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
499 #[link_name = "llvm.x86.tilemovrow"]
500 unsafefn tilemovrow(tile: i8, row: u32) -> i32x16;
501}
502
503#[cfg(test)]
504mod tests {
505 use crate::core_arch::x86::_mm_cvtness_sbh;
506 use crate::core_arch::x86_64::*;
507 use core::{array, mem::transmute};
508 use stdarch_test::simd_test;
509 #[cfg(target_os = "linux")]
510 use syscalls::{Sysno, syscall};
511
512 #[allow(non_camel_case_types)]
513 #[repr(packed)]
514 #[derive(Copy, Clone, Default, Debug, PartialEq)]
515 struct __tilecfg {
516 /// 0 `or` 1
517 palette: u8,
518 start_row: u8,
519 /// reserved, must be zero
520 reserved_a0: [u8; 14],
521 /// number of bytes of one row in each tile
522 colsb: [u16; 8],
523 /// reserved, must be zero
524 reserved_b0: [u16; 8],
525 /// number of rows in each tile
526 rows: [u8; 8],
527 /// reserved, must be zero
528 reserved_c0: [u8; 8],
529 }
530
531 impl __tilecfg {
532 fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
533 Self {
534 palette,
535 start_row,
536 reserved_a0: [0u8; 14],
537 colsb,
538 reserved_b0: [0u16; 8],
539 rows,
540 reserved_c0: [0u8; 8],
541 }
542 }
543
544 const fn as_ptr(&self) -> *const u8 {
545 self as *const Self as *const u8
546 }
547
548 fn as_mut_ptr(&mut self) -> *mut u8 {
549 self as *mut Self as *mut u8
550 }
551 }
552
553 #[cfg(not(target_os = "linux"))]
554 #[target_feature(enable = "amx-tile")]
555 fn _init_amx() {}
556
557 #[cfg(target_os = "linux")]
558 #[target_feature(enable = "amx-tile")]
559 #[inline]
560 unsafe fn _init_amx() {
561 let mut ret: usize;
562 let mut xfeatures: usize = 0;
563 ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
564 .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
565 if ret != 0 {
566 panic!("Failed to get XFEATURES");
567 } else {
568 match 0b11 & (xfeatures >> 17) {
569 0 => panic!("AMX is not available"),
570 1 => {
571 ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
572 .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
573 if ret != 0 {
574 panic!("Failed to enable AMX");
575 }
576 }
577 3 => {}
578 _ => unreachable!(),
579 }
580 }
581 }
582
583 #[simd_test(enable = "amx-tile")]
584 unsafe fn test_tile_loadconfig() {
585 let config = __tilecfg::default();
586 _tile_loadconfig(config.as_ptr());
587 _tile_release();
588 }
589
590 #[simd_test(enable = "amx-tile")]
591 unsafe fn test_tile_storeconfig() {
592 let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
593 _tile_loadconfig(config.as_ptr());
594 let mut _config = __tilecfg::default();
595 _tile_storeconfig(_config.as_mut_ptr());
596 _tile_release();
597 assert_eq!(config, _config);
598 }
599
600 #[simd_test(enable = "amx-tile")]
601 unsafe fn test_tile_zero() {
602 _init_amx();
603 let mut config = __tilecfg::default();
604 config.palette = 1;
605 config.colsb[0] = 64;
606 config.rows[0] = 16;
607 _tile_loadconfig(config.as_ptr());
608 _tile_zero::<0>();
609 let mut out = [[1_i8; 64]; 16];
610 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
611 _tile_release();
612 assert_eq!(out, [[0; 64]; 16]);
613 }
614
615 #[simd_test(enable = "amx-tile")]
616 unsafe fn test_tile_stored() {
617 _init_amx();
618 let mut config = __tilecfg::default();
619 config.palette = 1;
620 config.colsb[0] = 64;
621 config.rows[0] = 16;
622 _tile_loadconfig(config.as_ptr());
623 _tile_zero::<0>();
624 let mut out = [[1_i8; 64]; 16];
625 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
626 _tile_release();
627 assert_eq!(out, [[0; 64]; 16]);
628 }
629
630 #[simd_test(enable = "amx-tile")]
631 unsafe fn test_tile_loadd() {
632 _init_amx();
633 let mut config = __tilecfg::default();
634 config.palette = 1;
635 config.colsb[0] = 64;
636 config.rows[0] = 16;
637 _tile_loadconfig(config.as_ptr());
638 _tile_zero::<0>();
639 let mat = [1_i8; 1024];
640 _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
641 let mut out = [[0_i8; 64]; 16];
642 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
643 _tile_release();
644 assert_eq!(out, [[1; 64]; 16]);
645 }
646
647 #[simd_test(enable = "amx-tile")]
648 unsafe fn test_tile_stream_loadd() {
649 _init_amx();
650 let mut config = __tilecfg::default();
651 config.palette = 1;
652 config.colsb[0] = 64;
653 config.rows[0] = 16;
654 _tile_loadconfig(config.as_ptr());
655 _tile_zero::<0>();
656 let mat = [1_i8; 1024];
657 _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
658 let mut out = [[0_i8; 64]; 16];
659 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
660 _tile_release();
661 assert_eq!(out, [[1; 64]; 16]);
662 }
663
664 #[simd_test(enable = "amx-tile")]
665 unsafe fn test_tile_release() {
666 _tile_release();
667 }
668
669 #[simd_test(enable = "amx-bf16,avx512f")]
670 unsafe fn test_tile_dpbf16ps() {
671 _init_amx();
672 let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
673 let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
674 let ones: [u8; 1024] = transmute([bf16_1; 512]);
675 let twos: [u8; 1024] = transmute([bf16_2; 512]);
676 let mut res = [[0f32; 16]; 16];
677 let mut config = __tilecfg::default();
678 config.palette = 1;
679 (0..=2).for_each(|i| {
680 config.colsb[i] = 64;
681 config.rows[i] = 16;
682 });
683 _tile_loadconfig(config.as_ptr());
684 _tile_zero::<0>();
685 _tile_loadd::<1>(&ones as *const u8, 64);
686 _tile_loadd::<2>(&twos as *const u8, 64);
687 _tile_dpbf16ps::<0, 1, 2>();
688 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
689 _tile_release();
690 assert_eq!(res, [[64f32; 16]; 16]);
691 }
692
693 #[simd_test(enable = "amx-int8")]
694 unsafe fn test_tile_dpbssd() {
695 _init_amx();
696 let ones = [-1_i8; 1024];
697 let twos = [-2_i8; 1024];
698 let mut res = [[0_i32; 16]; 16];
699 let mut config = __tilecfg::default();
700 config.palette = 1;
701 (0..=2).for_each(|i| {
702 config.colsb[i] = 64;
703 config.rows[i] = 16;
704 });
705 _tile_loadconfig(config.as_ptr());
706 _tile_zero::<0>();
707 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
708 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
709 _tile_dpbssd::<0, 1, 2>();
710 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
711 _tile_release();
712 assert_eq!(res, [[128_i32; 16]; 16]);
713 }
714
715 #[simd_test(enable = "amx-int8")]
716 unsafe fn test_tile_dpbsud() {
717 _init_amx();
718 let ones = [-1_i8; 1024];
719 let twos = [2_u8; 1024];
720 let mut res = [[0_i32; 16]; 16];
721 let mut config = __tilecfg::default();
722 config.palette = 1;
723 (0..=2).for_each(|i| {
724 config.colsb[i] = 64;
725 config.rows[i] = 16;
726 });
727 _tile_loadconfig(config.as_ptr());
728 _tile_zero::<0>();
729 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
730 _tile_loadd::<2>(&twos as *const u8, 64);
731 _tile_dpbsud::<0, 1, 2>();
732 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
733 _tile_release();
734 assert_eq!(res, [[-128_i32; 16]; 16]);
735 }
736
737 #[simd_test(enable = "amx-int8")]
738 unsafe fn test_tile_dpbusd() {
739 _init_amx();
740 let ones = [1_u8; 1024];
741 let twos = [-2_i8; 1024];
742 let mut res = [[0_i32; 16]; 16];
743 let mut config = __tilecfg::default();
744 config.palette = 1;
745 (0..=2).for_each(|i| {
746 config.colsb[i] = 64;
747 config.rows[i] = 16;
748 });
749 _tile_loadconfig(config.as_ptr());
750 _tile_zero::<0>();
751 _tile_loadd::<1>(&ones as *const u8, 64);
752 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
753 _tile_dpbusd::<0, 1, 2>();
754 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
755 _tile_release();
756 assert_eq!(res, [[-128_i32; 16]; 16]);
757 }
758
759 #[simd_test(enable = "amx-int8")]
760 unsafe fn test_tile_dpbuud() {
761 _init_amx();
762 let ones = [1_u8; 1024];
763 let twos = [2_u8; 1024];
764 let mut res = [[0_i32; 16]; 16];
765 let mut config = __tilecfg::default();
766 config.palette = 1;
767 (0..=2).for_each(|i| {
768 config.colsb[i] = 64;
769 config.rows[i] = 16;
770 });
771 _tile_loadconfig(config.as_ptr());
772 _tile_zero::<0>();
773 _tile_loadd::<1>(&ones as *const u8, 64);
774 _tile_loadd::<2>(&twos as *const u8, 64);
775 _tile_dpbuud::<0, 1, 2>();
776 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
777 _tile_release();
778 assert_eq!(res, [[128_i32; 16]; 16]);
779 }
780
781 #[simd_test(enable = "amx-fp16")]
782 unsafe fn test_tile_dpfp16ps() {
783 _init_amx();
784 let ones = [1f16; 512];
785 let twos = [2f16; 512];
786 let mut res = [[0f32; 16]; 16];
787 let mut config = __tilecfg::default();
788 config.palette = 1;
789 (0..=2).for_each(|i| {
790 config.colsb[i] = 64;
791 config.rows[i] = 16;
792 });
793 _tile_loadconfig(config.as_ptr());
794 _tile_zero::<0>();
795 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
796 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
797 _tile_dpfp16ps::<0, 1, 2>();
798 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
799 _tile_release();
800 assert_eq!(res, [[64f32; 16]; 16]);
801 }
802
803 #[simd_test(enable = "amx-complex")]
804 unsafe fn test_tile_cmmimfp16ps() {
805 _init_amx();
806 let ones = [1f16; 512];
807 let twos = [2f16; 512];
808 let mut res = [[0f32; 16]; 16];
809 let mut config = __tilecfg::default();
810 config.palette = 1;
811 (0..=2).for_each(|i| {
812 config.colsb[i] = 64;
813 config.rows[i] = 16;
814 });
815 _tile_loadconfig(config.as_ptr());
816 _tile_zero::<0>();
817 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
818 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
819 _tile_cmmimfp16ps::<0, 1, 2>();
820 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
821 _tile_release();
822 assert_eq!(res, [[64f32; 16]; 16]);
823 }
824
825 #[simd_test(enable = "amx-complex")]
826 unsafe fn test_tile_cmmrlfp16ps() {
827 _init_amx();
828 let ones = [1f16; 512];
829 let twos = [2f16; 512];
830 let mut res = [[0f32; 16]; 16];
831 let mut config = __tilecfg::default();
832 config.palette = 1;
833 (0..=2).for_each(|i| {
834 config.colsb[i] = 64;
835 config.rows[i] = 16;
836 });
837 _tile_loadconfig(config.as_ptr());
838 _tile_zero::<0>();
839 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
840 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
841 _tile_cmmrlfp16ps::<0, 1, 2>();
842 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
843 _tile_release();
844 assert_eq!(res, [[0f32; 16]; 16]);
845 }
846
847 const BF8_ONE: u8 = 0x3c;
848 const BF8_TWO: u8 = 0x40;
849 const HF8_ONE: u8 = 0x38;
850 const HF8_TWO: u8 = 0x40;
851
852 #[simd_test(enable = "amx-fp8")]
853 unsafe fn test_tile_dpbf8ps() {
854 _init_amx();
855 let ones = [BF8_ONE; 1024];
856 let twos = [BF8_TWO; 1024];
857 let mut res = [[0.0_f32; 16]; 16];
858 let mut config = __tilecfg::default();
859 config.palette = 1;
860 (0..=2).for_each(|i| {
861 config.colsb[i] = 64;
862 config.rows[i] = 16;
863 });
864 _tile_loadconfig(config.as_ptr());
865 _tile_zero::<0>();
866 _tile_loadd::<1>(&ones as *const u8, 64);
867 _tile_loadd::<2>(&twos as *const u8, 64);
868 _tile_dpbf8ps::<0, 1, 2>();
869 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
870 _tile_release();
871 assert_eq!(res, [[128.0_f32; 16]; 16]);
872 }
873
874 #[simd_test(enable = "amx-fp8")]
875 unsafe fn test_tile_dpbhf8ps() {
876 _init_amx();
877 let ones = [BF8_ONE; 1024];
878 let twos = [HF8_TWO; 1024];
879 let mut res = [[0.0_f32; 16]; 16];
880 let mut config = __tilecfg::default();
881 config.palette = 1;
882 (0..=2).for_each(|i| {
883 config.colsb[i] = 64;
884 config.rows[i] = 16;
885 });
886 _tile_loadconfig(config.as_ptr());
887 _tile_zero::<0>();
888 _tile_loadd::<1>(&ones as *const u8, 64);
889 _tile_loadd::<2>(&twos as *const u8, 64);
890 _tile_dpbhf8ps::<0, 1, 2>();
891 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
892 _tile_release();
893 assert_eq!(res, [[128.0_f32; 16]; 16]);
894 }
895
896 #[simd_test(enable = "amx-fp8")]
897 unsafe fn test_tile_dphbf8ps() {
898 _init_amx();
899 let ones = [HF8_ONE; 1024];
900 let twos = [BF8_TWO; 1024];
901 let mut res = [[0.0_f32; 16]; 16];
902 let mut config = __tilecfg::default();
903 config.palette = 1;
904 (0..=2).for_each(|i| {
905 config.colsb[i] = 64;
906 config.rows[i] = 16;
907 });
908 _tile_loadconfig(config.as_ptr());
909 _tile_zero::<0>();
910 _tile_loadd::<1>(&ones as *const u8, 64);
911 _tile_loadd::<2>(&twos as *const u8, 64);
912 _tile_dphbf8ps::<0, 1, 2>();
913 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
914 _tile_release();
915 assert_eq!(res, [[128.0_f32; 16]; 16]);
916 }
917
918 #[simd_test(enable = "amx-fp8")]
919 unsafe fn test_tile_dphf8ps() {
920 _init_amx();
921 let ones = [HF8_ONE; 1024];
922 let twos = [HF8_TWO; 1024];
923 let mut res = [[0.0_f32; 16]; 16];
924 let mut config = __tilecfg::default();
925 config.palette = 1;
926 (0..=2).for_each(|i| {
927 config.colsb[i] = 64;
928 config.rows[i] = 16;
929 });
930 _tile_loadconfig(config.as_ptr());
931 _tile_zero::<0>();
932 _tile_loadd::<1>(&ones as *const u8, 64);
933 _tile_loadd::<2>(&twos as *const u8, 64);
934 _tile_dphf8ps::<0, 1, 2>();
935 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
936 _tile_release();
937 assert_eq!(res, [[128.0_f32; 16]; 16]);
938 }
939
940 #[simd_test(enable = "amx-movrs")]
941 unsafe fn test_tile_loaddrs() {
942 _init_amx();
943 let mut config = __tilecfg::default();
944 config.palette = 1;
945 config.colsb[0] = 64;
946 config.rows[0] = 16;
947 _tile_loadconfig(config.as_ptr());
948 _tile_zero::<0>();
949 let mat = [1_i8; 1024];
950 _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
951 let mut out = [[0_i8; 64]; 16];
952 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
953 _tile_release();
954 assert_eq!(out, [[1; 64]; 16]);
955 }
956
957 #[simd_test(enable = "amx-movrs")]
958 unsafe fn test_tile_stream_loaddrs() {
959 _init_amx();
960 let mut config = __tilecfg::default();
961 config.palette = 1;
962 config.colsb[0] = 64;
963 config.rows[0] = 16;
964 _tile_loadconfig(config.as_ptr());
965 _tile_zero::<0>();
966 let mat = [1_i8; 1024];
967 _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
968 let mut out = [[0_i8; 64]; 16];
969 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
970 _tile_release();
971 assert_eq!(out, [[1; 64]; 16]);
972 }
973
974 #[simd_test(enable = "amx-avx512,avx10.2")]
975 unsafe fn test_tile_movrow() {
976 _init_amx();
977 let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
978
979 let mut config = __tilecfg::default();
980 config.palette = 1;
981 config.colsb[0] = 64;
982 config.rows[0] = 16;
983 _tile_loadconfig(config.as_ptr());
984 _tile_loadd::<0>(array.as_ptr().cast(), 64);
985 for i in 0..16 {
986 let row = _tile_movrow::<0>(i);
987 assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
988 }
989 }
990
991 #[simd_test(enable = "amx-avx512,avx10.2")]
992 unsafe fn test_tile_cvtrowd2ps() {
993 _init_amx();
994 let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
995
996 let mut config = __tilecfg::default();
997 config.palette = 1;
998 config.colsb[0] = 64;
999 config.rows[0] = 16;
1000 _tile_loadconfig(config.as_ptr());
1001 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1002 for i in 0..16 {
1003 let row = _tile_cvtrowd2ps::<0>(i);
1004 assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1005 }
1006 }
1007
1008 #[simd_test(enable = "amx-avx512,avx10.2")]
1009 unsafe fn test_tile_cvtrowps2phh() {
1010 _init_amx();
1011 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1012
1013 let mut config = __tilecfg::default();
1014 config.palette = 1;
1015 config.colsb[0] = 64;
1016 config.rows[0] = 16;
1017 _tile_loadconfig(config.as_ptr());
1018 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1019 for i in 0..16 {
1020 let row = _tile_cvtrowps2phh::<0>(i);
1021 assert_eq!(
1022 *row.as_f16x32().as_array(),
1023 array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1024 );
1025 }
1026 }
1027
1028 #[simd_test(enable = "amx-avx512,avx10.2")]
1029 unsafe fn test_tile_cvtrowps2phl() {
1030 _init_amx();
1031 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1032
1033 let mut config = __tilecfg::default();
1034 config.palette = 1;
1035 config.colsb[0] = 64;
1036 config.rows[0] = 16;
1037 _tile_loadconfig(config.as_ptr());
1038 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1039 for i in 0..16 {
1040 let row = _tile_cvtrowps2phl::<0>(i);
1041 assert_eq!(
1042 *row.as_f16x32().as_array(),
1043 array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1044 );
1045 }
1046 }
1047
1048 #[simd_test(enable = "amx-tf32")]
1049 unsafe fn test_tile_mmultf32ps() {
1050 _init_amx();
1051 let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1052 let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1053 let mut res = [[0.0; 16]; 16];
1054
1055 let mut config = __tilecfg::default();
1056 config.palette = 1;
1057 (0..=2).for_each(|i| {
1058 config.colsb[i] = 64;
1059 config.rows[i] = 16;
1060 });
1061 _tile_loadconfig(config.as_ptr());
1062 _tile_zero::<0>();
1063 _tile_loadd::<1>(a.as_ptr().cast(), 64);
1064 _tile_loadd::<2>(b.as_ptr().cast(), 64);
1065 _tile_mmultf32ps::<0, 1, 2>();
1066 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1067 _tile_release();
1068
1069 let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1070 assert_eq!(res, expected);
1071 }
1072}
1073