1 | #[cfg (test)] |
2 | use stdarch_test::assert_instr; |
3 | |
4 | /// Load tile configuration from a 64-byte memory location specified by mem_addr. |
5 | /// The tile configuration format is specified below, and includes the tile type pallette, |
6 | /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, |
7 | /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. |
8 | /// Any invalid configurations will result in #GP fault. |
9 | /// |
10 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) |
11 | #[inline ] |
12 | #[target_feature (enable = "amx-tile" )] |
13 | #[cfg_attr (test, assert_instr(ldtilecfg))] |
14 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
15 | pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { |
16 | ldtilecfg(mem_addr); |
17 | } |
18 | |
19 | /// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. |
20 | /// The tile configuration format is specified below, and includes the tile type pallette, |
21 | /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. |
22 | /// |
23 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) |
24 | #[inline ] |
25 | #[target_feature (enable = "amx-tile" )] |
26 | #[cfg_attr (test, assert_instr(sttilecfg))] |
27 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
28 | pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { |
29 | sttilecfg(mem_addr); |
30 | } |
31 | |
32 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. |
33 | /// |
34 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) |
35 | #[inline ] |
36 | #[rustc_legacy_const_generics (0)] |
37 | #[target_feature (enable = "amx-tile" )] |
38 | #[cfg_attr (test, assert_instr(tileloadd, DST = 0))] |
39 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
40 | pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) { |
41 | static_assert_uimm_bits!(DST, 3); |
42 | tileloadd64(DST as i8, base, stride); |
43 | } |
44 | |
45 | /// Release the tile configuration to return to the init state, which releases all storage it currently holds. |
46 | /// |
47 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) |
48 | #[inline ] |
49 | #[target_feature (enable = "amx-tile" )] |
50 | #[cfg_attr (test, assert_instr(tilerelease))] |
51 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
52 | pub unsafe fn _tile_release() { |
53 | tilerelease(); |
54 | } |
55 | |
56 | /// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. |
57 | /// |
58 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) |
59 | #[inline ] |
60 | #[rustc_legacy_const_generics (0)] |
61 | #[target_feature (enable = "amx-tile" )] |
62 | #[cfg_attr (test, assert_instr(tilestored, DST = 0))] |
63 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
64 | pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) { |
65 | static_assert_uimm_bits!(DST, 3); |
66 | tilestored64(DST as i8, base, stride); |
67 | } |
68 | |
69 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration |
70 | /// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will |
71 | /// likely not be reused in the near future and the data caching can be optimized accordingly. |
72 | /// |
73 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) |
74 | #[inline ] |
75 | #[rustc_legacy_const_generics (0)] |
76 | #[target_feature (enable = "amx-tile" )] |
77 | #[cfg_attr (test, assert_instr(tileloaddt1, DST = 0))] |
78 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
79 | pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) { |
80 | static_assert_uimm_bits!(DST, 3); |
81 | tileloaddt164(DST as i8, base, stride); |
82 | } |
83 | |
84 | /// Zero the tile specified by tdest. |
85 | /// |
86 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) |
87 | #[inline ] |
88 | #[rustc_legacy_const_generics (0)] |
89 | #[target_feature (enable = "amx-tile" )] |
90 | #[cfg_attr (test, assert_instr(tilezero, DST = 0))] |
91 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
92 | pub unsafe fn _tile_zero<const DST: i32>() { |
93 | static_assert_uimm_bits!(DST, 3); |
94 | tilezero(DST as i8); |
95 | } |
96 | |
97 | /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, |
98 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
99 | /// with elements in dst, and store the 32-bit result back to tile dst. |
100 | /// |
101 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864) |
102 | #[inline ] |
103 | #[rustc_legacy_const_generics (0, 1, 2)] |
104 | #[target_feature (enable = "amx-bf16" )] |
105 | #[cfg_attr (test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))] |
106 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
107 | pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() { |
108 | static_assert_uimm_bits!(DST, 3); |
109 | static_assert_uimm_bits!(A, 3); |
110 | static_assert_uimm_bits!(B, 3); |
111 | tdpbf16ps(DST as i8, A as i8, B as i8); |
112 | } |
113 | |
114 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
115 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
116 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
117 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
118 | /// |
119 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866) |
120 | #[inline ] |
121 | #[rustc_legacy_const_generics (0, 1, 2)] |
122 | #[target_feature (enable = "amx-int8" )] |
123 | #[cfg_attr (test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))] |
124 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
125 | pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() { |
126 | static_assert_uimm_bits!(DST, 3); |
127 | static_assert_uimm_bits!(A, 3); |
128 | static_assert_uimm_bits!(B, 3); |
129 | tdpbssd(DST as i8, A as i8, B as i8); |
130 | } |
131 | |
132 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
133 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
134 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
135 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
136 | /// |
137 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868) |
138 | #[inline ] |
139 | #[rustc_legacy_const_generics (0, 1, 2)] |
140 | #[target_feature (enable = "amx-int8" )] |
141 | #[cfg_attr (test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))] |
142 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
143 | pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() { |
144 | static_assert_uimm_bits!(DST, 3); |
145 | static_assert_uimm_bits!(A, 3); |
146 | static_assert_uimm_bits!(B, 3); |
147 | tdpbsud(DST as i8, A as i8, B as i8); |
148 | } |
149 | |
150 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
151 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
152 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
153 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
154 | /// |
155 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870) |
156 | #[inline ] |
157 | #[rustc_legacy_const_generics (0, 1, 2)] |
158 | #[target_feature (enable = "amx-int8" )] |
159 | #[cfg_attr (test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))] |
160 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
161 | pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() { |
162 | static_assert_uimm_bits!(DST, 3); |
163 | static_assert_uimm_bits!(A, 3); |
164 | static_assert_uimm_bits!(B, 3); |
165 | tdpbusd(DST as i8, A as i8, B as i8); |
166 | } |
167 | |
168 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
169 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
170 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
171 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
172 | /// |
173 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872) |
174 | #[inline ] |
175 | #[rustc_legacy_const_generics (0, 1, 2)] |
176 | #[target_feature (enable = "amx-int8" )] |
177 | #[cfg_attr (test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))] |
178 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
179 | pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() { |
180 | static_assert_uimm_bits!(DST, 3); |
181 | static_assert_uimm_bits!(A, 3); |
182 | static_assert_uimm_bits!(B, 3); |
183 | tdpbuud(DST as i8, A as i8, B as i8); |
184 | } |
185 | |
186 | /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, |
187 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
188 | /// with elements in dst, and store the 32-bit result back to tile dst. |
189 | /// |
190 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874) |
191 | #[inline ] |
192 | #[rustc_legacy_const_generics (0, 1, 2)] |
193 | #[target_feature (enable = "amx-fp16" )] |
194 | #[cfg_attr (test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))] |
195 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
196 | pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() { |
197 | static_assert_uimm_bits!(DST, 3); |
198 | static_assert_uimm_bits!(A, 3); |
199 | static_assert_uimm_bits!(B, 3); |
200 | tdpfp16ps(DST as i8, A as i8, B as i8); |
201 | } |
202 | |
203 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
204 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
205 | /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), |
206 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
207 | /// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of |
208 | /// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, |
209 | /// and then accumulated into the corresponding row and column of dst. |
210 | /// |
211 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860) |
212 | #[inline ] |
213 | #[rustc_legacy_const_generics (0, 1, 2)] |
214 | #[target_feature (enable = "amx-complex" )] |
215 | #[cfg_attr (test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))] |
216 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
217 | pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() { |
218 | static_assert_uimm_bits!(DST, 3); |
219 | static_assert_uimm_bits!(A, 3); |
220 | static_assert_uimm_bits!(B, 3); |
221 | tcmmimfp16ps(DST as i8, A as i8, B as i8); |
222 | } |
223 | |
224 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
225 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
226 | /// Calculates the real part of the result. For each possible combination of (row of a, column of b), |
227 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
228 | /// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of |
229 | /// the a element is multiplied with the imaginary part of the corresponding b elements. |
230 | /// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. |
231 | /// |
232 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862) |
233 | #[inline ] |
234 | #[rustc_legacy_const_generics (0, 1, 2)] |
235 | #[target_feature (enable = "amx-complex" )] |
236 | #[cfg_attr (test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))] |
237 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
238 | pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() { |
239 | static_assert_uimm_bits!(DST, 3); |
240 | static_assert_uimm_bits!(A, 3); |
241 | static_assert_uimm_bits!(B, 3); |
242 | tcmmrlfp16ps(DST as i8, A as i8, B as i8); |
243 | } |
244 | |
245 | #[allow (improper_ctypes)] |
246 | unsafe extern "C" { |
247 | #[link_name = "llvm.x86.ldtilecfg" ] |
248 | unsafefn ldtilecfg(mem_addr: *const u8); |
249 | #[link_name = "llvm.x86.sttilecfg" ] |
250 | unsafefn sttilecfg(mem_addr: *mut u8); |
251 | #[link_name = "llvm.x86.tileloadd64" ] |
252 | unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize); |
253 | #[link_name = "llvm.x86.tileloaddt164" ] |
254 | unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize); |
255 | #[link_name = "llvm.x86.tilerelease" ] |
256 | unsafefn tilerelease(); |
257 | #[link_name = "llvm.x86.tilestored64" ] |
258 | unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize); |
259 | #[link_name = "llvm.x86.tilezero" ] |
260 | unsafefn tilezero(dst: i8); |
261 | #[link_name = "llvm.x86.tdpbf16ps" ] |
262 | unsafefn tdpbf16ps(dst: i8, a: i8, b: i8); |
263 | #[link_name = "llvm.x86.tdpbuud" ] |
264 | unsafefn tdpbuud(dst: i8, a: i8, b: i8); |
265 | #[link_name = "llvm.x86.tdpbusd" ] |
266 | unsafefn tdpbusd(dst: i8, a: i8, b: i8); |
267 | #[link_name = "llvm.x86.tdpbsud" ] |
268 | unsafefn tdpbsud(dst: i8, a: i8, b: i8); |
269 | #[link_name = "llvm.x86.tdpbssd" ] |
270 | unsafefn tdpbssd(dst: i8, a: i8, b: i8); |
271 | #[link_name = "llvm.x86.tdpfp16ps" ] |
272 | unsafefn tdpfp16ps(dst: i8, a: i8, b: i8); |
273 | #[link_name = "llvm.x86.tcmmimfp16ps" ] |
274 | unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8); |
275 | #[link_name = "llvm.x86.tcmmrlfp16ps" ] |
276 | unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8); |
277 | } |
278 | |
279 | #[cfg (test)] |
280 | mod tests { |
281 | use crate::core_arch::x86::_mm_cvtness_sbh; |
282 | use crate::core_arch::x86_64::*; |
283 | use core::mem::transmute; |
284 | use stdarch_test::simd_test; |
285 | #[cfg (target_os = "linux" )] |
286 | use syscalls::{Sysno, syscall}; |
287 | |
288 | #[allow (non_camel_case_types)] |
289 | #[repr (packed)] |
290 | #[derive (Copy, Clone, Default, Debug, PartialEq)] |
291 | struct __tilecfg { |
292 | /// 0 `or` 1 |
293 | palette: u8, |
294 | start_row: u8, |
295 | /// reserved, must be zero |
296 | reserved_a0: [u8; 14], |
297 | /// number of bytes of one row in each tile |
298 | colsb: [u16; 8], |
299 | /// reserved, must be zero |
300 | reserved_b0: [u16; 8], |
301 | /// number of rows in each tile |
302 | rows: [u8; 8], |
303 | /// reserved, must be zero |
304 | reserved_c0: [u8; 8], |
305 | } |
306 | |
307 | impl __tilecfg { |
308 | fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { |
309 | Self { |
310 | palette, |
311 | start_row, |
312 | reserved_a0: [0u8; 14], |
313 | colsb, |
314 | reserved_b0: [0u16; 8], |
315 | rows, |
316 | reserved_c0: [0u8; 8], |
317 | } |
318 | } |
319 | |
320 | const fn as_ptr(&self) -> *const u8 { |
321 | self as *const Self as *const u8 |
322 | } |
323 | |
324 | fn as_mut_ptr(&mut self) -> *mut u8 { |
325 | self as *mut Self as *mut u8 |
326 | } |
327 | } |
328 | |
329 | #[cfg (not(target_os = "linux" ))] |
330 | #[target_feature (enable = "amx-tile" )] |
331 | fn _init_amx() {} |
332 | |
333 | #[cfg (target_os = "linux" )] |
334 | #[target_feature (enable = "amx-tile" )] |
335 | #[inline ] |
336 | unsafe fn _init_amx() { |
337 | let mut ret: usize; |
338 | let mut xfeatures: usize = 0; |
339 | ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) |
340 | .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed" ); |
341 | if ret != 0 { |
342 | panic!("Failed to get XFEATURES" ); |
343 | } else { |
344 | match 0b11 & (xfeatures >> 17) { |
345 | 0 => panic!("AMX is not available" ), |
346 | 1 => { |
347 | ret = syscall!(Sysno::arch_prctl, 0x1023, 18) |
348 | .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed" ); |
349 | if ret != 0 { |
350 | panic!("Failed to enable AMX" ); |
351 | } |
352 | } |
353 | 3 => {} |
354 | _ => unreachable!(), |
355 | } |
356 | } |
357 | } |
358 | |
359 | #[simd_test(enable = "amx-tile" )] |
360 | unsafe fn test_tile_loadconfig() { |
361 | let config = __tilecfg::default(); |
362 | _tile_loadconfig(config.as_ptr()); |
363 | _tile_release(); |
364 | } |
365 | |
366 | #[simd_test(enable = "amx-tile" )] |
367 | unsafe fn test_tile_storeconfig() { |
368 | let config = __tilecfg::new(1, 0, [32; 8], [8; 8]); |
369 | _tile_loadconfig(config.as_ptr()); |
370 | let mut _config = __tilecfg::default(); |
371 | _tile_storeconfig(_config.as_mut_ptr()); |
372 | _tile_release(); |
373 | assert_eq!(config, _config); |
374 | } |
375 | |
376 | #[simd_test(enable = "amx-tile" )] |
377 | unsafe fn test_tile_zero() { |
378 | _init_amx(); |
379 | let mut config = __tilecfg::default(); |
380 | config.palette = 1; |
381 | config.colsb[0] = 64; |
382 | config.rows[0] = 16; |
383 | _tile_loadconfig(config.as_ptr()); |
384 | _tile_zero::<0>(); |
385 | let mut out = [[1_i8; 64]; 16]; |
386 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
387 | _tile_release(); |
388 | assert_eq!(out, [[0; 64]; 16]); |
389 | } |
390 | |
391 | #[simd_test(enable = "amx-tile" )] |
392 | unsafe fn test_tile_stored() { |
393 | _init_amx(); |
394 | let mut config = __tilecfg::default(); |
395 | config.palette = 1; |
396 | config.colsb[0] = 64; |
397 | config.rows[0] = 16; |
398 | _tile_loadconfig(config.as_ptr()); |
399 | _tile_zero::<0>(); |
400 | let mut out = [[1_i8; 64]; 16]; |
401 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
402 | _tile_release(); |
403 | assert_eq!(out, [[0; 64]; 16]); |
404 | } |
405 | |
406 | #[simd_test(enable = "amx-tile" )] |
407 | unsafe fn test_tile_loadd() { |
408 | _init_amx(); |
409 | let mut config = __tilecfg::default(); |
410 | config.palette = 1; |
411 | config.colsb[0] = 64; |
412 | config.rows[0] = 16; |
413 | _tile_loadconfig(config.as_ptr()); |
414 | _tile_zero::<0>(); |
415 | let mat = [1_i8; 1024]; |
416 | _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); |
417 | let mut out = [[0_i8; 64]; 16]; |
418 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
419 | _tile_release(); |
420 | assert_eq!(out, [[1; 64]; 16]); |
421 | } |
422 | |
423 | #[simd_test(enable = "amx-tile" )] |
424 | unsafe fn test_tile_stream_loadd() { |
425 | _init_amx(); |
426 | let mut config = __tilecfg::default(); |
427 | config.palette = 1; |
428 | config.colsb[0] = 64; |
429 | config.rows[0] = 16; |
430 | _tile_loadconfig(config.as_ptr()); |
431 | _tile_zero::<0>(); |
432 | let mat = [1_i8; 1024]; |
433 | _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); |
434 | let mut out = [[0_i8; 64]; 16]; |
435 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
436 | _tile_release(); |
437 | assert_eq!(out, [[1; 64]; 16]); |
438 | } |
439 | |
440 | #[simd_test(enable = "amx-tile" )] |
441 | unsafe fn test_tile_release() { |
442 | _tile_release(); |
443 | } |
444 | |
445 | #[simd_test(enable = "amx-bf16,avx512f" )] |
446 | unsafe fn test_tile_dpbf16ps() { |
447 | _init_amx(); |
448 | let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); |
449 | let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); |
450 | let ones: [u8; 1024] = transmute([bf16_1; 512]); |
451 | let twos: [u8; 1024] = transmute([bf16_2; 512]); |
452 | let mut res = [[0f32; 16]; 16]; |
453 | let mut config = __tilecfg::default(); |
454 | config.palette = 1; |
455 | (0..=2).for_each(|i| { |
456 | config.colsb[i] = 64; |
457 | config.rows[i] = 16; |
458 | }); |
459 | _tile_loadconfig(config.as_ptr()); |
460 | _tile_zero::<0>(); |
461 | _tile_loadd::<1>(&ones as *const u8, 64); |
462 | _tile_loadd::<2>(&twos as *const u8, 64); |
463 | _tile_dpbf16ps::<0, 1, 2>(); |
464 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
465 | _tile_release(); |
466 | assert_eq!(res, [[64f32; 16]; 16]); |
467 | } |
468 | |
469 | #[simd_test(enable = "amx-int8" )] |
470 | unsafe fn test_tile_dpbssd() { |
471 | _init_amx(); |
472 | let ones = [-1_i8; 1024]; |
473 | let twos = [-2_i8; 1024]; |
474 | let mut res = [[0_i32; 16]; 16]; |
475 | let mut config = __tilecfg::default(); |
476 | config.palette = 1; |
477 | (0..=2).for_each(|i| { |
478 | config.colsb[i] = 64; |
479 | config.rows[i] = 16; |
480 | }); |
481 | _tile_loadconfig(config.as_ptr()); |
482 | _tile_zero::<0>(); |
483 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
484 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
485 | _tile_dpbssd::<0, 1, 2>(); |
486 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
487 | _tile_release(); |
488 | assert_eq!(res, [[128_i32; 16]; 16]); |
489 | } |
490 | |
491 | #[simd_test(enable = "amx-int8" )] |
492 | unsafe fn test_tile_dpbsud() { |
493 | _init_amx(); |
494 | let ones = [-1_i8; 1024]; |
495 | let twos = [2_u8; 1024]; |
496 | let mut res = [[0_i32; 16]; 16]; |
497 | let mut config = __tilecfg::default(); |
498 | config.palette = 1; |
499 | (0..=2).for_each(|i| { |
500 | config.colsb[i] = 64; |
501 | config.rows[i] = 16; |
502 | }); |
503 | _tile_loadconfig(config.as_ptr()); |
504 | _tile_zero::<0>(); |
505 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
506 | _tile_loadd::<2>(&twos as *const u8, 64); |
507 | _tile_dpbsud::<0, 1, 2>(); |
508 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
509 | _tile_release(); |
510 | assert_eq!(res, [[-128_i32; 16]; 16]); |
511 | } |
512 | |
513 | #[simd_test(enable = "amx-int8" )] |
514 | unsafe fn test_tile_dpbusd() { |
515 | _init_amx(); |
516 | let ones = [1_u8; 1024]; |
517 | let twos = [-2_i8; 1024]; |
518 | let mut res = [[0_i32; 16]; 16]; |
519 | let mut config = __tilecfg::default(); |
520 | config.palette = 1; |
521 | (0..=2).for_each(|i| { |
522 | config.colsb[i] = 64; |
523 | config.rows[i] = 16; |
524 | }); |
525 | _tile_loadconfig(config.as_ptr()); |
526 | _tile_zero::<0>(); |
527 | _tile_loadd::<1>(&ones as *const u8, 64); |
528 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
529 | _tile_dpbusd::<0, 1, 2>(); |
530 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
531 | _tile_release(); |
532 | assert_eq!(res, [[-128_i32; 16]; 16]); |
533 | } |
534 | |
535 | #[simd_test(enable = "amx-int8" )] |
536 | unsafe fn test_tile_dpbuud() { |
537 | _init_amx(); |
538 | let ones = [1_u8; 1024]; |
539 | let twos = [2_u8; 1024]; |
540 | let mut res = [[0_i32; 16]; 16]; |
541 | let mut config = __tilecfg::default(); |
542 | config.palette = 1; |
543 | (0..=2).for_each(|i| { |
544 | config.colsb[i] = 64; |
545 | config.rows[i] = 16; |
546 | }); |
547 | _tile_loadconfig(config.as_ptr()); |
548 | _tile_zero::<0>(); |
549 | _tile_loadd::<1>(&ones as *const u8, 64); |
550 | _tile_loadd::<2>(&twos as *const u8, 64); |
551 | _tile_dpbuud::<0, 1, 2>(); |
552 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
553 | _tile_release(); |
554 | assert_eq!(res, [[128_i32; 16]; 16]); |
555 | } |
556 | |
557 | #[simd_test(enable = "amx-fp16" )] |
558 | unsafe fn test_tile_dpfp16ps() { |
559 | _init_amx(); |
560 | let ones = [1f16; 512]; |
561 | let twos = [2f16; 512]; |
562 | let mut res = [[0f32; 16]; 16]; |
563 | let mut config = __tilecfg::default(); |
564 | config.palette = 1; |
565 | (0..=2).for_each(|i| { |
566 | config.colsb[i] = 64; |
567 | config.rows[i] = 16; |
568 | }); |
569 | _tile_loadconfig(config.as_ptr()); |
570 | _tile_zero::<0>(); |
571 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
572 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
573 | _tile_dpfp16ps::<0, 1, 2>(); |
574 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
575 | _tile_release(); |
576 | assert_eq!(res, [[64f32; 16]; 16]); |
577 | } |
578 | |
579 | #[simd_test(enable = "amx-complex" )] |
580 | unsafe fn test_tile_cmmimfp16ps() { |
581 | _init_amx(); |
582 | let ones = [1f16; 512]; |
583 | let twos = [2f16; 512]; |
584 | let mut res = [[0f32; 16]; 16]; |
585 | let mut config = __tilecfg::default(); |
586 | config.palette = 1; |
587 | (0..=2).for_each(|i| { |
588 | config.colsb[i] = 64; |
589 | config.rows[i] = 16; |
590 | }); |
591 | _tile_loadconfig(config.as_ptr()); |
592 | _tile_zero::<0>(); |
593 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
594 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
595 | _tile_cmmimfp16ps::<0, 1, 2>(); |
596 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
597 | _tile_release(); |
598 | assert_eq!(res, [[64f32; 16]; 16]); |
599 | } |
600 | |
601 | #[simd_test(enable = "amx-complex" )] |
602 | unsafe fn test_tile_cmmrlfp16ps() { |
603 | _init_amx(); |
604 | let ones = [1f16; 512]; |
605 | let twos = [2f16; 512]; |
606 | let mut res = [[0f32; 16]; 16]; |
607 | let mut config = __tilecfg::default(); |
608 | config.palette = 1; |
609 | (0..=2).for_each(|i| { |
610 | config.colsb[i] = 64; |
611 | config.rows[i] = 16; |
612 | }); |
613 | _tile_loadconfig(config.as_ptr()); |
614 | _tile_zero::<0>(); |
615 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
616 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
617 | _tile_cmmrlfp16ps::<0, 1, 2>(); |
618 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
619 | _tile_release(); |
620 | assert_eq!(res, [[0f32; 16]; 16]); |
621 | } |
622 | } |
623 | |