| 1 | use crate::core_arch::{simd::*, x86::*}; |
| 2 | |
| 3 | #[cfg (test)] |
| 4 | use stdarch_test::assert_instr; |
| 5 | |
| 6 | /// Load tile configuration from a 64-byte memory location specified by mem_addr. |
| 7 | /// The tile configuration format is specified below, and includes the tile type pallette, |
| 8 | /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, |
| 9 | /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. |
| 10 | /// Any invalid configurations will result in #GP fault. |
| 11 | /// |
| 12 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) |
| 13 | #[inline ] |
| 14 | #[target_feature (enable = "amx-tile" )] |
| 15 | #[cfg_attr (test, assert_instr(ldtilecfg))] |
| 16 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 17 | pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { |
| 18 | ldtilecfg(mem_addr); |
| 19 | } |
| 20 | |
| 21 | /// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. |
| 22 | /// The tile configuration format is specified below, and includes the tile type pallette, |
| 23 | /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. |
| 24 | /// |
| 25 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) |
| 26 | #[inline ] |
| 27 | #[target_feature (enable = "amx-tile" )] |
| 28 | #[cfg_attr (test, assert_instr(sttilecfg))] |
| 29 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 30 | pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { |
| 31 | sttilecfg(mem_addr); |
| 32 | } |
| 33 | |
| 34 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. |
| 35 | /// |
| 36 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) |
| 37 | #[inline ] |
| 38 | #[rustc_legacy_const_generics (0)] |
| 39 | #[target_feature (enable = "amx-tile" )] |
| 40 | #[cfg_attr (test, assert_instr(tileloadd, DST = 0))] |
| 41 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 42 | pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) { |
| 43 | static_assert_uimm_bits!(DST, 3); |
| 44 | tileloadd64(DST as i8, base, stride); |
| 45 | } |
| 46 | |
| 47 | /// Release the tile configuration to return to the init state, which releases all storage it currently holds. |
| 48 | /// |
| 49 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) |
| 50 | #[inline ] |
| 51 | #[target_feature (enable = "amx-tile" )] |
| 52 | #[cfg_attr (test, assert_instr(tilerelease))] |
| 53 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 54 | pub unsafe fn _tile_release() { |
| 55 | tilerelease(); |
| 56 | } |
| 57 | |
| 58 | /// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. |
| 59 | /// |
| 60 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) |
| 61 | #[inline ] |
| 62 | #[rustc_legacy_const_generics (0)] |
| 63 | #[target_feature (enable = "amx-tile" )] |
| 64 | #[cfg_attr (test, assert_instr(tilestored, DST = 0))] |
| 65 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 66 | pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) { |
| 67 | static_assert_uimm_bits!(DST, 3); |
| 68 | tilestored64(DST as i8, base, stride); |
| 69 | } |
| 70 | |
| 71 | /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration |
| 72 | /// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will |
| 73 | /// likely not be reused in the near future and the data caching can be optimized accordingly. |
| 74 | /// |
| 75 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) |
| 76 | #[inline ] |
| 77 | #[rustc_legacy_const_generics (0)] |
| 78 | #[target_feature (enable = "amx-tile" )] |
| 79 | #[cfg_attr (test, assert_instr(tileloaddt1, DST = 0))] |
| 80 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 81 | pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) { |
| 82 | static_assert_uimm_bits!(DST, 3); |
| 83 | tileloaddt164(DST as i8, base, stride); |
| 84 | } |
| 85 | |
| 86 | /// Zero the tile specified by tdest. |
| 87 | /// |
| 88 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) |
| 89 | #[inline ] |
| 90 | #[rustc_legacy_const_generics (0)] |
| 91 | #[target_feature (enable = "amx-tile" )] |
| 92 | #[cfg_attr (test, assert_instr(tilezero, DST = 0))] |
| 93 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 94 | pub unsafe fn _tile_zero<const DST: i32>() { |
| 95 | static_assert_uimm_bits!(DST, 3); |
| 96 | tilezero(DST as i8); |
| 97 | } |
| 98 | |
| 99 | /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, |
| 100 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
| 101 | /// with elements in dst, and store the 32-bit result back to tile dst. |
| 102 | /// |
| 103 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864) |
| 104 | #[inline ] |
| 105 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 106 | #[target_feature (enable = "amx-bf16" )] |
| 107 | #[cfg_attr (test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))] |
| 108 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 109 | pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() { |
| 110 | static_assert_uimm_bits!(DST, 3); |
| 111 | static_assert_uimm_bits!(A, 3); |
| 112 | static_assert_uimm_bits!(B, 3); |
| 113 | tdpbf16ps(DST as i8, A as i8, B as i8); |
| 114 | } |
| 115 | |
| 116 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
| 117 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
| 118 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
| 119 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
| 120 | /// |
| 121 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866) |
| 122 | #[inline ] |
| 123 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 124 | #[target_feature (enable = "amx-int8" )] |
| 125 | #[cfg_attr (test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))] |
| 126 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 127 | pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() { |
| 128 | static_assert_uimm_bits!(DST, 3); |
| 129 | static_assert_uimm_bits!(A, 3); |
| 130 | static_assert_uimm_bits!(B, 3); |
| 131 | tdpbssd(DST as i8, A as i8, B as i8); |
| 132 | } |
| 133 | |
| 134 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
| 135 | /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding |
| 136 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
| 137 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
| 138 | /// |
| 139 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868) |
| 140 | #[inline ] |
| 141 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 142 | #[target_feature (enable = "amx-int8" )] |
| 143 | #[cfg_attr (test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))] |
| 144 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 145 | pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() { |
| 146 | static_assert_uimm_bits!(DST, 3); |
| 147 | static_assert_uimm_bits!(A, 3); |
| 148 | static_assert_uimm_bits!(B, 3); |
| 149 | tdpbsud(DST as i8, A as i8, B as i8); |
| 150 | } |
| 151 | |
| 152 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
| 153 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
| 154 | /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. |
| 155 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
| 156 | /// |
| 157 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870) |
| 158 | #[inline ] |
| 159 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 160 | #[target_feature (enable = "amx-int8" )] |
| 161 | #[cfg_attr (test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))] |
| 162 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 163 | pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() { |
| 164 | static_assert_uimm_bits!(DST, 3); |
| 165 | static_assert_uimm_bits!(A, 3); |
| 166 | static_assert_uimm_bits!(B, 3); |
| 167 | tdpbusd(DST as i8, A as i8, B as i8); |
| 168 | } |
| 169 | |
| 170 | /// Compute dot-product of bytes in tiles with a source/destination accumulator. |
| 171 | /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding |
| 172 | /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. |
| 173 | /// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. |
| 174 | /// |
| 175 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872) |
| 176 | #[inline ] |
| 177 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 178 | #[target_feature (enable = "amx-int8" )] |
| 179 | #[cfg_attr (test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))] |
| 180 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 181 | pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() { |
| 182 | static_assert_uimm_bits!(DST, 3); |
| 183 | static_assert_uimm_bits!(A, 3); |
| 184 | static_assert_uimm_bits!(B, 3); |
| 185 | tdpbuud(DST as i8, A as i8, B as i8); |
| 186 | } |
| 187 | |
| 188 | /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, |
| 189 | /// accumulating the intermediate single-precision (32-bit) floating-point elements |
| 190 | /// with elements in dst, and store the 32-bit result back to tile dst. |
| 191 | /// |
| 192 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874) |
| 193 | #[inline ] |
| 194 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 195 | #[target_feature (enable = "amx-fp16" )] |
| 196 | #[cfg_attr (test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))] |
| 197 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 198 | pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() { |
| 199 | static_assert_uimm_bits!(DST, 3); |
| 200 | static_assert_uimm_bits!(A, 3); |
| 201 | static_assert_uimm_bits!(B, 3); |
| 202 | tdpfp16ps(DST as i8, A as i8, B as i8); |
| 203 | } |
| 204 | |
| 205 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
| 206 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
| 207 | /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), |
| 208 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
| 209 | /// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of |
| 210 | /// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, |
| 211 | /// and then accumulated into the corresponding row and column of dst. |
| 212 | /// |
| 213 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860) |
| 214 | #[inline ] |
| 215 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 216 | #[target_feature (enable = "amx-complex" )] |
| 217 | #[cfg_attr (test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))] |
| 218 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 219 | pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() { |
| 220 | static_assert_uimm_bits!(DST, 3); |
| 221 | static_assert_uimm_bits!(A, 3); |
| 222 | static_assert_uimm_bits!(B, 3); |
| 223 | tcmmimfp16ps(DST as i8, A as i8, B as i8); |
| 224 | } |
| 225 | |
| 226 | /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. |
| 227 | /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. |
| 228 | /// Calculates the real part of the result. For each possible combination of (row of a, column of b), |
| 229 | /// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). |
| 230 | /// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of |
| 231 | /// the a element is multiplied with the imaginary part of the corresponding b elements. |
| 232 | /// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. |
| 233 | /// |
| 234 | /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862) |
| 235 | #[inline ] |
| 236 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 237 | #[target_feature (enable = "amx-complex" )] |
| 238 | #[cfg_attr (test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))] |
| 239 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 240 | pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() { |
| 241 | static_assert_uimm_bits!(DST, 3); |
| 242 | static_assert_uimm_bits!(A, 3); |
| 243 | static_assert_uimm_bits!(B, 3); |
| 244 | tcmmrlfp16ps(DST as i8, A as i8, B as i8); |
| 245 | } |
| 246 | |
| 247 | /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) |
| 248 | /// floating-point elements in tile b, accumulating the intermediate single-precision |
| 249 | /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 250 | /// back to tile dst. |
| 251 | #[inline ] |
| 252 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 253 | #[target_feature (enable = "amx-fp8" )] |
| 254 | #[cfg_attr ( |
| 255 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 256 | assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2) |
| 257 | )] |
| 258 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 259 | pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 260 | static_assert_uimm_bits!(DST, 3); |
| 261 | static_assert_uimm_bits!(A, 3); |
| 262 | static_assert_uimm_bits!(B, 3); |
| 263 | tdpbf8ps(DST as i8, A as i8, B as i8); |
| 264 | } |
| 265 | |
| 266 | /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 |
| 267 | /// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision |
| 268 | /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 269 | /// back to tile dst. |
| 270 | #[inline ] |
| 271 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 272 | #[target_feature (enable = "amx-fp8" )] |
| 273 | #[cfg_attr ( |
| 274 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 275 | assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2) |
| 276 | )] |
| 277 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 278 | pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 279 | static_assert_uimm_bits!(DST, 3); |
| 280 | static_assert_uimm_bits!(A, 3); |
| 281 | static_assert_uimm_bits!(B, 3); |
| 282 | tdpbhf8ps(DST as i8, A as i8, B as i8); |
| 283 | } |
| 284 | |
| 285 | /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 |
| 286 | /// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision |
| 287 | /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 288 | /// back to tile dst. |
| 289 | #[inline ] |
| 290 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 291 | #[target_feature (enable = "amx-fp8" )] |
| 292 | #[cfg_attr ( |
| 293 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 294 | assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2) |
| 295 | )] |
| 296 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 297 | pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 298 | static_assert_uimm_bits!(DST, 3); |
| 299 | static_assert_uimm_bits!(A, 3); |
| 300 | static_assert_uimm_bits!(B, 3); |
| 301 | tdphbf8ps(DST as i8, A as i8, B as i8); |
| 302 | } |
| 303 | |
| 304 | /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) |
| 305 | /// floating-point elements in tile b, accumulating the intermediate single-precision |
| 306 | /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 307 | /// back to tile dst. |
| 308 | #[inline ] |
| 309 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 310 | #[target_feature (enable = "amx-fp8" )] |
| 311 | #[cfg_attr ( |
| 312 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 313 | assert_instr(tdphf8ps, DST = 0, A = 1, B = 2) |
| 314 | )] |
| 315 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 316 | pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 317 | static_assert_uimm_bits!(DST, 3); |
| 318 | static_assert_uimm_bits!(A, 3); |
| 319 | static_assert_uimm_bits!(B, 3); |
| 320 | tdphf8ps(DST as i8, A as i8, B as i8); |
| 321 | } |
| 322 | |
| 323 | /// Load tile rows from memory specified by base address and stride into destination tile dst |
| 324 | /// using the tile configuration previously configured via _tile_loadconfig. |
| 325 | /// Additionally, this intrinsic indicates the source memory location is likely to become |
| 326 | /// read-shared by multiple processors, i.e., read in the future by at least one other processor |
| 327 | /// before it is written, assuming it is ever written in the future. |
| 328 | #[inline ] |
| 329 | #[rustc_legacy_const_generics (0)] |
| 330 | #[target_feature (enable = "amx-movrs" )] |
| 331 | #[cfg_attr ( |
| 332 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 333 | assert_instr(tileloaddrs, DST = 0) |
| 334 | )] |
| 335 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 336 | pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) { |
| 337 | static_assert_uimm_bits!(DST, 3); |
| 338 | tileloaddrs64(DST as i8, base, stride); |
| 339 | } |
| 340 | |
| 341 | /// Load tile rows from memory specified by base address and stride into destination tile dst |
| 342 | /// using the tile configuration previously configured via _tile_loadconfig. |
| 343 | /// Provides a hint to the implementation that the data would be reused but does not need |
| 344 | /// to be resident in the nearest cache levels. |
| 345 | /// Additionally, this intrinsic indicates the source memory location is likely to become |
| 346 | /// read-shared by multiple processors, i.e., read in the future by at least one other processor |
| 347 | /// before it is written, assuming it is ever written in the future. |
| 348 | #[inline ] |
| 349 | #[rustc_legacy_const_generics (0)] |
| 350 | #[target_feature (enable = "amx-movrs" )] |
| 351 | #[cfg_attr ( |
| 352 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 353 | assert_instr(tileloaddrst1, DST = 0) |
| 354 | )] |
| 355 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 356 | pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) { |
| 357 | static_assert_uimm_bits!(DST, 3); |
| 358 | tileloaddrst164(DST as i8, base, stride); |
| 359 | } |
| 360 | |
| 361 | /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) |
| 362 | /// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the |
| 363 | /// results into a packed single precision tile. |
| 364 | /// For each possible combination of (row of a, column of b), it performs |
| 365 | /// - convert to TF32 |
| 366 | /// - multiply the corresponding elements of a and b |
| 367 | /// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even |
| 368 | /// rounding mode. |
| 369 | /// Output FP32 denormals are always flushed to zero, input single precision denormals are always |
| 370 | /// handled and *not* treated as zero. |
| 371 | #[inline ] |
| 372 | #[rustc_legacy_const_generics (0, 1, 2)] |
| 373 | #[target_feature (enable = "amx-tf32" )] |
| 374 | #[cfg_attr ( |
| 375 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 376 | assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2) |
| 377 | )] |
| 378 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 379 | pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() { |
| 380 | static_assert_uimm_bits!(DST, 3); |
| 381 | static_assert_uimm_bits!(A, 3); |
| 382 | static_assert_uimm_bits!(B, 3); |
| 383 | tmmultf32ps(DST as i8, A as i8, B as i8); |
| 384 | } |
| 385 | |
| 386 | /// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer |
| 387 | /// elements to packed single-precision (32-bit) floating-point elements. |
| 388 | #[inline ] |
| 389 | #[rustc_legacy_const_generics (0)] |
| 390 | #[target_feature (enable = "amx-avx512,avx10.2" )] |
| 391 | #[cfg_attr ( |
| 392 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 393 | assert_instr(tcvtrowd2ps, TILE = 0) |
| 394 | )] |
| 395 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 396 | pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 { |
| 397 | static_assert_uimm_bits!(TILE, 3); |
| 398 | tcvtrowd2ps(TILE as i8, row).as_m512() |
| 399 | } |
| 400 | |
| 401 | /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) |
| 402 | /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting |
| 403 | /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. |
| 404 | #[inline ] |
| 405 | #[rustc_legacy_const_generics (0)] |
| 406 | #[target_feature (enable = "amx-avx512,avx10.2" )] |
| 407 | #[cfg_attr ( |
| 408 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 409 | assert_instr(tcvtrowps2phh, TILE = 0) |
| 410 | )] |
| 411 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 412 | pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h { |
| 413 | static_assert_uimm_bits!(TILE, 3); |
| 414 | tcvtrowps2phh(TILE as i8, row).as_m512h() |
| 415 | } |
| 416 | |
| 417 | /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) |
| 418 | /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting |
| 419 | /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. |
| 420 | #[inline ] |
| 421 | #[rustc_legacy_const_generics (0)] |
| 422 | #[target_feature (enable = "amx-avx512,avx10.2" )] |
| 423 | #[cfg_attr ( |
| 424 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 425 | assert_instr(tcvtrowps2phl, TILE = 0) |
| 426 | )] |
| 427 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 428 | pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h { |
| 429 | static_assert_uimm_bits!(TILE, 3); |
| 430 | tcvtrowps2phl(TILE as i8, row).as_m512h() |
| 431 | } |
| 432 | |
| 433 | /// Moves one row of tile data into a zmm vector register |
| 434 | #[inline ] |
| 435 | #[rustc_legacy_const_generics (0)] |
| 436 | #[target_feature (enable = "amx-avx512,avx10.2" )] |
| 437 | #[cfg_attr ( |
| 438 | all(test, any(target_os = "linux" , target_env = "msvc" )), |
| 439 | assert_instr(tilemovrow, TILE = 0) |
| 440 | )] |
| 441 | #[unstable (feature = "x86_amx_intrinsics" , issue = "126622" )] |
| 442 | pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i { |
| 443 | static_assert_uimm_bits!(TILE, 3); |
| 444 | tilemovrow(TILE as i8, row).as_m512i() |
| 445 | } |
| 446 | |
| 447 | #[allow (improper_ctypes)] |
| 448 | unsafe extern "C" { |
| 449 | #[link_name = "llvm.x86.ldtilecfg" ] |
| 450 | unsafefn ldtilecfg(mem_addr: *const u8); |
| 451 | #[link_name = "llvm.x86.sttilecfg" ] |
| 452 | unsafefn sttilecfg(mem_addr: *mut u8); |
| 453 | #[link_name = "llvm.x86.tileloadd64" ] |
| 454 | unsafefn tileloadd64(dst: i8, base: *const u8, stride: usize); |
| 455 | #[link_name = "llvm.x86.tileloaddt164" ] |
| 456 | unsafefn tileloaddt164(dst: i8, base: *const u8, stride: usize); |
| 457 | #[link_name = "llvm.x86.tilerelease" ] |
| 458 | unsafefn tilerelease(); |
| 459 | #[link_name = "llvm.x86.tilestored64" ] |
| 460 | unsafefn tilestored64(dst: i8, base: *mut u8, stride: usize); |
| 461 | #[link_name = "llvm.x86.tilezero" ] |
| 462 | unsafefn tilezero(dst: i8); |
| 463 | #[link_name = "llvm.x86.tdpbf16ps" ] |
| 464 | unsafefn tdpbf16ps(dst: i8, a: i8, b: i8); |
| 465 | #[link_name = "llvm.x86.tdpbuud" ] |
| 466 | unsafefn tdpbuud(dst: i8, a: i8, b: i8); |
| 467 | #[link_name = "llvm.x86.tdpbusd" ] |
| 468 | unsafefn tdpbusd(dst: i8, a: i8, b: i8); |
| 469 | #[link_name = "llvm.x86.tdpbsud" ] |
| 470 | unsafefn tdpbsud(dst: i8, a: i8, b: i8); |
| 471 | #[link_name = "llvm.x86.tdpbssd" ] |
| 472 | unsafefn tdpbssd(dst: i8, a: i8, b: i8); |
| 473 | #[link_name = "llvm.x86.tdpfp16ps" ] |
| 474 | unsafefn tdpfp16ps(dst: i8, a: i8, b: i8); |
| 475 | #[link_name = "llvm.x86.tcmmimfp16ps" ] |
| 476 | unsafefn tcmmimfp16ps(dst: i8, a: i8, b: i8); |
| 477 | #[link_name = "llvm.x86.tcmmrlfp16ps" ] |
| 478 | unsafefn tcmmrlfp16ps(dst: i8, a: i8, b: i8); |
| 479 | #[link_name = "llvm.x86.tdpbf8ps" ] |
| 480 | unsafefn tdpbf8ps(dst: i8, a: i8, b: i8); |
| 481 | #[link_name = "llvm.x86.tdpbhf8ps" ] |
| 482 | unsafefn tdpbhf8ps(dst: i8, a: i8, b: i8); |
| 483 | #[link_name = "llvm.x86.tdphbf8ps" ] |
| 484 | unsafefn tdphbf8ps(dst: i8, a: i8, b: i8); |
| 485 | #[link_name = "llvm.x86.tdphf8ps" ] |
| 486 | unsafefn tdphf8ps(dst: i8, a: i8, b: i8); |
| 487 | #[link_name = "llvm.x86.tileloaddrs64" ] |
| 488 | unsafefn tileloaddrs64(dst: i8, base: *const u8, stride: usize); |
| 489 | #[link_name = "llvm.x86.tileloaddrst164" ] |
| 490 | unsafefn tileloaddrst164(dst: i8, base: *const u8, stride: usize); |
| 491 | #[link_name = "llvm.x86.tmmultf32ps" ] |
| 492 | unsafefn tmmultf32ps(dst: i8, a: i8, b: i8); |
| 493 | #[link_name = "llvm.x86.tcvtrowd2ps" ] |
| 494 | unsafefn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; |
| 495 | #[link_name = "llvm.x86.tcvtrowps2phh" ] |
| 496 | unsafefn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; |
| 497 | #[link_name = "llvm.x86.tcvtrowps2phl" ] |
| 498 | unsafefn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; |
| 499 | #[link_name = "llvm.x86.tilemovrow" ] |
| 500 | unsafefn tilemovrow(tile: i8, row: u32) -> i32x16; |
| 501 | } |
| 502 | |
| 503 | #[cfg (test)] |
| 504 | mod tests { |
| 505 | use crate::core_arch::x86::_mm_cvtness_sbh; |
| 506 | use crate::core_arch::x86_64::*; |
| 507 | use core::{array, mem::transmute}; |
| 508 | use stdarch_test::simd_test; |
| 509 | #[cfg (target_os = "linux" )] |
| 510 | use syscalls::{Sysno, syscall}; |
| 511 | |
| 512 | #[allow (non_camel_case_types)] |
| 513 | #[repr (C, packed)] |
| 514 | #[derive (Copy, Clone, Default, Debug, PartialEq)] |
| 515 | struct __tilecfg { |
| 516 | /// 0 `or` 1 |
| 517 | palette: u8, |
| 518 | start_row: u8, |
| 519 | /// reserved, must be zero |
| 520 | reserved_a0: [u8; 14], |
| 521 | /// number of bytes of one row in each tile |
| 522 | colsb: [u16; 8], |
| 523 | /// reserved, must be zero |
| 524 | reserved_b0: [u16; 8], |
| 525 | /// number of rows in each tile |
| 526 | rows: [u8; 8], |
| 527 | /// reserved, must be zero |
| 528 | reserved_c0: [u8; 8], |
| 529 | } |
| 530 | |
| 531 | impl __tilecfg { |
| 532 | fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { |
| 533 | Self { |
| 534 | palette, |
| 535 | start_row, |
| 536 | reserved_a0: [0u8; 14], |
| 537 | colsb, |
| 538 | reserved_b0: [0u16; 8], |
| 539 | rows, |
| 540 | reserved_c0: [0u8; 8], |
| 541 | } |
| 542 | } |
| 543 | |
| 544 | const fn as_ptr(&self) -> *const u8 { |
| 545 | self as *const Self as *const u8 |
| 546 | } |
| 547 | |
| 548 | fn as_mut_ptr(&mut self) -> *mut u8 { |
| 549 | self as *mut Self as *mut u8 |
| 550 | } |
| 551 | } |
| 552 | |
| 553 | #[cfg (not(target_os = "linux" ))] |
| 554 | #[target_feature (enable = "amx-tile" )] |
| 555 | fn _init_amx() {} |
| 556 | |
| 557 | #[cfg (target_os = "linux" )] |
| 558 | #[target_feature (enable = "amx-tile" )] |
| 559 | #[inline ] |
| 560 | unsafe fn _init_amx() { |
| 561 | let mut ret: usize; |
| 562 | let mut xfeatures: usize = 0; |
| 563 | ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) |
| 564 | .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed" ); |
| 565 | if ret != 0 { |
| 566 | panic!("Failed to get XFEATURES" ); |
| 567 | } else { |
| 568 | match 0b11 & (xfeatures >> 17) { |
| 569 | 0 => panic!("AMX is not available" ), |
| 570 | 1 => { |
| 571 | ret = syscall!(Sysno::arch_prctl, 0x1023, 18) |
| 572 | .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed" ); |
| 573 | if ret != 0 { |
| 574 | panic!("Failed to enable AMX" ); |
| 575 | } |
| 576 | } |
| 577 | 3 => {} |
| 578 | _ => unreachable!(), |
| 579 | } |
| 580 | } |
| 581 | } |
| 582 | |
| 583 | #[simd_test(enable = "amx-tile" )] |
| 584 | fn test_tile_loadconfig() { |
| 585 | unsafe { |
| 586 | let config = __tilecfg::default(); |
| 587 | _tile_loadconfig(config.as_ptr()); |
| 588 | _tile_release(); |
| 589 | } |
| 590 | } |
| 591 | |
| 592 | #[simd_test(enable = "amx-tile" )] |
| 593 | fn test_tile_storeconfig() { |
| 594 | unsafe { |
| 595 | let config = __tilecfg::new(1, 0, [32; 8], [8; 8]); |
| 596 | _tile_loadconfig(config.as_ptr()); |
| 597 | let mut _config = __tilecfg::default(); |
| 598 | _tile_storeconfig(_config.as_mut_ptr()); |
| 599 | _tile_release(); |
| 600 | assert_eq!(config, _config); |
| 601 | } |
| 602 | } |
| 603 | |
| 604 | #[simd_test(enable = "amx-tile" )] |
| 605 | fn test_tile_zero() { |
| 606 | unsafe { |
| 607 | _init_amx(); |
| 608 | let mut config = __tilecfg::default(); |
| 609 | config.palette = 1; |
| 610 | config.colsb[0] = 64; |
| 611 | config.rows[0] = 16; |
| 612 | _tile_loadconfig(config.as_ptr()); |
| 613 | _tile_zero::<0>(); |
| 614 | let mut out = [[1_i8; 64]; 16]; |
| 615 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 616 | _tile_release(); |
| 617 | assert_eq!(out, [[0; 64]; 16]); |
| 618 | } |
| 619 | } |
| 620 | |
| 621 | #[simd_test(enable = "amx-tile" )] |
| 622 | fn test_tile_stored() { |
| 623 | unsafe { |
| 624 | _init_amx(); |
| 625 | let mut config = __tilecfg::default(); |
| 626 | config.palette = 1; |
| 627 | config.colsb[0] = 64; |
| 628 | config.rows[0] = 16; |
| 629 | _tile_loadconfig(config.as_ptr()); |
| 630 | _tile_zero::<0>(); |
| 631 | let mut out = [[1_i8; 64]; 16]; |
| 632 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 633 | _tile_release(); |
| 634 | assert_eq!(out, [[0; 64]; 16]); |
| 635 | } |
| 636 | } |
| 637 | |
| 638 | #[simd_test(enable = "amx-tile" )] |
| 639 | fn test_tile_loadd() { |
| 640 | unsafe { |
| 641 | _init_amx(); |
| 642 | let mut config = __tilecfg::default(); |
| 643 | config.palette = 1; |
| 644 | config.colsb[0] = 64; |
| 645 | config.rows[0] = 16; |
| 646 | _tile_loadconfig(config.as_ptr()); |
| 647 | _tile_zero::<0>(); |
| 648 | let mat = [1_i8; 1024]; |
| 649 | _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); |
| 650 | let mut out = [[0_i8; 64]; 16]; |
| 651 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 652 | _tile_release(); |
| 653 | assert_eq!(out, [[1; 64]; 16]); |
| 654 | } |
| 655 | } |
| 656 | |
| 657 | #[simd_test(enable = "amx-tile" )] |
| 658 | fn test_tile_stream_loadd() { |
| 659 | unsafe { |
| 660 | _init_amx(); |
| 661 | let mut config = __tilecfg::default(); |
| 662 | config.palette = 1; |
| 663 | config.colsb[0] = 64; |
| 664 | config.rows[0] = 16; |
| 665 | _tile_loadconfig(config.as_ptr()); |
| 666 | _tile_zero::<0>(); |
| 667 | let mat = [1_i8; 1024]; |
| 668 | _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); |
| 669 | let mut out = [[0_i8; 64]; 16]; |
| 670 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 671 | _tile_release(); |
| 672 | assert_eq!(out, [[1; 64]; 16]); |
| 673 | } |
| 674 | } |
| 675 | |
| 676 | #[simd_test(enable = "amx-tile" )] |
| 677 | fn test_tile_release() { |
| 678 | unsafe { |
| 679 | _tile_release(); |
| 680 | } |
| 681 | } |
| 682 | |
| 683 | #[simd_test(enable = "amx-bf16,avx512f" )] |
| 684 | fn test_tile_dpbf16ps() { |
| 685 | unsafe { |
| 686 | _init_amx(); |
| 687 | let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); |
| 688 | let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); |
| 689 | let ones: [u8; 1024] = transmute([bf16_1; 512]); |
| 690 | let twos: [u8; 1024] = transmute([bf16_2; 512]); |
| 691 | let mut res = [[0f32; 16]; 16]; |
| 692 | let mut config = __tilecfg::default(); |
| 693 | config.palette = 1; |
| 694 | (0..=2).for_each(|i| { |
| 695 | config.colsb[i] = 64; |
| 696 | config.rows[i] = 16; |
| 697 | }); |
| 698 | _tile_loadconfig(config.as_ptr()); |
| 699 | _tile_zero::<0>(); |
| 700 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 701 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 702 | _tile_dpbf16ps::<0, 1, 2>(); |
| 703 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
| 704 | _tile_release(); |
| 705 | assert_eq!(res, [[64f32; 16]; 16]); |
| 706 | } |
| 707 | } |
| 708 | |
| 709 | #[simd_test(enable = "amx-int8" )] |
| 710 | fn test_tile_dpbssd() { |
| 711 | unsafe { |
| 712 | _init_amx(); |
| 713 | let ones = [-1_i8; 1024]; |
| 714 | let twos = [-2_i8; 1024]; |
| 715 | let mut res = [[0_i32; 16]; 16]; |
| 716 | let mut config = __tilecfg::default(); |
| 717 | config.palette = 1; |
| 718 | (0..=2).for_each(|i| { |
| 719 | config.colsb[i] = 64; |
| 720 | config.rows[i] = 16; |
| 721 | }); |
| 722 | _tile_loadconfig(config.as_ptr()); |
| 723 | _tile_zero::<0>(); |
| 724 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
| 725 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
| 726 | _tile_dpbssd::<0, 1, 2>(); |
| 727 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
| 728 | _tile_release(); |
| 729 | assert_eq!(res, [[128_i32; 16]; 16]); |
| 730 | } |
| 731 | } |
| 732 | |
| 733 | #[simd_test(enable = "amx-int8" )] |
| 734 | fn test_tile_dpbsud() { |
| 735 | unsafe { |
| 736 | _init_amx(); |
| 737 | let ones = [-1_i8; 1024]; |
| 738 | let twos = [2_u8; 1024]; |
| 739 | let mut res = [[0_i32; 16]; 16]; |
| 740 | let mut config = __tilecfg::default(); |
| 741 | config.palette = 1; |
| 742 | (0..=2).for_each(|i| { |
| 743 | config.colsb[i] = 64; |
| 744 | config.rows[i] = 16; |
| 745 | }); |
| 746 | _tile_loadconfig(config.as_ptr()); |
| 747 | _tile_zero::<0>(); |
| 748 | _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); |
| 749 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 750 | _tile_dpbsud::<0, 1, 2>(); |
| 751 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
| 752 | _tile_release(); |
| 753 | assert_eq!(res, [[-128_i32; 16]; 16]); |
| 754 | } |
| 755 | } |
| 756 | |
| 757 | #[simd_test(enable = "amx-int8" )] |
| 758 | fn test_tile_dpbusd() { |
| 759 | unsafe { |
| 760 | _init_amx(); |
| 761 | let ones = [1_u8; 1024]; |
| 762 | let twos = [-2_i8; 1024]; |
| 763 | let mut res = [[0_i32; 16]; 16]; |
| 764 | let mut config = __tilecfg::default(); |
| 765 | config.palette = 1; |
| 766 | (0..=2).for_each(|i| { |
| 767 | config.colsb[i] = 64; |
| 768 | config.rows[i] = 16; |
| 769 | }); |
| 770 | _tile_loadconfig(config.as_ptr()); |
| 771 | _tile_zero::<0>(); |
| 772 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 773 | _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); |
| 774 | _tile_dpbusd::<0, 1, 2>(); |
| 775 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
| 776 | _tile_release(); |
| 777 | assert_eq!(res, [[-128_i32; 16]; 16]); |
| 778 | } |
| 779 | } |
| 780 | |
| 781 | #[simd_test(enable = "amx-int8" )] |
| 782 | fn test_tile_dpbuud() { |
| 783 | unsafe { |
| 784 | _init_amx(); |
| 785 | let ones = [1_u8; 1024]; |
| 786 | let twos = [2_u8; 1024]; |
| 787 | let mut res = [[0_i32; 16]; 16]; |
| 788 | let mut config = __tilecfg::default(); |
| 789 | config.palette = 1; |
| 790 | (0..=2).for_each(|i| { |
| 791 | config.colsb[i] = 64; |
| 792 | config.rows[i] = 16; |
| 793 | }); |
| 794 | _tile_loadconfig(config.as_ptr()); |
| 795 | _tile_zero::<0>(); |
| 796 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 797 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 798 | _tile_dpbuud::<0, 1, 2>(); |
| 799 | _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); |
| 800 | _tile_release(); |
| 801 | assert_eq!(res, [[128_i32; 16]; 16]); |
| 802 | } |
| 803 | } |
| 804 | |
| 805 | #[simd_test(enable = "amx-fp16" )] |
| 806 | fn test_tile_dpfp16ps() { |
| 807 | unsafe { |
| 808 | _init_amx(); |
| 809 | let ones = [1f16; 512]; |
| 810 | let twos = [2f16; 512]; |
| 811 | let mut res = [[0f32; 16]; 16]; |
| 812 | let mut config = __tilecfg::default(); |
| 813 | config.palette = 1; |
| 814 | (0..=2).for_each(|i| { |
| 815 | config.colsb[i] = 64; |
| 816 | config.rows[i] = 16; |
| 817 | }); |
| 818 | _tile_loadconfig(config.as_ptr()); |
| 819 | _tile_zero::<0>(); |
| 820 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
| 821 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
| 822 | _tile_dpfp16ps::<0, 1, 2>(); |
| 823 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
| 824 | _tile_release(); |
| 825 | assert_eq!(res, [[64f32; 16]; 16]); |
| 826 | } |
| 827 | } |
| 828 | |
| 829 | #[simd_test(enable = "amx-complex" )] |
| 830 | fn test_tile_cmmimfp16ps() { |
| 831 | unsafe { |
| 832 | _init_amx(); |
| 833 | let ones = [1f16; 512]; |
| 834 | let twos = [2f16; 512]; |
| 835 | let mut res = [[0f32; 16]; 16]; |
| 836 | let mut config = __tilecfg::default(); |
| 837 | config.palette = 1; |
| 838 | (0..=2).for_each(|i| { |
| 839 | config.colsb[i] = 64; |
| 840 | config.rows[i] = 16; |
| 841 | }); |
| 842 | _tile_loadconfig(config.as_ptr()); |
| 843 | _tile_zero::<0>(); |
| 844 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
| 845 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
| 846 | _tile_cmmimfp16ps::<0, 1, 2>(); |
| 847 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
| 848 | _tile_release(); |
| 849 | assert_eq!(res, [[64f32; 16]; 16]); |
| 850 | } |
| 851 | } |
| 852 | |
| 853 | #[simd_test(enable = "amx-complex" )] |
| 854 | fn test_tile_cmmrlfp16ps() { |
| 855 | unsafe { |
| 856 | _init_amx(); |
| 857 | let ones = [1f16; 512]; |
| 858 | let twos = [2f16; 512]; |
| 859 | let mut res = [[0f32; 16]; 16]; |
| 860 | let mut config = __tilecfg::default(); |
| 861 | config.palette = 1; |
| 862 | (0..=2).for_each(|i| { |
| 863 | config.colsb[i] = 64; |
| 864 | config.rows[i] = 16; |
| 865 | }); |
| 866 | _tile_loadconfig(config.as_ptr()); |
| 867 | _tile_zero::<0>(); |
| 868 | _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); |
| 869 | _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); |
| 870 | _tile_cmmrlfp16ps::<0, 1, 2>(); |
| 871 | _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); |
| 872 | _tile_release(); |
| 873 | assert_eq!(res, [[0f32; 16]; 16]); |
| 874 | } |
| 875 | } |
| 876 | |
| 877 | const BF8_ONE: u8 = 0x3c; |
| 878 | const BF8_TWO: u8 = 0x40; |
| 879 | const HF8_ONE: u8 = 0x38; |
| 880 | const HF8_TWO: u8 = 0x40; |
| 881 | |
| 882 | #[simd_test(enable = "amx-fp8" )] |
| 883 | fn test_tile_dpbf8ps() { |
| 884 | unsafe { |
| 885 | _init_amx(); |
| 886 | let ones = [BF8_ONE; 1024]; |
| 887 | let twos = [BF8_TWO; 1024]; |
| 888 | let mut res = [[0.0_f32; 16]; 16]; |
| 889 | let mut config = __tilecfg::default(); |
| 890 | config.palette = 1; |
| 891 | (0..=2).for_each(|i| { |
| 892 | config.colsb[i] = 64; |
| 893 | config.rows[i] = 16; |
| 894 | }); |
| 895 | _tile_loadconfig(config.as_ptr()); |
| 896 | _tile_zero::<0>(); |
| 897 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 898 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 899 | _tile_dpbf8ps::<0, 1, 2>(); |
| 900 | _tile_stored::<0>(res.as_mut_ptr().cast(), 64); |
| 901 | _tile_release(); |
| 902 | assert_eq!(res, [[128.0_f32; 16]; 16]); |
| 903 | } |
| 904 | } |
| 905 | |
| 906 | #[simd_test(enable = "amx-fp8" )] |
| 907 | fn test_tile_dpbhf8ps() { |
| 908 | unsafe { |
| 909 | _init_amx(); |
| 910 | let ones = [BF8_ONE; 1024]; |
| 911 | let twos = [HF8_TWO; 1024]; |
| 912 | let mut res = [[0.0_f32; 16]; 16]; |
| 913 | let mut config = __tilecfg::default(); |
| 914 | config.palette = 1; |
| 915 | (0..=2).for_each(|i| { |
| 916 | config.colsb[i] = 64; |
| 917 | config.rows[i] = 16; |
| 918 | }); |
| 919 | _tile_loadconfig(config.as_ptr()); |
| 920 | _tile_zero::<0>(); |
| 921 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 922 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 923 | _tile_dpbhf8ps::<0, 1, 2>(); |
| 924 | _tile_stored::<0>(res.as_mut_ptr().cast(), 64); |
| 925 | _tile_release(); |
| 926 | assert_eq!(res, [[128.0_f32; 16]; 16]); |
| 927 | } |
| 928 | } |
| 929 | |
| 930 | #[simd_test(enable = "amx-fp8" )] |
| 931 | fn test_tile_dphbf8ps() { |
| 932 | unsafe { |
| 933 | _init_amx(); |
| 934 | let ones = [HF8_ONE; 1024]; |
| 935 | let twos = [BF8_TWO; 1024]; |
| 936 | let mut res = [[0.0_f32; 16]; 16]; |
| 937 | let mut config = __tilecfg::default(); |
| 938 | config.palette = 1; |
| 939 | (0..=2).for_each(|i| { |
| 940 | config.colsb[i] = 64; |
| 941 | config.rows[i] = 16; |
| 942 | }); |
| 943 | _tile_loadconfig(config.as_ptr()); |
| 944 | _tile_zero::<0>(); |
| 945 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 946 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 947 | _tile_dphbf8ps::<0, 1, 2>(); |
| 948 | _tile_stored::<0>(res.as_mut_ptr().cast(), 64); |
| 949 | _tile_release(); |
| 950 | assert_eq!(res, [[128.0_f32; 16]; 16]); |
| 951 | } |
| 952 | } |
| 953 | |
| 954 | #[simd_test(enable = "amx-fp8" )] |
| 955 | fn test_tile_dphf8ps() { |
| 956 | unsafe { |
| 957 | _init_amx(); |
| 958 | let ones = [HF8_ONE; 1024]; |
| 959 | let twos = [HF8_TWO; 1024]; |
| 960 | let mut res = [[0.0_f32; 16]; 16]; |
| 961 | let mut config = __tilecfg::default(); |
| 962 | config.palette = 1; |
| 963 | (0..=2).for_each(|i| { |
| 964 | config.colsb[i] = 64; |
| 965 | config.rows[i] = 16; |
| 966 | }); |
| 967 | _tile_loadconfig(config.as_ptr()); |
| 968 | _tile_zero::<0>(); |
| 969 | _tile_loadd::<1>(&ones as *const u8, 64); |
| 970 | _tile_loadd::<2>(&twos as *const u8, 64); |
| 971 | _tile_dphf8ps::<0, 1, 2>(); |
| 972 | _tile_stored::<0>(res.as_mut_ptr().cast(), 64); |
| 973 | _tile_release(); |
| 974 | assert_eq!(res, [[128.0_f32; 16]; 16]); |
| 975 | } |
| 976 | } |
| 977 | |
| 978 | #[simd_test(enable = "amx-movrs" )] |
| 979 | fn test_tile_loaddrs() { |
| 980 | unsafe { |
| 981 | _init_amx(); |
| 982 | let mut config = __tilecfg::default(); |
| 983 | config.palette = 1; |
| 984 | config.colsb[0] = 64; |
| 985 | config.rows[0] = 16; |
| 986 | _tile_loadconfig(config.as_ptr()); |
| 987 | _tile_zero::<0>(); |
| 988 | let mat = [1_i8; 1024]; |
| 989 | _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64); |
| 990 | let mut out = [[0_i8; 64]; 16]; |
| 991 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 992 | _tile_release(); |
| 993 | assert_eq!(out, [[1; 64]; 16]); |
| 994 | } |
| 995 | } |
| 996 | |
| 997 | #[simd_test(enable = "amx-movrs" )] |
| 998 | fn test_tile_stream_loaddrs() { |
| 999 | unsafe { |
| 1000 | _init_amx(); |
| 1001 | let mut config = __tilecfg::default(); |
| 1002 | config.palette = 1; |
| 1003 | config.colsb[0] = 64; |
| 1004 | config.rows[0] = 16; |
| 1005 | _tile_loadconfig(config.as_ptr()); |
| 1006 | _tile_zero::<0>(); |
| 1007 | let mat = [1_i8; 1024]; |
| 1008 | _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64); |
| 1009 | let mut out = [[0_i8; 64]; 16]; |
| 1010 | _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); |
| 1011 | _tile_release(); |
| 1012 | assert_eq!(out, [[1; 64]; 16]); |
| 1013 | } |
| 1014 | } |
| 1015 | |
| 1016 | #[simd_test(enable = "amx-avx512,avx10.2" )] |
| 1017 | fn test_tile_movrow() { |
| 1018 | unsafe { |
| 1019 | _init_amx(); |
| 1020 | let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]); |
| 1021 | |
| 1022 | let mut config = __tilecfg::default(); |
| 1023 | config.palette = 1; |
| 1024 | config.colsb[0] = 64; |
| 1025 | config.rows[0] = 16; |
| 1026 | _tile_loadconfig(config.as_ptr()); |
| 1027 | _tile_loadd::<0>(array.as_ptr().cast(), 64); |
| 1028 | for i in 0..16 { |
| 1029 | let row = _tile_movrow::<0>(i); |
| 1030 | assert_eq!(*row.as_u8x64().as_array(), [i as _; _]); |
| 1031 | } |
| 1032 | } |
| 1033 | } |
| 1034 | |
| 1035 | #[simd_test(enable = "amx-avx512,avx10.2" )] |
| 1036 | fn test_tile_cvtrowd2ps() { |
| 1037 | unsafe { |
| 1038 | _init_amx(); |
| 1039 | let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]); |
| 1040 | |
| 1041 | let mut config = __tilecfg::default(); |
| 1042 | config.palette = 1; |
| 1043 | config.colsb[0] = 64; |
| 1044 | config.rows[0] = 16; |
| 1045 | _tile_loadconfig(config.as_ptr()); |
| 1046 | _tile_loadd::<0>(array.as_ptr().cast(), 64); |
| 1047 | for i in 0..16 { |
| 1048 | let row = _tile_cvtrowd2ps::<0>(i); |
| 1049 | assert_eq!(*row.as_f32x16().as_array(), [i as _; _]); |
| 1050 | } |
| 1051 | } |
| 1052 | } |
| 1053 | |
| 1054 | #[simd_test(enable = "amx-avx512,avx10.2" )] |
| 1055 | fn test_tile_cvtrowps2phh() { |
| 1056 | unsafe { |
| 1057 | _init_amx(); |
| 1058 | let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); |
| 1059 | |
| 1060 | let mut config = __tilecfg::default(); |
| 1061 | config.palette = 1; |
| 1062 | config.colsb[0] = 64; |
| 1063 | config.rows[0] = 16; |
| 1064 | _tile_loadconfig(config.as_ptr()); |
| 1065 | _tile_loadd::<0>(array.as_ptr().cast(), 64); |
| 1066 | for i in 0..16 { |
| 1067 | let row = _tile_cvtrowps2phh::<0>(i); |
| 1068 | assert_eq!( |
| 1069 | *row.as_f16x32().as_array(), |
| 1070 | array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ }) |
| 1071 | ); |
| 1072 | } |
| 1073 | } |
| 1074 | } |
| 1075 | |
| 1076 | #[simd_test(enable = "amx-avx512,avx10.2" )] |
| 1077 | fn test_tile_cvtrowps2phl() { |
| 1078 | unsafe { |
| 1079 | _init_amx(); |
| 1080 | let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); |
| 1081 | |
| 1082 | let mut config = __tilecfg::default(); |
| 1083 | config.palette = 1; |
| 1084 | config.colsb[0] = 64; |
| 1085 | config.rows[0] = 16; |
| 1086 | _tile_loadconfig(config.as_ptr()); |
| 1087 | _tile_loadd::<0>(array.as_ptr().cast(), 64); |
| 1088 | for i in 0..16 { |
| 1089 | let row = _tile_cvtrowps2phl::<0>(i); |
| 1090 | assert_eq!( |
| 1091 | *row.as_f16x32().as_array(), |
| 1092 | array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 }) |
| 1093 | ); |
| 1094 | } |
| 1095 | } |
| 1096 | } |
| 1097 | |
| 1098 | #[simd_test(enable = "amx-tf32" )] |
| 1099 | fn test_tile_mmultf32ps() { |
| 1100 | unsafe { |
| 1101 | _init_amx(); |
| 1102 | let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); |
| 1103 | let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _]; |
| 1104 | let mut res = [[0.0; 16]; 16]; |
| 1105 | |
| 1106 | let mut config = __tilecfg::default(); |
| 1107 | config.palette = 1; |
| 1108 | (0..=2).for_each(|i| { |
| 1109 | config.colsb[i] = 64; |
| 1110 | config.rows[i] = 16; |
| 1111 | }); |
| 1112 | _tile_loadconfig(config.as_ptr()); |
| 1113 | _tile_zero::<0>(); |
| 1114 | _tile_loadd::<1>(a.as_ptr().cast(), 64); |
| 1115 | _tile_loadd::<2>(b.as_ptr().cast(), 64); |
| 1116 | _tile_mmultf32ps::<0, 1, 2>(); |
| 1117 | _tile_stored::<0>(res.as_mut_ptr().cast(), 64); |
| 1118 | _tile_release(); |
| 1119 | |
| 1120 | let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32)); |
| 1121 | assert_eq!(res, expected); |
| 1122 | } |
| 1123 | } |
| 1124 | } |
| 1125 | |