1#[cfg(test)]
2use stdarch_test::assert_instr;
3
4/// Load tile configuration from a 64-byte memory location specified by mem_addr.
5/// The tile configuration format is specified below, and includes the tile type pallette,
6/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
7/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
8/// Any invalid configurations will result in #GP fault.
9///
10/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
11#[inline]
12#[target_feature(enable = "amx-tile")]
13#[cfg_attr(test, assert_instr(ldtilecfg))]
14#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
15pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
16 ldtilecfg(mem_addr);
17}
18
19/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
20/// The tile configuration format is specified below, and includes the tile type pallette,
21/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
22///
23/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
24#[inline]
25#[target_feature(enable = "amx-tile")]
26#[cfg_attr(test, assert_instr(sttilecfg))]
27#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
28pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
29 sttilecfg(mem_addr);
30}
31
32/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
33///
34/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
35#[inline]
36#[rustc_legacy_const_generics(0)]
37#[target_feature(enable = "amx-tile")]
38#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
39#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
40pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
41 static_assert_uimm_bits!(DST, 3);
42 tileloadd64(DST as i8, base, stride);
43}
44
45/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
46///
47/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
48#[inline]
49#[target_feature(enable = "amx-tile")]
50#[cfg_attr(test, assert_instr(tilerelease))]
51#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
52pub unsafe fn _tile_release() {
53 tilerelease();
54}
55
56/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
57///
58/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
59#[inline]
60#[rustc_legacy_const_generics(0)]
61#[target_feature(enable = "amx-tile")]
62#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
63#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
64pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
65 static_assert_uimm_bits!(DST, 3);
66 tilestored64(DST as i8, base, stride);
67}
68
69/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
70/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
71/// likely not be reused in the near future and the data caching can be optimized accordingly.
72///
73/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
74#[inline]
75#[rustc_legacy_const_generics(0)]
76#[target_feature(enable = "amx-tile")]
77#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
78#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
79pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
80 static_assert_uimm_bits!(DST, 3);
81 tileloaddt164(DST as i8, base, stride);
82}
83
84/// Zero the tile specified by tdest.
85///
86/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
87#[inline]
88#[rustc_legacy_const_generics(0)]
89#[target_feature(enable = "amx-tile")]
90#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
91#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
92pub unsafe fn _tile_zero<const DST: i32>() {
93 static_assert_uimm_bits!(DST, 3);
94 tilezero(DST as i8);
95}
96
97/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
98/// accumulating the intermediate single-precision (32-bit) floating-point elements
99/// with elements in dst, and store the 32-bit result back to tile dst.
100///
101/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
102#[inline]
103#[rustc_legacy_const_generics(0, 1, 2)]
104#[target_feature(enable = "amx-bf16")]
105#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
106#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
107pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
108 static_assert_uimm_bits!(DST, 3);
109 static_assert_uimm_bits!(A, 3);
110 static_assert_uimm_bits!(B, 3);
111 tdpbf16ps(DST as i8, A as i8, B as i8);
112}
113
114/// Compute dot-product of bytes in tiles with a source/destination accumulator.
115/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
116/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
117/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
118///
119/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
120#[inline]
121#[rustc_legacy_const_generics(0, 1, 2)]
122#[target_feature(enable = "amx-int8")]
123#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
124#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
125pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
126 static_assert_uimm_bits!(DST, 3);
127 static_assert_uimm_bits!(A, 3);
128 static_assert_uimm_bits!(B, 3);
129 tdpbssd(DST as i8, A as i8, B as i8);
130}
131
132/// Compute dot-product of bytes in tiles with a source/destination accumulator.
133/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
134/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
135/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
136///
137/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
138#[inline]
139#[rustc_legacy_const_generics(0, 1, 2)]
140#[target_feature(enable = "amx-int8")]
141#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
142#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
143pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
144 static_assert_uimm_bits!(DST, 3);
145 static_assert_uimm_bits!(A, 3);
146 static_assert_uimm_bits!(B, 3);
147 tdpbsud(DST as i8, A as i8, B as i8);
148}
149
150/// Compute dot-product of bytes in tiles with a source/destination accumulator.
151/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
152/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
153/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
154///
155/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
156#[inline]
157#[rustc_legacy_const_generics(0, 1, 2)]
158#[target_feature(enable = "amx-int8")]
159#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
160#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
161pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
162 static_assert_uimm_bits!(DST, 3);
163 static_assert_uimm_bits!(A, 3);
164 static_assert_uimm_bits!(B, 3);
165 tdpbusd(DST as i8, A as i8, B as i8);
166}
167
168/// Compute dot-product of bytes in tiles with a source/destination accumulator.
169/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
170/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
171/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
172///
173/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
174#[inline]
175#[rustc_legacy_const_generics(0, 1, 2)]
176#[target_feature(enable = "amx-int8")]
177#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
178#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
179pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
180 static_assert_uimm_bits!(DST, 3);
181 static_assert_uimm_bits!(A, 3);
182 static_assert_uimm_bits!(B, 3);
183 tdpbuud(DST as i8, A as i8, B as i8);
184}
185
186/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
187/// accumulating the intermediate single-precision (32-bit) floating-point elements
188/// with elements in dst, and store the 32-bit result back to tile dst.
189///
190/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
191#[inline]
192#[rustc_legacy_const_generics(0, 1, 2)]
193#[target_feature(enable = "amx-fp16")]
194#[cfg_attr(
195 all(test, any(target_os = "linux", target_env = "msvc")),
196 assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2)
197)]
198#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
199pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
200 static_assert_uimm_bits!(DST, 3);
201 static_assert_uimm_bits!(A, 3);
202 static_assert_uimm_bits!(B, 3);
203 tdpfp16ps(DST as i8, A as i8, B as i8);
204}
205
206/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
207/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
208/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
209/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
210/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
211/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
212/// and then accumulated into the corresponding row and column of dst.
213///
214/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
215#[inline]
216#[rustc_legacy_const_generics(0, 1, 2)]
217#[target_feature(enable = "amx-complex")]
218#[cfg_attr(
219 all(test, any(target_os = "linux", target_env = "msvc")),
220 assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2)
221)]
222#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
223pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
224 static_assert_uimm_bits!(DST, 3);
225 static_assert_uimm_bits!(A, 3);
226 static_assert_uimm_bits!(B, 3);
227 tcmmimfp16ps(DST as i8, A as i8, B as i8);
228}
229
230/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
231/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
232/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
233/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
234/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
235/// the a element is multiplied with the imaginary part of the corresponding b elements.
236/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
237///
238/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
239#[inline]
240#[rustc_legacy_const_generics(0, 1, 2)]
241#[target_feature(enable = "amx-complex")]
242#[cfg_attr(
243 all(test, any(target_os = "linux", target_env = "msvc")),
244 assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2)
245)]
246#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
247pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
248 static_assert_uimm_bits!(DST, 3);
249 static_assert_uimm_bits!(A, 3);
250 static_assert_uimm_bits!(B, 3);
251 tcmmrlfp16ps(DST as i8, A as i8, B as i8);
252}
253
254#[allow(improper_ctypes)]
255unsafe extern "C" {
256 #[link_name = "llvm.x86.ldtilecfg"]
257 unsafefn ldtilecfg(mem_addr: *const u8);
258 #[link_name = "llvm.x86.sttilecfg"]
259 unsafefn sttilecfg(mem_addr: *mut u8);
260 #[link_name = "llvm.x86.tileloadd64"]
261 unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize);
262 #[link_name = "llvm.x86.tileloaddt164"]
263 unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize);
264 #[link_name = "llvm.x86.tilerelease"]
265 unsafefn tilerelease();
266 #[link_name = "llvm.x86.tilestored64"]
267 unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize);
268 #[link_name = "llvm.x86.tilezero"]
269 unsafefn tilezero(dst: i8);
270 #[link_name = "llvm.x86.tdpbf16ps"]
271 unsafefn tdpbf16ps(dst: i8, a: i8, b: i8);
272 #[link_name = "llvm.x86.tdpbuud"]
273 unsafefn tdpbuud(dst: i8, a: i8, b: i8);
274 #[link_name = "llvm.x86.tdpbusd"]
275 unsafefn tdpbusd(dst: i8, a: i8, b: i8);
276 #[link_name = "llvm.x86.tdpbsud"]
277 unsafefn tdpbsud(dst: i8, a: i8, b: i8);
278 #[link_name = "llvm.x86.tdpbssd"]
279 unsafefn tdpbssd(dst: i8, a: i8, b: i8);
280 #[link_name = "llvm.x86.tdpfp16ps"]
281 unsafefn tdpfp16ps(dst: i8, a: i8, b: i8);
282 #[link_name = "llvm.x86.tcmmimfp16ps"]
283 unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8);
284 #[link_name = "llvm.x86.tcmmrlfp16ps"]
285 unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
286}
287
288#[cfg(test)]
289mod tests {
290 use crate::core_arch::x86::_mm_cvtness_sbh;
291 use crate::core_arch::x86_64::*;
292 use core::mem::transmute;
293 use stdarch_test::simd_test;
294 #[cfg(target_os = "linux")]
295 use syscalls::{Sysno, syscall};
296
297 #[allow(non_camel_case_types)]
298 #[repr(packed)]
299 #[derive(Copy, Clone, Default, Debug, PartialEq)]
300 struct __tilecfg {
301 /// 0 `or` 1
302 palette: u8,
303 start_row: u8,
304 /// reserved, must be zero
305 reserved_a0: [u8; 14],
306 /// number of bytes of one row in each tile
307 colsb: [u16; 8],
308 /// reserved, must be zero
309 reserved_b0: [u16; 8],
310 /// number of rows in each tile
311 rows: [u8; 8],
312 /// reserved, must be zero
313 reserved_c0: [u8; 8],
314 }
315
316 impl __tilecfg {
317 fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
318 Self {
319 palette,
320 start_row,
321 reserved_a0: [0u8; 14],
322 colsb,
323 reserved_b0: [0u16; 8],
324 rows,
325 reserved_c0: [0u8; 8],
326 }
327 }
328
329 const fn as_ptr(&self) -> *const u8 {
330 self as *const Self as *const u8
331 }
332
333 fn as_mut_ptr(&mut self) -> *mut u8 {
334 self as *mut Self as *mut u8
335 }
336 }
337
338 #[cfg(not(target_os = "linux"))]
339 #[target_feature(enable = "amx-tile")]
340 fn _init_amx() {}
341
342 #[cfg(target_os = "linux")]
343 #[target_feature(enable = "amx-tile")]
344 #[inline]
345 unsafe fn _init_amx() {
346 let mut ret: usize;
347 let mut xfeatures: usize = 0;
348 ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
349 .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
350 if ret != 0 {
351 panic!("Failed to get XFEATURES");
352 } else {
353 match 0b11 & (xfeatures >> 17) {
354 0 => panic!("AMX is not available"),
355 1 => {
356 ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
357 .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
358 if ret != 0 {
359 panic!("Failed to enable AMX");
360 }
361 }
362 3 => {}
363 _ => unreachable!(),
364 }
365 }
366 }
367
368 #[simd_test(enable = "amx-tile")]
369 unsafe fn test_tile_loadconfig() {
370 let config = __tilecfg::default();
371 _tile_loadconfig(config.as_ptr());
372 _tile_release();
373 }
374
375 #[simd_test(enable = "amx-tile")]
376 unsafe fn test_tile_storeconfig() {
377 let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
378 _tile_loadconfig(config.as_ptr());
379 let mut _config = __tilecfg::default();
380 _tile_storeconfig(_config.as_mut_ptr());
381 _tile_release();
382 assert_eq!(config, _config);
383 }
384
385 #[simd_test(enable = "amx-tile")]
386 unsafe fn test_tile_zero() {
387 _init_amx();
388 let mut config = __tilecfg::default();
389 config.palette = 1;
390 config.colsb[0] = 64;
391 config.rows[0] = 16;
392 _tile_loadconfig(config.as_ptr());
393 _tile_zero::<0>();
394 let mut out = [[1_i8; 64]; 16];
395 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
396 _tile_release();
397 assert_eq!(out, [[0; 64]; 16]);
398 }
399
400 #[simd_test(enable = "amx-tile")]
401 unsafe fn test_tile_stored() {
402 _init_amx();
403 let mut config = __tilecfg::default();
404 config.palette = 1;
405 config.colsb[0] = 64;
406 config.rows[0] = 16;
407 _tile_loadconfig(config.as_ptr());
408 _tile_zero::<0>();
409 let mut out = [[1_i8; 64]; 16];
410 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
411 _tile_release();
412 assert_eq!(out, [[0; 64]; 16]);
413 }
414
415 #[simd_test(enable = "amx-tile")]
416 unsafe fn test_tile_loadd() {
417 _init_amx();
418 let mut config = __tilecfg::default();
419 config.palette = 1;
420 config.colsb[0] = 64;
421 config.rows[0] = 16;
422 _tile_loadconfig(config.as_ptr());
423 _tile_zero::<0>();
424 let mat = [1_i8; 1024];
425 _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
426 let mut out = [[0_i8; 64]; 16];
427 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
428 _tile_release();
429 assert_eq!(out, [[1; 64]; 16]);
430 }
431
432 #[simd_test(enable = "amx-tile")]
433 unsafe fn test_tile_stream_loadd() {
434 _init_amx();
435 let mut config = __tilecfg::default();
436 config.palette = 1;
437 config.colsb[0] = 64;
438 config.rows[0] = 16;
439 _tile_loadconfig(config.as_ptr());
440 _tile_zero::<0>();
441 let mat = [1_i8; 1024];
442 _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
443 let mut out = [[0_i8; 64]; 16];
444 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
445 _tile_release();
446 assert_eq!(out, [[1; 64]; 16]);
447 }
448
449 #[simd_test(enable = "amx-tile")]
450 unsafe fn test_tile_release() {
451 _tile_release();
452 }
453
454 #[simd_test(enable = "amx-bf16,avx512f")]
455 unsafe fn test_tile_dpbf16ps() {
456 _init_amx();
457 let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
458 let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
459 let ones: [u8; 1024] = transmute([bf16_1; 512]);
460 let twos: [u8; 1024] = transmute([bf16_2; 512]);
461 let mut res = [[0f32; 16]; 16];
462 let mut config = __tilecfg::default();
463 config.palette = 1;
464 (0..=2).for_each(|i| {
465 config.colsb[i] = 64;
466 config.rows[i] = 16;
467 });
468 _tile_loadconfig(config.as_ptr());
469 _tile_zero::<0>();
470 _tile_loadd::<1>(&ones as *const u8, 64);
471 _tile_loadd::<2>(&twos as *const u8, 64);
472 _tile_dpbf16ps::<0, 1, 2>();
473 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
474 _tile_release();
475 assert_eq!(res, [[64f32; 16]; 16]);
476 }
477
478 #[simd_test(enable = "amx-int8")]
479 unsafe fn test_tile_dpbssd() {
480 _init_amx();
481 let ones = [-1_i8; 1024];
482 let twos = [-2_i8; 1024];
483 let mut res = [[0_i32; 16]; 16];
484 let mut config = __tilecfg::default();
485 config.palette = 1;
486 (0..=2).for_each(|i| {
487 config.colsb[i] = 64;
488 config.rows[i] = 16;
489 });
490 _tile_loadconfig(config.as_ptr());
491 _tile_zero::<0>();
492 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
493 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
494 _tile_dpbssd::<0, 1, 2>();
495 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
496 _tile_release();
497 assert_eq!(res, [[128_i32; 16]; 16]);
498 }
499
500 #[simd_test(enable = "amx-int8")]
501 unsafe fn test_tile_dpbsud() {
502 _init_amx();
503 let ones = [-1_i8; 1024];
504 let twos = [2_u8; 1024];
505 let mut res = [[0_i32; 16]; 16];
506 let mut config = __tilecfg::default();
507 config.palette = 1;
508 (0..=2).for_each(|i| {
509 config.colsb[i] = 64;
510 config.rows[i] = 16;
511 });
512 _tile_loadconfig(config.as_ptr());
513 _tile_zero::<0>();
514 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
515 _tile_loadd::<2>(&twos as *const u8, 64);
516 _tile_dpbsud::<0, 1, 2>();
517 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
518 _tile_release();
519 assert_eq!(res, [[-128_i32; 16]; 16]);
520 }
521
522 #[simd_test(enable = "amx-int8")]
523 unsafe fn test_tile_dpbusd() {
524 _init_amx();
525 let ones = [1_u8; 1024];
526 let twos = [-2_i8; 1024];
527 let mut res = [[0_i32; 16]; 16];
528 let mut config = __tilecfg::default();
529 config.palette = 1;
530 (0..=2).for_each(|i| {
531 config.colsb[i] = 64;
532 config.rows[i] = 16;
533 });
534 _tile_loadconfig(config.as_ptr());
535 _tile_zero::<0>();
536 _tile_loadd::<1>(&ones as *const u8, 64);
537 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
538 _tile_dpbusd::<0, 1, 2>();
539 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
540 _tile_release();
541 assert_eq!(res, [[-128_i32; 16]; 16]);
542 }
543
544 #[simd_test(enable = "amx-int8")]
545 unsafe fn test_tile_dpbuud() {
546 _init_amx();
547 let ones = [1_u8; 1024];
548 let twos = [2_u8; 1024];
549 let mut res = [[0_i32; 16]; 16];
550 let mut config = __tilecfg::default();
551 config.palette = 1;
552 (0..=2).for_each(|i| {
553 config.colsb[i] = 64;
554 config.rows[i] = 16;
555 });
556 _tile_loadconfig(config.as_ptr());
557 _tile_zero::<0>();
558 _tile_loadd::<1>(&ones as *const u8, 64);
559 _tile_loadd::<2>(&twos as *const u8, 64);
560 _tile_dpbuud::<0, 1, 2>();
561 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
562 _tile_release();
563 assert_eq!(res, [[128_i32; 16]; 16]);
564 }
565
566 #[simd_test(enable = "amx-fp16")]
567 unsafe fn test_tile_dpfp16ps() {
568 _init_amx();
569 let ones = [1f16; 512];
570 let twos = [2f16; 512];
571 let mut res = [[0f32; 16]; 16];
572 let mut config = __tilecfg::default();
573 config.palette = 1;
574 (0..=2).for_each(|i| {
575 config.colsb[i] = 64;
576 config.rows[i] = 16;
577 });
578 _tile_loadconfig(config.as_ptr());
579 _tile_zero::<0>();
580 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
581 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
582 _tile_dpfp16ps::<0, 1, 2>();
583 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
584 _tile_release();
585 assert_eq!(res, [[64f32; 16]; 16]);
586 }
587
588 #[simd_test(enable = "amx-complex")]
589 unsafe fn test_tile_cmmimfp16ps() {
590 _init_amx();
591 let ones = [1f16; 512];
592 let twos = [2f16; 512];
593 let mut res = [[0f32; 16]; 16];
594 let mut config = __tilecfg::default();
595 config.palette = 1;
596 (0..=2).for_each(|i| {
597 config.colsb[i] = 64;
598 config.rows[i] = 16;
599 });
600 _tile_loadconfig(config.as_ptr());
601 _tile_zero::<0>();
602 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
603 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
604 _tile_cmmimfp16ps::<0, 1, 2>();
605 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
606 _tile_release();
607 assert_eq!(res, [[64f32; 16]; 16]);
608 }
609
610 #[simd_test(enable = "amx-complex")]
611 unsafe fn test_tile_cmmrlfp16ps() {
612 _init_amx();
613 let ones = [1f16; 512];
614 let twos = [2f16; 512];
615 let mut res = [[0f32; 16]; 16];
616 let mut config = __tilecfg::default();
617 config.palette = 1;
618 (0..=2).for_each(|i| {
619 config.colsb[i] = 64;
620 config.rows[i] = 16;
621 });
622 _tile_loadconfig(config.as_ptr());
623 _tile_zero::<0>();
624 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
625 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
626 _tile_cmmrlfp16ps::<0, 1, 2>();
627 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
628 _tile_release();
629 assert_eq!(res, [[0f32; 16]; 16]);
630 }
631}
632