1 | #![allow (dead_code, unused_imports)]
|
2 | use crate::leading_zeros::leading_zeros_u16;
|
3 | use core::mem;
|
4 |
|
5 | #[cfg (any(target_arch = "x86" , target_arch = "x86_64" ))]
|
6 | mod x86;
|
7 |
|
8 | #[cfg (target_arch = "aarch64" )]
|
9 | mod aarch64;
|
10 |
|
11 | macro_rules! convert_fn {
|
12 | (if x86_feature("f16c" ) { $f16c:expr }
|
13 | else if aarch64_feature("fp16" ) { $aarch64:expr }
|
14 | else { $fallback:expr }) => {
|
15 | cfg_if::cfg_if! {
|
16 | // Use intrinsics directly when a compile target or using no_std
|
17 | if #[cfg(all(
|
18 | any(target_arch = "x86" , target_arch = "x86_64" ),
|
19 | target_feature = "f16c"
|
20 | ))] {
|
21 | $f16c
|
22 | }
|
23 | else if #[cfg(all(
|
24 | target_arch = "aarch64" ,
|
25 | target_feature = "fp16"
|
26 | ))] {
|
27 | $aarch64
|
28 |
|
29 | }
|
30 |
|
31 | // Use CPU feature detection if using std
|
32 | else if #[cfg(all(
|
33 | feature = "std" ,
|
34 | any(target_arch = "x86" , target_arch = "x86_64" )
|
35 | ))] {
|
36 | use std::arch::is_x86_feature_detected;
|
37 | if is_x86_feature_detected!("f16c" ) {
|
38 | $f16c
|
39 | } else {
|
40 | $fallback
|
41 | }
|
42 | }
|
43 | else if #[cfg(all(
|
44 | feature = "std" ,
|
45 | target_arch = "aarch64" ,
|
46 | ))] {
|
47 | use std::arch::is_aarch64_feature_detected;
|
48 | if is_aarch64_feature_detected!("fp16" ) {
|
49 | $aarch64
|
50 | } else {
|
51 | $fallback
|
52 | }
|
53 | }
|
54 |
|
55 | // Fallback to software
|
56 | else {
|
57 | $fallback
|
58 | }
|
59 | }
|
60 | };
|
61 | }
|
62 |
|
63 | #[inline ]
|
64 | pub(crate) fn f32_to_f16(f: f32) -> u16 {
|
65 | convert_fn! {
|
66 | if x86_feature("f16c" ) {
|
67 | unsafe { x86::f32_to_f16_x86_f16c(f) }
|
68 | } else if aarch64_feature("fp16" ) {
|
69 | unsafe { aarch64::f32_to_f16_fp16(f) }
|
70 | } else {
|
71 | f32_to_f16_fallback(f)
|
72 | }
|
73 | }
|
74 | }
|
75 |
|
76 | #[inline ]
|
77 | pub(crate) fn f64_to_f16(f: f64) -> u16 {
|
78 | convert_fn! {
|
79 | if x86_feature("f16c" ) {
|
80 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
|
81 | } else if aarch64_feature("fp16" ) {
|
82 | unsafe { aarch64::f64_to_f16_fp16(f) }
|
83 | } else {
|
84 | f64_to_f16_fallback(f)
|
85 | }
|
86 | }
|
87 | }
|
88 |
|
89 | #[inline ]
|
90 | pub(crate) fn f16_to_f32(i: u16) -> f32 {
|
91 | convert_fn! {
|
92 | if x86_feature("f16c" ) {
|
93 | unsafe { x86::f16_to_f32_x86_f16c(i) }
|
94 | } else if aarch64_feature("fp16" ) {
|
95 | unsafe { aarch64::f16_to_f32_fp16(i) }
|
96 | } else {
|
97 | f16_to_f32_fallback(i)
|
98 | }
|
99 | }
|
100 | }
|
101 |
|
102 | #[inline ]
|
103 | pub(crate) fn f16_to_f64(i: u16) -> f64 {
|
104 | convert_fn! {
|
105 | if x86_feature("f16c" ) {
|
106 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
|
107 | } else if aarch64_feature("fp16" ) {
|
108 | unsafe { aarch64::f16_to_f64_fp16(i) }
|
109 | } else {
|
110 | f16_to_f64_fallback(i)
|
111 | }
|
112 | }
|
113 | }
|
114 |
|
115 | #[inline ]
|
116 | pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] {
|
117 | convert_fn! {
|
118 | if x86_feature("f16c" ) {
|
119 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
|
120 | } else if aarch64_feature("fp16" ) {
|
121 | unsafe { aarch64::f32x4_to_f16x4_fp16(f) }
|
122 | } else {
|
123 | f32x4_to_f16x4_fallback(f)
|
124 | }
|
125 | }
|
126 | }
|
127 |
|
128 | #[inline ]
|
129 | pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] {
|
130 | convert_fn! {
|
131 | if x86_feature("f16c" ) {
|
132 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
|
133 | } else if aarch64_feature("fp16" ) {
|
134 | unsafe { aarch64::f16x4_to_f32x4_fp16(i) }
|
135 | } else {
|
136 | f16x4_to_f32x4_fallback(i)
|
137 | }
|
138 | }
|
139 | }
|
140 |
|
141 | #[inline ]
|
142 | pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] {
|
143 | convert_fn! {
|
144 | if x86_feature("f16c" ) {
|
145 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
|
146 | } else if aarch64_feature("fp16" ) {
|
147 | unsafe { aarch64::f64x4_to_f16x4_fp16(f) }
|
148 | } else {
|
149 | f64x4_to_f16x4_fallback(f)
|
150 | }
|
151 | }
|
152 | }
|
153 |
|
154 | #[inline ]
|
155 | pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] {
|
156 | convert_fn! {
|
157 | if x86_feature("f16c" ) {
|
158 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
|
159 | } else if aarch64_feature("fp16" ) {
|
160 | unsafe { aarch64::f16x4_to_f64x4_fp16(i) }
|
161 | } else {
|
162 | f16x4_to_f64x4_fallback(i)
|
163 | }
|
164 | }
|
165 | }
|
166 |
|
167 | #[inline ]
|
168 | pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] {
|
169 | convert_fn! {
|
170 | if x86_feature("f16c" ) {
|
171 | unsafe { x86::f32x8_to_f16x8_x86_f16c(f) }
|
172 | } else if aarch64_feature("fp16" ) {
|
173 | {
|
174 | let mut result = [0u16; 8];
|
175 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
|
176 | aarch64::f32x4_to_f16x4_fp16);
|
177 | result
|
178 | }
|
179 | } else {
|
180 | f32x8_to_f16x8_fallback(f)
|
181 | }
|
182 | }
|
183 | }
|
184 |
|
185 | #[inline ]
|
186 | pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] {
|
187 | convert_fn! {
|
188 | if x86_feature("f16c" ) {
|
189 | unsafe { x86::f16x8_to_f32x8_x86_f16c(i) }
|
190 | } else if aarch64_feature("fp16" ) {
|
191 | {
|
192 | let mut result = [0f32; 8];
|
193 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
|
194 | aarch64::f16x4_to_f32x4_fp16);
|
195 | result
|
196 | }
|
197 | } else {
|
198 | f16x8_to_f32x8_fallback(i)
|
199 | }
|
200 | }
|
201 | }
|
202 |
|
203 | #[inline ]
|
204 | pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] {
|
205 | convert_fn! {
|
206 | if x86_feature("f16c" ) {
|
207 | unsafe { x86::f64x8_to_f16x8_x86_f16c(f) }
|
208 | } else if aarch64_feature("fp16" ) {
|
209 | {
|
210 | let mut result = [0u16; 8];
|
211 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
|
212 | aarch64::f64x4_to_f16x4_fp16);
|
213 | result
|
214 | }
|
215 | } else {
|
216 | f64x8_to_f16x8_fallback(f)
|
217 | }
|
218 | }
|
219 | }
|
220 |
|
221 | #[inline ]
|
222 | pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] {
|
223 | convert_fn! {
|
224 | if x86_feature("f16c" ) {
|
225 | unsafe { x86::f16x8_to_f64x8_x86_f16c(i) }
|
226 | } else if aarch64_feature("fp16" ) {
|
227 | {
|
228 | let mut result = [0f64; 8];
|
229 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
|
230 | aarch64::f16x4_to_f64x4_fp16);
|
231 | result
|
232 | }
|
233 | } else {
|
234 | f16x8_to_f64x8_fallback(i)
|
235 | }
|
236 | }
|
237 | }
|
238 |
|
239 | #[inline ]
|
240 | pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) {
|
241 | convert_fn! {
|
242 | if x86_feature("f16c" ) {
|
243 | convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c,
|
244 | x86::f32x4_to_f16x4_x86_f16c)
|
245 | } else if aarch64_feature("fp16" ) {
|
246 | convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16)
|
247 | } else {
|
248 | slice_fallback(src, dst, f32_to_f16_fallback)
|
249 | }
|
250 | }
|
251 | }
|
252 |
|
253 | #[inline ]
|
254 | pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) {
|
255 | convert_fn! {
|
256 | if x86_feature("f16c" ) {
|
257 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c,
|
258 | x86::f16x4_to_f32x4_x86_f16c)
|
259 | } else if aarch64_feature("fp16" ) {
|
260 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16)
|
261 | } else {
|
262 | slice_fallback(src, dst, f16_to_f32_fallback)
|
263 | }
|
264 | }
|
265 | }
|
266 |
|
267 | #[inline ]
|
268 | pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) {
|
269 | convert_fn! {
|
270 | if x86_feature("f16c" ) {
|
271 | convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c,
|
272 | x86::f64x4_to_f16x4_x86_f16c)
|
273 | } else if aarch64_feature("fp16" ) {
|
274 | convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16)
|
275 | } else {
|
276 | slice_fallback(src, dst, f64_to_f16_fallback)
|
277 | }
|
278 | }
|
279 | }
|
280 |
|
281 | #[inline ]
|
282 | pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) {
|
283 | convert_fn! {
|
284 | if x86_feature("f16c" ) {
|
285 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c,
|
286 | x86::f16x4_to_f64x4_x86_f16c)
|
287 | } else if aarch64_feature("fp16" ) {
|
288 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16)
|
289 | } else {
|
290 | slice_fallback(src, dst, f16_to_f64_fallback)
|
291 | }
|
292 | }
|
293 | }
|
294 |
|
295 | macro_rules! math_fn {
|
296 | (if aarch64_feature("fp16" ) { $aarch64:expr }
|
297 | else { $fallback:expr }) => {
|
298 | cfg_if::cfg_if! {
|
299 | // Use intrinsics directly when a compile target or using no_std
|
300 | if #[cfg(all(
|
301 | target_arch = "aarch64" ,
|
302 | target_feature = "fp16"
|
303 | ))] {
|
304 | $aarch64
|
305 | }
|
306 |
|
307 | // Use CPU feature detection if using std
|
308 | else if #[cfg(all(
|
309 | feature = "std" ,
|
310 | target_arch = "aarch64" ,
|
311 | not(target_feature = "fp16" )
|
312 | ))] {
|
313 | use std::arch::is_aarch64_feature_detected;
|
314 | if is_aarch64_feature_detected!("fp16" ) {
|
315 | $aarch64
|
316 | } else {
|
317 | $fallback
|
318 | }
|
319 | }
|
320 |
|
321 | // Fallback to software
|
322 | else {
|
323 | $fallback
|
324 | }
|
325 | }
|
326 | };
|
327 | }
|
328 |
|
329 | #[inline ]
|
330 | pub(crate) fn add_f16(a: u16, b: u16) -> u16 {
|
331 | math_fn! {
|
332 | if aarch64_feature("fp16" ) {
|
333 | unsafe { aarch64::add_f16_fp16(a, b) }
|
334 | } else {
|
335 | add_f16_fallback(a, b)
|
336 | }
|
337 | }
|
338 | }
|
339 |
|
340 | #[inline ]
|
341 | pub(crate) fn subtract_f16(a: u16, b: u16) -> u16 {
|
342 | math_fn! {
|
343 | if aarch64_feature("fp16" ) {
|
344 | unsafe { aarch64::subtract_f16_fp16(a, b) }
|
345 | } else {
|
346 | subtract_f16_fallback(a, b)
|
347 | }
|
348 | }
|
349 | }
|
350 |
|
351 | #[inline ]
|
352 | pub(crate) fn multiply_f16(a: u16, b: u16) -> u16 {
|
353 | math_fn! {
|
354 | if aarch64_feature("fp16" ) {
|
355 | unsafe { aarch64::multiply_f16_fp16(a, b) }
|
356 | } else {
|
357 | multiply_f16_fallback(a, b)
|
358 | }
|
359 | }
|
360 | }
|
361 |
|
362 | #[inline ]
|
363 | pub(crate) fn divide_f16(a: u16, b: u16) -> u16 {
|
364 | math_fn! {
|
365 | if aarch64_feature("fp16" ) {
|
366 | unsafe { aarch64::divide_f16_fp16(a, b) }
|
367 | } else {
|
368 | divide_f16_fallback(a, b)
|
369 | }
|
370 | }
|
371 | }
|
372 |
|
373 | #[inline ]
|
374 | pub(crate) fn remainder_f16(a: u16, b: u16) -> u16 {
|
375 | remainder_f16_fallback(a, b)
|
376 | }
|
377 |
|
378 | #[inline ]
|
379 | pub(crate) fn product_f16<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
380 | math_fn! {
|
381 | if aarch64_feature("fp16" ) {
|
382 | iter.fold(0, |acc, x| unsafe { aarch64::multiply_f16_fp16(acc, x) })
|
383 | } else {
|
384 | product_f16_fallback(iter)
|
385 | }
|
386 | }
|
387 | }
|
388 |
|
389 | #[inline ]
|
390 | pub(crate) fn sum_f16<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
391 | math_fn! {
|
392 | if aarch64_feature("fp16" ) {
|
393 | iter.fold(0, |acc, x| unsafe { aarch64::add_f16_fp16(acc, x) })
|
394 | } else {
|
395 | sum_f16_fallback(iter)
|
396 | }
|
397 | }
|
398 | }
|
399 |
|
400 | /// Chunks sliced into x8 or x4 arrays
|
401 | #[inline ]
|
402 | fn convert_chunked_slice_8<S: Copy + Default, D: Copy>(
|
403 | src: &[S],
|
404 | dst: &mut [D],
|
405 | fn8: unsafe fn(&[S; 8]) -> [D; 8],
|
406 | fn4: unsafe fn(&[S; 4]) -> [D; 4],
|
407 | ) {
|
408 | assert_eq!(src.len(), dst.len());
|
409 |
|
410 | // TODO: Can be further optimized with array_chunks when it becomes stabilized
|
411 |
|
412 | let src_chunks = src.chunks_exact(8);
|
413 | let mut dst_chunks = dst.chunks_exact_mut(8);
|
414 | let src_remainder = src_chunks.remainder();
|
415 | for (s, d) in src_chunks.zip(&mut dst_chunks) {
|
416 | let chunk: &[S; 8] = s.try_into().unwrap();
|
417 | d.copy_from_slice(unsafe { &fn8(chunk) });
|
418 | }
|
419 |
|
420 | // Process remainder
|
421 | if src_remainder.len() > 4 {
|
422 | let mut buf: [S; 8] = Default::default();
|
423 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
424 | let vec = unsafe { fn8(&buf) };
|
425 | let dst_remainder = dst_chunks.into_remainder();
|
426 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
427 | } else if !src_remainder.is_empty() {
|
428 | let mut buf: [S; 4] = Default::default();
|
429 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
430 | let vec = unsafe { fn4(&buf) };
|
431 | let dst_remainder = dst_chunks.into_remainder();
|
432 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
433 | }
|
434 | }
|
435 |
|
436 | /// Chunks sliced into x4 arrays
|
437 | #[inline ]
|
438 | fn convert_chunked_slice_4<S: Copy + Default, D: Copy>(
|
439 | src: &[S],
|
440 | dst: &mut [D],
|
441 | f: unsafe fn(&[S; 4]) -> [D; 4],
|
442 | ) {
|
443 | assert_eq!(src.len(), dst.len());
|
444 |
|
445 | // TODO: Can be further optimized with array_chunks when it becomes stabilized
|
446 |
|
447 | let src_chunks: ChunksExact<'_, S> = src.chunks_exact(chunk_size:4);
|
448 | let mut dst_chunks: ChunksExactMut<'_, D> = dst.chunks_exact_mut(chunk_size:4);
|
449 | let src_remainder: &[S] = src_chunks.remainder();
|
450 | for (s: &[S], d: &mut [D]) in src_chunks.zip(&mut dst_chunks) {
|
451 | let chunk: &[S; 4] = s.try_into().unwrap();
|
452 | d.copy_from_slice(src:unsafe { &f(chunk) });
|
453 | }
|
454 |
|
455 | // Process remainder
|
456 | if !src_remainder.is_empty() {
|
457 | let mut buf: [S; 4] = Default::default();
|
458 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
459 | let vec: [D; 4] = unsafe { f(&buf) };
|
460 | let dst_remainder: &mut [D] = dst_chunks.into_remainder();
|
461 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
462 | }
|
463 | }
|
464 |
|
465 | /////////////// Fallbacks ////////////////
|
466 |
|
467 | // In the below functions, round to nearest, with ties to even.
|
468 | // Let us call the most significant bit that will be shifted out the round_bit.
|
469 | //
|
470 | // Round up if either
|
471 | // a) Removed part > tie.
|
472 | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
|
473 | // b) Removed part == tie, and retained part is odd.
|
474 | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
|
475 | // (If removed part == tie and retained part is even, do not round up.)
|
476 | // These two conditions can be combined into one:
|
477 | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
|
478 | // which can be simplified into
|
479 | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
|
480 |
|
481 | #[inline ]
|
482 | pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 {
|
483 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
|
484 | // Convert to raw bytes
|
485 | let x: u32 = unsafe { mem::transmute::<f32, u32>(value) };
|
486 |
|
487 | // Extract IEEE754 components
|
488 | let sign = x & 0x8000_0000u32;
|
489 | let exp = x & 0x7F80_0000u32;
|
490 | let man = x & 0x007F_FFFFu32;
|
491 |
|
492 | // Check for all exponent bits being set, which is Infinity or NaN
|
493 | if exp == 0x7F80_0000u32 {
|
494 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
|
495 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 };
|
496 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16;
|
497 | }
|
498 |
|
499 | // The number is normalized, start assembling half precision version
|
500 | let half_sign = sign >> 16;
|
501 | // Unbias the exponent, then bias for half precision
|
502 | let unbiased_exp = ((exp >> 23) as i32) - 127;
|
503 | let half_exp = unbiased_exp + 15;
|
504 |
|
505 | // Check for exponent overflow, return +infinity
|
506 | if half_exp >= 0x1F {
|
507 | return (half_sign | 0x7C00u32) as u16;
|
508 | }
|
509 |
|
510 | // Check for underflow
|
511 | if half_exp <= 0 {
|
512 | // Check mantissa for what we can do
|
513 | if 14 - half_exp > 24 {
|
514 | // No rounding possibility, so this is a full underflow, return signed zero
|
515 | return half_sign as u16;
|
516 | }
|
517 | // Don't forget about hidden leading mantissa bit when assembling mantissa
|
518 | let man = man | 0x0080_0000u32;
|
519 | let mut half_man = man >> (14 - half_exp);
|
520 | // Check for rounding (see comment above functions)
|
521 | let round_bit = 1 << (13 - half_exp);
|
522 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
523 | half_man += 1;
|
524 | }
|
525 | // No exponent for subnormals
|
526 | return (half_sign | half_man) as u16;
|
527 | }
|
528 |
|
529 | // Rebias the exponent
|
530 | let half_exp = (half_exp as u32) << 10;
|
531 | let half_man = man >> 13;
|
532 | // Check for rounding (see comment above functions)
|
533 | let round_bit = 0x0000_1000u32;
|
534 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
535 | // Round it
|
536 | ((half_sign | half_exp | half_man) + 1) as u16
|
537 | } else {
|
538 | (half_sign | half_exp | half_man) as u16
|
539 | }
|
540 | }
|
541 |
|
542 | #[inline ]
|
543 | pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 {
|
544 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
|
545 | // be lost on half-precision.
|
546 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
|
547 | let val: u64 = unsafe { mem::transmute::<f64, u64>(value) };
|
548 | let x = (val >> 32) as u32;
|
549 |
|
550 | // Extract IEEE754 components
|
551 | let sign = x & 0x8000_0000u32;
|
552 | let exp = x & 0x7FF0_0000u32;
|
553 | let man = x & 0x000F_FFFFu32;
|
554 |
|
555 | // Check for all exponent bits being set, which is Infinity or NaN
|
556 | if exp == 0x7FF0_0000u32 {
|
557 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
|
558 | // We also have to check the last 32 bits.
|
559 | let nan_bit = if man == 0 && (val as u32 == 0) {
|
560 | 0
|
561 | } else {
|
562 | 0x0200u32
|
563 | };
|
564 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16;
|
565 | }
|
566 |
|
567 | // The number is normalized, start assembling half precision version
|
568 | let half_sign = sign >> 16;
|
569 | // Unbias the exponent, then bias for half precision
|
570 | let unbiased_exp = ((exp >> 20) as i64) - 1023;
|
571 | let half_exp = unbiased_exp + 15;
|
572 |
|
573 | // Check for exponent overflow, return +infinity
|
574 | if half_exp >= 0x1F {
|
575 | return (half_sign | 0x7C00u32) as u16;
|
576 | }
|
577 |
|
578 | // Check for underflow
|
579 | if half_exp <= 0 {
|
580 | // Check mantissa for what we can do
|
581 | if 10 - half_exp > 21 {
|
582 | // No rounding possibility, so this is a full underflow, return signed zero
|
583 | return half_sign as u16;
|
584 | }
|
585 | // Don't forget about hidden leading mantissa bit when assembling mantissa
|
586 | let man = man | 0x0010_0000u32;
|
587 | let mut half_man = man >> (11 - half_exp);
|
588 | // Check for rounding (see comment above functions)
|
589 | let round_bit = 1 << (10 - half_exp);
|
590 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
591 | half_man += 1;
|
592 | }
|
593 | // No exponent for subnormals
|
594 | return (half_sign | half_man) as u16;
|
595 | }
|
596 |
|
597 | // Rebias the exponent
|
598 | let half_exp = (half_exp as u32) << 10;
|
599 | let half_man = man >> 10;
|
600 | // Check for rounding (see comment above functions)
|
601 | let round_bit = 0x0000_0200u32;
|
602 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
603 | // Round it
|
604 | ((half_sign | half_exp | half_man) + 1) as u16
|
605 | } else {
|
606 | (half_sign | half_exp | half_man) as u16
|
607 | }
|
608 | }
|
609 |
|
610 | #[inline ]
|
611 | pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 {
|
612 | // Check for signed zero
|
613 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
|
614 | if i & 0x7FFFu16 == 0 {
|
615 | return unsafe { mem::transmute::<u32, f32>((i as u32) << 16) };
|
616 | }
|
617 |
|
618 | let half_sign = (i & 0x8000u16) as u32;
|
619 | let half_exp = (i & 0x7C00u16) as u32;
|
620 | let half_man = (i & 0x03FFu16) as u32;
|
621 |
|
622 | // Check for an infinity or NaN when all exponent bits set
|
623 | if half_exp == 0x7C00u32 {
|
624 | // Check for signed infinity if mantissa is zero
|
625 | if half_man == 0 {
|
626 | return unsafe { mem::transmute::<u32, f32>((half_sign << 16) | 0x7F80_0000u32) };
|
627 | } else {
|
628 | // NaN, keep current mantissa but also set most significiant mantissa bit
|
629 | return unsafe {
|
630 | mem::transmute::<u32, f32>((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13))
|
631 | };
|
632 | }
|
633 | }
|
634 |
|
635 | // Calculate single-precision components with adjusted exponent
|
636 | let sign = half_sign << 16;
|
637 | // Unbias exponent
|
638 | let unbiased_exp = ((half_exp as i32) >> 10) - 15;
|
639 |
|
640 | // Check for subnormals, which will be normalized by adjusting exponent
|
641 | if half_exp == 0 {
|
642 | // Calculate how much to adjust the exponent by
|
643 | let e = leading_zeros_u16(half_man as u16) - 6;
|
644 |
|
645 | // Rebias and adjust exponent
|
646 | let exp = (127 - 15 - e) << 23;
|
647 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
|
648 | return unsafe { mem::transmute::<u32, f32>(sign | exp | man) };
|
649 | }
|
650 |
|
651 | // Rebias exponent for a normalized normal
|
652 | let exp = ((unbiased_exp + 127) as u32) << 23;
|
653 | let man = (half_man & 0x03FFu32) << 13;
|
654 | unsafe { mem::transmute::<u32, f32>(sign | exp | man) }
|
655 | }
|
656 |
|
657 | #[inline ]
|
658 | pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 {
|
659 | // Check for signed zero
|
660 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
|
661 | if i & 0x7FFFu16 == 0 {
|
662 | return unsafe { mem::transmute::<u64, f64>((i as u64) << 48) };
|
663 | }
|
664 |
|
665 | let half_sign = (i & 0x8000u16) as u64;
|
666 | let half_exp = (i & 0x7C00u16) as u64;
|
667 | let half_man = (i & 0x03FFu16) as u64;
|
668 |
|
669 | // Check for an infinity or NaN when all exponent bits set
|
670 | if half_exp == 0x7C00u64 {
|
671 | // Check for signed infinity if mantissa is zero
|
672 | if half_man == 0 {
|
673 | return unsafe {
|
674 | mem::transmute::<u64, f64>((half_sign << 48) | 0x7FF0_0000_0000_0000u64)
|
675 | };
|
676 | } else {
|
677 | // NaN, keep current mantissa but also set most significiant mantissa bit
|
678 | return unsafe {
|
679 | mem::transmute::<u64, f64>(
|
680 | (half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42),
|
681 | )
|
682 | };
|
683 | }
|
684 | }
|
685 |
|
686 | // Calculate double-precision components with adjusted exponent
|
687 | let sign = half_sign << 48;
|
688 | // Unbias exponent
|
689 | let unbiased_exp = ((half_exp as i64) >> 10) - 15;
|
690 |
|
691 | // Check for subnormals, which will be normalized by adjusting exponent
|
692 | if half_exp == 0 {
|
693 | // Calculate how much to adjust the exponent by
|
694 | let e = leading_zeros_u16(half_man as u16) - 6;
|
695 |
|
696 | // Rebias and adjust exponent
|
697 | let exp = ((1023 - 15 - e) as u64) << 52;
|
698 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64;
|
699 | return unsafe { mem::transmute::<u64, f64>(sign | exp | man) };
|
700 | }
|
701 |
|
702 | // Rebias exponent for a normalized normal
|
703 | let exp = ((unbiased_exp + 1023) as u64) << 52;
|
704 | let man = (half_man & 0x03FFu64) << 42;
|
705 | unsafe { mem::transmute::<u64, f64>(sign | exp | man) }
|
706 | }
|
707 |
|
708 | #[inline ]
|
709 | fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] {
|
710 | [
|
711 | f16_to_f32_fallback(v[0]),
|
712 | f16_to_f32_fallback(v[1]),
|
713 | f16_to_f32_fallback(v[2]),
|
714 | f16_to_f32_fallback(v[3]),
|
715 | ]
|
716 | }
|
717 |
|
718 | #[inline ]
|
719 | fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] {
|
720 | [
|
721 | f32_to_f16_fallback(v[0]),
|
722 | f32_to_f16_fallback(v[1]),
|
723 | f32_to_f16_fallback(v[2]),
|
724 | f32_to_f16_fallback(v[3]),
|
725 | ]
|
726 | }
|
727 |
|
728 | #[inline ]
|
729 | fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] {
|
730 | [
|
731 | f16_to_f64_fallback(v[0]),
|
732 | f16_to_f64_fallback(v[1]),
|
733 | f16_to_f64_fallback(v[2]),
|
734 | f16_to_f64_fallback(v[3]),
|
735 | ]
|
736 | }
|
737 |
|
738 | #[inline ]
|
739 | fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] {
|
740 | [
|
741 | f64_to_f16_fallback(v[0]),
|
742 | f64_to_f16_fallback(v[1]),
|
743 | f64_to_f16_fallback(v[2]),
|
744 | f64_to_f16_fallback(v[3]),
|
745 | ]
|
746 | }
|
747 |
|
748 | #[inline ]
|
749 | fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] {
|
750 | [
|
751 | f16_to_f32_fallback(v[0]),
|
752 | f16_to_f32_fallback(v[1]),
|
753 | f16_to_f32_fallback(v[2]),
|
754 | f16_to_f32_fallback(v[3]),
|
755 | f16_to_f32_fallback(v[4]),
|
756 | f16_to_f32_fallback(v[5]),
|
757 | f16_to_f32_fallback(v[6]),
|
758 | f16_to_f32_fallback(v[7]),
|
759 | ]
|
760 | }
|
761 |
|
762 | #[inline ]
|
763 | fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] {
|
764 | [
|
765 | f32_to_f16_fallback(v[0]),
|
766 | f32_to_f16_fallback(v[1]),
|
767 | f32_to_f16_fallback(v[2]),
|
768 | f32_to_f16_fallback(v[3]),
|
769 | f32_to_f16_fallback(v[4]),
|
770 | f32_to_f16_fallback(v[5]),
|
771 | f32_to_f16_fallback(v[6]),
|
772 | f32_to_f16_fallback(v[7]),
|
773 | ]
|
774 | }
|
775 |
|
776 | #[inline ]
|
777 | fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] {
|
778 | [
|
779 | f16_to_f64_fallback(v[0]),
|
780 | f16_to_f64_fallback(v[1]),
|
781 | f16_to_f64_fallback(v[2]),
|
782 | f16_to_f64_fallback(v[3]),
|
783 | f16_to_f64_fallback(v[4]),
|
784 | f16_to_f64_fallback(v[5]),
|
785 | f16_to_f64_fallback(v[6]),
|
786 | f16_to_f64_fallback(v[7]),
|
787 | ]
|
788 | }
|
789 |
|
790 | #[inline ]
|
791 | fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] {
|
792 | [
|
793 | f64_to_f16_fallback(v[0]),
|
794 | f64_to_f16_fallback(v[1]),
|
795 | f64_to_f16_fallback(v[2]),
|
796 | f64_to_f16_fallback(v[3]),
|
797 | f64_to_f16_fallback(v[4]),
|
798 | f64_to_f16_fallback(v[5]),
|
799 | f64_to_f16_fallback(v[6]),
|
800 | f64_to_f16_fallback(v[7]),
|
801 | ]
|
802 | }
|
803 |
|
804 | #[inline ]
|
805 | fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) {
|
806 | assert_eq!(src.len(), dst.len());
|
807 | for (s: S, d: &mut D) in src.iter().copied().zip(dst.iter_mut()) {
|
808 | *d = f(s);
|
809 | }
|
810 | }
|
811 |
|
812 | #[inline ]
|
813 | fn add_f16_fallback(a: u16, b: u16) -> u16 {
|
814 | f32_to_f16(f16_to_f32(a) + f16_to_f32(b))
|
815 | }
|
816 |
|
817 | #[inline ]
|
818 | fn subtract_f16_fallback(a: u16, b: u16) -> u16 {
|
819 | f32_to_f16(f16_to_f32(a) - f16_to_f32(b))
|
820 | }
|
821 |
|
822 | #[inline ]
|
823 | fn multiply_f16_fallback(a: u16, b: u16) -> u16 {
|
824 | f32_to_f16(f16_to_f32(a) * f16_to_f32(b))
|
825 | }
|
826 |
|
827 | #[inline ]
|
828 | fn divide_f16_fallback(a: u16, b: u16) -> u16 {
|
829 | f32_to_f16(f16_to_f32(a) / f16_to_f32(b))
|
830 | }
|
831 |
|
832 | #[inline ]
|
833 | fn remainder_f16_fallback(a: u16, b: u16) -> u16 {
|
834 | f32_to_f16(f16_to_f32(a) % f16_to_f32(b))
|
835 | }
|
836 |
|
837 | #[inline ]
|
838 | fn product_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
839 | f32_to_f16(iter.map(f16_to_f32).product())
|
840 | }
|
841 |
|
842 | #[inline ]
|
843 | fn sum_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
844 | f32_to_f16(iter.map(f16_to_f32).sum())
|
845 | }
|
846 |
|
847 | // TODO SIMD arithmetic
|
848 | |