| 1 | #![allow (dead_code, unused_imports)] |
| 2 | |
| 3 | macro_rules! convert_fn { |
| 4 | (fn $name:ident($var:ident : $vartype:ty) -> $restype:ty { |
| 5 | if feature("f16c" ) { $f16c:expr } |
| 6 | else { $fallback:expr }}) => { |
| 7 | #[inline] |
| 8 | pub(crate) fn $name($var: $vartype) -> $restype { |
| 9 | // Use CPU feature detection if using std |
| 10 | #[cfg(all( |
| 11 | feature = "use-intrinsics" , |
| 12 | feature = "std" , |
| 13 | any(target_arch = "x86" , target_arch = "x86_64" ), |
| 14 | not(target_feature = "f16c" ) |
| 15 | ))] |
| 16 | { |
| 17 | if is_x86_feature_detected!("f16c" ) { |
| 18 | $f16c |
| 19 | } else { |
| 20 | $fallback |
| 21 | } |
| 22 | } |
| 23 | // Use intrinsics directly when a compile target or using no_std |
| 24 | #[cfg(all( |
| 25 | feature = "use-intrinsics" , |
| 26 | any(target_arch = "x86" , target_arch = "x86_64" ), |
| 27 | target_feature = "f16c" |
| 28 | ))] |
| 29 | { |
| 30 | $f16c |
| 31 | } |
| 32 | // Fallback to software |
| 33 | #[cfg(any( |
| 34 | not(feature = "use-intrinsics" ), |
| 35 | not(any(target_arch = "x86" , target_arch = "x86_64" )), |
| 36 | all(not(feature = "std" ), not(target_feature = "f16c" )) |
| 37 | ))] |
| 38 | { |
| 39 | $fallback |
| 40 | } |
| 41 | } |
| 42 | }; |
| 43 | } |
| 44 | |
| 45 | convert_fn! { |
| 46 | fn f32_to_f16(f: f32) -> u16 { |
| 47 | if feature("f16c" ) { |
| 48 | unsafe { x86::f32_to_f16_x86_f16c(f) } |
| 49 | } else { |
| 50 | f32_to_f16_fallback(f) |
| 51 | } |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | convert_fn! { |
| 56 | fn f64_to_f16(f: f64) -> u16 { |
| 57 | if feature("f16c" ) { |
| 58 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) } |
| 59 | } else { |
| 60 | f64_to_f16_fallback(f) |
| 61 | } |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | convert_fn! { |
| 66 | fn f16_to_f32(i: u16) -> f32 { |
| 67 | if feature("f16c" ) { |
| 68 | unsafe { x86::f16_to_f32_x86_f16c(i) } |
| 69 | } else { |
| 70 | f16_to_f32_fallback(i) |
| 71 | } |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | convert_fn! { |
| 76 | fn f16_to_f64(i: u16) -> f64 { |
| 77 | if feature("f16c" ) { |
| 78 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 } |
| 79 | } else { |
| 80 | f16_to_f64_fallback(i) |
| 81 | } |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | // TODO: While SIMD versions are faster, further improvements can be made by doing runtime feature |
| 86 | // detection once at beginning of convert slice method, rather than per chunk |
| 87 | |
| 88 | convert_fn! { |
| 89 | fn f32x4_to_f16x4(f: &[f32]) -> [u16; 4] { |
| 90 | if feature("f16c" ) { |
| 91 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) } |
| 92 | } else { |
| 93 | f32x4_to_f16x4_fallback(f) |
| 94 | } |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | convert_fn! { |
| 99 | fn f16x4_to_f32x4(i: &[u16]) -> [f32; 4] { |
| 100 | if feature("f16c" ) { |
| 101 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) } |
| 102 | } else { |
| 103 | f16x4_to_f32x4_fallback(i) |
| 104 | } |
| 105 | } |
| 106 | } |
| 107 | |
| 108 | convert_fn! { |
| 109 | fn f64x4_to_f16x4(f: &[f64]) -> [u16; 4] { |
| 110 | if feature("f16c" ) { |
| 111 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) } |
| 112 | } else { |
| 113 | f64x4_to_f16x4_fallback(f) |
| 114 | } |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | convert_fn! { |
| 119 | fn f16x4_to_f64x4(i: &[u16]) -> [f64; 4] { |
| 120 | if feature("f16c" ) { |
| 121 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) } |
| 122 | } else { |
| 123 | f16x4_to_f64x4_fallback(i) |
| 124 | } |
| 125 | } |
| 126 | } |
| 127 | |
| 128 | /////////////// Fallbacks //////////////// |
| 129 | |
| 130 | // In the below functions, round to nearest, with ties to even. |
| 131 | // Let us call the most significant bit that will be shifted out the round_bit. |
| 132 | // |
| 133 | // Round up if either |
| 134 | // a) Removed part > tie. |
| 135 | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0 |
| 136 | // b) Removed part == tie, and retained part is odd. |
| 137 | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0 |
| 138 | // (If removed part == tie and retained part is even, do not round up.) |
| 139 | // These two conditions can be combined into one: |
| 140 | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0 |
| 141 | // which can be simplified into |
| 142 | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0 |
| 143 | |
| 144 | fn f32_to_f16_fallback(value: f32) -> u16 { |
| 145 | // Convert to raw bytes |
| 146 | let x = value.to_bits(); |
| 147 | |
| 148 | // Extract IEEE754 components |
| 149 | let sign = x & 0x8000_0000u32; |
| 150 | let exp = x & 0x7F80_0000u32; |
| 151 | let man = x & 0x007F_FFFFu32; |
| 152 | |
| 153 | // Check for all exponent bits being set, which is Infinity or NaN |
| 154 | if exp == 0x7F80_0000u32 { |
| 155 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits) |
| 156 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; |
| 157 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16; |
| 158 | } |
| 159 | |
| 160 | // The number is normalized, start assembling half precision version |
| 161 | let half_sign = sign >> 16; |
| 162 | // Unbias the exponent, then bias for half precision |
| 163 | let unbiased_exp = ((exp >> 23) as i32) - 127; |
| 164 | let half_exp = unbiased_exp + 15; |
| 165 | |
| 166 | // Check for exponent overflow, return +infinity |
| 167 | if half_exp >= 0x1F { |
| 168 | return (half_sign | 0x7C00u32) as u16; |
| 169 | } |
| 170 | |
| 171 | // Check for underflow |
| 172 | if half_exp <= 0 { |
| 173 | // Check mantissa for what we can do |
| 174 | if 14 - half_exp > 24 { |
| 175 | // No rounding possibility, so this is a full underflow, return signed zero |
| 176 | return half_sign as u16; |
| 177 | } |
| 178 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
| 179 | let man = man | 0x0080_0000u32; |
| 180 | let mut half_man = man >> (14 - half_exp); |
| 181 | // Check for rounding (see comment above functions) |
| 182 | let round_bit = 1 << (13 - half_exp); |
| 183 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
| 184 | half_man += 1; |
| 185 | } |
| 186 | // No exponent for subnormals |
| 187 | return (half_sign | half_man) as u16; |
| 188 | } |
| 189 | |
| 190 | // Rebias the exponent |
| 191 | let half_exp = (half_exp as u32) << 10; |
| 192 | let half_man = man >> 13; |
| 193 | // Check for rounding (see comment above functions) |
| 194 | let round_bit = 0x0000_1000u32; |
| 195 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
| 196 | // Round it |
| 197 | ((half_sign | half_exp | half_man) + 1) as u16 |
| 198 | } else { |
| 199 | (half_sign | half_exp | half_man) as u16 |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | fn f64_to_f16_fallback(value: f64) -> u16 { |
| 204 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always |
| 205 | // be lost on half-precision. |
| 206 | let val = value.to_bits(); |
| 207 | let x = (val >> 32) as u32; |
| 208 | |
| 209 | // Extract IEEE754 components |
| 210 | let sign = x & 0x8000_0000u32; |
| 211 | let exp = x & 0x7FF0_0000u32; |
| 212 | let man = x & 0x000F_FFFFu32; |
| 213 | |
| 214 | // Check for all exponent bits being set, which is Infinity or NaN |
| 215 | if exp == 0x7FF0_0000u32 { |
| 216 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits). |
| 217 | // We also have to check the last 32 bits. |
| 218 | let nan_bit = if man == 0 && (val as u32 == 0) { |
| 219 | 0 |
| 220 | } else { |
| 221 | 0x0200u32 |
| 222 | }; |
| 223 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16; |
| 224 | } |
| 225 | |
| 226 | // The number is normalized, start assembling half precision version |
| 227 | let half_sign = sign >> 16; |
| 228 | // Unbias the exponent, then bias for half precision |
| 229 | let unbiased_exp = ((exp >> 20) as i64) - 1023; |
| 230 | let half_exp = unbiased_exp + 15; |
| 231 | |
| 232 | // Check for exponent overflow, return +infinity |
| 233 | if half_exp >= 0x1F { |
| 234 | return (half_sign | 0x7C00u32) as u16; |
| 235 | } |
| 236 | |
| 237 | // Check for underflow |
| 238 | if half_exp <= 0 { |
| 239 | // Check mantissa for what we can do |
| 240 | if 10 - half_exp > 21 { |
| 241 | // No rounding possibility, so this is a full underflow, return signed zero |
| 242 | return half_sign as u16; |
| 243 | } |
| 244 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
| 245 | let man = man | 0x0010_0000u32; |
| 246 | let mut half_man = man >> (11 - half_exp); |
| 247 | // Check for rounding (see comment above functions) |
| 248 | let round_bit = 1 << (10 - half_exp); |
| 249 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
| 250 | half_man += 1; |
| 251 | } |
| 252 | // No exponent for subnormals |
| 253 | return (half_sign | half_man) as u16; |
| 254 | } |
| 255 | |
| 256 | // Rebias the exponent |
| 257 | let half_exp = (half_exp as u32) << 10; |
| 258 | let half_man = man >> 10; |
| 259 | // Check for rounding (see comment above functions) |
| 260 | let round_bit = 0x0000_0200u32; |
| 261 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
| 262 | // Round it |
| 263 | ((half_sign | half_exp | half_man) + 1) as u16 |
| 264 | } else { |
| 265 | (half_sign | half_exp | half_man) as u16 |
| 266 | } |
| 267 | } |
| 268 | |
| 269 | fn f16_to_f32_fallback(i: u16) -> f32 { |
| 270 | // Check for signed zero |
| 271 | if i & 0x7FFFu16 == 0 { |
| 272 | return f32::from_bits((i as u32) << 16); |
| 273 | } |
| 274 | |
| 275 | let half_sign = (i & 0x8000u16) as u32; |
| 276 | let half_exp = (i & 0x7C00u16) as u32; |
| 277 | let half_man = (i & 0x03FFu16) as u32; |
| 278 | |
| 279 | // Check for an infinity or NaN when all exponent bits set |
| 280 | if half_exp == 0x7C00u32 { |
| 281 | // Check for signed infinity if mantissa is zero |
| 282 | if half_man == 0 { |
| 283 | return f32::from_bits((half_sign << 16) | 0x7F80_0000u32); |
| 284 | } else { |
| 285 | // NaN, keep current mantissa but also set most significiant mantissa bit |
| 286 | return f32::from_bits((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13)); |
| 287 | } |
| 288 | } |
| 289 | |
| 290 | // Calculate single-precision components with adjusted exponent |
| 291 | let sign = half_sign << 16; |
| 292 | // Unbias exponent |
| 293 | let unbiased_exp = ((half_exp as i32) >> 10) - 15; |
| 294 | |
| 295 | // Check for subnormals, which will be normalized by adjusting exponent |
| 296 | if half_exp == 0 { |
| 297 | // Calculate how much to adjust the exponent by |
| 298 | let e = (half_man as u16).leading_zeros() - 6; |
| 299 | |
| 300 | // Rebias and adjust exponent |
| 301 | let exp = (127 - 15 - e) << 23; |
| 302 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; |
| 303 | return f32::from_bits(sign | exp | man); |
| 304 | } |
| 305 | |
| 306 | // Rebias exponent for a normalized normal |
| 307 | let exp = ((unbiased_exp + 127) as u32) << 23; |
| 308 | let man = (half_man & 0x03FFu32) << 13; |
| 309 | f32::from_bits(sign | exp | man) |
| 310 | } |
| 311 | |
| 312 | fn f16_to_f64_fallback(i: u16) -> f64 { |
| 313 | // Check for signed zero |
| 314 | if i & 0x7FFFu16 == 0 { |
| 315 | return f64::from_bits((i as u64) << 48); |
| 316 | } |
| 317 | |
| 318 | let half_sign = (i & 0x8000u16) as u64; |
| 319 | let half_exp = (i & 0x7C00u16) as u64; |
| 320 | let half_man = (i & 0x03FFu16) as u64; |
| 321 | |
| 322 | // Check for an infinity or NaN when all exponent bits set |
| 323 | if half_exp == 0x7C00u64 { |
| 324 | // Check for signed infinity if mantissa is zero |
| 325 | if half_man == 0 { |
| 326 | return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64); |
| 327 | } else { |
| 328 | // NaN, keep current mantissa but also set most significiant mantissa bit |
| 329 | return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42)); |
| 330 | } |
| 331 | } |
| 332 | |
| 333 | // Calculate double-precision components with adjusted exponent |
| 334 | let sign = half_sign << 48; |
| 335 | // Unbias exponent |
| 336 | let unbiased_exp = ((half_exp as i64) >> 10) - 15; |
| 337 | |
| 338 | // Check for subnormals, which will be normalized by adjusting exponent |
| 339 | if half_exp == 0 { |
| 340 | // Calculate how much to adjust the exponent by |
| 341 | let e = (half_man as u16).leading_zeros() - 6; |
| 342 | |
| 343 | // Rebias and adjust exponent |
| 344 | let exp = ((1023 - 15 - e) as u64) << 52; |
| 345 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64; |
| 346 | return f64::from_bits(sign | exp | man); |
| 347 | } |
| 348 | |
| 349 | // Rebias exponent for a normalized normal |
| 350 | let exp = ((unbiased_exp + 1023) as u64) << 52; |
| 351 | let man = (half_man & 0x03FFu64) << 42; |
| 352 | f64::from_bits(sign | exp | man) |
| 353 | } |
| 354 | |
| 355 | #[inline ] |
| 356 | fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] { |
| 357 | debug_assert!(v.len() >= 4); |
| 358 | |
| 359 | [ |
| 360 | f16_to_f32_fallback(v[0]), |
| 361 | f16_to_f32_fallback(v[1]), |
| 362 | f16_to_f32_fallback(v[2]), |
| 363 | f16_to_f32_fallback(v[3]), |
| 364 | ] |
| 365 | } |
| 366 | |
| 367 | #[inline ] |
| 368 | fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] { |
| 369 | debug_assert!(v.len() >= 4); |
| 370 | |
| 371 | [ |
| 372 | f32_to_f16_fallback(v[0]), |
| 373 | f32_to_f16_fallback(v[1]), |
| 374 | f32_to_f16_fallback(v[2]), |
| 375 | f32_to_f16_fallback(v[3]), |
| 376 | ] |
| 377 | } |
| 378 | |
| 379 | #[inline ] |
| 380 | fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] { |
| 381 | debug_assert!(v.len() >= 4); |
| 382 | |
| 383 | [ |
| 384 | f16_to_f64_fallback(v[0]), |
| 385 | f16_to_f64_fallback(v[1]), |
| 386 | f16_to_f64_fallback(v[2]), |
| 387 | f16_to_f64_fallback(v[3]), |
| 388 | ] |
| 389 | } |
| 390 | |
| 391 | #[inline ] |
| 392 | fn f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4] { |
| 393 | debug_assert!(v.len() >= 4); |
| 394 | |
| 395 | [ |
| 396 | f64_to_f16_fallback(v[0]), |
| 397 | f64_to_f16_fallback(v[1]), |
| 398 | f64_to_f16_fallback(v[2]), |
| 399 | f64_to_f16_fallback(v[3]), |
| 400 | ] |
| 401 | } |
| 402 | |
| 403 | /////////////// x86/x86_64 f16c //////////////// |
| 404 | #[cfg (all( |
| 405 | feature = "use-intrinsics" , |
| 406 | any(target_arch = "x86" , target_arch = "x86_64" ) |
| 407 | ))] |
| 408 | mod x86 { |
| 409 | use core::{mem::MaybeUninit, ptr}; |
| 410 | |
| 411 | #[cfg (target_arch = "x86" )] |
| 412 | use core::arch::x86::{__m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT}; |
| 413 | #[cfg (target_arch = "x86_64" )] |
| 414 | use core::arch::x86_64::{ |
| 415 | __m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT, |
| 416 | }; |
| 417 | |
| 418 | #[target_feature (enable = "f16c" )] |
| 419 | #[inline ] |
| 420 | pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 { |
| 421 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
| 422 | vec.as_mut_ptr().cast::<u16>().write(i); |
| 423 | let retval = _mm_cvtph_ps(vec.assume_init()); |
| 424 | *(&retval as *const __m128).cast() |
| 425 | } |
| 426 | |
| 427 | #[target_feature (enable = "f16c" )] |
| 428 | #[inline ] |
| 429 | pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 { |
| 430 | let mut vec = MaybeUninit::<__m128>::zeroed(); |
| 431 | vec.as_mut_ptr().cast::<f32>().write(f); |
| 432 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
| 433 | *(&retval as *const __m128i).cast() |
| 434 | } |
| 435 | |
| 436 | #[target_feature (enable = "f16c" )] |
| 437 | #[inline ] |
| 438 | pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4] { |
| 439 | debug_assert!(v.len() >= 4); |
| 440 | |
| 441 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
| 442 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
| 443 | let retval = _mm_cvtph_ps(vec.assume_init()); |
| 444 | *(&retval as *const __m128).cast() |
| 445 | } |
| 446 | |
| 447 | #[target_feature (enable = "f16c" )] |
| 448 | #[inline ] |
| 449 | pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4] { |
| 450 | debug_assert!(v.len() >= 4); |
| 451 | |
| 452 | let mut vec = MaybeUninit::<__m128>::uninit(); |
| 453 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
| 454 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
| 455 | *(&retval as *const __m128i).cast() |
| 456 | } |
| 457 | |
| 458 | #[target_feature (enable = "f16c" )] |
| 459 | #[inline ] |
| 460 | pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4] { |
| 461 | debug_assert!(v.len() >= 4); |
| 462 | |
| 463 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
| 464 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
| 465 | let retval = _mm_cvtph_ps(vec.assume_init()); |
| 466 | let array = *(&retval as *const __m128).cast::<[f32; 4]>(); |
| 467 | // Let compiler vectorize this regular cast for now. |
| 468 | // TODO: investigate auto-detecting sse2/avx convert features |
| 469 | [ |
| 470 | array[0] as f64, |
| 471 | array[1] as f64, |
| 472 | array[2] as f64, |
| 473 | array[3] as f64, |
| 474 | ] |
| 475 | } |
| 476 | |
| 477 | #[target_feature (enable = "f16c" )] |
| 478 | #[inline ] |
| 479 | pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4] { |
| 480 | debug_assert!(v.len() >= 4); |
| 481 | |
| 482 | // Let compiler vectorize this regular cast for now. |
| 483 | // TODO: investigate auto-detecting sse2/avx convert features |
| 484 | let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32]; |
| 485 | |
| 486 | let mut vec = MaybeUninit::<__m128>::uninit(); |
| 487 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
| 488 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
| 489 | *(&retval as *const __m128i).cast() |
| 490 | } |
| 491 | } |
| 492 | |