1 | #![allow (dead_code, unused_imports)] |
2 | |
3 | macro_rules! convert_fn { |
4 | (fn $name:ident($var:ident : $vartype:ty) -> $restype:ty { |
5 | if feature("f16c" ) { $f16c:expr } |
6 | else { $fallback:expr }}) => { |
7 | #[inline] |
8 | pub(crate) fn $name($var: $vartype) -> $restype { |
9 | // Use CPU feature detection if using std |
10 | #[cfg(all( |
11 | feature = "use-intrinsics" , |
12 | feature = "std" , |
13 | any(target_arch = "x86" , target_arch = "x86_64" ), |
14 | not(target_feature = "f16c" ) |
15 | ))] |
16 | { |
17 | if is_x86_feature_detected!("f16c" ) { |
18 | $f16c |
19 | } else { |
20 | $fallback |
21 | } |
22 | } |
23 | // Use intrinsics directly when a compile target or using no_std |
24 | #[cfg(all( |
25 | feature = "use-intrinsics" , |
26 | any(target_arch = "x86" , target_arch = "x86_64" ), |
27 | target_feature = "f16c" |
28 | ))] |
29 | { |
30 | $f16c |
31 | } |
32 | // Fallback to software |
33 | #[cfg(any( |
34 | not(feature = "use-intrinsics" ), |
35 | not(any(target_arch = "x86" , target_arch = "x86_64" )), |
36 | all(not(feature = "std" ), not(target_feature = "f16c" )) |
37 | ))] |
38 | { |
39 | $fallback |
40 | } |
41 | } |
42 | }; |
43 | } |
44 | |
45 | convert_fn! { |
46 | fn f32_to_f16(f: f32) -> u16 { |
47 | if feature("f16c" ) { |
48 | unsafe { x86::f32_to_f16_x86_f16c(f) } |
49 | } else { |
50 | f32_to_f16_fallback(f) |
51 | } |
52 | } |
53 | } |
54 | |
55 | convert_fn! { |
56 | fn f64_to_f16(f: f64) -> u16 { |
57 | if feature("f16c" ) { |
58 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) } |
59 | } else { |
60 | f64_to_f16_fallback(f) |
61 | } |
62 | } |
63 | } |
64 | |
65 | convert_fn! { |
66 | fn f16_to_f32(i: u16) -> f32 { |
67 | if feature("f16c" ) { |
68 | unsafe { x86::f16_to_f32_x86_f16c(i) } |
69 | } else { |
70 | f16_to_f32_fallback(i) |
71 | } |
72 | } |
73 | } |
74 | |
75 | convert_fn! { |
76 | fn f16_to_f64(i: u16) -> f64 { |
77 | if feature("f16c" ) { |
78 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 } |
79 | } else { |
80 | f16_to_f64_fallback(i) |
81 | } |
82 | } |
83 | } |
84 | |
85 | // TODO: While SIMD versions are faster, further improvements can be made by doing runtime feature |
86 | // detection once at beginning of convert slice method, rather than per chunk |
87 | |
88 | convert_fn! { |
89 | fn f32x4_to_f16x4(f: &[f32]) -> [u16; 4] { |
90 | if feature("f16c" ) { |
91 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) } |
92 | } else { |
93 | f32x4_to_f16x4_fallback(f) |
94 | } |
95 | } |
96 | } |
97 | |
98 | convert_fn! { |
99 | fn f16x4_to_f32x4(i: &[u16]) -> [f32; 4] { |
100 | if feature("f16c" ) { |
101 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) } |
102 | } else { |
103 | f16x4_to_f32x4_fallback(i) |
104 | } |
105 | } |
106 | } |
107 | |
108 | convert_fn! { |
109 | fn f64x4_to_f16x4(f: &[f64]) -> [u16; 4] { |
110 | if feature("f16c" ) { |
111 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) } |
112 | } else { |
113 | f64x4_to_f16x4_fallback(f) |
114 | } |
115 | } |
116 | } |
117 | |
118 | convert_fn! { |
119 | fn f16x4_to_f64x4(i: &[u16]) -> [f64; 4] { |
120 | if feature("f16c" ) { |
121 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) } |
122 | } else { |
123 | f16x4_to_f64x4_fallback(i) |
124 | } |
125 | } |
126 | } |
127 | |
128 | /////////////// Fallbacks //////////////// |
129 | |
130 | // In the below functions, round to nearest, with ties to even. |
131 | // Let us call the most significant bit that will be shifted out the round_bit. |
132 | // |
133 | // Round up if either |
134 | // a) Removed part > tie. |
135 | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0 |
136 | // b) Removed part == tie, and retained part is odd. |
137 | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0 |
138 | // (If removed part == tie and retained part is even, do not round up.) |
139 | // These two conditions can be combined into one: |
140 | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0 |
141 | // which can be simplified into |
142 | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0 |
143 | |
144 | fn f32_to_f16_fallback(value: f32) -> u16 { |
145 | // Convert to raw bytes |
146 | let x = value.to_bits(); |
147 | |
148 | // Extract IEEE754 components |
149 | let sign = x & 0x8000_0000u32; |
150 | let exp = x & 0x7F80_0000u32; |
151 | let man = x & 0x007F_FFFFu32; |
152 | |
153 | // Check for all exponent bits being set, which is Infinity or NaN |
154 | if exp == 0x7F80_0000u32 { |
155 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits) |
156 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; |
157 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16; |
158 | } |
159 | |
160 | // The number is normalized, start assembling half precision version |
161 | let half_sign = sign >> 16; |
162 | // Unbias the exponent, then bias for half precision |
163 | let unbiased_exp = ((exp >> 23) as i32) - 127; |
164 | let half_exp = unbiased_exp + 15; |
165 | |
166 | // Check for exponent overflow, return +infinity |
167 | if half_exp >= 0x1F { |
168 | return (half_sign | 0x7C00u32) as u16; |
169 | } |
170 | |
171 | // Check for underflow |
172 | if half_exp <= 0 { |
173 | // Check mantissa for what we can do |
174 | if 14 - half_exp > 24 { |
175 | // No rounding possibility, so this is a full underflow, return signed zero |
176 | return half_sign as u16; |
177 | } |
178 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
179 | let man = man | 0x0080_0000u32; |
180 | let mut half_man = man >> (14 - half_exp); |
181 | // Check for rounding (see comment above functions) |
182 | let round_bit = 1 << (13 - half_exp); |
183 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
184 | half_man += 1; |
185 | } |
186 | // No exponent for subnormals |
187 | return (half_sign | half_man) as u16; |
188 | } |
189 | |
190 | // Rebias the exponent |
191 | let half_exp = (half_exp as u32) << 10; |
192 | let half_man = man >> 13; |
193 | // Check for rounding (see comment above functions) |
194 | let round_bit = 0x0000_1000u32; |
195 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
196 | // Round it |
197 | ((half_sign | half_exp | half_man) + 1) as u16 |
198 | } else { |
199 | (half_sign | half_exp | half_man) as u16 |
200 | } |
201 | } |
202 | |
203 | fn f64_to_f16_fallback(value: f64) -> u16 { |
204 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always |
205 | // be lost on half-precision. |
206 | let val = value.to_bits(); |
207 | let x = (val >> 32) as u32; |
208 | |
209 | // Extract IEEE754 components |
210 | let sign = x & 0x8000_0000u32; |
211 | let exp = x & 0x7FF0_0000u32; |
212 | let man = x & 0x000F_FFFFu32; |
213 | |
214 | // Check for all exponent bits being set, which is Infinity or NaN |
215 | if exp == 0x7FF0_0000u32 { |
216 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits). |
217 | // We also have to check the last 32 bits. |
218 | let nan_bit = if man == 0 && (val as u32 == 0) { |
219 | 0 |
220 | } else { |
221 | 0x0200u32 |
222 | }; |
223 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16; |
224 | } |
225 | |
226 | // The number is normalized, start assembling half precision version |
227 | let half_sign = sign >> 16; |
228 | // Unbias the exponent, then bias for half precision |
229 | let unbiased_exp = ((exp >> 20) as i64) - 1023; |
230 | let half_exp = unbiased_exp + 15; |
231 | |
232 | // Check for exponent overflow, return +infinity |
233 | if half_exp >= 0x1F { |
234 | return (half_sign | 0x7C00u32) as u16; |
235 | } |
236 | |
237 | // Check for underflow |
238 | if half_exp <= 0 { |
239 | // Check mantissa for what we can do |
240 | if 10 - half_exp > 21 { |
241 | // No rounding possibility, so this is a full underflow, return signed zero |
242 | return half_sign as u16; |
243 | } |
244 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
245 | let man = man | 0x0010_0000u32; |
246 | let mut half_man = man >> (11 - half_exp); |
247 | // Check for rounding (see comment above functions) |
248 | let round_bit = 1 << (10 - half_exp); |
249 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
250 | half_man += 1; |
251 | } |
252 | // No exponent for subnormals |
253 | return (half_sign | half_man) as u16; |
254 | } |
255 | |
256 | // Rebias the exponent |
257 | let half_exp = (half_exp as u32) << 10; |
258 | let half_man = man >> 10; |
259 | // Check for rounding (see comment above functions) |
260 | let round_bit = 0x0000_0200u32; |
261 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
262 | // Round it |
263 | ((half_sign | half_exp | half_man) + 1) as u16 |
264 | } else { |
265 | (half_sign | half_exp | half_man) as u16 |
266 | } |
267 | } |
268 | |
269 | fn f16_to_f32_fallback(i: u16) -> f32 { |
270 | // Check for signed zero |
271 | if i & 0x7FFFu16 == 0 { |
272 | return f32::from_bits((i as u32) << 16); |
273 | } |
274 | |
275 | let half_sign = (i & 0x8000u16) as u32; |
276 | let half_exp = (i & 0x7C00u16) as u32; |
277 | let half_man = (i & 0x03FFu16) as u32; |
278 | |
279 | // Check for an infinity or NaN when all exponent bits set |
280 | if half_exp == 0x7C00u32 { |
281 | // Check for signed infinity if mantissa is zero |
282 | if half_man == 0 { |
283 | return f32::from_bits((half_sign << 16) | 0x7F80_0000u32); |
284 | } else { |
285 | // NaN, keep current mantissa but also set most significiant mantissa bit |
286 | return f32::from_bits((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13)); |
287 | } |
288 | } |
289 | |
290 | // Calculate single-precision components with adjusted exponent |
291 | let sign = half_sign << 16; |
292 | // Unbias exponent |
293 | let unbiased_exp = ((half_exp as i32) >> 10) - 15; |
294 | |
295 | // Check for subnormals, which will be normalized by adjusting exponent |
296 | if half_exp == 0 { |
297 | // Calculate how much to adjust the exponent by |
298 | let e = (half_man as u16).leading_zeros() - 6; |
299 | |
300 | // Rebias and adjust exponent |
301 | let exp = (127 - 15 - e) << 23; |
302 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; |
303 | return f32::from_bits(sign | exp | man); |
304 | } |
305 | |
306 | // Rebias exponent for a normalized normal |
307 | let exp = ((unbiased_exp + 127) as u32) << 23; |
308 | let man = (half_man & 0x03FFu32) << 13; |
309 | f32::from_bits(sign | exp | man) |
310 | } |
311 | |
312 | fn f16_to_f64_fallback(i: u16) -> f64 { |
313 | // Check for signed zero |
314 | if i & 0x7FFFu16 == 0 { |
315 | return f64::from_bits((i as u64) << 48); |
316 | } |
317 | |
318 | let half_sign = (i & 0x8000u16) as u64; |
319 | let half_exp = (i & 0x7C00u16) as u64; |
320 | let half_man = (i & 0x03FFu16) as u64; |
321 | |
322 | // Check for an infinity or NaN when all exponent bits set |
323 | if half_exp == 0x7C00u64 { |
324 | // Check for signed infinity if mantissa is zero |
325 | if half_man == 0 { |
326 | return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64); |
327 | } else { |
328 | // NaN, keep current mantissa but also set most significiant mantissa bit |
329 | return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42)); |
330 | } |
331 | } |
332 | |
333 | // Calculate double-precision components with adjusted exponent |
334 | let sign = half_sign << 48; |
335 | // Unbias exponent |
336 | let unbiased_exp = ((half_exp as i64) >> 10) - 15; |
337 | |
338 | // Check for subnormals, which will be normalized by adjusting exponent |
339 | if half_exp == 0 { |
340 | // Calculate how much to adjust the exponent by |
341 | let e = (half_man as u16).leading_zeros() - 6; |
342 | |
343 | // Rebias and adjust exponent |
344 | let exp = ((1023 - 15 - e) as u64) << 52; |
345 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64; |
346 | return f64::from_bits(sign | exp | man); |
347 | } |
348 | |
349 | // Rebias exponent for a normalized normal |
350 | let exp = ((unbiased_exp + 1023) as u64) << 52; |
351 | let man = (half_man & 0x03FFu64) << 42; |
352 | f64::from_bits(sign | exp | man) |
353 | } |
354 | |
355 | #[inline ] |
356 | fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] { |
357 | debug_assert!(v.len() >= 4); |
358 | |
359 | [ |
360 | f16_to_f32_fallback(v[0]), |
361 | f16_to_f32_fallback(v[1]), |
362 | f16_to_f32_fallback(v[2]), |
363 | f16_to_f32_fallback(v[3]), |
364 | ] |
365 | } |
366 | |
367 | #[inline ] |
368 | fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] { |
369 | debug_assert!(v.len() >= 4); |
370 | |
371 | [ |
372 | f32_to_f16_fallback(v[0]), |
373 | f32_to_f16_fallback(v[1]), |
374 | f32_to_f16_fallback(v[2]), |
375 | f32_to_f16_fallback(v[3]), |
376 | ] |
377 | } |
378 | |
379 | #[inline ] |
380 | fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] { |
381 | debug_assert!(v.len() >= 4); |
382 | |
383 | [ |
384 | f16_to_f64_fallback(v[0]), |
385 | f16_to_f64_fallback(v[1]), |
386 | f16_to_f64_fallback(v[2]), |
387 | f16_to_f64_fallback(v[3]), |
388 | ] |
389 | } |
390 | |
391 | #[inline ] |
392 | fn f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4] { |
393 | debug_assert!(v.len() >= 4); |
394 | |
395 | [ |
396 | f64_to_f16_fallback(v[0]), |
397 | f64_to_f16_fallback(v[1]), |
398 | f64_to_f16_fallback(v[2]), |
399 | f64_to_f16_fallback(v[3]), |
400 | ] |
401 | } |
402 | |
403 | /////////////// x86/x86_64 f16c //////////////// |
404 | #[cfg (all( |
405 | feature = "use-intrinsics" , |
406 | any(target_arch = "x86" , target_arch = "x86_64" ) |
407 | ))] |
408 | mod x86 { |
409 | use core::{mem::MaybeUninit, ptr}; |
410 | |
411 | #[cfg (target_arch = "x86" )] |
412 | use core::arch::x86::{__m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT}; |
413 | #[cfg (target_arch = "x86_64" )] |
414 | use core::arch::x86_64::{ |
415 | __m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT, |
416 | }; |
417 | |
418 | #[target_feature (enable = "f16c" )] |
419 | #[inline ] |
420 | pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 { |
421 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
422 | vec.as_mut_ptr().cast::<u16>().write(i); |
423 | let retval = _mm_cvtph_ps(vec.assume_init()); |
424 | *(&retval as *const __m128).cast() |
425 | } |
426 | |
427 | #[target_feature (enable = "f16c" )] |
428 | #[inline ] |
429 | pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 { |
430 | let mut vec = MaybeUninit::<__m128>::zeroed(); |
431 | vec.as_mut_ptr().cast::<f32>().write(f); |
432 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
433 | *(&retval as *const __m128i).cast() |
434 | } |
435 | |
436 | #[target_feature (enable = "f16c" )] |
437 | #[inline ] |
438 | pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4] { |
439 | debug_assert!(v.len() >= 4); |
440 | |
441 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
442 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
443 | let retval = _mm_cvtph_ps(vec.assume_init()); |
444 | *(&retval as *const __m128).cast() |
445 | } |
446 | |
447 | #[target_feature (enable = "f16c" )] |
448 | #[inline ] |
449 | pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4] { |
450 | debug_assert!(v.len() >= 4); |
451 | |
452 | let mut vec = MaybeUninit::<__m128>::uninit(); |
453 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
454 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
455 | *(&retval as *const __m128i).cast() |
456 | } |
457 | |
458 | #[target_feature (enable = "f16c" )] |
459 | #[inline ] |
460 | pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4] { |
461 | debug_assert!(v.len() >= 4); |
462 | |
463 | let mut vec = MaybeUninit::<__m128i>::zeroed(); |
464 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
465 | let retval = _mm_cvtph_ps(vec.assume_init()); |
466 | let array = *(&retval as *const __m128).cast::<[f32; 4]>(); |
467 | // Let compiler vectorize this regular cast for now. |
468 | // TODO: investigate auto-detecting sse2/avx convert features |
469 | [ |
470 | array[0] as f64, |
471 | array[1] as f64, |
472 | array[2] as f64, |
473 | array[3] as f64, |
474 | ] |
475 | } |
476 | |
477 | #[target_feature (enable = "f16c" )] |
478 | #[inline ] |
479 | pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4] { |
480 | debug_assert!(v.len() >= 4); |
481 | |
482 | // Let compiler vectorize this regular cast for now. |
483 | // TODO: investigate auto-detecting sse2/avx convert features |
484 | let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32]; |
485 | |
486 | let mut vec = MaybeUninit::<__m128>::uninit(); |
487 | ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4); |
488 | let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT); |
489 | *(&retval as *const __m128i).cast() |
490 | } |
491 | } |
492 | |