1#![allow(dead_code, unused_imports)]
2
3macro_rules! convert_fn {
4 (fn $name:ident($var:ident : $vartype:ty) -> $restype:ty {
5 if feature("f16c") { $f16c:expr }
6 else { $fallback:expr }}) => {
7 #[inline]
8 pub(crate) fn $name($var: $vartype) -> $restype {
9 // Use CPU feature detection if using std
10 #[cfg(all(
11 feature = "use-intrinsics",
12 feature = "std",
13 any(target_arch = "x86", target_arch = "x86_64"),
14 not(target_feature = "f16c")
15 ))]
16 {
17 if is_x86_feature_detected!("f16c") {
18 $f16c
19 } else {
20 $fallback
21 }
22 }
23 // Use intrinsics directly when a compile target or using no_std
24 #[cfg(all(
25 feature = "use-intrinsics",
26 any(target_arch = "x86", target_arch = "x86_64"),
27 target_feature = "f16c"
28 ))]
29 {
30 $f16c
31 }
32 // Fallback to software
33 #[cfg(any(
34 not(feature = "use-intrinsics"),
35 not(any(target_arch = "x86", target_arch = "x86_64")),
36 all(not(feature = "std"), not(target_feature = "f16c"))
37 ))]
38 {
39 $fallback
40 }
41 }
42 };
43}
44
45convert_fn! {
46 fn f32_to_f16(f: f32) -> u16 {
47 if feature("f16c") {
48 unsafe { x86::f32_to_f16_x86_f16c(f) }
49 } else {
50 f32_to_f16_fallback(f)
51 }
52 }
53}
54
55convert_fn! {
56 fn f64_to_f16(f: f64) -> u16 {
57 if feature("f16c") {
58 unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
59 } else {
60 f64_to_f16_fallback(f)
61 }
62 }
63}
64
65convert_fn! {
66 fn f16_to_f32(i: u16) -> f32 {
67 if feature("f16c") {
68 unsafe { x86::f16_to_f32_x86_f16c(i) }
69 } else {
70 f16_to_f32_fallback(i)
71 }
72 }
73}
74
75convert_fn! {
76 fn f16_to_f64(i: u16) -> f64 {
77 if feature("f16c") {
78 unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
79 } else {
80 f16_to_f64_fallback(i)
81 }
82 }
83}
84
85// TODO: While SIMD versions are faster, further improvements can be made by doing runtime feature
86// detection once at beginning of convert slice method, rather than per chunk
87
88convert_fn! {
89 fn f32x4_to_f16x4(f: &[f32]) -> [u16; 4] {
90 if feature("f16c") {
91 unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
92 } else {
93 f32x4_to_f16x4_fallback(f)
94 }
95 }
96}
97
98convert_fn! {
99 fn f16x4_to_f32x4(i: &[u16]) -> [f32; 4] {
100 if feature("f16c") {
101 unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
102 } else {
103 f16x4_to_f32x4_fallback(i)
104 }
105 }
106}
107
108convert_fn! {
109 fn f64x4_to_f16x4(f: &[f64]) -> [u16; 4] {
110 if feature("f16c") {
111 unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
112 } else {
113 f64x4_to_f16x4_fallback(f)
114 }
115 }
116}
117
118convert_fn! {
119 fn f16x4_to_f64x4(i: &[u16]) -> [f64; 4] {
120 if feature("f16c") {
121 unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
122 } else {
123 f16x4_to_f64x4_fallback(i)
124 }
125 }
126}
127
128/////////////// Fallbacks ////////////////
129
130// In the below functions, round to nearest, with ties to even.
131// Let us call the most significant bit that will be shifted out the round_bit.
132//
133// Round up if either
134// a) Removed part > tie.
135// (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
136// b) Removed part == tie, and retained part is odd.
137// (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
138// (If removed part == tie and retained part is even, do not round up.)
139// These two conditions can be combined into one:
140// (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
141// which can be simplified into
142// (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
143
144fn f32_to_f16_fallback(value: f32) -> u16 {
145 // Convert to raw bytes
146 let x = value.to_bits();
147
148 // Extract IEEE754 components
149 let sign = x & 0x8000_0000u32;
150 let exp = x & 0x7F80_0000u32;
151 let man = x & 0x007F_FFFFu32;
152
153 // Check for all exponent bits being set, which is Infinity or NaN
154 if exp == 0x7F80_0000u32 {
155 // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
156 let nan_bit = if man == 0 { 0 } else { 0x0200u32 };
157 return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16;
158 }
159
160 // The number is normalized, start assembling half precision version
161 let half_sign = sign >> 16;
162 // Unbias the exponent, then bias for half precision
163 let unbiased_exp = ((exp >> 23) as i32) - 127;
164 let half_exp = unbiased_exp + 15;
165
166 // Check for exponent overflow, return +infinity
167 if half_exp >= 0x1F {
168 return (half_sign | 0x7C00u32) as u16;
169 }
170
171 // Check for underflow
172 if half_exp <= 0 {
173 // Check mantissa for what we can do
174 if 14 - half_exp > 24 {
175 // No rounding possibility, so this is a full underflow, return signed zero
176 return half_sign as u16;
177 }
178 // Don't forget about hidden leading mantissa bit when assembling mantissa
179 let man = man | 0x0080_0000u32;
180 let mut half_man = man >> (14 - half_exp);
181 // Check for rounding (see comment above functions)
182 let round_bit = 1 << (13 - half_exp);
183 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
184 half_man += 1;
185 }
186 // No exponent for subnormals
187 return (half_sign | half_man) as u16;
188 }
189
190 // Rebias the exponent
191 let half_exp = (half_exp as u32) << 10;
192 let half_man = man >> 13;
193 // Check for rounding (see comment above functions)
194 let round_bit = 0x0000_1000u32;
195 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
196 // Round it
197 ((half_sign | half_exp | half_man) + 1) as u16
198 } else {
199 (half_sign | half_exp | half_man) as u16
200 }
201}
202
203fn f64_to_f16_fallback(value: f64) -> u16 {
204 // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
205 // be lost on half-precision.
206 let val = value.to_bits();
207 let x = (val >> 32) as u32;
208
209 // Extract IEEE754 components
210 let sign = x & 0x8000_0000u32;
211 let exp = x & 0x7FF0_0000u32;
212 let man = x & 0x000F_FFFFu32;
213
214 // Check for all exponent bits being set, which is Infinity or NaN
215 if exp == 0x7FF0_0000u32 {
216 // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
217 // We also have to check the last 32 bits.
218 let nan_bit = if man == 0 && (val as u32 == 0) {
219 0
220 } else {
221 0x0200u32
222 };
223 return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16;
224 }
225
226 // The number is normalized, start assembling half precision version
227 let half_sign = sign >> 16;
228 // Unbias the exponent, then bias for half precision
229 let unbiased_exp = ((exp >> 20) as i64) - 1023;
230 let half_exp = unbiased_exp + 15;
231
232 // Check for exponent overflow, return +infinity
233 if half_exp >= 0x1F {
234 return (half_sign | 0x7C00u32) as u16;
235 }
236
237 // Check for underflow
238 if half_exp <= 0 {
239 // Check mantissa for what we can do
240 if 10 - half_exp > 21 {
241 // No rounding possibility, so this is a full underflow, return signed zero
242 return half_sign as u16;
243 }
244 // Don't forget about hidden leading mantissa bit when assembling mantissa
245 let man = man | 0x0010_0000u32;
246 let mut half_man = man >> (11 - half_exp);
247 // Check for rounding (see comment above functions)
248 let round_bit = 1 << (10 - half_exp);
249 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
250 half_man += 1;
251 }
252 // No exponent for subnormals
253 return (half_sign | half_man) as u16;
254 }
255
256 // Rebias the exponent
257 let half_exp = (half_exp as u32) << 10;
258 let half_man = man >> 10;
259 // Check for rounding (see comment above functions)
260 let round_bit = 0x0000_0200u32;
261 if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
262 // Round it
263 ((half_sign | half_exp | half_man) + 1) as u16
264 } else {
265 (half_sign | half_exp | half_man) as u16
266 }
267}
268
269fn f16_to_f32_fallback(i: u16) -> f32 {
270 // Check for signed zero
271 if i & 0x7FFFu16 == 0 {
272 return f32::from_bits((i as u32) << 16);
273 }
274
275 let half_sign = (i & 0x8000u16) as u32;
276 let half_exp = (i & 0x7C00u16) as u32;
277 let half_man = (i & 0x03FFu16) as u32;
278
279 // Check for an infinity or NaN when all exponent bits set
280 if half_exp == 0x7C00u32 {
281 // Check for signed infinity if mantissa is zero
282 if half_man == 0 {
283 return f32::from_bits((half_sign << 16) | 0x7F80_0000u32);
284 } else {
285 // NaN, keep current mantissa but also set most significiant mantissa bit
286 return f32::from_bits((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13));
287 }
288 }
289
290 // Calculate single-precision components with adjusted exponent
291 let sign = half_sign << 16;
292 // Unbias exponent
293 let unbiased_exp = ((half_exp as i32) >> 10) - 15;
294
295 // Check for subnormals, which will be normalized by adjusting exponent
296 if half_exp == 0 {
297 // Calculate how much to adjust the exponent by
298 let e = (half_man as u16).leading_zeros() - 6;
299
300 // Rebias and adjust exponent
301 let exp = (127 - 15 - e) << 23;
302 let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
303 return f32::from_bits(sign | exp | man);
304 }
305
306 // Rebias exponent for a normalized normal
307 let exp = ((unbiased_exp + 127) as u32) << 23;
308 let man = (half_man & 0x03FFu32) << 13;
309 f32::from_bits(sign | exp | man)
310}
311
312fn f16_to_f64_fallback(i: u16) -> f64 {
313 // Check for signed zero
314 if i & 0x7FFFu16 == 0 {
315 return f64::from_bits((i as u64) << 48);
316 }
317
318 let half_sign = (i & 0x8000u16) as u64;
319 let half_exp = (i & 0x7C00u16) as u64;
320 let half_man = (i & 0x03FFu16) as u64;
321
322 // Check for an infinity or NaN when all exponent bits set
323 if half_exp == 0x7C00u64 {
324 // Check for signed infinity if mantissa is zero
325 if half_man == 0 {
326 return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64);
327 } else {
328 // NaN, keep current mantissa but also set most significiant mantissa bit
329 return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42));
330 }
331 }
332
333 // Calculate double-precision components with adjusted exponent
334 let sign = half_sign << 48;
335 // Unbias exponent
336 let unbiased_exp = ((half_exp as i64) >> 10) - 15;
337
338 // Check for subnormals, which will be normalized by adjusting exponent
339 if half_exp == 0 {
340 // Calculate how much to adjust the exponent by
341 let e = (half_man as u16).leading_zeros() - 6;
342
343 // Rebias and adjust exponent
344 let exp = ((1023 - 15 - e) as u64) << 52;
345 let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64;
346 return f64::from_bits(sign | exp | man);
347 }
348
349 // Rebias exponent for a normalized normal
350 let exp = ((unbiased_exp + 1023) as u64) << 52;
351 let man = (half_man & 0x03FFu64) << 42;
352 f64::from_bits(sign | exp | man)
353}
354
355#[inline]
356fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] {
357 debug_assert!(v.len() >= 4);
358
359 [
360 f16_to_f32_fallback(v[0]),
361 f16_to_f32_fallback(v[1]),
362 f16_to_f32_fallback(v[2]),
363 f16_to_f32_fallback(v[3]),
364 ]
365}
366
367#[inline]
368fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] {
369 debug_assert!(v.len() >= 4);
370
371 [
372 f32_to_f16_fallback(v[0]),
373 f32_to_f16_fallback(v[1]),
374 f32_to_f16_fallback(v[2]),
375 f32_to_f16_fallback(v[3]),
376 ]
377}
378
379#[inline]
380fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] {
381 debug_assert!(v.len() >= 4);
382
383 [
384 f16_to_f64_fallback(v[0]),
385 f16_to_f64_fallback(v[1]),
386 f16_to_f64_fallback(v[2]),
387 f16_to_f64_fallback(v[3]),
388 ]
389}
390
391#[inline]
392fn f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4] {
393 debug_assert!(v.len() >= 4);
394
395 [
396 f64_to_f16_fallback(v[0]),
397 f64_to_f16_fallback(v[1]),
398 f64_to_f16_fallback(v[2]),
399 f64_to_f16_fallback(v[3]),
400 ]
401}
402
403/////////////// x86/x86_64 f16c ////////////////
404#[cfg(all(
405 feature = "use-intrinsics",
406 any(target_arch = "x86", target_arch = "x86_64")
407))]
408mod x86 {
409 use core::{mem::MaybeUninit, ptr};
410
411 #[cfg(target_arch = "x86")]
412 use core::arch::x86::{__m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT};
413 #[cfg(target_arch = "x86_64")]
414 use core::arch::x86_64::{
415 __m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT,
416 };
417
418 #[target_feature(enable = "f16c")]
419 #[inline]
420 pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
421 let mut vec = MaybeUninit::<__m128i>::zeroed();
422 vec.as_mut_ptr().cast::<u16>().write(i);
423 let retval = _mm_cvtph_ps(vec.assume_init());
424 *(&retval as *const __m128).cast()
425 }
426
427 #[target_feature(enable = "f16c")]
428 #[inline]
429 pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
430 let mut vec = MaybeUninit::<__m128>::zeroed();
431 vec.as_mut_ptr().cast::<f32>().write(f);
432 let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
433 *(&retval as *const __m128i).cast()
434 }
435
436 #[target_feature(enable = "f16c")]
437 #[inline]
438 pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4] {
439 debug_assert!(v.len() >= 4);
440
441 let mut vec = MaybeUninit::<__m128i>::zeroed();
442 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
443 let retval = _mm_cvtph_ps(vec.assume_init());
444 *(&retval as *const __m128).cast()
445 }
446
447 #[target_feature(enable = "f16c")]
448 #[inline]
449 pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4] {
450 debug_assert!(v.len() >= 4);
451
452 let mut vec = MaybeUninit::<__m128>::uninit();
453 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
454 let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
455 *(&retval as *const __m128i).cast()
456 }
457
458 #[target_feature(enable = "f16c")]
459 #[inline]
460 pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4] {
461 debug_assert!(v.len() >= 4);
462
463 let mut vec = MaybeUninit::<__m128i>::zeroed();
464 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
465 let retval = _mm_cvtph_ps(vec.assume_init());
466 let array = *(&retval as *const __m128).cast::<[f32; 4]>();
467 // Let compiler vectorize this regular cast for now.
468 // TODO: investigate auto-detecting sse2/avx convert features
469 [
470 array[0] as f64,
471 array[1] as f64,
472 array[2] as f64,
473 array[3] as f64,
474 ]
475 }
476
477 #[target_feature(enable = "f16c")]
478 #[inline]
479 pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4] {
480 debug_assert!(v.len() >= 4);
481
482 // Let compiler vectorize this regular cast for now.
483 // TODO: investigate auto-detecting sse2/avx convert features
484 let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
485
486 let mut vec = MaybeUninit::<__m128>::uninit();
487 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
488 let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
489 *(&retval as *const __m128i).cast()
490 }
491}
492