1 | #[cfg(feature = "bytemuck")] |
---|---|

2 | use bytemuck::{Pod, Zeroable}; |

3 | use core::{ |

4 | cmp::Ordering, |

5 | fmt::{ |

6 | Binary, Debug, Display, Error, Formatter, LowerExp, LowerHex, Octal, UpperExp, UpperHex, |

7 | }, |

8 | iter::{Product, Sum}, |

9 | num::{FpCategory, ParseFloatError}, |

10 | ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign}, |

11 | str::FromStr, |

12 | }; |

13 | #[cfg(feature = "serde")] |

14 | use serde::{Deserialize, Serialize}; |

15 | #[cfg(feature = "zerocopy")] |

16 | use zerocopy::{AsBytes, FromBytes}; |

17 | |

18 | pub(crate) mod convert; |

19 | |

20 | /// A 16-bit floating point type implementing the [`bfloat16`] format. |

21 | /// |

22 | /// The [`bfloat16`] floating point format is a truncated 16-bit version of the IEEE 754 standard |

23 | /// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] by |

24 | /// having a lower precision than [`f16`][crate::f16]. While [`f16`][crate::f16] has a precision of |

25 | /// 11 bits, [`bf16`] has a precision of only 8 bits. |

26 | /// |

27 | /// Like [`f16`][crate::f16], [`bf16`] does not offer arithmetic operations as it is intended for |

28 | /// compact storage rather than calculations. Operations should be performed with [`f32`] or |

29 | /// higher-precision types and converted to/from [`bf16`] as necessary. |

30 | /// |

31 | /// [`bfloat16`]: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format |

32 | #[allow(non_camel_case_types)] |

33 | #[derive(Clone, Copy, Default)] |

34 | #[repr(transparent)] |

35 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] |

36 | #[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))] |

37 | #[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))] |

38 | pub struct bf16(u16); |

39 | |

40 | impl bf16 { |

41 | /// Constructs a [`bf16`] value from the raw bits. |

42 | #[inline] |

43 | pub const fn from_bits(bits: u16) -> bf16 { |

44 | bf16(bits) |

45 | } |

46 | |

47 | /// Constructs a [`bf16`] value from a 32-bit floating point value. |

48 | /// |

49 | /// If the 32-bit value is too large to fit, ±∞ will result. NaN values are preserved. |

50 | /// Subnormal values that are too tiny to be represented will result in ±0. All other values |

51 | /// are truncated and rounded to the nearest representable value. |

52 | #[inline] |

53 | pub fn from_f32(value: f32) -> bf16 { |

54 | bf16(convert::f32_to_bf16(value)) |

55 | } |

56 | |

57 | /// Constructs a [`bf16`] value from a 64-bit floating point value. |

58 | /// |

59 | /// If the 64-bit value is to large to fit, ±∞ will result. NaN values are preserved. |

60 | /// 64-bit subnormal values are too tiny to be represented and result in ±0. Exponents that |

61 | /// underflow the minimum exponent will result in subnormals or ±0. All other values are |

62 | /// truncated and rounded to the nearest representable value. |

63 | #[inline] |

64 | pub fn from_f64(value: f64) -> bf16 { |

65 | bf16(convert::f64_to_bf16(value)) |

66 | } |

67 | |

68 | /// Converts a [`bf16`] into the underlying bit representation. |

69 | #[inline] |

70 | pub const fn to_bits(self) -> u16 { |

71 | self.0 |

72 | } |

73 | |

74 | /// Returns the memory representation of the underlying bit representation as a byte array in |

75 | /// little-endian byte order. |

76 | /// |

77 | /// # Examples |

78 | /// |

79 | /// ```rust |

80 | /// # use half::prelude::*; |

81 | /// let bytes = bf16::from_f32(12.5).to_le_bytes(); |

82 | /// assert_eq!(bytes, [0x48, 0x41]); |

83 | /// ``` |

84 | #[inline] |

85 | pub const fn to_le_bytes(self) -> [u8; 2] { |

86 | self.0.to_le_bytes() |

87 | } |

88 | |

89 | /// Returns the memory representation of the underlying bit representation as a byte array in |

90 | /// big-endian (network) byte order. |

91 | /// |

92 | /// # Examples |

93 | /// |

94 | /// ```rust |

95 | /// # use half::prelude::*; |

96 | /// let bytes = bf16::from_f32(12.5).to_be_bytes(); |

97 | /// assert_eq!(bytes, [0x41, 0x48]); |

98 | /// ``` |

99 | #[inline] |

100 | pub const fn to_be_bytes(self) -> [u8; 2] { |

101 | self.0.to_be_bytes() |

102 | } |

103 | |

104 | /// Returns the memory representation of the underlying bit representation as a byte array in |

105 | /// native byte order. |

106 | /// |

107 | /// As the target platform's native endianness is used, portable code should use |

108 | /// [`to_be_bytes`][bf16::to_be_bytes] or [`to_le_bytes`][bf16::to_le_bytes], as appropriate, |

109 | /// instead. |

110 | /// |

111 | /// # Examples |

112 | /// |

113 | /// ```rust |

114 | /// # use half::prelude::*; |

115 | /// let bytes = bf16::from_f32(12.5).to_ne_bytes(); |

116 | /// assert_eq!(bytes, if cfg!(target_endian = "big") { |

117 | /// [0x41, 0x48] |

118 | /// } else { |

119 | /// [0x48, 0x41] |

120 | /// }); |

121 | /// ``` |

122 | #[inline] |

123 | pub const fn to_ne_bytes(self) -> [u8; 2] { |

124 | self.0.to_ne_bytes() |

125 | } |

126 | |

127 | /// Creates a floating point value from its representation as a byte array in little endian. |

128 | /// |

129 | /// # Examples |

130 | /// |

131 | /// ```rust |

132 | /// # use half::prelude::*; |

133 | /// let value = bf16::from_le_bytes([0x48, 0x41]); |

134 | /// assert_eq!(value, bf16::from_f32(12.5)); |

135 | /// ``` |

136 | #[inline] |

137 | pub const fn from_le_bytes(bytes: [u8; 2]) -> bf16 { |

138 | bf16::from_bits(u16::from_le_bytes(bytes)) |

139 | } |

140 | |

141 | /// Creates a floating point value from its representation as a byte array in big endian. |

142 | /// |

143 | /// # Examples |

144 | /// |

145 | /// ```rust |

146 | /// # use half::prelude::*; |

147 | /// let value = bf16::from_be_bytes([0x41, 0x48]); |

148 | /// assert_eq!(value, bf16::from_f32(12.5)); |

149 | /// ``` |

150 | #[inline] |

151 | pub const fn from_be_bytes(bytes: [u8; 2]) -> bf16 { |

152 | bf16::from_bits(u16::from_be_bytes(bytes)) |

153 | } |

154 | |

155 | /// Creates a floating point value from its representation as a byte array in native endian. |

156 | /// |

157 | /// As the target platform's native endianness is used, portable code likely wants to use |

158 | /// [`from_be_bytes`][bf16::from_be_bytes] or [`from_le_bytes`][bf16::from_le_bytes], as |

159 | /// appropriate instead. |

160 | /// |

161 | /// # Examples |

162 | /// |

163 | /// ```rust |

164 | /// # use half::prelude::*; |

165 | /// let value = bf16::from_ne_bytes(if cfg!(target_endian = "big") { |

166 | /// [0x41, 0x48] |

167 | /// } else { |

168 | /// [0x48, 0x41] |

169 | /// }); |

170 | /// assert_eq!(value, bf16::from_f32(12.5)); |

171 | /// ``` |

172 | #[inline] |

173 | pub const fn from_ne_bytes(bytes: [u8; 2]) -> bf16 { |

174 | bf16::from_bits(u16::from_ne_bytes(bytes)) |

175 | } |

176 | |

177 | /// Converts a [`bf16`] value into an [`f32`] value. |

178 | /// |

179 | /// This conversion is lossless as all values can be represented exactly in [`f32`]. |

180 | #[inline] |

181 | pub fn to_f32(self) -> f32 { |

182 | convert::bf16_to_f32(self.0) |

183 | } |

184 | |

185 | /// Converts a [`bf16`] value into an [`f64`] value. |

186 | /// |

187 | /// This conversion is lossless as all values can be represented exactly in [`f64`]. |

188 | #[inline] |

189 | pub fn to_f64(self) -> f64 { |

190 | convert::bf16_to_f64(self.0) |

191 | } |

192 | |

193 | /// Returns `true` if this value is NaN and `false` otherwise. |

194 | /// |

195 | /// # Examples |

196 | /// |

197 | /// ```rust |

198 | /// # use half::prelude::*; |

199 | /// |

200 | /// let nan = bf16::NAN; |

201 | /// let f = bf16::from_f32(7.0_f32); |

202 | /// |

203 | /// assert!(nan.is_nan()); |

204 | /// assert!(!f.is_nan()); |

205 | /// ``` |

206 | #[inline] |

207 | pub const fn is_nan(self) -> bool { |

208 | self.0 & 0x7FFFu16 > 0x7F80u16 |

209 | } |

210 | |

211 | /// Returns `true` if this value is ±∞ and `false` otherwise. |

212 | /// |

213 | /// # Examples |

214 | /// |

215 | /// ```rust |

216 | /// # use half::prelude::*; |

217 | /// |

218 | /// let f = bf16::from_f32(7.0f32); |

219 | /// let inf = bf16::INFINITY; |

220 | /// let neg_inf = bf16::NEG_INFINITY; |

221 | /// let nan = bf16::NAN; |

222 | /// |

223 | /// assert!(!f.is_infinite()); |

224 | /// assert!(!nan.is_infinite()); |

225 | /// |

226 | /// assert!(inf.is_infinite()); |

227 | /// assert!(neg_inf.is_infinite()); |

228 | /// ``` |

229 | #[inline] |

230 | pub const fn is_infinite(self) -> bool { |

231 | self.0 & 0x7FFFu16 == 0x7F80u16 |

232 | } |

233 | |

234 | /// Returns `true` if this number is neither infinite nor NaN. |

235 | /// |

236 | /// # Examples |

237 | /// |

238 | /// ```rust |

239 | /// # use half::prelude::*; |

240 | /// |

241 | /// let f = bf16::from_f32(7.0f32); |

242 | /// let inf = bf16::INFINITY; |

243 | /// let neg_inf = bf16::NEG_INFINITY; |

244 | /// let nan = bf16::NAN; |

245 | /// |

246 | /// assert!(f.is_finite()); |

247 | /// |

248 | /// assert!(!nan.is_finite()); |

249 | /// assert!(!inf.is_finite()); |

250 | /// assert!(!neg_inf.is_finite()); |

251 | /// ``` |

252 | #[inline] |

253 | pub const fn is_finite(self) -> bool { |

254 | self.0 & 0x7F80u16 != 0x7F80u16 |

255 | } |

256 | |

257 | /// Returns `true` if the number is neither zero, infinite, subnormal, or NaN. |

258 | /// |

259 | /// # Examples |

260 | /// |

261 | /// ```rust |

262 | /// # use half::prelude::*; |

263 | /// |

264 | /// let min = bf16::MIN_POSITIVE; |

265 | /// let max = bf16::MAX; |

266 | /// let lower_than_min = bf16::from_f32(1.0e-39_f32); |

267 | /// let zero = bf16::from_f32(0.0_f32); |

268 | /// |

269 | /// assert!(min.is_normal()); |

270 | /// assert!(max.is_normal()); |

271 | /// |

272 | /// assert!(!zero.is_normal()); |

273 | /// assert!(!bf16::NAN.is_normal()); |

274 | /// assert!(!bf16::INFINITY.is_normal()); |

275 | /// // Values between 0 and `min` are subnormal. |

276 | /// assert!(!lower_than_min.is_normal()); |

277 | /// ``` |

278 | #[inline] |

279 | pub const fn is_normal(self) -> bool { |

280 | let exp = self.0 & 0x7F80u16; |

281 | exp != 0x7F80u16 && exp != 0 |

282 | } |

283 | |

284 | /// Returns the floating point category of the number. |

285 | /// |

286 | /// If only one property is going to be tested, it is generally faster to use the specific |

287 | /// predicate instead. |

288 | /// |

289 | /// # Examples |

290 | /// |

291 | /// ```rust |

292 | /// use std::num::FpCategory; |

293 | /// # use half::prelude::*; |

294 | /// |

295 | /// let num = bf16::from_f32(12.4_f32); |

296 | /// let inf = bf16::INFINITY; |

297 | /// |

298 | /// assert_eq!(num.classify(), FpCategory::Normal); |

299 | /// assert_eq!(inf.classify(), FpCategory::Infinite); |

300 | /// ``` |

301 | pub const fn classify(self) -> FpCategory { |

302 | let exp = self.0 & 0x7F80u16; |

303 | let man = self.0 & 0x007Fu16; |

304 | match (exp, man) { |

305 | (0, 0) => FpCategory::Zero, |

306 | (0, _) => FpCategory::Subnormal, |

307 | (0x7F80u16, 0) => FpCategory::Infinite, |

308 | (0x7F80u16, _) => FpCategory::Nan, |

309 | _ => FpCategory::Normal, |

310 | } |

311 | } |

312 | |

313 | /// Returns a number that represents the sign of `self`. |

314 | /// |

315 | /// * 1.0 if the number is positive, +0.0 or [`INFINITY`][bf16::INFINITY] |

316 | /// * −1.0 if the number is negative, −0.0` or [`NEG_INFINITY`][bf16::NEG_INFINITY] |

317 | /// * [`NAN`][bf16::NAN] if the number is NaN |

318 | /// |

319 | /// # Examples |

320 | /// |

321 | /// ```rust |

322 | /// # use half::prelude::*; |

323 | /// |

324 | /// let f = bf16::from_f32(3.5_f32); |

325 | /// |

326 | /// assert_eq!(f.signum(), bf16::from_f32(1.0)); |

327 | /// assert_eq!(bf16::NEG_INFINITY.signum(), bf16::from_f32(-1.0)); |

328 | /// |

329 | /// assert!(bf16::NAN.signum().is_nan()); |

330 | /// ``` |

331 | pub const fn signum(self) -> bf16 { |

332 | if self.is_nan() { |

333 | self |

334 | } else if self.0 & 0x8000u16 != 0 { |

335 | Self::NEG_ONE |

336 | } else { |

337 | Self::ONE |

338 | } |

339 | } |

340 | |

341 | /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a |

342 | /// positive sign bit and +∞. |

343 | /// |

344 | /// # Examples |

345 | /// |

346 | /// ```rust |

347 | /// # use half::prelude::*; |

348 | /// |

349 | /// let nan = bf16::NAN; |

350 | /// let f = bf16::from_f32(7.0_f32); |

351 | /// let g = bf16::from_f32(-7.0_f32); |

352 | /// |

353 | /// assert!(f.is_sign_positive()); |

354 | /// assert!(!g.is_sign_positive()); |

355 | /// // NaN can be either positive or negative |

356 | /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); |

357 | /// ``` |

358 | #[inline] |

359 | pub const fn is_sign_positive(self) -> bool { |

360 | self.0 & 0x8000u16 == 0 |

361 | } |

362 | |

363 | /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a |

364 | /// negative sign bit and −∞. |

365 | /// |

366 | /// # Examples |

367 | /// |

368 | /// ```rust |

369 | /// # use half::prelude::*; |

370 | /// |

371 | /// let nan = bf16::NAN; |

372 | /// let f = bf16::from_f32(7.0f32); |

373 | /// let g = bf16::from_f32(-7.0f32); |

374 | /// |

375 | /// assert!(!f.is_sign_negative()); |

376 | /// assert!(g.is_sign_negative()); |

377 | /// // NaN can be either positive or negative |

378 | /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); |

379 | /// ``` |

380 | #[inline] |

381 | pub const fn is_sign_negative(self) -> bool { |

382 | self.0 & 0x8000u16 != 0 |

383 | } |

384 | |

385 | /// Returns a number composed of the magnitude of `self` and the sign of `sign`. |

386 | /// |

387 | /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`. |

388 | /// If `self` is NaN, then NaN with the sign of `sign` is returned. |

389 | /// |

390 | /// # Examples |

391 | /// |

392 | /// ``` |

393 | /// # use half::prelude::*; |

394 | /// let f = bf16::from_f32(3.5); |

395 | /// |

396 | /// assert_eq!(f.copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); |

397 | /// assert_eq!(f.copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); |

398 | /// assert_eq!((-f).copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); |

399 | /// assert_eq!((-f).copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); |

400 | /// |

401 | /// assert!(bf16::NAN.copysign(bf16::from_f32(1.0)).is_nan()); |

402 | /// ``` |

403 | #[inline] |

404 | pub const fn copysign(self, sign: bf16) -> bf16 { |

405 | bf16((sign.0 & 0x8000u16) | (self.0 & 0x7FFFu16)) |

406 | } |

407 | |

408 | /// Returns the maximum of the two numbers. |

409 | /// |

410 | /// If one of the arguments is NaN, then the other argument is returned. |

411 | /// |

412 | /// # Examples |

413 | /// |

414 | /// ``` |

415 | /// # use half::prelude::*; |

416 | /// let x = bf16::from_f32(1.0); |

417 | /// let y = bf16::from_f32(2.0); |

418 | /// |

419 | /// assert_eq!(x.max(y), y); |

420 | /// ``` |

421 | #[inline] |

422 | pub fn max(self, other: bf16) -> bf16 { |

423 | if other > self && !other.is_nan() { |

424 | other |

425 | } else { |

426 | self |

427 | } |

428 | } |

429 | |

430 | /// Returns the minimum of the two numbers. |

431 | /// |

432 | /// If one of the arguments is NaN, then the other argument is returned. |

433 | /// |

434 | /// # Examples |

435 | /// |

436 | /// ``` |

437 | /// # use half::prelude::*; |

438 | /// let x = bf16::from_f32(1.0); |

439 | /// let y = bf16::from_f32(2.0); |

440 | /// |

441 | /// assert_eq!(x.min(y), x); |

442 | /// ``` |

443 | #[inline] |

444 | pub fn min(self, other: bf16) -> bf16 { |

445 | if other < self && !other.is_nan() { |

446 | other |

447 | } else { |

448 | self |

449 | } |

450 | } |

451 | |

452 | /// Restrict a value to a certain interval unless it is NaN. |

453 | /// |

454 | /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`. |

455 | /// Otherwise this returns `self`. |

456 | /// |

457 | /// Note that this function returns NaN if the initial value was NaN as well. |

458 | /// |

459 | /// # Panics |

460 | /// Panics if `min > max`, `min` is NaN, or `max` is NaN. |

461 | /// |

462 | /// # Examples |

463 | /// |

464 | /// ``` |

465 | /// # use half::prelude::*; |

466 | /// assert!(bf16::from_f32(-3.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(-2.0)); |

467 | /// assert!(bf16::from_f32(0.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(0.0)); |

468 | /// assert!(bf16::from_f32(2.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(1.0)); |

469 | /// assert!(bf16::NAN.clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)).is_nan()); |

470 | /// ``` |

471 | #[inline] |

472 | pub fn clamp(self, min: bf16, max: bf16) -> bf16 { |

473 | assert!(min <= max); |

474 | let mut x = self; |

475 | if x < min { |

476 | x = min; |

477 | } |

478 | if x > max { |

479 | x = max; |

480 | } |

481 | x |

482 | } |

483 | |

484 | /// Approximate number of [`bf16`] significant digits in base 10 |

485 | pub const DIGITS: u32 = 2; |

486 | /// [`bf16`] |

487 | /// [machine epsilon](https://en.wikipedia.org/wiki/Machine_epsilon) value |

488 | /// |

489 | /// This is the difference between 1.0 and the next largest representable number. |

490 | pub const EPSILON: bf16 = bf16(0x3C00u16); |

491 | /// [`bf16`] positive Infinity (+∞) |

492 | pub const INFINITY: bf16 = bf16(0x7F80u16); |

493 | /// Number of [`bf16`] significant digits in base 2 |

494 | pub const MANTISSA_DIGITS: u32 = 8; |

495 | /// Largest finite [`bf16`] value |

496 | pub const MAX: bf16 = bf16(0x7F7F); |

497 | /// Maximum possible [`bf16`] power of 10 exponent |

498 | pub const MAX_10_EXP: i32 = 38; |

499 | /// Maximum possible [`bf16`] power of 2 exponent |

500 | pub const MAX_EXP: i32 = 128; |

501 | /// Smallest finite [`bf16`] value |

502 | pub const MIN: bf16 = bf16(0xFF7F); |

503 | /// Minimum possible normal [`bf16`] power of 10 exponent |

504 | pub const MIN_10_EXP: i32 = -37; |

505 | /// One greater than the minimum possible normal [`bf16`] power of 2 exponent |

506 | pub const MIN_EXP: i32 = -125; |

507 | /// Smallest positive normal [`bf16`] value |

508 | pub const MIN_POSITIVE: bf16 = bf16(0x0080u16); |

509 | /// [`bf16`] Not a Number (NaN) |

510 | pub const NAN: bf16 = bf16(0x7FC0u16); |

511 | /// [`bf16`] negative infinity (-∞). |

512 | pub const NEG_INFINITY: bf16 = bf16(0xFF80u16); |

513 | /// The radix or base of the internal representation of [`bf16`] |

514 | pub const RADIX: u32 = 2; |

515 | |

516 | /// Minimum positive subnormal [`bf16`] value |

517 | pub const MIN_POSITIVE_SUBNORMAL: bf16 = bf16(0x0001u16); |

518 | /// Maximum subnormal [`bf16`] value |

519 | pub const MAX_SUBNORMAL: bf16 = bf16(0x007Fu16); |

520 | |

521 | /// [`bf16`] 1 |

522 | pub const ONE: bf16 = bf16(0x3F80u16); |

523 | /// [`bf16`] 0 |

524 | pub const ZERO: bf16 = bf16(0x0000u16); |

525 | /// [`bf16`] -0 |

526 | pub const NEG_ZERO: bf16 = bf16(0x8000u16); |

527 | /// [`bf16`] -1 |

528 | pub const NEG_ONE: bf16 = bf16(0xBF80u16); |

529 | |

530 | /// [`bf16`] Euler's number (ℯ) |

531 | pub const E: bf16 = bf16(0x402Eu16); |

532 | /// [`bf16`] Archimedes' constant (π) |

533 | pub const PI: bf16 = bf16(0x4049u16); |

534 | /// [`bf16`] 1/π |

535 | pub const FRAC_1_PI: bf16 = bf16(0x3EA3u16); |

536 | /// [`bf16`] 1/√2 |

537 | pub const FRAC_1_SQRT_2: bf16 = bf16(0x3F35u16); |

538 | /// [`bf16`] 2/π |

539 | pub const FRAC_2_PI: bf16 = bf16(0x3F23u16); |

540 | /// [`bf16`] 2/√π |

541 | pub const FRAC_2_SQRT_PI: bf16 = bf16(0x3F90u16); |

542 | /// [`bf16`] π/2 |

543 | pub const FRAC_PI_2: bf16 = bf16(0x3FC9u16); |

544 | /// [`bf16`] π/3 |

545 | pub const FRAC_PI_3: bf16 = bf16(0x3F86u16); |

546 | /// [`bf16`] π/4 |

547 | pub const FRAC_PI_4: bf16 = bf16(0x3F49u16); |

548 | /// [`bf16`] π/6 |

549 | pub const FRAC_PI_6: bf16 = bf16(0x3F06u16); |

550 | /// [`bf16`] π/8 |

551 | pub const FRAC_PI_8: bf16 = bf16(0x3EC9u16); |

552 | /// [`bf16`] 𝗅𝗇 10 |

553 | pub const LN_10: bf16 = bf16(0x4013u16); |

554 | /// [`bf16`] 𝗅𝗇 2 |

555 | pub const LN_2: bf16 = bf16(0x3F31u16); |

556 | /// [`bf16`] 𝗅𝗈𝗀₁₀ℯ |

557 | pub const LOG10_E: bf16 = bf16(0x3EDEu16); |

558 | /// [`bf16`] 𝗅𝗈𝗀₁₀2 |

559 | pub const LOG10_2: bf16 = bf16(0x3E9Au16); |

560 | /// [`bf16`] 𝗅𝗈𝗀₂ℯ |

561 | pub const LOG2_E: bf16 = bf16(0x3FB9u16); |

562 | /// [`bf16`] 𝗅𝗈𝗀₂10 |

563 | pub const LOG2_10: bf16 = bf16(0x4055u16); |

564 | /// [`bf16`] √2 |

565 | pub const SQRT_2: bf16 = bf16(0x3FB5u16); |

566 | } |

567 | |

568 | impl From<bf16> for f32 { |

569 | #[inline] |

570 | fn from(x: bf16) -> f32 { |

571 | x.to_f32() |

572 | } |

573 | } |

574 | |

575 | impl From<bf16> for f64 { |

576 | #[inline] |

577 | fn from(x: bf16) -> f64 { |

578 | x.to_f64() |

579 | } |

580 | } |

581 | |

582 | impl From<i8> for bf16 { |

583 | #[inline] |

584 | fn from(x: i8) -> bf16 { |

585 | // Convert to f32, then to bf16 |

586 | bf16::from_f32(f32::from(x)) |

587 | } |

588 | } |

589 | |

590 | impl From<u8> for bf16 { |

591 | #[inline] |

592 | fn from(x: u8) -> bf16 { |

593 | // Convert to f32, then to f16 |

594 | bf16::from_f32(f32::from(x)) |

595 | } |

596 | } |

597 | |

598 | impl PartialEq for bf16 { |

599 | fn eq(&self, other: &bf16) -> bool { |

600 | if self.is_nan() || other.is_nan() { |

601 | false |

602 | } else { |

603 | (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) |

604 | } |

605 | } |

606 | } |

607 | |

608 | impl PartialOrd for bf16 { |

609 | fn partial_cmp(&self, other: &bf16) -> Option<Ordering> { |

610 | if self.is_nan() || other.is_nan() { |

611 | None |

612 | } else { |

613 | let neg = self.0 & 0x8000u16 != 0; |

614 | let other_neg = other.0 & 0x8000u16 != 0; |

615 | match (neg, other_neg) { |

616 | (false, false) => Some(self.0.cmp(&other.0)), |

617 | (false, true) => { |

618 | if (self.0 | other.0) & 0x7FFFu16 == 0 { |

619 | Some(Ordering::Equal) |

620 | } else { |

621 | Some(Ordering::Greater) |

622 | } |

623 | } |

624 | (true, false) => { |

625 | if (self.0 | other.0) & 0x7FFFu16 == 0 { |

626 | Some(Ordering::Equal) |

627 | } else { |

628 | Some(Ordering::Less) |

629 | } |

630 | } |

631 | (true, true) => Some(other.0.cmp(&self.0)), |

632 | } |

633 | } |

634 | } |

635 | |

636 | fn lt(&self, other: &bf16) -> bool { |

637 | if self.is_nan() || other.is_nan() { |

638 | false |

639 | } else { |

640 | let neg = self.0 & 0x8000u16 != 0; |

641 | let other_neg = other.0 & 0x8000u16 != 0; |

642 | match (neg, other_neg) { |

643 | (false, false) => self.0 < other.0, |

644 | (false, true) => false, |

645 | (true, false) => (self.0 | other.0) & 0x7FFFu16 != 0, |

646 | (true, true) => self.0 > other.0, |

647 | } |

648 | } |

649 | } |

650 | |

651 | fn le(&self, other: &bf16) -> bool { |

652 | if self.is_nan() || other.is_nan() { |

653 | false |

654 | } else { |

655 | let neg = self.0 & 0x8000u16 != 0; |

656 | let other_neg = other.0 & 0x8000u16 != 0; |

657 | match (neg, other_neg) { |

658 | (false, false) => self.0 <= other.0, |

659 | (false, true) => (self.0 | other.0) & 0x7FFFu16 == 0, |

660 | (true, false) => true, |

661 | (true, true) => self.0 >= other.0, |

662 | } |

663 | } |

664 | } |

665 | |

666 | fn gt(&self, other: &bf16) -> bool { |

667 | if self.is_nan() || other.is_nan() { |

668 | false |

669 | } else { |

670 | let neg = self.0 & 0x8000u16 != 0; |

671 | let other_neg = other.0 & 0x8000u16 != 0; |

672 | match (neg, other_neg) { |

673 | (false, false) => self.0 > other.0, |

674 | (false, true) => (self.0 | other.0) & 0x7FFFu16 != 0, |

675 | (true, false) => false, |

676 | (true, true) => self.0 < other.0, |

677 | } |

678 | } |

679 | } |

680 | |

681 | fn ge(&self, other: &bf16) -> bool { |

682 | if self.is_nan() || other.is_nan() { |

683 | false |

684 | } else { |

685 | let neg = self.0 & 0x8000u16 != 0; |

686 | let other_neg = other.0 & 0x8000u16 != 0; |

687 | match (neg, other_neg) { |

688 | (false, false) => self.0 >= other.0, |

689 | (false, true) => true, |

690 | (true, false) => (self.0 | other.0) & 0x7FFFu16 == 0, |

691 | (true, true) => self.0 <= other.0, |

692 | } |

693 | } |

694 | } |

695 | } |

696 | |

697 | impl FromStr for bf16 { |

698 | type Err = ParseFloatError; |

699 | fn from_str(src: &str) -> Result<bf16, ParseFloatError> { |

700 | f32::from_str(src).map(bf16::from_f32) |

701 | } |

702 | } |

703 | |

704 | impl Debug for bf16 { |

705 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

706 | write!(f, "{:?}", self.to_f32()) |

707 | } |

708 | } |

709 | |

710 | impl Display for bf16 { |

711 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

712 | write!(f, "{}", self.to_f32()) |

713 | } |

714 | } |

715 | |

716 | impl LowerExp for bf16 { |

717 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

718 | write!(f, "{:e}", self.to_f32()) |

719 | } |

720 | } |

721 | |

722 | impl UpperExp for bf16 { |

723 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

724 | write!(f, "{:E}", self.to_f32()) |

725 | } |

726 | } |

727 | |

728 | impl Binary for bf16 { |

729 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

730 | write!(f, "{:b}", self.0) |

731 | } |

732 | } |

733 | |

734 | impl Octal for bf16 { |

735 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

736 | write!(f, "{:o}", self.0) |

737 | } |

738 | } |

739 | |

740 | impl LowerHex for bf16 { |

741 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

742 | write!(f, "{:x}", self.0) |

743 | } |

744 | } |

745 | |

746 | impl UpperHex for bf16 { |

747 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { |

748 | write!(f, "{:X}", self.0) |

749 | } |

750 | } |

751 | |

752 | impl Neg for bf16 { |

753 | type Output = Self; |

754 | |

755 | fn neg(self) -> Self::Output { |

756 | Self(self.0 ^ 0x8000) |

757 | } |

758 | } |

759 | |

760 | impl Add for bf16 { |

761 | type Output = Self; |

762 | |

763 | fn add(self, rhs: Self) -> Self::Output { |

764 | Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs)) |

765 | } |

766 | } |

767 | |

768 | impl Add<&bf16> for bf16 { |

769 | type Output = <bf16 as Add<bf16>>::Output; |

770 | |

771 | #[inline] |

772 | fn add(self, rhs: &bf16) -> Self::Output { |

773 | self.add(*rhs) |

774 | } |

775 | } |

776 | |

777 | impl Add<&bf16> for &bf16 { |

778 | type Output = <bf16 as Add<bf16>>::Output; |

779 | |

780 | #[inline] |

781 | fn add(self, rhs: &bf16) -> Self::Output { |

782 | (*self).add(*rhs) |

783 | } |

784 | } |

785 | |

786 | impl Add<bf16> for &bf16 { |

787 | type Output = <bf16 as Add<bf16>>::Output; |

788 | |

789 | #[inline] |

790 | fn add(self, rhs: bf16) -> Self::Output { |

791 | (*self).add(rhs) |

792 | } |

793 | } |

794 | |

795 | impl AddAssign for bf16 { |

796 | #[inline] |

797 | fn add_assign(&mut self, rhs: Self) { |

798 | *self = (*self).add(rhs); |

799 | } |

800 | } |

801 | |

802 | impl AddAssign<&bf16> for bf16 { |

803 | #[inline] |

804 | fn add_assign(&mut self, rhs: &bf16) { |

805 | *self = (*self).add(rhs); |

806 | } |

807 | } |

808 | |

809 | impl Sub for bf16 { |

810 | type Output = Self; |

811 | |

812 | fn sub(self, rhs: Self) -> Self::Output { |

813 | Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs)) |

814 | } |

815 | } |

816 | |

817 | impl Sub<&bf16> for bf16 { |

818 | type Output = <bf16 as Sub<bf16>>::Output; |

819 | |

820 | #[inline] |

821 | fn sub(self, rhs: &bf16) -> Self::Output { |

822 | self.sub(*rhs) |

823 | } |

824 | } |

825 | |

826 | impl Sub<&bf16> for &bf16 { |

827 | type Output = <bf16 as Sub<bf16>>::Output; |

828 | |

829 | #[inline] |

830 | fn sub(self, rhs: &bf16) -> Self::Output { |

831 | (*self).sub(*rhs) |

832 | } |

833 | } |

834 | |

835 | impl Sub<bf16> for &bf16 { |

836 | type Output = <bf16 as Sub<bf16>>::Output; |

837 | |

838 | #[inline] |

839 | fn sub(self, rhs: bf16) -> Self::Output { |

840 | (*self).sub(rhs) |

841 | } |

842 | } |

843 | |

844 | impl SubAssign for bf16 { |

845 | #[inline] |

846 | fn sub_assign(&mut self, rhs: Self) { |

847 | *self = (*self).sub(rhs); |

848 | } |

849 | } |

850 | |

851 | impl SubAssign<&bf16> for bf16 { |

852 | #[inline] |

853 | fn sub_assign(&mut self, rhs: &bf16) { |

854 | *self = (*self).sub(rhs); |

855 | } |

856 | } |

857 | |

858 | impl Mul for bf16 { |

859 | type Output = Self; |

860 | |

861 | fn mul(self, rhs: Self) -> Self::Output { |

862 | Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs)) |

863 | } |

864 | } |

865 | |

866 | impl Mul<&bf16> for bf16 { |

867 | type Output = <bf16 as Mul<bf16>>::Output; |

868 | |

869 | #[inline] |

870 | fn mul(self, rhs: &bf16) -> Self::Output { |

871 | self.mul(*rhs) |

872 | } |

873 | } |

874 | |

875 | impl Mul<&bf16> for &bf16 { |

876 | type Output = <bf16 as Mul<bf16>>::Output; |

877 | |

878 | #[inline] |

879 | fn mul(self, rhs: &bf16) -> Self::Output { |

880 | (*self).mul(*rhs) |

881 | } |

882 | } |

883 | |

884 | impl Mul<bf16> for &bf16 { |

885 | type Output = <bf16 as Mul<bf16>>::Output; |

886 | |

887 | #[inline] |

888 | fn mul(self, rhs: bf16) -> Self::Output { |

889 | (*self).mul(rhs) |

890 | } |

891 | } |

892 | |

893 | impl MulAssign for bf16 { |

894 | #[inline] |

895 | fn mul_assign(&mut self, rhs: Self) { |

896 | *self = (*self).mul(rhs); |

897 | } |

898 | } |

899 | |

900 | impl MulAssign<&bf16> for bf16 { |

901 | #[inline] |

902 | fn mul_assign(&mut self, rhs: &bf16) { |

903 | *self = (*self).mul(rhs); |

904 | } |

905 | } |

906 | |

907 | impl Div for bf16 { |

908 | type Output = Self; |

909 | |

910 | fn div(self, rhs: Self) -> Self::Output { |

911 | Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs)) |

912 | } |

913 | } |

914 | |

915 | impl Div<&bf16> for bf16 { |

916 | type Output = <bf16 as Div<bf16>>::Output; |

917 | |

918 | #[inline] |

919 | fn div(self, rhs: &bf16) -> Self::Output { |

920 | self.div(*rhs) |

921 | } |

922 | } |

923 | |

924 | impl Div<&bf16> for &bf16 { |

925 | type Output = <bf16 as Div<bf16>>::Output; |

926 | |

927 | #[inline] |

928 | fn div(self, rhs: &bf16) -> Self::Output { |

929 | (*self).div(*rhs) |

930 | } |

931 | } |

932 | |

933 | impl Div<bf16> for &bf16 { |

934 | type Output = <bf16 as Div<bf16>>::Output; |

935 | |

936 | #[inline] |

937 | fn div(self, rhs: bf16) -> Self::Output { |

938 | (*self).div(rhs) |

939 | } |

940 | } |

941 | |

942 | impl DivAssign for bf16 { |

943 | #[inline] |

944 | fn div_assign(&mut self, rhs: Self) { |

945 | *self = (*self).div(rhs); |

946 | } |

947 | } |

948 | |

949 | impl DivAssign<&bf16> for bf16 { |

950 | #[inline] |

951 | fn div_assign(&mut self, rhs: &bf16) { |

952 | *self = (*self).div(rhs); |

953 | } |

954 | } |

955 | |

956 | impl Rem for bf16 { |

957 | type Output = Self; |

958 | |

959 | fn rem(self, rhs: Self) -> Self::Output { |

960 | Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs)) |

961 | } |

962 | } |

963 | |

964 | impl Rem<&bf16> for bf16 { |

965 | type Output = <bf16 as Rem<bf16>>::Output; |

966 | |

967 | #[inline] |

968 | fn rem(self, rhs: &bf16) -> Self::Output { |

969 | self.rem(*rhs) |

970 | } |

971 | } |

972 | |

973 | impl Rem<&bf16> for &bf16 { |

974 | type Output = <bf16 as Rem<bf16>>::Output; |

975 | |

976 | #[inline] |

977 | fn rem(self, rhs: &bf16) -> Self::Output { |

978 | (*self).rem(*rhs) |

979 | } |

980 | } |

981 | |

982 | impl Rem<bf16> for &bf16 { |

983 | type Output = <bf16 as Rem<bf16>>::Output; |

984 | |

985 | #[inline] |

986 | fn rem(self, rhs: bf16) -> Self::Output { |

987 | (*self).rem(rhs) |

988 | } |

989 | } |

990 | |

991 | impl RemAssign for bf16 { |

992 | #[inline] |

993 | fn rem_assign(&mut self, rhs: Self) { |

994 | *self = (*self).rem(rhs); |

995 | } |

996 | } |

997 | |

998 | impl RemAssign<&bf16> for bf16 { |

999 | #[inline] |

1000 | fn rem_assign(&mut self, rhs: &bf16) { |

1001 | *self = (*self).rem(rhs); |

1002 | } |

1003 | } |

1004 | |

1005 | impl Product for bf16 { |

1006 | #[inline] |

1007 | fn product<I: Iterator<Item = Self>>(iter: I) -> Self { |

1008 | bf16::from_f32(iter.map(|f| f.to_f32()).product()) |

1009 | } |

1010 | } |

1011 | |

1012 | impl<'a> Product<&'a bf16> for bf16 { |

1013 | #[inline] |

1014 | fn product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { |

1015 | bf16::from_f32(iter.map(|f| f.to_f32()).product()) |

1016 | } |

1017 | } |

1018 | |

1019 | impl Sum for bf16 { |

1020 | #[inline] |

1021 | fn sum<I: Iterator<Item = Self>>(iter: I) -> Self { |

1022 | bf16::from_f32(iter.map(|f| f.to_f32()).sum()) |

1023 | } |

1024 | } |

1025 | |

1026 | impl<'a> Sum<&'a bf16> for bf16 { |

1027 | #[inline] |

1028 | fn sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { |

1029 | bf16::from_f32(iter.map(|f| f.to_f32()).product()) |

1030 | } |

1031 | } |

1032 | |

1033 | #[allow( |

1034 | clippy::cognitive_complexity, |

1035 | clippy::float_cmp, |

1036 | clippy::neg_cmp_op_on_partial_ord |

1037 | )] |

1038 | #[cfg(test)] |

1039 | mod test { |

1040 | use super::*; |

1041 | use core::cmp::Ordering; |

1042 | #[cfg(feature = "num-traits")] |

1043 | use num_traits::{AsPrimitive, FromPrimitive, ToPrimitive}; |

1044 | use quickcheck_macros::quickcheck; |

1045 | |

1046 | #[cfg(feature = "num-traits")] |

1047 | #[test] |

1048 | fn as_primitive() { |

1049 | let two = bf16::from_f32(2.0); |

1050 | assert_eq!(<i32 as AsPrimitive<bf16>>::as_(2), two); |

1051 | assert_eq!(<bf16 as AsPrimitive<i32>>::as_(two), 2); |

1052 | |

1053 | assert_eq!(<f32 as AsPrimitive<bf16>>::as_(2.0), two); |

1054 | assert_eq!(<bf16 as AsPrimitive<f32>>::as_(two), 2.0); |

1055 | |

1056 | assert_eq!(<f64 as AsPrimitive<bf16>>::as_(2.0), two); |

1057 | assert_eq!(<bf16 as AsPrimitive<f64>>::as_(two), 2.0); |

1058 | } |

1059 | |

1060 | #[cfg(feature = "num-traits")] |

1061 | #[test] |

1062 | fn to_primitive() { |

1063 | let two = bf16::from_f32(2.0); |

1064 | assert_eq!(ToPrimitive::to_i32(&two).unwrap(), 2i32); |

1065 | assert_eq!(ToPrimitive::to_f32(&two).unwrap(), 2.0f32); |

1066 | assert_eq!(ToPrimitive::to_f64(&two).unwrap(), 2.0f64); |

1067 | } |

1068 | |

1069 | #[cfg(feature = "num-traits")] |

1070 | #[test] |

1071 | fn from_primitive() { |

1072 | let two = bf16::from_f32(2.0); |

1073 | assert_eq!(<bf16 as FromPrimitive>::from_i32(2).unwrap(), two); |

1074 | assert_eq!(<bf16 as FromPrimitive>::from_f32(2.0).unwrap(), two); |

1075 | assert_eq!(<bf16 as FromPrimitive>::from_f64(2.0).unwrap(), two); |

1076 | } |

1077 | |

1078 | #[test] |

1079 | fn test_bf16_consts_from_f32() { |

1080 | let one = bf16::from_f32(1.0); |

1081 | let zero = bf16::from_f32(0.0); |

1082 | let neg_zero = bf16::from_f32(-0.0); |

1083 | let neg_one = bf16::from_f32(-1.0); |

1084 | let inf = bf16::from_f32(core::f32::INFINITY); |

1085 | let neg_inf = bf16::from_f32(core::f32::NEG_INFINITY); |

1086 | let nan = bf16::from_f32(core::f32::NAN); |

1087 | |

1088 | assert_eq!(bf16::ONE, one); |

1089 | assert_eq!(bf16::ZERO, zero); |

1090 | assert!(zero.is_sign_positive()); |

1091 | assert_eq!(bf16::NEG_ZERO, neg_zero); |

1092 | assert!(neg_zero.is_sign_negative()); |

1093 | assert_eq!(bf16::NEG_ONE, neg_one); |

1094 | assert!(neg_one.is_sign_negative()); |

1095 | assert_eq!(bf16::INFINITY, inf); |

1096 | assert_eq!(bf16::NEG_INFINITY, neg_inf); |

1097 | assert!(nan.is_nan()); |

1098 | assert!(bf16::NAN.is_nan()); |

1099 | |

1100 | let e = bf16::from_f32(core::f32::consts::E); |

1101 | let pi = bf16::from_f32(core::f32::consts::PI); |

1102 | let frac_1_pi = bf16::from_f32(core::f32::consts::FRAC_1_PI); |

1103 | let frac_1_sqrt_2 = bf16::from_f32(core::f32::consts::FRAC_1_SQRT_2); |

1104 | let frac_2_pi = bf16::from_f32(core::f32::consts::FRAC_2_PI); |

1105 | let frac_2_sqrt_pi = bf16::from_f32(core::f32::consts::FRAC_2_SQRT_PI); |

1106 | let frac_pi_2 = bf16::from_f32(core::f32::consts::FRAC_PI_2); |

1107 | let frac_pi_3 = bf16::from_f32(core::f32::consts::FRAC_PI_3); |

1108 | let frac_pi_4 = bf16::from_f32(core::f32::consts::FRAC_PI_4); |

1109 | let frac_pi_6 = bf16::from_f32(core::f32::consts::FRAC_PI_6); |

1110 | let frac_pi_8 = bf16::from_f32(core::f32::consts::FRAC_PI_8); |

1111 | let ln_10 = bf16::from_f32(core::f32::consts::LN_10); |

1112 | let ln_2 = bf16::from_f32(core::f32::consts::LN_2); |

1113 | let log10_e = bf16::from_f32(core::f32::consts::LOG10_E); |

1114 | // core::f32::consts::LOG10_2 requires rustc 1.43.0 |

1115 | let log10_2 = bf16::from_f32(2f32.log10()); |

1116 | let log2_e = bf16::from_f32(core::f32::consts::LOG2_E); |

1117 | // core::f32::consts::LOG2_10 requires rustc 1.43.0 |

1118 | let log2_10 = bf16::from_f32(10f32.log2()); |

1119 | let sqrt_2 = bf16::from_f32(core::f32::consts::SQRT_2); |

1120 | |

1121 | assert_eq!(bf16::E, e); |

1122 | assert_eq!(bf16::PI, pi); |

1123 | assert_eq!(bf16::FRAC_1_PI, frac_1_pi); |

1124 | assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); |

1125 | assert_eq!(bf16::FRAC_2_PI, frac_2_pi); |

1126 | assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); |

1127 | assert_eq!(bf16::FRAC_PI_2, frac_pi_2); |

1128 | assert_eq!(bf16::FRAC_PI_3, frac_pi_3); |

1129 | assert_eq!(bf16::FRAC_PI_4, frac_pi_4); |

1130 | assert_eq!(bf16::FRAC_PI_6, frac_pi_6); |

1131 | assert_eq!(bf16::FRAC_PI_8, frac_pi_8); |

1132 | assert_eq!(bf16::LN_10, ln_10); |

1133 | assert_eq!(bf16::LN_2, ln_2); |

1134 | assert_eq!(bf16::LOG10_E, log10_e); |

1135 | assert_eq!(bf16::LOG10_2, log10_2); |

1136 | assert_eq!(bf16::LOG2_E, log2_e); |

1137 | assert_eq!(bf16::LOG2_10, log2_10); |

1138 | assert_eq!(bf16::SQRT_2, sqrt_2); |

1139 | } |

1140 | |

1141 | #[test] |

1142 | fn test_bf16_consts_from_f64() { |

1143 | let one = bf16::from_f64(1.0); |

1144 | let zero = bf16::from_f64(0.0); |

1145 | let neg_zero = bf16::from_f64(-0.0); |

1146 | let inf = bf16::from_f64(core::f64::INFINITY); |

1147 | let neg_inf = bf16::from_f64(core::f64::NEG_INFINITY); |

1148 | let nan = bf16::from_f64(core::f64::NAN); |

1149 | |

1150 | assert_eq!(bf16::ONE, one); |

1151 | assert_eq!(bf16::ZERO, zero); |

1152 | assert_eq!(bf16::NEG_ZERO, neg_zero); |

1153 | assert_eq!(bf16::INFINITY, inf); |

1154 | assert_eq!(bf16::NEG_INFINITY, neg_inf); |

1155 | assert!(nan.is_nan()); |

1156 | assert!(bf16::NAN.is_nan()); |

1157 | |

1158 | let e = bf16::from_f64(core::f64::consts::E); |

1159 | let pi = bf16::from_f64(core::f64::consts::PI); |

1160 | let frac_1_pi = bf16::from_f64(core::f64::consts::FRAC_1_PI); |

1161 | let frac_1_sqrt_2 = bf16::from_f64(core::f64::consts::FRAC_1_SQRT_2); |

1162 | let frac_2_pi = bf16::from_f64(core::f64::consts::FRAC_2_PI); |

1163 | let frac_2_sqrt_pi = bf16::from_f64(core::f64::consts::FRAC_2_SQRT_PI); |

1164 | let frac_pi_2 = bf16::from_f64(core::f64::consts::FRAC_PI_2); |

1165 | let frac_pi_3 = bf16::from_f64(core::f64::consts::FRAC_PI_3); |

1166 | let frac_pi_4 = bf16::from_f64(core::f64::consts::FRAC_PI_4); |

1167 | let frac_pi_6 = bf16::from_f64(core::f64::consts::FRAC_PI_6); |

1168 | let frac_pi_8 = bf16::from_f64(core::f64::consts::FRAC_PI_8); |

1169 | let ln_10 = bf16::from_f64(core::f64::consts::LN_10); |

1170 | let ln_2 = bf16::from_f64(core::f64::consts::LN_2); |

1171 | let log10_e = bf16::from_f64(core::f64::consts::LOG10_E); |

1172 | // core::f64::consts::LOG10_2 requires rustc 1.43.0 |

1173 | let log10_2 = bf16::from_f64(2f64.log10()); |

1174 | let log2_e = bf16::from_f64(core::f64::consts::LOG2_E); |

1175 | // core::f64::consts::LOG2_10 requires rustc 1.43.0 |

1176 | let log2_10 = bf16::from_f64(10f64.log2()); |

1177 | let sqrt_2 = bf16::from_f64(core::f64::consts::SQRT_2); |

1178 | |

1179 | assert_eq!(bf16::E, e); |

1180 | assert_eq!(bf16::PI, pi); |

1181 | assert_eq!(bf16::FRAC_1_PI, frac_1_pi); |

1182 | assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); |

1183 | assert_eq!(bf16::FRAC_2_PI, frac_2_pi); |

1184 | assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); |

1185 | assert_eq!(bf16::FRAC_PI_2, frac_pi_2); |

1186 | assert_eq!(bf16::FRAC_PI_3, frac_pi_3); |

1187 | assert_eq!(bf16::FRAC_PI_4, frac_pi_4); |

1188 | assert_eq!(bf16::FRAC_PI_6, frac_pi_6); |

1189 | assert_eq!(bf16::FRAC_PI_8, frac_pi_8); |

1190 | assert_eq!(bf16::LN_10, ln_10); |

1191 | assert_eq!(bf16::LN_2, ln_2); |

1192 | assert_eq!(bf16::LOG10_E, log10_e); |

1193 | assert_eq!(bf16::LOG10_2, log10_2); |

1194 | assert_eq!(bf16::LOG2_E, log2_e); |

1195 | assert_eq!(bf16::LOG2_10, log2_10); |

1196 | assert_eq!(bf16::SQRT_2, sqrt_2); |

1197 | } |

1198 | |

1199 | #[test] |

1200 | fn test_nan_conversion_to_smaller() { |

1201 | let nan64 = f64::from_bits(0x7FF0_0000_0000_0001u64); |

1202 | let neg_nan64 = f64::from_bits(0xFFF0_0000_0000_0001u64); |

1203 | let nan32 = f32::from_bits(0x7F80_0001u32); |

1204 | let neg_nan32 = f32::from_bits(0xFF80_0001u32); |

1205 | let nan32_from_64 = nan64 as f32; |

1206 | let neg_nan32_from_64 = neg_nan64 as f32; |

1207 | let nan16_from_64 = bf16::from_f64(nan64); |

1208 | let neg_nan16_from_64 = bf16::from_f64(neg_nan64); |

1209 | let nan16_from_32 = bf16::from_f32(nan32); |

1210 | let neg_nan16_from_32 = bf16::from_f32(neg_nan32); |

1211 | |

1212 | assert!(nan64.is_nan() && nan64.is_sign_positive()); |

1213 | assert!(neg_nan64.is_nan() && neg_nan64.is_sign_negative()); |

1214 | assert!(nan32.is_nan() && nan32.is_sign_positive()); |

1215 | assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); |

1216 | assert!(nan32_from_64.is_nan() && nan32_from_64.is_sign_positive()); |

1217 | assert!(neg_nan32_from_64.is_nan() && neg_nan32_from_64.is_sign_negative()); |

1218 | assert!(nan16_from_64.is_nan() && nan16_from_64.is_sign_positive()); |

1219 | assert!(neg_nan16_from_64.is_nan() && neg_nan16_from_64.is_sign_negative()); |

1220 | assert!(nan16_from_32.is_nan() && nan16_from_32.is_sign_positive()); |

1221 | assert!(neg_nan16_from_32.is_nan() && neg_nan16_from_32.is_sign_negative()); |

1222 | } |

1223 | |

1224 | #[test] |

1225 | fn test_nan_conversion_to_larger() { |

1226 | let nan16 = bf16::from_bits(0x7F81u16); |

1227 | let neg_nan16 = bf16::from_bits(0xFF81u16); |

1228 | let nan32 = f32::from_bits(0x7F80_0001u32); |

1229 | let neg_nan32 = f32::from_bits(0xFF80_0001u32); |

1230 | let nan32_from_16 = f32::from(nan16); |

1231 | let neg_nan32_from_16 = f32::from(neg_nan16); |

1232 | let nan64_from_16 = f64::from(nan16); |

1233 | let neg_nan64_from_16 = f64::from(neg_nan16); |

1234 | let nan64_from_32 = f64::from(nan32); |

1235 | let neg_nan64_from_32 = f64::from(neg_nan32); |

1236 | |

1237 | assert!(nan16.is_nan() && nan16.is_sign_positive()); |

1238 | assert!(neg_nan16.is_nan() && neg_nan16.is_sign_negative()); |

1239 | assert!(nan32.is_nan() && nan32.is_sign_positive()); |

1240 | assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); |

1241 | assert!(nan32_from_16.is_nan() && nan32_from_16.is_sign_positive()); |

1242 | assert!(neg_nan32_from_16.is_nan() && neg_nan32_from_16.is_sign_negative()); |

1243 | assert!(nan64_from_16.is_nan() && nan64_from_16.is_sign_positive()); |

1244 | assert!(neg_nan64_from_16.is_nan() && neg_nan64_from_16.is_sign_negative()); |

1245 | assert!(nan64_from_32.is_nan() && nan64_from_32.is_sign_positive()); |

1246 | assert!(neg_nan64_from_32.is_nan() && neg_nan64_from_32.is_sign_negative()); |

1247 | } |

1248 | |

1249 | #[test] |

1250 | fn test_bf16_to_f32() { |

1251 | let f = bf16::from_f32(7.0); |

1252 | assert_eq!(f.to_f32(), 7.0f32); |

1253 | |

1254 | // 7.1 is NOT exactly representable in 16-bit, it's rounded |

1255 | let f = bf16::from_f32(7.1); |

1256 | let diff = (f.to_f32() - 7.1f32).abs(); |

1257 | // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 |

1258 | assert!(diff <= 4.0 * bf16::EPSILON.to_f32()); |

1259 | |

1260 | let tiny32 = f32::from_bits(0x0001_0000u32); |

1261 | assert_eq!(bf16::from_bits(0x0001).to_f32(), tiny32); |

1262 | assert_eq!(bf16::from_bits(0x0005).to_f32(), 5.0 * tiny32); |

1263 | |

1264 | assert_eq!(bf16::from_bits(0x0001), bf16::from_f32(tiny32)); |

1265 | assert_eq!(bf16::from_bits(0x0005), bf16::from_f32(5.0 * tiny32)); |

1266 | } |

1267 | |

1268 | #[test] |

1269 | fn test_bf16_to_f64() { |

1270 | let f = bf16::from_f64(7.0); |

1271 | assert_eq!(f.to_f64(), 7.0f64); |

1272 | |

1273 | // 7.1 is NOT exactly representable in 16-bit, it's rounded |

1274 | let f = bf16::from_f64(7.1); |

1275 | let diff = (f.to_f64() - 7.1f64).abs(); |

1276 | // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 |

1277 | assert!(diff <= 4.0 * bf16::EPSILON.to_f64()); |

1278 | |

1279 | let tiny64 = 2.0f64.powi(-133); |

1280 | assert_eq!(bf16::from_bits(0x0001).to_f64(), tiny64); |

1281 | assert_eq!(bf16::from_bits(0x0005).to_f64(), 5.0 * tiny64); |

1282 | |

1283 | assert_eq!(bf16::from_bits(0x0001), bf16::from_f64(tiny64)); |

1284 | assert_eq!(bf16::from_bits(0x0005), bf16::from_f64(5.0 * tiny64)); |

1285 | } |

1286 | |

1287 | #[test] |

1288 | fn test_comparisons() { |

1289 | let zero = bf16::from_f64(0.0); |

1290 | let one = bf16::from_f64(1.0); |

1291 | let neg_zero = bf16::from_f64(-0.0); |

1292 | let neg_one = bf16::from_f64(-1.0); |

1293 | |

1294 | assert_eq!(zero.partial_cmp(&neg_zero), Some(Ordering::Equal)); |

1295 | assert_eq!(neg_zero.partial_cmp(&zero), Some(Ordering::Equal)); |

1296 | assert!(zero == neg_zero); |

1297 | assert!(neg_zero == zero); |

1298 | assert!(!(zero != neg_zero)); |

1299 | assert!(!(neg_zero != zero)); |

1300 | assert!(!(zero < neg_zero)); |

1301 | assert!(!(neg_zero < zero)); |

1302 | assert!(zero <= neg_zero); |

1303 | assert!(neg_zero <= zero); |

1304 | assert!(!(zero > neg_zero)); |

1305 | assert!(!(neg_zero > zero)); |

1306 | assert!(zero >= neg_zero); |

1307 | assert!(neg_zero >= zero); |

1308 | |

1309 | assert_eq!(one.partial_cmp(&neg_zero), Some(Ordering::Greater)); |

1310 | assert_eq!(neg_zero.partial_cmp(&one), Some(Ordering::Less)); |

1311 | assert!(!(one == neg_zero)); |

1312 | assert!(!(neg_zero == one)); |

1313 | assert!(one != neg_zero); |

1314 | assert!(neg_zero != one); |

1315 | assert!(!(one < neg_zero)); |

1316 | assert!(neg_zero < one); |

1317 | assert!(!(one <= neg_zero)); |

1318 | assert!(neg_zero <= one); |

1319 | assert!(one > neg_zero); |

1320 | assert!(!(neg_zero > one)); |

1321 | assert!(one >= neg_zero); |

1322 | assert!(!(neg_zero >= one)); |

1323 | |

1324 | assert_eq!(one.partial_cmp(&neg_one), Some(Ordering::Greater)); |

1325 | assert_eq!(neg_one.partial_cmp(&one), Some(Ordering::Less)); |

1326 | assert!(!(one == neg_one)); |

1327 | assert!(!(neg_one == one)); |

1328 | assert!(one != neg_one); |

1329 | assert!(neg_one != one); |

1330 | assert!(!(one < neg_one)); |

1331 | assert!(neg_one < one); |

1332 | assert!(!(one <= neg_one)); |

1333 | assert!(neg_one <= one); |

1334 | assert!(one > neg_one); |

1335 | assert!(!(neg_one > one)); |

1336 | assert!(one >= neg_one); |

1337 | assert!(!(neg_one >= one)); |

1338 | } |

1339 | |

1340 | #[test] |

1341 | #[allow(clippy::erasing_op, clippy::identity_op)] |

1342 | fn round_to_even_f32() { |

1343 | // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 |

1344 | let min_sub = bf16::from_bits(1); |

1345 | let min_sub_f = (-133f32).exp2(); |

1346 | assert_eq!(bf16::from_f32(min_sub_f).to_bits(), min_sub.to_bits()); |

1347 | assert_eq!(f32::from(min_sub).to_bits(), min_sub_f.to_bits()); |

1348 | |

1349 | // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) |

1350 | // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) |

1351 | // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) |

1352 | assert_eq!( |

1353 | bf16::from_f32(min_sub_f * 0.49).to_bits(), |

1354 | min_sub.to_bits() * 0 |

1355 | ); |

1356 | assert_eq!( |

1357 | bf16::from_f32(min_sub_f * 0.50).to_bits(), |

1358 | min_sub.to_bits() * 0 |

1359 | ); |

1360 | assert_eq!( |

1361 | bf16::from_f32(min_sub_f * 0.51).to_bits(), |

1362 | min_sub.to_bits() * 1 |

1363 | ); |

1364 | |

1365 | // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) |

1366 | // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) |

1367 | // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) |

1368 | assert_eq!( |

1369 | bf16::from_f32(min_sub_f * 1.49).to_bits(), |

1370 | min_sub.to_bits() * 1 |

1371 | ); |

1372 | assert_eq!( |

1373 | bf16::from_f32(min_sub_f * 1.50).to_bits(), |

1374 | min_sub.to_bits() * 2 |

1375 | ); |

1376 | assert_eq!( |

1377 | bf16::from_f32(min_sub_f * 1.51).to_bits(), |

1378 | min_sub.to_bits() * 2 |

1379 | ); |

1380 | |

1381 | // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) |

1382 | // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) |

1383 | // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) |

1384 | assert_eq!( |

1385 | bf16::from_f32(min_sub_f * 2.49).to_bits(), |

1386 | min_sub.to_bits() * 2 |

1387 | ); |

1388 | assert_eq!( |

1389 | bf16::from_f32(min_sub_f * 2.50).to_bits(), |

1390 | min_sub.to_bits() * 2 |

1391 | ); |

1392 | assert_eq!( |

1393 | bf16::from_f32(min_sub_f * 2.51).to_bits(), |

1394 | min_sub.to_bits() * 3 |

1395 | ); |

1396 | |

1397 | assert_eq!( |

1398 | bf16::from_f32(250.49f32).to_bits(), |

1399 | bf16::from_f32(250.0).to_bits() |

1400 | ); |

1401 | assert_eq!( |

1402 | bf16::from_f32(250.50f32).to_bits(), |

1403 | bf16::from_f32(250.0).to_bits() |

1404 | ); |

1405 | assert_eq!( |

1406 | bf16::from_f32(250.51f32).to_bits(), |

1407 | bf16::from_f32(251.0).to_bits() |

1408 | ); |

1409 | assert_eq!( |

1410 | bf16::from_f32(251.49f32).to_bits(), |

1411 | bf16::from_f32(251.0).to_bits() |

1412 | ); |

1413 | assert_eq!( |

1414 | bf16::from_f32(251.50f32).to_bits(), |

1415 | bf16::from_f32(252.0).to_bits() |

1416 | ); |

1417 | assert_eq!( |

1418 | bf16::from_f32(251.51f32).to_bits(), |

1419 | bf16::from_f32(252.0).to_bits() |

1420 | ); |

1421 | assert_eq!( |

1422 | bf16::from_f32(252.49f32).to_bits(), |

1423 | bf16::from_f32(252.0).to_bits() |

1424 | ); |

1425 | assert_eq!( |

1426 | bf16::from_f32(252.50f32).to_bits(), |

1427 | bf16::from_f32(252.0).to_bits() |

1428 | ); |

1429 | assert_eq!( |

1430 | bf16::from_f32(252.51f32).to_bits(), |

1431 | bf16::from_f32(253.0).to_bits() |

1432 | ); |

1433 | } |

1434 | |

1435 | #[test] |

1436 | #[allow(clippy::erasing_op, clippy::identity_op)] |

1437 | fn round_to_even_f64() { |

1438 | // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 |

1439 | let min_sub = bf16::from_bits(1); |

1440 | let min_sub_f = (-133f64).exp2(); |

1441 | assert_eq!(bf16::from_f64(min_sub_f).to_bits(), min_sub.to_bits()); |

1442 | assert_eq!(f64::from(min_sub).to_bits(), min_sub_f.to_bits()); |

1443 | |

1444 | // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) |

1445 | // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) |

1446 | // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) |

1447 | assert_eq!( |

1448 | bf16::from_f64(min_sub_f * 0.49).to_bits(), |

1449 | min_sub.to_bits() * 0 |

1450 | ); |

1451 | assert_eq!( |

1452 | bf16::from_f64(min_sub_f * 0.50).to_bits(), |

1453 | min_sub.to_bits() * 0 |

1454 | ); |

1455 | assert_eq!( |

1456 | bf16::from_f64(min_sub_f * 0.51).to_bits(), |

1457 | min_sub.to_bits() * 1 |

1458 | ); |

1459 | |

1460 | // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) |

1461 | // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) |

1462 | // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) |

1463 | assert_eq!( |

1464 | bf16::from_f64(min_sub_f * 1.49).to_bits(), |

1465 | min_sub.to_bits() * 1 |

1466 | ); |

1467 | assert_eq!( |

1468 | bf16::from_f64(min_sub_f * 1.50).to_bits(), |

1469 | min_sub.to_bits() * 2 |

1470 | ); |

1471 | assert_eq!( |

1472 | bf16::from_f64(min_sub_f * 1.51).to_bits(), |

1473 | min_sub.to_bits() * 2 |

1474 | ); |

1475 | |

1476 | // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) |

1477 | // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) |

1478 | // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) |

1479 | assert_eq!( |

1480 | bf16::from_f64(min_sub_f * 2.49).to_bits(), |

1481 | min_sub.to_bits() * 2 |

1482 | ); |

1483 | assert_eq!( |

1484 | bf16::from_f64(min_sub_f * 2.50).to_bits(), |

1485 | min_sub.to_bits() * 2 |

1486 | ); |

1487 | assert_eq!( |

1488 | bf16::from_f64(min_sub_f * 2.51).to_bits(), |

1489 | min_sub.to_bits() * 3 |

1490 | ); |

1491 | |

1492 | assert_eq!( |

1493 | bf16::from_f64(250.49f64).to_bits(), |

1494 | bf16::from_f64(250.0).to_bits() |

1495 | ); |

1496 | assert_eq!( |

1497 | bf16::from_f64(250.50f64).to_bits(), |

1498 | bf16::from_f64(250.0).to_bits() |

1499 | ); |

1500 | assert_eq!( |

1501 | bf16::from_f64(250.51f64).to_bits(), |

1502 | bf16::from_f64(251.0).to_bits() |

1503 | ); |

1504 | assert_eq!( |

1505 | bf16::from_f64(251.49f64).to_bits(), |

1506 | bf16::from_f64(251.0).to_bits() |

1507 | ); |

1508 | assert_eq!( |

1509 | bf16::from_f64(251.50f64).to_bits(), |

1510 | bf16::from_f64(252.0).to_bits() |

1511 | ); |

1512 | assert_eq!( |

1513 | bf16::from_f64(251.51f64).to_bits(), |

1514 | bf16::from_f64(252.0).to_bits() |

1515 | ); |

1516 | assert_eq!( |

1517 | bf16::from_f64(252.49f64).to_bits(), |

1518 | bf16::from_f64(252.0).to_bits() |

1519 | ); |

1520 | assert_eq!( |

1521 | bf16::from_f64(252.50f64).to_bits(), |

1522 | bf16::from_f64(252.0).to_bits() |

1523 | ); |

1524 | assert_eq!( |

1525 | bf16::from_f64(252.51f64).to_bits(), |

1526 | bf16::from_f64(253.0).to_bits() |

1527 | ); |

1528 | } |

1529 | |

1530 | impl quickcheck::Arbitrary for bf16 { |

1531 | fn arbitrary(g: &mut quickcheck::Gen) -> Self { |

1532 | bf16(u16::arbitrary(g)) |

1533 | } |

1534 | } |

1535 | |

1536 | #[quickcheck] |

1537 | fn qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool { |

1538 | let roundtrip = bf16::from_f32(f.to_f32()); |

1539 | if f.is_nan() { |

1540 | roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() |

1541 | } else { |

1542 | f.0 == roundtrip.0 |

1543 | } |

1544 | } |

1545 | |

1546 | #[quickcheck] |

1547 | fn qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool { |

1548 | let roundtrip = bf16::from_f64(f.to_f64()); |

1549 | if f.is_nan() { |

1550 | roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() |

1551 | } else { |

1552 | f.0 == roundtrip.0 |

1553 | } |

1554 | } |

1555 | } |

1556 |