1 | #![allow (dead_code, unused_imports)] |
2 | use crate::leading_zeros::leading_zeros_u16; |
3 | use core::mem; |
4 | |
5 | #[cfg (any(target_arch = "x86" , target_arch = "x86_64" ))] |
6 | mod x86; |
7 | |
8 | #[cfg (target_arch = "aarch64" )] |
9 | mod aarch64; |
10 | |
11 | macro_rules! convert_fn { |
12 | (if x86_feature("f16c" ) { $f16c:expr } |
13 | else if aarch64_feature("fp16" ) { $aarch64:expr } |
14 | else { $fallback:expr }) => { |
15 | cfg_if::cfg_if! { |
16 | // Use intrinsics directly when a compile target or using no_std |
17 | if #[cfg(all( |
18 | any(target_arch = "x86" , target_arch = "x86_64" ), |
19 | target_feature = "f16c" |
20 | ))] { |
21 | $f16c |
22 | } |
23 | else if #[cfg(all( |
24 | target_arch = "aarch64" , |
25 | target_feature = "fp16" |
26 | ))] { |
27 | $aarch64 |
28 | |
29 | } |
30 | |
31 | // Use CPU feature detection if using std |
32 | else if #[cfg(all( |
33 | feature = "std" , |
34 | any(target_arch = "x86" , target_arch = "x86_64" ) |
35 | ))] { |
36 | use std::arch::is_x86_feature_detected; |
37 | if is_x86_feature_detected!("f16c" ) { |
38 | $f16c |
39 | } else { |
40 | $fallback |
41 | } |
42 | } |
43 | else if #[cfg(all( |
44 | feature = "std" , |
45 | target_arch = "aarch64" , |
46 | ))] { |
47 | use std::arch::is_aarch64_feature_detected; |
48 | if is_aarch64_feature_detected!("fp16" ) { |
49 | $aarch64 |
50 | } else { |
51 | $fallback |
52 | } |
53 | } |
54 | |
55 | // Fallback to software |
56 | else { |
57 | $fallback |
58 | } |
59 | } |
60 | }; |
61 | } |
62 | |
63 | #[inline ] |
64 | pub(crate) fn f32_to_f16(f: f32) -> u16 { |
65 | convert_fn! { |
66 | if x86_feature("f16c" ) { |
67 | unsafe { x86::f32_to_f16_x86_f16c(f) } |
68 | } else if aarch64_feature("fp16" ) { |
69 | unsafe { aarch64::f32_to_f16_fp16(f) } |
70 | } else { |
71 | f32_to_f16_fallback(f) |
72 | } |
73 | } |
74 | } |
75 | |
76 | #[inline ] |
77 | pub(crate) fn f64_to_f16(f: f64) -> u16 { |
78 | convert_fn! { |
79 | if x86_feature("f16c" ) { |
80 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) } |
81 | } else if aarch64_feature("fp16" ) { |
82 | unsafe { aarch64::f64_to_f16_fp16(f) } |
83 | } else { |
84 | f64_to_f16_fallback(f) |
85 | } |
86 | } |
87 | } |
88 | |
89 | #[inline ] |
90 | pub(crate) fn f16_to_f32(i: u16) -> f32 { |
91 | convert_fn! { |
92 | if x86_feature("f16c" ) { |
93 | unsafe { x86::f16_to_f32_x86_f16c(i) } |
94 | } else if aarch64_feature("fp16" ) { |
95 | unsafe { aarch64::f16_to_f32_fp16(i) } |
96 | } else { |
97 | f16_to_f32_fallback(i) |
98 | } |
99 | } |
100 | } |
101 | |
102 | #[inline ] |
103 | pub(crate) fn f16_to_f64(i: u16) -> f64 { |
104 | convert_fn! { |
105 | if x86_feature("f16c" ) { |
106 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 } |
107 | } else if aarch64_feature("fp16" ) { |
108 | unsafe { aarch64::f16_to_f64_fp16(i) } |
109 | } else { |
110 | f16_to_f64_fallback(i) |
111 | } |
112 | } |
113 | } |
114 | |
115 | #[inline ] |
116 | pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] { |
117 | convert_fn! { |
118 | if x86_feature("f16c" ) { |
119 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) } |
120 | } else if aarch64_feature("fp16" ) { |
121 | unsafe { aarch64::f32x4_to_f16x4_fp16(f) } |
122 | } else { |
123 | f32x4_to_f16x4_fallback(f) |
124 | } |
125 | } |
126 | } |
127 | |
128 | #[inline ] |
129 | pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] { |
130 | convert_fn! { |
131 | if x86_feature("f16c" ) { |
132 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) } |
133 | } else if aarch64_feature("fp16" ) { |
134 | unsafe { aarch64::f16x4_to_f32x4_fp16(i) } |
135 | } else { |
136 | f16x4_to_f32x4_fallback(i) |
137 | } |
138 | } |
139 | } |
140 | |
141 | #[inline ] |
142 | pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] { |
143 | convert_fn! { |
144 | if x86_feature("f16c" ) { |
145 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) } |
146 | } else if aarch64_feature("fp16" ) { |
147 | unsafe { aarch64::f64x4_to_f16x4_fp16(f) } |
148 | } else { |
149 | f64x4_to_f16x4_fallback(f) |
150 | } |
151 | } |
152 | } |
153 | |
154 | #[inline ] |
155 | pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] { |
156 | convert_fn! { |
157 | if x86_feature("f16c" ) { |
158 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) } |
159 | } else if aarch64_feature("fp16" ) { |
160 | unsafe { aarch64::f16x4_to_f64x4_fp16(i) } |
161 | } else { |
162 | f16x4_to_f64x4_fallback(i) |
163 | } |
164 | } |
165 | } |
166 | |
167 | #[inline ] |
168 | pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] { |
169 | convert_fn! { |
170 | if x86_feature("f16c" ) { |
171 | unsafe { x86::f32x8_to_f16x8_x86_f16c(f) } |
172 | } else if aarch64_feature("fp16" ) { |
173 | { |
174 | let mut result = [0u16; 8]; |
175 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(), |
176 | aarch64::f32x4_to_f16x4_fp16); |
177 | result |
178 | } |
179 | } else { |
180 | f32x8_to_f16x8_fallback(f) |
181 | } |
182 | } |
183 | } |
184 | |
185 | #[inline ] |
186 | pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] { |
187 | convert_fn! { |
188 | if x86_feature("f16c" ) { |
189 | unsafe { x86::f16x8_to_f32x8_x86_f16c(i) } |
190 | } else if aarch64_feature("fp16" ) { |
191 | { |
192 | let mut result = [0f32; 8]; |
193 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(), |
194 | aarch64::f16x4_to_f32x4_fp16); |
195 | result |
196 | } |
197 | } else { |
198 | f16x8_to_f32x8_fallback(i) |
199 | } |
200 | } |
201 | } |
202 | |
203 | #[inline ] |
204 | pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] { |
205 | convert_fn! { |
206 | if x86_feature("f16c" ) { |
207 | unsafe { x86::f64x8_to_f16x8_x86_f16c(f) } |
208 | } else if aarch64_feature("fp16" ) { |
209 | { |
210 | let mut result = [0u16; 8]; |
211 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(), |
212 | aarch64::f64x4_to_f16x4_fp16); |
213 | result |
214 | } |
215 | } else { |
216 | f64x8_to_f16x8_fallback(f) |
217 | } |
218 | } |
219 | } |
220 | |
221 | #[inline ] |
222 | pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] { |
223 | convert_fn! { |
224 | if x86_feature("f16c" ) { |
225 | unsafe { x86::f16x8_to_f64x8_x86_f16c(i) } |
226 | } else if aarch64_feature("fp16" ) { |
227 | { |
228 | let mut result = [0f64; 8]; |
229 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(), |
230 | aarch64::f16x4_to_f64x4_fp16); |
231 | result |
232 | } |
233 | } else { |
234 | f16x8_to_f64x8_fallback(i) |
235 | } |
236 | } |
237 | } |
238 | |
239 | #[inline ] |
240 | pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) { |
241 | convert_fn! { |
242 | if x86_feature("f16c" ) { |
243 | convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c, |
244 | x86::f32x4_to_f16x4_x86_f16c) |
245 | } else if aarch64_feature("fp16" ) { |
246 | convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16) |
247 | } else { |
248 | slice_fallback(src, dst, f32_to_f16_fallback) |
249 | } |
250 | } |
251 | } |
252 | |
253 | #[inline ] |
254 | pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) { |
255 | convert_fn! { |
256 | if x86_feature("f16c" ) { |
257 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c, |
258 | x86::f16x4_to_f32x4_x86_f16c) |
259 | } else if aarch64_feature("fp16" ) { |
260 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16) |
261 | } else { |
262 | slice_fallback(src, dst, f16_to_f32_fallback) |
263 | } |
264 | } |
265 | } |
266 | |
267 | #[inline ] |
268 | pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) { |
269 | convert_fn! { |
270 | if x86_feature("f16c" ) { |
271 | convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c, |
272 | x86::f64x4_to_f16x4_x86_f16c) |
273 | } else if aarch64_feature("fp16" ) { |
274 | convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16) |
275 | } else { |
276 | slice_fallback(src, dst, f64_to_f16_fallback) |
277 | } |
278 | } |
279 | } |
280 | |
281 | #[inline ] |
282 | pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) { |
283 | convert_fn! { |
284 | if x86_feature("f16c" ) { |
285 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c, |
286 | x86::f16x4_to_f64x4_x86_f16c) |
287 | } else if aarch64_feature("fp16" ) { |
288 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16) |
289 | } else { |
290 | slice_fallback(src, dst, f16_to_f64_fallback) |
291 | } |
292 | } |
293 | } |
294 | |
295 | macro_rules! math_fn { |
296 | (if aarch64_feature("fp16" ) { $aarch64:expr } |
297 | else { $fallback:expr }) => { |
298 | cfg_if::cfg_if! { |
299 | // Use intrinsics directly when a compile target or using no_std |
300 | if #[cfg(all( |
301 | target_arch = "aarch64" , |
302 | target_feature = "fp16" |
303 | ))] { |
304 | $aarch64 |
305 | } |
306 | |
307 | // Use CPU feature detection if using std |
308 | else if #[cfg(all( |
309 | feature = "std" , |
310 | target_arch = "aarch64" , |
311 | not(target_feature = "fp16" ) |
312 | ))] { |
313 | use std::arch::is_aarch64_feature_detected; |
314 | if is_aarch64_feature_detected!("fp16" ) { |
315 | $aarch64 |
316 | } else { |
317 | $fallback |
318 | } |
319 | } |
320 | |
321 | // Fallback to software |
322 | else { |
323 | $fallback |
324 | } |
325 | } |
326 | }; |
327 | } |
328 | |
329 | #[inline ] |
330 | pub(crate) fn add_f16(a: u16, b: u16) -> u16 { |
331 | math_fn! { |
332 | if aarch64_feature("fp16" ) { |
333 | unsafe { aarch64::add_f16_fp16(a, b) } |
334 | } else { |
335 | add_f16_fallback(a, b) |
336 | } |
337 | } |
338 | } |
339 | |
340 | #[inline ] |
341 | pub(crate) fn subtract_f16(a: u16, b: u16) -> u16 { |
342 | math_fn! { |
343 | if aarch64_feature("fp16" ) { |
344 | unsafe { aarch64::subtract_f16_fp16(a, b) } |
345 | } else { |
346 | subtract_f16_fallback(a, b) |
347 | } |
348 | } |
349 | } |
350 | |
351 | #[inline ] |
352 | pub(crate) fn multiply_f16(a: u16, b: u16) -> u16 { |
353 | math_fn! { |
354 | if aarch64_feature("fp16" ) { |
355 | unsafe { aarch64::multiply_f16_fp16(a, b) } |
356 | } else { |
357 | multiply_f16_fallback(a, b) |
358 | } |
359 | } |
360 | } |
361 | |
362 | #[inline ] |
363 | pub(crate) fn divide_f16(a: u16, b: u16) -> u16 { |
364 | math_fn! { |
365 | if aarch64_feature("fp16" ) { |
366 | unsafe { aarch64::divide_f16_fp16(a, b) } |
367 | } else { |
368 | divide_f16_fallback(a, b) |
369 | } |
370 | } |
371 | } |
372 | |
373 | #[inline ] |
374 | pub(crate) fn remainder_f16(a: u16, b: u16) -> u16 { |
375 | remainder_f16_fallback(a, b) |
376 | } |
377 | |
378 | #[inline ] |
379 | pub(crate) fn product_f16<I: Iterator<Item = u16>>(iter: I) -> u16 { |
380 | math_fn! { |
381 | if aarch64_feature("fp16" ) { |
382 | iter.fold(0, |acc, x| unsafe { aarch64::multiply_f16_fp16(acc, x) }) |
383 | } else { |
384 | product_f16_fallback(iter) |
385 | } |
386 | } |
387 | } |
388 | |
389 | #[inline ] |
390 | pub(crate) fn sum_f16<I: Iterator<Item = u16>>(iter: I) -> u16 { |
391 | math_fn! { |
392 | if aarch64_feature("fp16" ) { |
393 | iter.fold(0, |acc, x| unsafe { aarch64::add_f16_fp16(acc, x) }) |
394 | } else { |
395 | sum_f16_fallback(iter) |
396 | } |
397 | } |
398 | } |
399 | |
400 | /// Chunks sliced into x8 or x4 arrays |
401 | #[inline ] |
402 | fn convert_chunked_slice_8<S: Copy + Default, D: Copy>( |
403 | src: &[S], |
404 | dst: &mut [D], |
405 | fn8: unsafe fn(&[S; 8]) -> [D; 8], |
406 | fn4: unsafe fn(&[S; 4]) -> [D; 4], |
407 | ) { |
408 | assert_eq!(src.len(), dst.len()); |
409 | |
410 | // TODO: Can be further optimized with array_chunks when it becomes stabilized |
411 | |
412 | let src_chunks = src.chunks_exact(8); |
413 | let mut dst_chunks = dst.chunks_exact_mut(8); |
414 | let src_remainder = src_chunks.remainder(); |
415 | for (s, d) in src_chunks.zip(&mut dst_chunks) { |
416 | let chunk: &[S; 8] = s.try_into().unwrap(); |
417 | d.copy_from_slice(unsafe { &fn8(chunk) }); |
418 | } |
419 | |
420 | // Process remainder |
421 | if src_remainder.len() > 4 { |
422 | let mut buf: [S; 8] = Default::default(); |
423 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
424 | let vec = unsafe { fn8(&buf) }; |
425 | let dst_remainder = dst_chunks.into_remainder(); |
426 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
427 | } else if !src_remainder.is_empty() { |
428 | let mut buf: [S; 4] = Default::default(); |
429 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
430 | let vec = unsafe { fn4(&buf) }; |
431 | let dst_remainder = dst_chunks.into_remainder(); |
432 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
433 | } |
434 | } |
435 | |
436 | /// Chunks sliced into x4 arrays |
437 | #[inline ] |
438 | fn convert_chunked_slice_4<S: Copy + Default, D: Copy>( |
439 | src: &[S], |
440 | dst: &mut [D], |
441 | f: unsafe fn(&[S; 4]) -> [D; 4], |
442 | ) { |
443 | assert_eq!(src.len(), dst.len()); |
444 | |
445 | // TODO: Can be further optimized with array_chunks when it becomes stabilized |
446 | |
447 | let src_chunks: ChunksExact<'_, S> = src.chunks_exact(chunk_size:4); |
448 | let mut dst_chunks: ChunksExactMut<'_, D> = dst.chunks_exact_mut(chunk_size:4); |
449 | let src_remainder: &[S] = src_chunks.remainder(); |
450 | for (s: &[S], d: &mut [D]) in src_chunks.zip(&mut dst_chunks) { |
451 | let chunk: &[S; 4] = s.try_into().unwrap(); |
452 | d.copy_from_slice(src:unsafe { &f(chunk) }); |
453 | } |
454 | |
455 | // Process remainder |
456 | if !src_remainder.is_empty() { |
457 | let mut buf: [S; 4] = Default::default(); |
458 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
459 | let vec: [D; 4] = unsafe { f(&buf) }; |
460 | let dst_remainder: &mut [D] = dst_chunks.into_remainder(); |
461 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
462 | } |
463 | } |
464 | |
465 | /////////////// Fallbacks //////////////// |
466 | |
467 | // In the below functions, round to nearest, with ties to even. |
468 | // Let us call the most significant bit that will be shifted out the round_bit. |
469 | // |
470 | // Round up if either |
471 | // a) Removed part > tie. |
472 | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0 |
473 | // b) Removed part == tie, and retained part is odd. |
474 | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0 |
475 | // (If removed part == tie and retained part is even, do not round up.) |
476 | // These two conditions can be combined into one: |
477 | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0 |
478 | // which can be simplified into |
479 | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0 |
480 | |
481 | #[inline ] |
482 | pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 { |
483 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized |
484 | // Convert to raw bytes |
485 | let x: u32 = unsafe { mem::transmute(value) }; |
486 | |
487 | // Extract IEEE754 components |
488 | let sign = x & 0x8000_0000u32; |
489 | let exp = x & 0x7F80_0000u32; |
490 | let man = x & 0x007F_FFFFu32; |
491 | |
492 | // Check for all exponent bits being set, which is Infinity or NaN |
493 | if exp == 0x7F80_0000u32 { |
494 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits) |
495 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; |
496 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16; |
497 | } |
498 | |
499 | // The number is normalized, start assembling half precision version |
500 | let half_sign = sign >> 16; |
501 | // Unbias the exponent, then bias for half precision |
502 | let unbiased_exp = ((exp >> 23) as i32) - 127; |
503 | let half_exp = unbiased_exp + 15; |
504 | |
505 | // Check for exponent overflow, return +infinity |
506 | if half_exp >= 0x1F { |
507 | return (half_sign | 0x7C00u32) as u16; |
508 | } |
509 | |
510 | // Check for underflow |
511 | if half_exp <= 0 { |
512 | // Check mantissa for what we can do |
513 | if 14 - half_exp > 24 { |
514 | // No rounding possibility, so this is a full underflow, return signed zero |
515 | return half_sign as u16; |
516 | } |
517 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
518 | let man = man | 0x0080_0000u32; |
519 | let mut half_man = man >> (14 - half_exp); |
520 | // Check for rounding (see comment above functions) |
521 | let round_bit = 1 << (13 - half_exp); |
522 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
523 | half_man += 1; |
524 | } |
525 | // No exponent for subnormals |
526 | return (half_sign | half_man) as u16; |
527 | } |
528 | |
529 | // Rebias the exponent |
530 | let half_exp = (half_exp as u32) << 10; |
531 | let half_man = man >> 13; |
532 | // Check for rounding (see comment above functions) |
533 | let round_bit = 0x0000_1000u32; |
534 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
535 | // Round it |
536 | ((half_sign | half_exp | half_man) + 1) as u16 |
537 | } else { |
538 | (half_sign | half_exp | half_man) as u16 |
539 | } |
540 | } |
541 | |
542 | #[inline ] |
543 | pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 { |
544 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always |
545 | // be lost on half-precision. |
546 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized |
547 | let val: u64 = unsafe { mem::transmute(value) }; |
548 | let x = (val >> 32) as u32; |
549 | |
550 | // Extract IEEE754 components |
551 | let sign = x & 0x8000_0000u32; |
552 | let exp = x & 0x7FF0_0000u32; |
553 | let man = x & 0x000F_FFFFu32; |
554 | |
555 | // Check for all exponent bits being set, which is Infinity or NaN |
556 | if exp == 0x7FF0_0000u32 { |
557 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits). |
558 | // We also have to check the last 32 bits. |
559 | let nan_bit = if man == 0 && (val as u32 == 0) { |
560 | 0 |
561 | } else { |
562 | 0x0200u32 |
563 | }; |
564 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16; |
565 | } |
566 | |
567 | // The number is normalized, start assembling half precision version |
568 | let half_sign = sign >> 16; |
569 | // Unbias the exponent, then bias for half precision |
570 | let unbiased_exp = ((exp >> 20) as i64) - 1023; |
571 | let half_exp = unbiased_exp + 15; |
572 | |
573 | // Check for exponent overflow, return +infinity |
574 | if half_exp >= 0x1F { |
575 | return (half_sign | 0x7C00u32) as u16; |
576 | } |
577 | |
578 | // Check for underflow |
579 | if half_exp <= 0 { |
580 | // Check mantissa for what we can do |
581 | if 10 - half_exp > 21 { |
582 | // No rounding possibility, so this is a full underflow, return signed zero |
583 | return half_sign as u16; |
584 | } |
585 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
586 | let man = man | 0x0010_0000u32; |
587 | let mut half_man = man >> (11 - half_exp); |
588 | // Check for rounding (see comment above functions) |
589 | let round_bit = 1 << (10 - half_exp); |
590 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
591 | half_man += 1; |
592 | } |
593 | // No exponent for subnormals |
594 | return (half_sign | half_man) as u16; |
595 | } |
596 | |
597 | // Rebias the exponent |
598 | let half_exp = (half_exp as u32) << 10; |
599 | let half_man = man >> 10; |
600 | // Check for rounding (see comment above functions) |
601 | let round_bit = 0x0000_0200u32; |
602 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
603 | // Round it |
604 | ((half_sign | half_exp | half_man) + 1) as u16 |
605 | } else { |
606 | (half_sign | half_exp | half_man) as u16 |
607 | } |
608 | } |
609 | |
610 | #[inline ] |
611 | pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 { |
612 | // Check for signed zero |
613 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized |
614 | if i & 0x7FFFu16 == 0 { |
615 | return unsafe { mem::transmute((i as u32) << 16) }; |
616 | } |
617 | |
618 | let half_sign = (i & 0x8000u16) as u32; |
619 | let half_exp = (i & 0x7C00u16) as u32; |
620 | let half_man = (i & 0x03FFu16) as u32; |
621 | |
622 | // Check for an infinity or NaN when all exponent bits set |
623 | if half_exp == 0x7C00u32 { |
624 | // Check for signed infinity if mantissa is zero |
625 | if half_man == 0 { |
626 | return unsafe { mem::transmute((half_sign << 16) | 0x7F80_0000u32) }; |
627 | } else { |
628 | // NaN, keep current mantissa but also set most significiant mantissa bit |
629 | return unsafe { |
630 | mem::transmute((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13)) |
631 | }; |
632 | } |
633 | } |
634 | |
635 | // Calculate single-precision components with adjusted exponent |
636 | let sign = half_sign << 16; |
637 | // Unbias exponent |
638 | let unbiased_exp = ((half_exp as i32) >> 10) - 15; |
639 | |
640 | // Check for subnormals, which will be normalized by adjusting exponent |
641 | if half_exp == 0 { |
642 | // Calculate how much to adjust the exponent by |
643 | let e = leading_zeros_u16(half_man as u16) - 6; |
644 | |
645 | // Rebias and adjust exponent |
646 | let exp = (127 - 15 - e) << 23; |
647 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; |
648 | return unsafe { mem::transmute(sign | exp | man) }; |
649 | } |
650 | |
651 | // Rebias exponent for a normalized normal |
652 | let exp = ((unbiased_exp + 127) as u32) << 23; |
653 | let man = (half_man & 0x03FFu32) << 13; |
654 | unsafe { mem::transmute(sign | exp | man) } |
655 | } |
656 | |
657 | #[inline ] |
658 | pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 { |
659 | // Check for signed zero |
660 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized |
661 | if i & 0x7FFFu16 == 0 { |
662 | return unsafe { mem::transmute((i as u64) << 48) }; |
663 | } |
664 | |
665 | let half_sign = (i & 0x8000u16) as u64; |
666 | let half_exp = (i & 0x7C00u16) as u64; |
667 | let half_man = (i & 0x03FFu16) as u64; |
668 | |
669 | // Check for an infinity or NaN when all exponent bits set |
670 | if half_exp == 0x7C00u64 { |
671 | // Check for signed infinity if mantissa is zero |
672 | if half_man == 0 { |
673 | return unsafe { mem::transmute((half_sign << 48) | 0x7FF0_0000_0000_0000u64) }; |
674 | } else { |
675 | // NaN, keep current mantissa but also set most significiant mantissa bit |
676 | return unsafe { |
677 | mem::transmute((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42)) |
678 | }; |
679 | } |
680 | } |
681 | |
682 | // Calculate double-precision components with adjusted exponent |
683 | let sign = half_sign << 48; |
684 | // Unbias exponent |
685 | let unbiased_exp = ((half_exp as i64) >> 10) - 15; |
686 | |
687 | // Check for subnormals, which will be normalized by adjusting exponent |
688 | if half_exp == 0 { |
689 | // Calculate how much to adjust the exponent by |
690 | let e = leading_zeros_u16(half_man as u16) - 6; |
691 | |
692 | // Rebias and adjust exponent |
693 | let exp = ((1023 - 15 - e) as u64) << 52; |
694 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64; |
695 | return unsafe { mem::transmute(sign | exp | man) }; |
696 | } |
697 | |
698 | // Rebias exponent for a normalized normal |
699 | let exp = ((unbiased_exp + 1023) as u64) << 52; |
700 | let man = (half_man & 0x03FFu64) << 42; |
701 | unsafe { mem::transmute(sign | exp | man) } |
702 | } |
703 | |
704 | #[inline ] |
705 | fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] { |
706 | [ |
707 | f16_to_f32_fallback(v[0]), |
708 | f16_to_f32_fallback(v[1]), |
709 | f16_to_f32_fallback(v[2]), |
710 | f16_to_f32_fallback(v[3]), |
711 | ] |
712 | } |
713 | |
714 | #[inline ] |
715 | fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] { |
716 | [ |
717 | f32_to_f16_fallback(v[0]), |
718 | f32_to_f16_fallback(v[1]), |
719 | f32_to_f16_fallback(v[2]), |
720 | f32_to_f16_fallback(v[3]), |
721 | ] |
722 | } |
723 | |
724 | #[inline ] |
725 | fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] { |
726 | [ |
727 | f16_to_f64_fallback(v[0]), |
728 | f16_to_f64_fallback(v[1]), |
729 | f16_to_f64_fallback(v[2]), |
730 | f16_to_f64_fallback(v[3]), |
731 | ] |
732 | } |
733 | |
734 | #[inline ] |
735 | fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] { |
736 | [ |
737 | f64_to_f16_fallback(v[0]), |
738 | f64_to_f16_fallback(v[1]), |
739 | f64_to_f16_fallback(v[2]), |
740 | f64_to_f16_fallback(v[3]), |
741 | ] |
742 | } |
743 | |
744 | #[inline ] |
745 | fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] { |
746 | [ |
747 | f16_to_f32_fallback(v[0]), |
748 | f16_to_f32_fallback(v[1]), |
749 | f16_to_f32_fallback(v[2]), |
750 | f16_to_f32_fallback(v[3]), |
751 | f16_to_f32_fallback(v[4]), |
752 | f16_to_f32_fallback(v[5]), |
753 | f16_to_f32_fallback(v[6]), |
754 | f16_to_f32_fallback(v[7]), |
755 | ] |
756 | } |
757 | |
758 | #[inline ] |
759 | fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] { |
760 | [ |
761 | f32_to_f16_fallback(v[0]), |
762 | f32_to_f16_fallback(v[1]), |
763 | f32_to_f16_fallback(v[2]), |
764 | f32_to_f16_fallback(v[3]), |
765 | f32_to_f16_fallback(v[4]), |
766 | f32_to_f16_fallback(v[5]), |
767 | f32_to_f16_fallback(v[6]), |
768 | f32_to_f16_fallback(v[7]), |
769 | ] |
770 | } |
771 | |
772 | #[inline ] |
773 | fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] { |
774 | [ |
775 | f16_to_f64_fallback(v[0]), |
776 | f16_to_f64_fallback(v[1]), |
777 | f16_to_f64_fallback(v[2]), |
778 | f16_to_f64_fallback(v[3]), |
779 | f16_to_f64_fallback(v[4]), |
780 | f16_to_f64_fallback(v[5]), |
781 | f16_to_f64_fallback(v[6]), |
782 | f16_to_f64_fallback(v[7]), |
783 | ] |
784 | } |
785 | |
786 | #[inline ] |
787 | fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] { |
788 | [ |
789 | f64_to_f16_fallback(v[0]), |
790 | f64_to_f16_fallback(v[1]), |
791 | f64_to_f16_fallback(v[2]), |
792 | f64_to_f16_fallback(v[3]), |
793 | f64_to_f16_fallback(v[4]), |
794 | f64_to_f16_fallback(v[5]), |
795 | f64_to_f16_fallback(v[6]), |
796 | f64_to_f16_fallback(v[7]), |
797 | ] |
798 | } |
799 | |
800 | #[inline ] |
801 | fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) { |
802 | assert_eq!(src.len(), dst.len()); |
803 | for (s: S, d: &mut D) in src.iter().copied().zip(dst.iter_mut()) { |
804 | *d = f(s); |
805 | } |
806 | } |
807 | |
808 | #[inline ] |
809 | fn add_f16_fallback(a: u16, b: u16) -> u16 { |
810 | f32_to_f16(f16_to_f32(a) + f16_to_f32(b)) |
811 | } |
812 | |
813 | #[inline ] |
814 | fn subtract_f16_fallback(a: u16, b: u16) -> u16 { |
815 | f32_to_f16(f16_to_f32(a) - f16_to_f32(b)) |
816 | } |
817 | |
818 | #[inline ] |
819 | fn multiply_f16_fallback(a: u16, b: u16) -> u16 { |
820 | f32_to_f16(f16_to_f32(a) * f16_to_f32(b)) |
821 | } |
822 | |
823 | #[inline ] |
824 | fn divide_f16_fallback(a: u16, b: u16) -> u16 { |
825 | f32_to_f16(f16_to_f32(a) / f16_to_f32(b)) |
826 | } |
827 | |
828 | #[inline ] |
829 | fn remainder_f16_fallback(a: u16, b: u16) -> u16 { |
830 | f32_to_f16(f16_to_f32(a) % f16_to_f32(b)) |
831 | } |
832 | |
833 | #[inline ] |
834 | fn product_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 { |
835 | f32_to_f16(iter.map(f16_to_f32).product()) |
836 | } |
837 | |
838 | #[inline ] |
839 | fn sum_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 { |
840 | f32_to_f16(iter.map(f16_to_f32).sum()) |
841 | } |
842 | |
843 | // TODO SIMD arithmetic |
844 | |