| 1 | #![allow (dead_code, unused_imports)]
|
| 2 | use crate::leading_zeros::leading_zeros_u16;
|
| 3 | use core::mem;
|
| 4 |
|
| 5 | #[cfg (any(target_arch = "x86" , target_arch = "x86_64" ))]
|
| 6 | mod x86;
|
| 7 |
|
| 8 | #[cfg (target_arch = "aarch64" )]
|
| 9 | mod aarch64;
|
| 10 |
|
| 11 | macro_rules! convert_fn {
|
| 12 | (if x86_feature("f16c" ) { $f16c:expr }
|
| 13 | else if aarch64_feature("fp16" ) { $aarch64:expr }
|
| 14 | else { $fallback:expr }) => {
|
| 15 | cfg_if::cfg_if! {
|
| 16 | // Use intrinsics directly when a compile target or using no_std
|
| 17 | if #[cfg(all(
|
| 18 | any(target_arch = "x86" , target_arch = "x86_64" ),
|
| 19 | target_feature = "f16c"
|
| 20 | ))] {
|
| 21 | $f16c
|
| 22 | }
|
| 23 | else if #[cfg(all(
|
| 24 | target_arch = "aarch64" ,
|
| 25 | target_feature = "fp16"
|
| 26 | ))] {
|
| 27 | $aarch64
|
| 28 |
|
| 29 | }
|
| 30 |
|
| 31 | // Use CPU feature detection if using std
|
| 32 | else if #[cfg(all(
|
| 33 | feature = "std" ,
|
| 34 | any(target_arch = "x86" , target_arch = "x86_64" )
|
| 35 | ))] {
|
| 36 | use std::arch::is_x86_feature_detected;
|
| 37 | if is_x86_feature_detected!("f16c" ) {
|
| 38 | $f16c
|
| 39 | } else {
|
| 40 | $fallback
|
| 41 | }
|
| 42 | }
|
| 43 | else if #[cfg(all(
|
| 44 | feature = "std" ,
|
| 45 | target_arch = "aarch64" ,
|
| 46 | ))] {
|
| 47 | use std::arch::is_aarch64_feature_detected;
|
| 48 | if is_aarch64_feature_detected!("fp16" ) {
|
| 49 | $aarch64
|
| 50 | } else {
|
| 51 | $fallback
|
| 52 | }
|
| 53 | }
|
| 54 |
|
| 55 | // Fallback to software
|
| 56 | else {
|
| 57 | $fallback
|
| 58 | }
|
| 59 | }
|
| 60 | };
|
| 61 | }
|
| 62 |
|
| 63 | #[inline ]
|
| 64 | pub(crate) fn f32_to_f16(f: f32) -> u16 {
|
| 65 | convert_fn! {
|
| 66 | if x86_feature("f16c" ) {
|
| 67 | unsafe { x86::f32_to_f16_x86_f16c(f) }
|
| 68 | } else if aarch64_feature("fp16" ) {
|
| 69 | unsafe { aarch64::f32_to_f16_fp16(f) }
|
| 70 | } else {
|
| 71 | f32_to_f16_fallback(f)
|
| 72 | }
|
| 73 | }
|
| 74 | }
|
| 75 |
|
| 76 | #[inline ]
|
| 77 | pub(crate) fn f64_to_f16(f: f64) -> u16 {
|
| 78 | convert_fn! {
|
| 79 | if x86_feature("f16c" ) {
|
| 80 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
|
| 81 | } else if aarch64_feature("fp16" ) {
|
| 82 | unsafe { aarch64::f64_to_f16_fp16(f) }
|
| 83 | } else {
|
| 84 | f64_to_f16_fallback(f)
|
| 85 | }
|
| 86 | }
|
| 87 | }
|
| 88 |
|
| 89 | #[inline ]
|
| 90 | pub(crate) fn f16_to_f32(i: u16) -> f32 {
|
| 91 | convert_fn! {
|
| 92 | if x86_feature("f16c" ) {
|
| 93 | unsafe { x86::f16_to_f32_x86_f16c(i) }
|
| 94 | } else if aarch64_feature("fp16" ) {
|
| 95 | unsafe { aarch64::f16_to_f32_fp16(i) }
|
| 96 | } else {
|
| 97 | f16_to_f32_fallback(i)
|
| 98 | }
|
| 99 | }
|
| 100 | }
|
| 101 |
|
| 102 | #[inline ]
|
| 103 | pub(crate) fn f16_to_f64(i: u16) -> f64 {
|
| 104 | convert_fn! {
|
| 105 | if x86_feature("f16c" ) {
|
| 106 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
|
| 107 | } else if aarch64_feature("fp16" ) {
|
| 108 | unsafe { aarch64::f16_to_f64_fp16(i) }
|
| 109 | } else {
|
| 110 | f16_to_f64_fallback(i)
|
| 111 | }
|
| 112 | }
|
| 113 | }
|
| 114 |
|
| 115 | #[inline ]
|
| 116 | pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] {
|
| 117 | convert_fn! {
|
| 118 | if x86_feature("f16c" ) {
|
| 119 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
|
| 120 | } else if aarch64_feature("fp16" ) {
|
| 121 | unsafe { aarch64::f32x4_to_f16x4_fp16(f) }
|
| 122 | } else {
|
| 123 | f32x4_to_f16x4_fallback(f)
|
| 124 | }
|
| 125 | }
|
| 126 | }
|
| 127 |
|
| 128 | #[inline ]
|
| 129 | pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] {
|
| 130 | convert_fn! {
|
| 131 | if x86_feature("f16c" ) {
|
| 132 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
|
| 133 | } else if aarch64_feature("fp16" ) {
|
| 134 | unsafe { aarch64::f16x4_to_f32x4_fp16(i) }
|
| 135 | } else {
|
| 136 | f16x4_to_f32x4_fallback(i)
|
| 137 | }
|
| 138 | }
|
| 139 | }
|
| 140 |
|
| 141 | #[inline ]
|
| 142 | pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] {
|
| 143 | convert_fn! {
|
| 144 | if x86_feature("f16c" ) {
|
| 145 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
|
| 146 | } else if aarch64_feature("fp16" ) {
|
| 147 | unsafe { aarch64::f64x4_to_f16x4_fp16(f) }
|
| 148 | } else {
|
| 149 | f64x4_to_f16x4_fallback(f)
|
| 150 | }
|
| 151 | }
|
| 152 | }
|
| 153 |
|
| 154 | #[inline ]
|
| 155 | pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] {
|
| 156 | convert_fn! {
|
| 157 | if x86_feature("f16c" ) {
|
| 158 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
|
| 159 | } else if aarch64_feature("fp16" ) {
|
| 160 | unsafe { aarch64::f16x4_to_f64x4_fp16(i) }
|
| 161 | } else {
|
| 162 | f16x4_to_f64x4_fallback(i)
|
| 163 | }
|
| 164 | }
|
| 165 | }
|
| 166 |
|
| 167 | #[inline ]
|
| 168 | pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] {
|
| 169 | convert_fn! {
|
| 170 | if x86_feature("f16c" ) {
|
| 171 | unsafe { x86::f32x8_to_f16x8_x86_f16c(f) }
|
| 172 | } else if aarch64_feature("fp16" ) {
|
| 173 | {
|
| 174 | let mut result = [0u16; 8];
|
| 175 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
|
| 176 | aarch64::f32x4_to_f16x4_fp16);
|
| 177 | result
|
| 178 | }
|
| 179 | } else {
|
| 180 | f32x8_to_f16x8_fallback(f)
|
| 181 | }
|
| 182 | }
|
| 183 | }
|
| 184 |
|
| 185 | #[inline ]
|
| 186 | pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] {
|
| 187 | convert_fn! {
|
| 188 | if x86_feature("f16c" ) {
|
| 189 | unsafe { x86::f16x8_to_f32x8_x86_f16c(i) }
|
| 190 | } else if aarch64_feature("fp16" ) {
|
| 191 | {
|
| 192 | let mut result = [0f32; 8];
|
| 193 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
|
| 194 | aarch64::f16x4_to_f32x4_fp16);
|
| 195 | result
|
| 196 | }
|
| 197 | } else {
|
| 198 | f16x8_to_f32x8_fallback(i)
|
| 199 | }
|
| 200 | }
|
| 201 | }
|
| 202 |
|
| 203 | #[inline ]
|
| 204 | pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] {
|
| 205 | convert_fn! {
|
| 206 | if x86_feature("f16c" ) {
|
| 207 | unsafe { x86::f64x8_to_f16x8_x86_f16c(f) }
|
| 208 | } else if aarch64_feature("fp16" ) {
|
| 209 | {
|
| 210 | let mut result = [0u16; 8];
|
| 211 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(),
|
| 212 | aarch64::f64x4_to_f16x4_fp16);
|
| 213 | result
|
| 214 | }
|
| 215 | } else {
|
| 216 | f64x8_to_f16x8_fallback(f)
|
| 217 | }
|
| 218 | }
|
| 219 | }
|
| 220 |
|
| 221 | #[inline ]
|
| 222 | pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] {
|
| 223 | convert_fn! {
|
| 224 | if x86_feature("f16c" ) {
|
| 225 | unsafe { x86::f16x8_to_f64x8_x86_f16c(i) }
|
| 226 | } else if aarch64_feature("fp16" ) {
|
| 227 | {
|
| 228 | let mut result = [0f64; 8];
|
| 229 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(),
|
| 230 | aarch64::f16x4_to_f64x4_fp16);
|
| 231 | result
|
| 232 | }
|
| 233 | } else {
|
| 234 | f16x8_to_f64x8_fallback(i)
|
| 235 | }
|
| 236 | }
|
| 237 | }
|
| 238 |
|
| 239 | #[inline ]
|
| 240 | pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) {
|
| 241 | convert_fn! {
|
| 242 | if x86_feature("f16c" ) {
|
| 243 | convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c,
|
| 244 | x86::f32x4_to_f16x4_x86_f16c)
|
| 245 | } else if aarch64_feature("fp16" ) {
|
| 246 | convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16)
|
| 247 | } else {
|
| 248 | slice_fallback(src, dst, f32_to_f16_fallback)
|
| 249 | }
|
| 250 | }
|
| 251 | }
|
| 252 |
|
| 253 | #[inline ]
|
| 254 | pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) {
|
| 255 | convert_fn! {
|
| 256 | if x86_feature("f16c" ) {
|
| 257 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c,
|
| 258 | x86::f16x4_to_f32x4_x86_f16c)
|
| 259 | } else if aarch64_feature("fp16" ) {
|
| 260 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16)
|
| 261 | } else {
|
| 262 | slice_fallback(src, dst, f16_to_f32_fallback)
|
| 263 | }
|
| 264 | }
|
| 265 | }
|
| 266 |
|
| 267 | #[inline ]
|
| 268 | pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) {
|
| 269 | convert_fn! {
|
| 270 | if x86_feature("f16c" ) {
|
| 271 | convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c,
|
| 272 | x86::f64x4_to_f16x4_x86_f16c)
|
| 273 | } else if aarch64_feature("fp16" ) {
|
| 274 | convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16)
|
| 275 | } else {
|
| 276 | slice_fallback(src, dst, f64_to_f16_fallback)
|
| 277 | }
|
| 278 | }
|
| 279 | }
|
| 280 |
|
| 281 | #[inline ]
|
| 282 | pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) {
|
| 283 | convert_fn! {
|
| 284 | if x86_feature("f16c" ) {
|
| 285 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c,
|
| 286 | x86::f16x4_to_f64x4_x86_f16c)
|
| 287 | } else if aarch64_feature("fp16" ) {
|
| 288 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16)
|
| 289 | } else {
|
| 290 | slice_fallback(src, dst, f16_to_f64_fallback)
|
| 291 | }
|
| 292 | }
|
| 293 | }
|
| 294 |
|
| 295 | macro_rules! math_fn {
|
| 296 | (if aarch64_feature("fp16" ) { $aarch64:expr }
|
| 297 | else { $fallback:expr }) => {
|
| 298 | cfg_if::cfg_if! {
|
| 299 | // Use intrinsics directly when a compile target or using no_std
|
| 300 | if #[cfg(all(
|
| 301 | target_arch = "aarch64" ,
|
| 302 | target_feature = "fp16"
|
| 303 | ))] {
|
| 304 | $aarch64
|
| 305 | }
|
| 306 |
|
| 307 | // Use CPU feature detection if using std
|
| 308 | else if #[cfg(all(
|
| 309 | feature = "std" ,
|
| 310 | target_arch = "aarch64" ,
|
| 311 | not(target_feature = "fp16" )
|
| 312 | ))] {
|
| 313 | use std::arch::is_aarch64_feature_detected;
|
| 314 | if is_aarch64_feature_detected!("fp16" ) {
|
| 315 | $aarch64
|
| 316 | } else {
|
| 317 | $fallback
|
| 318 | }
|
| 319 | }
|
| 320 |
|
| 321 | // Fallback to software
|
| 322 | else {
|
| 323 | $fallback
|
| 324 | }
|
| 325 | }
|
| 326 | };
|
| 327 | }
|
| 328 |
|
| 329 | #[inline ]
|
| 330 | pub(crate) fn add_f16(a: u16, b: u16) -> u16 {
|
| 331 | math_fn! {
|
| 332 | if aarch64_feature("fp16" ) {
|
| 333 | unsafe { aarch64::add_f16_fp16(a, b) }
|
| 334 | } else {
|
| 335 | add_f16_fallback(a, b)
|
| 336 | }
|
| 337 | }
|
| 338 | }
|
| 339 |
|
| 340 | #[inline ]
|
| 341 | pub(crate) fn subtract_f16(a: u16, b: u16) -> u16 {
|
| 342 | math_fn! {
|
| 343 | if aarch64_feature("fp16" ) {
|
| 344 | unsafe { aarch64::subtract_f16_fp16(a, b) }
|
| 345 | } else {
|
| 346 | subtract_f16_fallback(a, b)
|
| 347 | }
|
| 348 | }
|
| 349 | }
|
| 350 |
|
| 351 | #[inline ]
|
| 352 | pub(crate) fn multiply_f16(a: u16, b: u16) -> u16 {
|
| 353 | math_fn! {
|
| 354 | if aarch64_feature("fp16" ) {
|
| 355 | unsafe { aarch64::multiply_f16_fp16(a, b) }
|
| 356 | } else {
|
| 357 | multiply_f16_fallback(a, b)
|
| 358 | }
|
| 359 | }
|
| 360 | }
|
| 361 |
|
| 362 | #[inline ]
|
| 363 | pub(crate) fn divide_f16(a: u16, b: u16) -> u16 {
|
| 364 | math_fn! {
|
| 365 | if aarch64_feature("fp16" ) {
|
| 366 | unsafe { aarch64::divide_f16_fp16(a, b) }
|
| 367 | } else {
|
| 368 | divide_f16_fallback(a, b)
|
| 369 | }
|
| 370 | }
|
| 371 | }
|
| 372 |
|
| 373 | #[inline ]
|
| 374 | pub(crate) fn remainder_f16(a: u16, b: u16) -> u16 {
|
| 375 | remainder_f16_fallback(a, b)
|
| 376 | }
|
| 377 |
|
| 378 | #[inline ]
|
| 379 | pub(crate) fn product_f16<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
| 380 | math_fn! {
|
| 381 | if aarch64_feature("fp16" ) {
|
| 382 | iter.fold(0, |acc, x| unsafe { aarch64::multiply_f16_fp16(acc, x) })
|
| 383 | } else {
|
| 384 | product_f16_fallback(iter)
|
| 385 | }
|
| 386 | }
|
| 387 | }
|
| 388 |
|
| 389 | #[inline ]
|
| 390 | pub(crate) fn sum_f16<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
| 391 | math_fn! {
|
| 392 | if aarch64_feature("fp16" ) {
|
| 393 | iter.fold(0, |acc, x| unsafe { aarch64::add_f16_fp16(acc, x) })
|
| 394 | } else {
|
| 395 | sum_f16_fallback(iter)
|
| 396 | }
|
| 397 | }
|
| 398 | }
|
| 399 |
|
| 400 | /// Chunks sliced into x8 or x4 arrays
|
| 401 | #[inline ]
|
| 402 | fn convert_chunked_slice_8<S: Copy + Default, D: Copy>(
|
| 403 | src: &[S],
|
| 404 | dst: &mut [D],
|
| 405 | fn8: unsafe fn(&[S; 8]) -> [D; 8],
|
| 406 | fn4: unsafe fn(&[S; 4]) -> [D; 4],
|
| 407 | ) {
|
| 408 | assert_eq!(src.len(), dst.len());
|
| 409 |
|
| 410 | // TODO: Can be further optimized with array_chunks when it becomes stabilized
|
| 411 |
|
| 412 | let src_chunks = src.chunks_exact(8);
|
| 413 | let mut dst_chunks = dst.chunks_exact_mut(8);
|
| 414 | let src_remainder = src_chunks.remainder();
|
| 415 | for (s, d) in src_chunks.zip(&mut dst_chunks) {
|
| 416 | let chunk: &[S; 8] = s.try_into().unwrap();
|
| 417 | d.copy_from_slice(unsafe { &fn8(chunk) });
|
| 418 | }
|
| 419 |
|
| 420 | // Process remainder
|
| 421 | if src_remainder.len() > 4 {
|
| 422 | let mut buf: [S; 8] = Default::default();
|
| 423 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
| 424 | let vec = unsafe { fn8(&buf) };
|
| 425 | let dst_remainder = dst_chunks.into_remainder();
|
| 426 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
| 427 | } else if !src_remainder.is_empty() {
|
| 428 | let mut buf: [S; 4] = Default::default();
|
| 429 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
| 430 | let vec = unsafe { fn4(&buf) };
|
| 431 | let dst_remainder = dst_chunks.into_remainder();
|
| 432 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
| 433 | }
|
| 434 | }
|
| 435 |
|
| 436 | /// Chunks sliced into x4 arrays
|
| 437 | #[inline ]
|
| 438 | fn convert_chunked_slice_4<S: Copy + Default, D: Copy>(
|
| 439 | src: &[S],
|
| 440 | dst: &mut [D],
|
| 441 | f: unsafe fn(&[S; 4]) -> [D; 4],
|
| 442 | ) {
|
| 443 | assert_eq!(src.len(), dst.len());
|
| 444 |
|
| 445 | // TODO: Can be further optimized with array_chunks when it becomes stabilized
|
| 446 |
|
| 447 | let src_chunks: ChunksExact<'_, S> = src.chunks_exact(chunk_size:4);
|
| 448 | let mut dst_chunks: ChunksExactMut<'_, D> = dst.chunks_exact_mut(chunk_size:4);
|
| 449 | let src_remainder: &[S] = src_chunks.remainder();
|
| 450 | for (s: &[S], d: &mut [D]) in src_chunks.zip(&mut dst_chunks) {
|
| 451 | let chunk: &[S; 4] = s.try_into().unwrap();
|
| 452 | d.copy_from_slice(src:unsafe { &f(chunk) });
|
| 453 | }
|
| 454 |
|
| 455 | // Process remainder
|
| 456 | if !src_remainder.is_empty() {
|
| 457 | let mut buf: [S; 4] = Default::default();
|
| 458 | buf[..src_remainder.len()].copy_from_slice(src_remainder);
|
| 459 | let vec: [D; 4] = unsafe { f(&buf) };
|
| 460 | let dst_remainder: &mut [D] = dst_chunks.into_remainder();
|
| 461 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]);
|
| 462 | }
|
| 463 | }
|
| 464 |
|
| 465 | /////////////// Fallbacks ////////////////
|
| 466 |
|
| 467 | // In the below functions, round to nearest, with ties to even.
|
| 468 | // Let us call the most significant bit that will be shifted out the round_bit.
|
| 469 | //
|
| 470 | // Round up if either
|
| 471 | // a) Removed part > tie.
|
| 472 | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
|
| 473 | // b) Removed part == tie, and retained part is odd.
|
| 474 | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
|
| 475 | // (If removed part == tie and retained part is even, do not round up.)
|
| 476 | // These two conditions can be combined into one:
|
| 477 | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
|
| 478 | // which can be simplified into
|
| 479 | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
|
| 480 |
|
| 481 | #[inline ]
|
| 482 | pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 {
|
| 483 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
|
| 484 | // Convert to raw bytes
|
| 485 | let x: u32 = unsafe { mem::transmute::<f32, u32>(value) };
|
| 486 |
|
| 487 | // Extract IEEE754 components
|
| 488 | let sign = x & 0x8000_0000u32;
|
| 489 | let exp = x & 0x7F80_0000u32;
|
| 490 | let man = x & 0x007F_FFFFu32;
|
| 491 |
|
| 492 | // Check for all exponent bits being set, which is Infinity or NaN
|
| 493 | if exp == 0x7F80_0000u32 {
|
| 494 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
|
| 495 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 };
|
| 496 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16;
|
| 497 | }
|
| 498 |
|
| 499 | // The number is normalized, start assembling half precision version
|
| 500 | let half_sign = sign >> 16;
|
| 501 | // Unbias the exponent, then bias for half precision
|
| 502 | let unbiased_exp = ((exp >> 23) as i32) - 127;
|
| 503 | let half_exp = unbiased_exp + 15;
|
| 504 |
|
| 505 | // Check for exponent overflow, return +infinity
|
| 506 | if half_exp >= 0x1F {
|
| 507 | return (half_sign | 0x7C00u32) as u16;
|
| 508 | }
|
| 509 |
|
| 510 | // Check for underflow
|
| 511 | if half_exp <= 0 {
|
| 512 | // Check mantissa for what we can do
|
| 513 | if 14 - half_exp > 24 {
|
| 514 | // No rounding possibility, so this is a full underflow, return signed zero
|
| 515 | return half_sign as u16;
|
| 516 | }
|
| 517 | // Don't forget about hidden leading mantissa bit when assembling mantissa
|
| 518 | let man = man | 0x0080_0000u32;
|
| 519 | let mut half_man = man >> (14 - half_exp);
|
| 520 | // Check for rounding (see comment above functions)
|
| 521 | let round_bit = 1 << (13 - half_exp);
|
| 522 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
| 523 | half_man += 1;
|
| 524 | }
|
| 525 | // No exponent for subnormals
|
| 526 | return (half_sign | half_man) as u16;
|
| 527 | }
|
| 528 |
|
| 529 | // Rebias the exponent
|
| 530 | let half_exp = (half_exp as u32) << 10;
|
| 531 | let half_man = man >> 13;
|
| 532 | // Check for rounding (see comment above functions)
|
| 533 | let round_bit = 0x0000_1000u32;
|
| 534 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
| 535 | // Round it
|
| 536 | ((half_sign | half_exp | half_man) + 1) as u16
|
| 537 | } else {
|
| 538 | (half_sign | half_exp | half_man) as u16
|
| 539 | }
|
| 540 | }
|
| 541 |
|
| 542 | #[inline ]
|
| 543 | pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 {
|
| 544 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
|
| 545 | // be lost on half-precision.
|
| 546 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized
|
| 547 | let val: u64 = unsafe { mem::transmute::<f64, u64>(value) };
|
| 548 | let x = (val >> 32) as u32;
|
| 549 |
|
| 550 | // Extract IEEE754 components
|
| 551 | let sign = x & 0x8000_0000u32;
|
| 552 | let exp = x & 0x7FF0_0000u32;
|
| 553 | let man = x & 0x000F_FFFFu32;
|
| 554 |
|
| 555 | // Check for all exponent bits being set, which is Infinity or NaN
|
| 556 | if exp == 0x7FF0_0000u32 {
|
| 557 | // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
|
| 558 | // We also have to check the last 32 bits.
|
| 559 | let nan_bit = if man == 0 && (val as u32 == 0) {
|
| 560 | 0
|
| 561 | } else {
|
| 562 | 0x0200u32
|
| 563 | };
|
| 564 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16;
|
| 565 | }
|
| 566 |
|
| 567 | // The number is normalized, start assembling half precision version
|
| 568 | let half_sign = sign >> 16;
|
| 569 | // Unbias the exponent, then bias for half precision
|
| 570 | let unbiased_exp = ((exp >> 20) as i64) - 1023;
|
| 571 | let half_exp = unbiased_exp + 15;
|
| 572 |
|
| 573 | // Check for exponent overflow, return +infinity
|
| 574 | if half_exp >= 0x1F {
|
| 575 | return (half_sign | 0x7C00u32) as u16;
|
| 576 | }
|
| 577 |
|
| 578 | // Check for underflow
|
| 579 | if half_exp <= 0 {
|
| 580 | // Check mantissa for what we can do
|
| 581 | if 10 - half_exp > 21 {
|
| 582 | // No rounding possibility, so this is a full underflow, return signed zero
|
| 583 | return half_sign as u16;
|
| 584 | }
|
| 585 | // Don't forget about hidden leading mantissa bit when assembling mantissa
|
| 586 | let man = man | 0x0010_0000u32;
|
| 587 | let mut half_man = man >> (11 - half_exp);
|
| 588 | // Check for rounding (see comment above functions)
|
| 589 | let round_bit = 1 << (10 - half_exp);
|
| 590 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
| 591 | half_man += 1;
|
| 592 | }
|
| 593 | // No exponent for subnormals
|
| 594 | return (half_sign | half_man) as u16;
|
| 595 | }
|
| 596 |
|
| 597 | // Rebias the exponent
|
| 598 | let half_exp = (half_exp as u32) << 10;
|
| 599 | let half_man = man >> 10;
|
| 600 | // Check for rounding (see comment above functions)
|
| 601 | let round_bit = 0x0000_0200u32;
|
| 602 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
|
| 603 | // Round it
|
| 604 | ((half_sign | half_exp | half_man) + 1) as u16
|
| 605 | } else {
|
| 606 | (half_sign | half_exp | half_man) as u16
|
| 607 | }
|
| 608 | }
|
| 609 |
|
| 610 | #[inline ]
|
| 611 | pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 {
|
| 612 | // Check for signed zero
|
| 613 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
|
| 614 | if i & 0x7FFFu16 == 0 {
|
| 615 | return unsafe { mem::transmute::<u32, f32>((i as u32) << 16) };
|
| 616 | }
|
| 617 |
|
| 618 | let half_sign = (i & 0x8000u16) as u32;
|
| 619 | let half_exp = (i & 0x7C00u16) as u32;
|
| 620 | let half_man = (i & 0x03FFu16) as u32;
|
| 621 |
|
| 622 | // Check for an infinity or NaN when all exponent bits set
|
| 623 | if half_exp == 0x7C00u32 {
|
| 624 | // Check for signed infinity if mantissa is zero
|
| 625 | if half_man == 0 {
|
| 626 | return unsafe { mem::transmute::<u32, f32>((half_sign << 16) | 0x7F80_0000u32) };
|
| 627 | } else {
|
| 628 | // NaN, keep current mantissa but also set most significiant mantissa bit
|
| 629 | return unsafe {
|
| 630 | mem::transmute::<u32, f32>((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13))
|
| 631 | };
|
| 632 | }
|
| 633 | }
|
| 634 |
|
| 635 | // Calculate single-precision components with adjusted exponent
|
| 636 | let sign = half_sign << 16;
|
| 637 | // Unbias exponent
|
| 638 | let unbiased_exp = ((half_exp as i32) >> 10) - 15;
|
| 639 |
|
| 640 | // Check for subnormals, which will be normalized by adjusting exponent
|
| 641 | if half_exp == 0 {
|
| 642 | // Calculate how much to adjust the exponent by
|
| 643 | let e = leading_zeros_u16(half_man as u16) - 6;
|
| 644 |
|
| 645 | // Rebias and adjust exponent
|
| 646 | let exp = (127 - 15 - e) << 23;
|
| 647 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
|
| 648 | return unsafe { mem::transmute::<u32, f32>(sign | exp | man) };
|
| 649 | }
|
| 650 |
|
| 651 | // Rebias exponent for a normalized normal
|
| 652 | let exp = ((unbiased_exp + 127) as u32) << 23;
|
| 653 | let man = (half_man & 0x03FFu32) << 13;
|
| 654 | unsafe { mem::transmute::<u32, f32>(sign | exp | man) }
|
| 655 | }
|
| 656 |
|
| 657 | #[inline ]
|
| 658 | pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 {
|
| 659 | // Check for signed zero
|
| 660 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized
|
| 661 | if i & 0x7FFFu16 == 0 {
|
| 662 | return unsafe { mem::transmute::<u64, f64>((i as u64) << 48) };
|
| 663 | }
|
| 664 |
|
| 665 | let half_sign = (i & 0x8000u16) as u64;
|
| 666 | let half_exp = (i & 0x7C00u16) as u64;
|
| 667 | let half_man = (i & 0x03FFu16) as u64;
|
| 668 |
|
| 669 | // Check for an infinity or NaN when all exponent bits set
|
| 670 | if half_exp == 0x7C00u64 {
|
| 671 | // Check for signed infinity if mantissa is zero
|
| 672 | if half_man == 0 {
|
| 673 | return unsafe {
|
| 674 | mem::transmute::<u64, f64>((half_sign << 48) | 0x7FF0_0000_0000_0000u64)
|
| 675 | };
|
| 676 | } else {
|
| 677 | // NaN, keep current mantissa but also set most significiant mantissa bit
|
| 678 | return unsafe {
|
| 679 | mem::transmute::<u64, f64>(
|
| 680 | (half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42),
|
| 681 | )
|
| 682 | };
|
| 683 | }
|
| 684 | }
|
| 685 |
|
| 686 | // Calculate double-precision components with adjusted exponent
|
| 687 | let sign = half_sign << 48;
|
| 688 | // Unbias exponent
|
| 689 | let unbiased_exp = ((half_exp as i64) >> 10) - 15;
|
| 690 |
|
| 691 | // Check for subnormals, which will be normalized by adjusting exponent
|
| 692 | if half_exp == 0 {
|
| 693 | // Calculate how much to adjust the exponent by
|
| 694 | let e = leading_zeros_u16(half_man as u16) - 6;
|
| 695 |
|
| 696 | // Rebias and adjust exponent
|
| 697 | let exp = ((1023 - 15 - e) as u64) << 52;
|
| 698 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64;
|
| 699 | return unsafe { mem::transmute::<u64, f64>(sign | exp | man) };
|
| 700 | }
|
| 701 |
|
| 702 | // Rebias exponent for a normalized normal
|
| 703 | let exp = ((unbiased_exp + 1023) as u64) << 52;
|
| 704 | let man = (half_man & 0x03FFu64) << 42;
|
| 705 | unsafe { mem::transmute::<u64, f64>(sign | exp | man) }
|
| 706 | }
|
| 707 |
|
| 708 | #[inline ]
|
| 709 | fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] {
|
| 710 | [
|
| 711 | f16_to_f32_fallback(v[0]),
|
| 712 | f16_to_f32_fallback(v[1]),
|
| 713 | f16_to_f32_fallback(v[2]),
|
| 714 | f16_to_f32_fallback(v[3]),
|
| 715 | ]
|
| 716 | }
|
| 717 |
|
| 718 | #[inline ]
|
| 719 | fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] {
|
| 720 | [
|
| 721 | f32_to_f16_fallback(v[0]),
|
| 722 | f32_to_f16_fallback(v[1]),
|
| 723 | f32_to_f16_fallback(v[2]),
|
| 724 | f32_to_f16_fallback(v[3]),
|
| 725 | ]
|
| 726 | }
|
| 727 |
|
| 728 | #[inline ]
|
| 729 | fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] {
|
| 730 | [
|
| 731 | f16_to_f64_fallback(v[0]),
|
| 732 | f16_to_f64_fallback(v[1]),
|
| 733 | f16_to_f64_fallback(v[2]),
|
| 734 | f16_to_f64_fallback(v[3]),
|
| 735 | ]
|
| 736 | }
|
| 737 |
|
| 738 | #[inline ]
|
| 739 | fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] {
|
| 740 | [
|
| 741 | f64_to_f16_fallback(v[0]),
|
| 742 | f64_to_f16_fallback(v[1]),
|
| 743 | f64_to_f16_fallback(v[2]),
|
| 744 | f64_to_f16_fallback(v[3]),
|
| 745 | ]
|
| 746 | }
|
| 747 |
|
| 748 | #[inline ]
|
| 749 | fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] {
|
| 750 | [
|
| 751 | f16_to_f32_fallback(v[0]),
|
| 752 | f16_to_f32_fallback(v[1]),
|
| 753 | f16_to_f32_fallback(v[2]),
|
| 754 | f16_to_f32_fallback(v[3]),
|
| 755 | f16_to_f32_fallback(v[4]),
|
| 756 | f16_to_f32_fallback(v[5]),
|
| 757 | f16_to_f32_fallback(v[6]),
|
| 758 | f16_to_f32_fallback(v[7]),
|
| 759 | ]
|
| 760 | }
|
| 761 |
|
| 762 | #[inline ]
|
| 763 | fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] {
|
| 764 | [
|
| 765 | f32_to_f16_fallback(v[0]),
|
| 766 | f32_to_f16_fallback(v[1]),
|
| 767 | f32_to_f16_fallback(v[2]),
|
| 768 | f32_to_f16_fallback(v[3]),
|
| 769 | f32_to_f16_fallback(v[4]),
|
| 770 | f32_to_f16_fallback(v[5]),
|
| 771 | f32_to_f16_fallback(v[6]),
|
| 772 | f32_to_f16_fallback(v[7]),
|
| 773 | ]
|
| 774 | }
|
| 775 |
|
| 776 | #[inline ]
|
| 777 | fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] {
|
| 778 | [
|
| 779 | f16_to_f64_fallback(v[0]),
|
| 780 | f16_to_f64_fallback(v[1]),
|
| 781 | f16_to_f64_fallback(v[2]),
|
| 782 | f16_to_f64_fallback(v[3]),
|
| 783 | f16_to_f64_fallback(v[4]),
|
| 784 | f16_to_f64_fallback(v[5]),
|
| 785 | f16_to_f64_fallback(v[6]),
|
| 786 | f16_to_f64_fallback(v[7]),
|
| 787 | ]
|
| 788 | }
|
| 789 |
|
| 790 | #[inline ]
|
| 791 | fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] {
|
| 792 | [
|
| 793 | f64_to_f16_fallback(v[0]),
|
| 794 | f64_to_f16_fallback(v[1]),
|
| 795 | f64_to_f16_fallback(v[2]),
|
| 796 | f64_to_f16_fallback(v[3]),
|
| 797 | f64_to_f16_fallback(v[4]),
|
| 798 | f64_to_f16_fallback(v[5]),
|
| 799 | f64_to_f16_fallback(v[6]),
|
| 800 | f64_to_f16_fallback(v[7]),
|
| 801 | ]
|
| 802 | }
|
| 803 |
|
| 804 | #[inline ]
|
| 805 | fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) {
|
| 806 | assert_eq!(src.len(), dst.len());
|
| 807 | for (s: S, d: &mut D) in src.iter().copied().zip(dst.iter_mut()) {
|
| 808 | *d = f(s);
|
| 809 | }
|
| 810 | }
|
| 811 |
|
| 812 | #[inline ]
|
| 813 | fn add_f16_fallback(a: u16, b: u16) -> u16 {
|
| 814 | f32_to_f16(f16_to_f32(a) + f16_to_f32(b))
|
| 815 | }
|
| 816 |
|
| 817 | #[inline ]
|
| 818 | fn subtract_f16_fallback(a: u16, b: u16) -> u16 {
|
| 819 | f32_to_f16(f16_to_f32(a) - f16_to_f32(b))
|
| 820 | }
|
| 821 |
|
| 822 | #[inline ]
|
| 823 | fn multiply_f16_fallback(a: u16, b: u16) -> u16 {
|
| 824 | f32_to_f16(f16_to_f32(a) * f16_to_f32(b))
|
| 825 | }
|
| 826 |
|
| 827 | #[inline ]
|
| 828 | fn divide_f16_fallback(a: u16, b: u16) -> u16 {
|
| 829 | f32_to_f16(f16_to_f32(a) / f16_to_f32(b))
|
| 830 | }
|
| 831 |
|
| 832 | #[inline ]
|
| 833 | fn remainder_f16_fallback(a: u16, b: u16) -> u16 {
|
| 834 | f32_to_f16(f16_to_f32(a) % f16_to_f32(b))
|
| 835 | }
|
| 836 |
|
| 837 | #[inline ]
|
| 838 | fn product_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
| 839 | f32_to_f16(iter.map(f16_to_f32).product())
|
| 840 | }
|
| 841 |
|
| 842 | #[inline ]
|
| 843 | fn sum_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 {
|
| 844 | f32_to_f16(iter.map(f16_to_f32).sum())
|
| 845 | }
|
| 846 |
|
| 847 | // TODO SIMD arithmetic
|
| 848 | |