1 | #[cfg (test)] |
2 | use stdarch_test::assert_instr; |
3 | |
4 | /// Load tile configuration from a 64-byte memory location specified by mem_addr. |
5 | /// The tile configuration format is specified below, and includes the tile type pallette, |
6 | /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, |
7 | /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. |
8 | /// Any invalid configurations will result in #GP fault. |
9 | /// |
10 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) |
11 | #[inline ] |
12 | #[target_feature (enable = "amx-tile" )] |
13 | #[cfg_attr (test, assert_instr(ldtilecfg))] |
14 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
15 | pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { |
16 | ldtilecfg(mem_addr); |
17 | } |
18 | |
19 | /// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. |
20 | /// The tile configuration format is specified below, and includes the tile type pallette, |
21 | /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. |
22 | /// |
23 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) |
24 | #[inline ] |
25 | #[target_feature (enable = "amx-tile" )] |
26 | #[cfg_attr (test, assert_instr(sttilecfg))] |
27 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
28 | pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { |
29 | sttilecfg(mem_addr); |
30 | } |
31 | |
32 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. |
33 | /// |
34 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) |
35 | #[inline ] |
36 | #[rustc_legacy_const_generics (0)] |
37 | #[target_feature (enable = "amx-tile" )] |
38 | #[cfg_attr (test, assert_instr(tileloadd, DST = 0))] |
39 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
40 | pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) { |
41 | static_assert_uimm_bits!(DST, 3); |
42 | tileloadd64(DST as i8, base, stride); |
43 | } |
44 | |
45 | /// Release the tile configuration to return to the init state, which releases all storage it currently holds. |
46 | /// |
47 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) |
48 | #[inline ] |
49 | #[target_feature (enable = "amx-tile" )] |
50 | #[cfg_attr (test, assert_instr(tilerelease))] |
51 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
52 | pub unsafe fn _tile_release() { |
53 | tilerelease(); |
54 | } |
55 | |
56 | /// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. |
57 | /// |
58 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) |
59 | #[inline ] |
60 | #[rustc_legacy_const_generics (0)] |
61 | #[target_feature (enable = "amx-tile" )] |
62 | #[cfg_attr (test, assert_instr(tilestored, DST = 0))] |
63 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
64 | pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) { |
65 | static_assert_uimm_bits!(DST, 3); |
66 | tilestored64(DST as i8, base, stride); |
67 | } |
68 | |
69 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration |
70 | /// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will |
71 | /// likely not be reused in the near future and the data caching can be optimized accordingly. |
72 | /// |
73 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) |
74 | #[inline ] |
75 | #[rustc_legacy_const_generics (0)] |
76 | #[target_feature (enable = "amx-tile" )] |
77 | #[cfg_attr (test, assert_instr(tileloaddt1, DST = 0))] |
78 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
79 | pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) { |
80 | static_assert_uimm_bits!(DST, 3); |
81 | tileloaddt164(DST as i8, base, stride); |
82 | } |
83 | |
84 | /// Zero the tile specified by tdest. |
85 | /// |
86 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) |
87 | #[inline ] |
88 | #[rustc_legacy_const_generics (0)] |
89 | #[target_feature (enable = "amx-tile" )] |
90 | #[cfg_attr (test, assert_instr(tilezero, DST = 0))] |
91 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
92 | pub unsafe fn _tile_zero<const DST: i32>() { |
93 | static_assert_uimm_bits!(DST, 3); |
94 | tilezero(DST as i8); |
95 | } |
96 | |
97 | /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, |
98 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
99 | /// with elements in dst, and store the 32-bit result back to tile dst. |
100 | /// |
101 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864) |
102 | #[inline ] |
103 | #[rustc_legacy_const_generics (0, 1, 2)] |
104 | #[target_feature (enable = "amx-bf16" )] |
105 | #[cfg_attr (test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))] |
106 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
107 | pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() { |
108 | static_assert_uimm_bits!(DST, 3); |
109 | static_assert_uimm_bits!(A, 3); |
110 | static_assert_uimm_bits!(B, 3); |
111 | tdpbf16ps(DST as i8, A as i8, B as i8); |
112 | } |
113 | |
114 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
115 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
116 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
117 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
118 | /// |
119 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866) |
120 | #[inline ] |
121 | #[rustc_legacy_const_generics (0, 1, 2)] |
122 | #[target_feature (enable = "amx-int8" )] |
123 | #[cfg_attr (test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))] |
124 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
125 | pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() { |
126 | static_assert_uimm_bits!(DST, 3); |
127 | static_assert_uimm_bits!(A, 3); |
128 | static_assert_uimm_bits!(B, 3); |
129 | tdpbssd(DST as i8, A as i8, B as i8); |
130 | } |
131 | |
132 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
133 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
134 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
135 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
136 | /// |
137 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868) |
138 | #[inline ] |
139 | #[rustc_legacy_const_generics (0, 1, 2)] |
140 | #[target_feature (enable = "amx-int8" )] |
141 | #[cfg_attr (test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))] |
142 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
143 | pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() { |
144 | static_assert_uimm_bits!(DST, 3); |
145 | static_assert_uimm_bits!(A, 3); |
146 | static_assert_uimm_bits!(B, 3); |
147 | tdpbsud(DST as i8, A as i8, B as i8); |
148 | } |
149 | |
150 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
151 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
152 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
153 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
154 | /// |
155 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870) |
156 | #[inline ] |
157 | #[rustc_legacy_const_generics (0, 1, 2)] |
158 | #[target_feature (enable = "amx-int8" )] |
159 | #[cfg_attr (test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))] |
160 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
161 | pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() { |
162 | static_assert_uimm_bits!(DST, 3); |
163 | static_assert_uimm_bits!(A, 3); |
164 | static_assert_uimm_bits!(B, 3); |
165 | tdpbusd(DST as i8, A as i8, B as i8); |
166 | } |
167 | |
168 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
169 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
170 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
171 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
172 | /// |
173 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872) |
174 | #[inline ] |
175 | #[rustc_legacy_const_generics (0, 1, 2)] |
176 | #[target_feature (enable = "amx-int8" )] |
177 | #[cfg_attr (test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))] |
178 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
179 | pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() { |
180 | static_assert_uimm_bits!(DST, 3); |
181 | static_assert_uimm_bits!(A, 3); |
182 | static_assert_uimm_bits!(B, 3); |
183 | tdpbuud(DST as i8, A as i8, B as i8); |
184 | } |
185 | |
186 | /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, |
187 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
188 | /// with elements in dst, and store the 32-bit result back to tile dst. |
189 | /// |
190 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874) |
191 | #[inline ] |
192 | #[rustc_legacy_const_generics (0, 1, 2)] |
193 | #[target_feature (enable = "amx-fp16" )] |
194 | #[cfg_attr ( |
195 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
196 | assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2) |
197 | )] |
198 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
199 | pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() { |
200 | static_assert_uimm_bits!(DST, 3); |
201 | static_assert_uimm_bits!(A, 3); |
202 | static_assert_uimm_bits!(B, 3); |
203 | tdpfp16ps(DST as i8, A as i8, B as i8); |
204 | } |
205 | |
206 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
207 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
208 | /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), |
209 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
210 | /// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of |
211 | /// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, |
212 | /// and then accumulated into the corresponding row and column of dst. |
213 | /// |
214 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860) |
215 | #[inline ] |
216 | #[rustc_legacy_const_generics (0, 1, 2)] |
217 | #[target_feature (enable = "amx-complex" )] |
218 | #[cfg_attr ( |
219 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
220 | assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2) |
221 | )] |
222 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
223 | pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() { |
224 | static_assert_uimm_bits!(DST, 3); |
225 | static_assert_uimm_bits!(A, 3); |
226 | static_assert_uimm_bits!(B, 3); |
227 | tcmmimfp16ps(DST as i8, A as i8, B as i8); |
228 | } |
229 | |
230 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
231 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
232 | /// Calculates the real part of the result. For each possible combination of (row of a, column of b), |
233 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
234 | /// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of |
235 | /// the a element is multiplied with the imaginary part of the corresponding b elements. |
236 | /// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. |
237 | /// |
238 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862) |
239 | #[inline ] |
240 | #[rustc_legacy_const_generics (0, 1, 2)] |
241 | #[target_feature (enable = "amx-complex" )] |
242 | #[cfg_attr ( |
243 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
244 | assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2) |
245 | )] |
246 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
247 | pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() { |
248 | static_assert_uimm_bits!(DST, 3); |
249 | static_assert_uimm_bits!(A, 3); |
250 | static_assert_uimm_bits!(B, 3); |
251 | tcmmrlfp16ps(DST as i8, A as i8, B as i8); |
252 | } |
253 | |
254 | #[allow (improper_ctypes)] |
255 | unsafe extern "C" { |
256 | #[link_name = "llvm.x86.ldtilecfg" ] |
257 | unsafefn ldtilecfg(mem_addr: *const u8); |
258 | #[link_name = "llvm.x86.sttilecfg" ] |
259 | unsafefn sttilecfg(mem_addr: *mut u8); |
260 | #[link_name = "llvm.x86.tileloadd64" ] |
261 | unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize); |
262 | #[link_name = "llvm.x86.tileloaddt164" ] |
263 | unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize); |
264 | #[link_name = "llvm.x86.tilerelease" ] |
265 | unsafefn tilerelease(); |
266 | #[link_name = "llvm.x86.tilestored64" ] |
267 | unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize); |
268 | #[link_name = "llvm.x86.tilezero" ] |
269 | unsafefn tilezero(dst: i8); |
270 | #[link_name = "llvm.x86.tdpbf16ps" ] |
271 | unsafefn tdpbf16ps(dst: i8, a: i8, b: i8); |
272 | #[link_name = "llvm.x86.tdpbuud" ] |
273 | unsafefn tdpbuud(dst: i8, a: i8, b: i8); |
274 | #[link_name = "llvm.x86.tdpbusd" ] |
275 | unsafefn tdpbusd(dst: i8, a: i8, b: i8); |
276 | #[link_name = "llvm.x86.tdpbsud" ] |
277 | unsafefn tdpbsud(dst: i8, a: i8, b: i8); |
278 | #[link_name = "llvm.x86.tdpbssd" ] |
279 | unsafefn tdpbssd(dst: i8, a: i8, b: i8); |
280 | #[link_name = "llvm.x86.tdpfp16ps" ] |
281 | unsafefn tdpfp16ps(dst: i8, a: i8, b: i8); |
282 | #[link_name = "llvm.x86.tcmmimfp16ps" ] |
283 | unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8); |
284 | #[link_name = "llvm.x86.tcmmrlfp16ps" ] |
285 | unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8); |
286 | } |
287 | |
288 | #[cfg (test)] |
289 | mod tests { |
290 | use crate::core_arch::x86::_mm_cvtness_sbh; |
291 | use crate::core_arch::x86_64::*; |
292 | use core::mem::transmute; |
293 | use stdarch_test::simd_test; |
294 | #[cfg (target_os = "linux" )] |
295 | use syscalls::{Sysno, syscall}; |
296 | |
297 | #[allow (non_camel_case_types)] |
298 | #[repr (packed)] |
299 | #[derive (Copy, Clone, Default, Debug, PartialEq)] |
300 | struct __tilecfg { |
301 | /// 0 `or` 1 |
302 | palette: u8, |
303 | start_row: u8, |
304 | /// reserved, must be zero |
305 | reserved_a0: [u8; 14], |
306 | /// number of bytes of one row in each tile |
307 | colsb: [u16; 8], |
308 | /// reserved, must be zero |
309 | reserved_b0: [u16; 8], |
310 | /// number of rows in each tile |
311 | rows: [u8; 8], |
312 | /// reserved, must be zero |
313 | reserved_c0: [u8; 8], |
314 | } |
315 | |
316 | impl __tilecfg { |
317 | fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { |
318 | Self { |
319 | palette, |
320 | start_row, |
321 | reserved_a0: [0u8; 14], |
322 | colsb, |
323 | reserved_b0: [0u16; 8], |
324 | rows, |
325 | reserved_c0: [0u8; 8], |
326 | } |
327 | } |
328 | |
329 | const fn as_ptr(&self) -> *const u8 { |
330 | self as *const Self as *const u8 |
331 | } |
332 | |
333 | fn as_mut_ptr(&mut self) -> *mut u8 { |
334 | self as *mut Self as *mut u8 |
335 | } |
336 | } |
337 | |
338 | #[cfg (not(target_os = "linux" ))] |
339 | #[target_feature (enable = "amx-tile" )] |
340 | fn _init_amx() {} |
341 | |
342 | #[cfg (target_os = "linux" )] |
343 | #[target_feature (enable = "amx-tile" )] |
344 | #[inline ] |
345 | unsafe fn _init_amx() { |
346 | let mut ret: usize; |
347 | let mut xfeatures: usize = 0; |
348 | ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) |
349 | .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed" ); |
350 | if ret != 0 { |
351 | panic!("Failed to get XFEATURES" ); |
352 | } else { |
353 | match 0b11 & (xfeatures >> 17) { |
354 | 0 => panic!("AMX is not available" ), |
355 | 1 => { |
356 | ret = syscall!(Sysno::arch_prctl, 0x1023, 18) |
357 | .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed" ); |
358 | if ret != 0 { |
359 | panic!("Failed to enable AMX" ); |
360 | } |
361 | } |
362 | 3 => {} |
363 | _ => unreachable!(), |
364 | } |
365 | } |
366 | } |
367 | |
368 | #[simd_test(enable = "amx-tile" )] |
369 | unsafe fn test_tile_loadconfig() { |
370 | let config = __tilecfg::default(); |
371 | _tile_loadconfig(config.as_ptr()); |
372 | _tile_release(); |
373 | } |
374 | |
375 | #[simd_test(enable = "amx-tile" )] |
376 | unsafe fn test_tile_storeconfig() { |
377 | let config = __tilecfg::new(1, 0, [32; 8], [8; 8]); |
378 | _tile_loadconfig(config.as_ptr()); |
379 | let mut _config = __tilecfg::default(); |
380 | _tile_storeconfig(_config.as_mut_ptr()); |
381 | _tile_release(); |
382 | assert_eq!(config, _config); |
383 | } |
384 | |
385 | #[simd_test(enable = "amx-tile" )] |
386 | unsafe fn test_tile_zero() { |
387 | _init_amx(); |
388 | let mut config = __tilecfg::default(); |
389 | config.palette = 1; |
390 | config.colsb[0] = 64; |
391 | config.rows[0] = 16; |
392 | _tile_loadconfig(config.as_ptr()); |
393 | _tile_zero::<0>(); |
394 | let mut out = [[1_i8; 64]; 16]; |
395 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
396 | _tile_release(); |
397 | assert_eq!(out, [[0; 64]; 16]); |
398 | } |
399 | |
400 | #[simd_test(enable = "amx-tile" )] |
401 | unsafe fn test_tile_stored() { |
402 | _init_amx(); |
403 | let mut config = __tilecfg::default(); |
404 | config.palette = 1; |
405 | config.colsb[0] = 64; |
406 | config.rows[0] = 16; |
407 | _tile_loadconfig(config.as_ptr()); |
408 | _tile_zero::<0>(); |
409 | let mut out = [[1_i8; 64]; 16]; |
410 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
411 | _tile_release(); |
412 | assert_eq!(out, [[0; 64]; 16]); |
413 | } |
414 | |
415 | #[simd_test(enable = "amx-tile" )] |
416 | unsafe fn test_tile_loadd() { |
417 | _init_amx(); |
418 | let mut config = __tilecfg::default(); |
419 | config.palette = 1; |
420 | config.colsb[0] = 64; |
421 | config.rows[0] = 16; |
422 | _tile_loadconfig(config.as_ptr()); |
423 | _tile_zero::<0>(); |
424 | let mat = [1_i8; 1024]; |
425 | _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); |
426 | let mut out = [[0_i8; 64]; 16]; |
427 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
428 | _tile_release(); |
429 | assert_eq!(out, [[1; 64]; 16]); |
430 | } |
431 | |
432 | #[simd_test(enable = "amx-tile" )] |
433 | unsafe fn test_tile_stream_loadd() { |
434 | _init_amx(); |
435 | let mut config = __tilecfg::default(); |
436 | config.palette = 1; |
437 | config.colsb[0] = 64; |
438 | config.rows[0] = 16; |
439 | _tile_loadconfig(config.as_ptr()); |
440 | _tile_zero::<0>(); |
441 | let mat = [1_i8; 1024]; |
442 | _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); |
443 | let mut out = [[0_i8; 64]; 16]; |
444 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
445 | _tile_release(); |
446 | assert_eq!(out, [[1; 64]; 16]); |
447 | } |
448 | |
449 | #[simd_test(enable = "amx-tile" )] |
450 | unsafe fn test_tile_release() { |
451 | _tile_release(); |
452 | } |
453 | |
454 | #[simd_test(enable = "amx-bf16,avx512f" )] |
455 | unsafe fn test_tile_dpbf16ps() { |
456 | _init_amx(); |
457 | let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); |
458 | let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); |
459 | let ones: [u8; 1024] = transmute([bf16_1; 512]); |
460 | let twos: [u8; 1024] = transmute([bf16_2; 512]); |
461 | let mut res = [[0f32; 16]; 16]; |
462 | let mut config = __tilecfg::default(); |
463 | config.palette = 1; |
464 | (0..=2).for_each(|i| { |
465 | config.colsb[i] = 64; |
466 | config.rows[i] = 16; |
467 | }); |
468 | _tile_loadconfig(config.as_ptr()); |
469 | _tile_zero::<0>(); |
470 | _tile_loadd::<1>(&ones as *const u8, 64); |
471 | _tile_loadd::<2>(&twos as *const u8, 64); |
472 | _tile_dpbf16ps::<0, 1, 2>(); |
473 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
474 | _tile_release(); |
475 | assert_eq!(res, [[64f32; 16]; 16]); |
476 | } |
477 | |
478 | #[simd_test(enable = "amx-int8" )] |
479 | unsafe fn test_tile_dpbssd() { |
480 | _init_amx(); |
481 | let ones = [-1_i8; 1024]; |
482 | let twos = [-2_i8; 1024]; |
483 | let mut res = [[0_i32; 16]; 16]; |
484 | let mut config = __tilecfg::default(); |
485 | config.palette = 1; |
486 | (0..=2).for_each(|i| { |
487 | config.colsb[i] = 64; |
488 | config.rows[i] = 16; |
489 | }); |
490 | _tile_loadconfig(config.as_ptr()); |
491 | _tile_zero::<0>(); |
492 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
493 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
494 | _tile_dpbssd::<0, 1, 2>(); |
495 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
496 | _tile_release(); |
497 | assert_eq!(res, [[128_i32; 16]; 16]); |
498 | } |
499 | |
500 | #[simd_test(enable = "amx-int8" )] |
501 | unsafe fn test_tile_dpbsud() { |
502 | _init_amx(); |
503 | let ones = [-1_i8; 1024]; |
504 | let twos = [2_u8; 1024]; |
505 | let mut res = [[0_i32; 16]; 16]; |
506 | let mut config = __tilecfg::default(); |
507 | config.palette = 1; |
508 | (0..=2).for_each(|i| { |
509 | config.colsb[i] = 64; |
510 | config.rows[i] = 16; |
511 | }); |
512 | _tile_loadconfig(config.as_ptr()); |
513 | _tile_zero::<0>(); |
514 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
515 | _tile_loadd::<2>(&twos as *const u8, 64); |
516 | _tile_dpbsud::<0, 1, 2>(); |
517 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
518 | _tile_release(); |
519 | assert_eq!(res, [[-128_i32; 16]; 16]); |
520 | } |
521 | |
522 | #[simd_test(enable = "amx-int8" )] |
523 | unsafe fn test_tile_dpbusd() { |
524 | _init_amx(); |
525 | let ones = [1_u8; 1024]; |
526 | let twos = [-2_i8; 1024]; |
527 | let mut res = [[0_i32; 16]; 16]; |
528 | let mut config = __tilecfg::default(); |
529 | config.palette = 1; |
530 | (0..=2).for_each(|i| { |
531 | config.colsb[i] = 64; |
532 | config.rows[i] = 16; |
533 | }); |
534 | _tile_loadconfig(config.as_ptr()); |
535 | _tile_zero::<0>(); |
536 | _tile_loadd::<1>(&ones as *const u8, 64); |
537 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
538 | _tile_dpbusd::<0, 1, 2>(); |
539 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
540 | _tile_release(); |
541 | assert_eq!(res, [[-128_i32; 16]; 16]); |
542 | } |
543 | |
544 | #[simd_test(enable = "amx-int8" )] |
545 | unsafe fn test_tile_dpbuud() { |
546 | _init_amx(); |
547 | let ones = [1_u8; 1024]; |
548 | let twos = [2_u8; 1024]; |
549 | let mut res = [[0_i32; 16]; 16]; |
550 | let mut config = __tilecfg::default(); |
551 | config.palette = 1; |
552 | (0..=2).for_each(|i| { |
553 | config.colsb[i] = 64; |
554 | config.rows[i] = 16; |
555 | }); |
556 | _tile_loadconfig(config.as_ptr()); |
557 | _tile_zero::<0>(); |
558 | _tile_loadd::<1>(&ones as *const u8, 64); |
559 | _tile_loadd::<2>(&twos as *const u8, 64); |
560 | _tile_dpbuud::<0, 1, 2>(); |
561 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
562 | _tile_release(); |
563 | assert_eq!(res, [[128_i32; 16]; 16]); |
564 | } |
565 | |
566 | #[simd_test(enable = "amx-fp16" )] |
567 | unsafe fn test_tile_dpfp16ps() { |
568 | _init_amx(); |
569 | let ones = [1f16; 512]; |
570 | let twos = [2f16; 512]; |
571 | let mut res = [[0f32; 16]; 16]; |
572 | let mut config = __tilecfg::default(); |
573 | config.palette = 1; |
574 | (0..=2).for_each(|i| { |
575 | config.colsb[i] = 64; |
576 | config.rows[i] = 16; |
577 | }); |
578 | _tile_loadconfig(config.as_ptr()); |
579 | _tile_zero::<0>(); |
580 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
581 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
582 | _tile_dpfp16ps::<0, 1, 2>(); |
583 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
584 | _tile_release(); |
585 | assert_eq!(res, [[64f32; 16]; 16]); |
586 | } |
587 | |
588 | #[simd_test(enable = "amx-complex" )] |
589 | unsafe fn test_tile_cmmimfp16ps() { |
590 | _init_amx(); |
591 | let ones = [1f16; 512]; |
592 | let twos = [2f16; 512]; |
593 | let mut res = [[0f32; 16]; 16]; |
594 | let mut config = __tilecfg::default(); |
595 | config.palette = 1; |
596 | (0..=2).for_each(|i| { |
597 | config.colsb[i] = 64; |
598 | config.rows[i] = 16; |
599 | }); |
600 | _tile_loadconfig(config.as_ptr()); |
601 | _tile_zero::<0>(); |
602 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
603 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
604 | _tile_cmmimfp16ps::<0, 1, 2>(); |
605 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
606 | _tile_release(); |
607 | assert_eq!(res, [[64f32; 16]; 16]); |
608 | } |
609 | |
610 | #[simd_test(enable = "amx-complex" )] |
611 | unsafe fn test_tile_cmmrlfp16ps() { |
612 | _init_amx(); |
613 | let ones = [1f16; 512]; |
614 | let twos = [2f16; 512]; |
615 | let mut res = [[0f32; 16]; 16]; |
616 | let mut config = __tilecfg::default(); |
617 | config.palette = 1; |
618 | (0..=2).for_each(|i| { |
619 | config.colsb[i] = 64; |
620 | config.rows[i] = 16; |
621 | }); |
622 | _tile_loadconfig(config.as_ptr()); |
623 | _tile_zero::<0>(); |
624 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
625 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
626 | _tile_cmmrlfp16ps::<0, 1, 2>(); |
627 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
628 | _tile_release(); |
629 | assert_eq!(res, [[0f32; 16]; 16]); |
630 | } |
631 | } |
632 | |