1 | // Adapted from https://github.com/Alexhuszagh/rust-lexical. |
---|---|

2 | |

3 | //! Defines rounding schemes for floating-point numbers. |

4 | |

5 | use super::float::ExtendedFloat; |

6 | use super::num::*; |

7 | use super::shift::*; |

8 | use core::mem; |

9 | |

10 | // MASKS |

11 | |

12 | /// Calculate a scalar factor of 2 above the halfway point. |

13 | #[inline] |

14 | pub(crate) fn nth_bit(n: u64) -> u64 { |

15 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |

16 | debug_assert!(n < bits, "nth_bit() overflow in shl."); |

17 | |

18 | 1 << n |

19 | } |

20 | |

21 | /// Generate a bitwise mask for the lower `n` bits. |

22 | #[inline] |

23 | pub(crate) fn lower_n_mask(n: u64) -> u64 { |

24 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |

25 | debug_assert!(n <= bits, "lower_n_mask() overflow in shl."); |

26 | |

27 | if n == bits { |

28 | u64::max_value() |

29 | } else { |

30 | (1 << n) - 1 |

31 | } |

32 | } |

33 | |

34 | /// Calculate the halfway point for the lower `n` bits. |

35 | #[inline] |

36 | pub(crate) fn lower_n_halfway(n: u64) -> u64 { |

37 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |

38 | debug_assert!(n <= bits, "lower_n_halfway() overflow in shl."); |

39 | |

40 | if n == 0 { |

41 | 0 |

42 | } else { |

43 | nth_bit(n - 1) |

44 | } |

45 | } |

46 | |

47 | /// Calculate a bitwise mask with `n` 1 bits starting at the `bit` position. |

48 | #[inline] |

49 | pub(crate) fn internal_n_mask(bit: u64, n: u64) -> u64 { |

50 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |

51 | debug_assert!(bit <= bits, "internal_n_halfway() overflow in shl."); |

52 | debug_assert!(n <= bits, "internal_n_halfway() overflow in shl."); |

53 | debug_assert!(bit >= n, "internal_n_halfway() overflow in sub."); |

54 | |

55 | lower_n_mask(bit) ^ lower_n_mask(bit - n) |

56 | } |

57 | |

58 | // NEAREST ROUNDING |

59 | |

60 | // Shift right N-bytes and round to the nearest. |

61 | // |

62 | // Return if we are above halfway and if we are halfway. |

63 | #[inline] |

64 | pub(crate) fn round_nearest(fp: &mut ExtendedFloat, shift: i32) -> (bool, bool) { |

65 | // Extract the truncated bits using mask. |

66 | // Calculate if the value of the truncated bits are either above |

67 | // the mid-way point, or equal to it. |

68 | // |

69 | // For example, for 4 truncated bytes, the mask would be b1111 |

70 | // and the midway point would be b1000. |

71 | let mask: u64 = lower_n_mask(shift as u64); |

72 | let halfway: u64 = lower_n_halfway(shift as u64); |

73 | |

74 | let truncated_bits = fp.mant & mask; |

75 | let is_above = truncated_bits > halfway; |

76 | let is_halfway = truncated_bits == halfway; |

77 | |

78 | // Bit shift so the leading bit is in the hidden bit. |

79 | overflowing_shr(fp, shift); |

80 | |

81 | (is_above, is_halfway) |

82 | } |

83 | |

84 | // Tie rounded floating point to event. |

85 | #[inline] |

86 | pub(crate) fn tie_even(fp: &mut ExtendedFloat, is_above: bool, is_halfway: bool) { |

87 | // Extract the last bit after shifting (and determine if it is odd). |

88 | let is_odd = fp.mant & 1 == 1; |

89 | |

90 | // Calculate if we need to roundup. |

91 | // We need to roundup if we are above halfway, or if we are odd |

92 | // and at half-way (need to tie-to-even). |

93 | if is_above || (is_odd && is_halfway) { |

94 | fp.mant += 1; |

95 | } |

96 | } |

97 | |

98 | // Shift right N-bytes and round nearest, tie-to-even. |

99 | // |

100 | // Floating-point arithmetic uses round to nearest, ties to even, |

101 | // which rounds to the nearest value, if the value is halfway in between, |

102 | // round to an even value. |

103 | #[inline] |

104 | pub(crate) fn round_nearest_tie_even(fp: &mut ExtendedFloat, shift: i32) { |

105 | let (is_above, is_halfway) = round_nearest(fp, shift); |

106 | tie_even(fp, is_above, is_halfway); |

107 | } |

108 | |

109 | // DIRECTED ROUNDING |

110 | |

111 | // Shift right N-bytes and round towards a direction. |

112 | // |

113 | // Return if we have any truncated bytes. |

114 | #[inline] |

115 | fn round_toward(fp: &mut ExtendedFloat, shift: i32) -> bool { |

116 | let mask: u64 = lower_n_mask(shift as u64); |

117 | let truncated_bits = fp.mant & mask; |

118 | |

119 | // Bit shift so the leading bit is in the hidden bit. |

120 | overflowing_shr(fp, shift); |

121 | |

122 | truncated_bits != 0 |

123 | } |

124 | |

125 | // Round down. |

126 | #[inline] |

127 | fn downard(_: &mut ExtendedFloat, _: bool) {} |

128 | |

129 | // Shift right N-bytes and round toward zero. |

130 | // |

131 | // Floating-point arithmetic defines round toward zero, which rounds |

132 | // towards positive zero. |

133 | #[inline] |

134 | pub(crate) fn round_downward(fp: &mut ExtendedFloat, shift: i32) { |

135 | // Bit shift so the leading bit is in the hidden bit. |

136 | // No rounding schemes, so we just ignore everything else. |

137 | let is_truncated = round_toward(fp, shift); |

138 | downard(fp, is_truncated); |

139 | } |

140 | |

141 | // ROUND TO FLOAT |

142 | |

143 | // Shift the ExtendedFloat fraction to the fraction bits in a native float. |

144 | // |

145 | // Floating-point arithmetic uses round to nearest, ties to even, |

146 | // which rounds to the nearest value, if the value is halfway in between, |

147 | // round to an even value. |

148 | #[inline] |

149 | pub(crate) fn round_to_float<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm) |

150 | where |

151 | F: Float, |

152 | Algorithm: FnOnce(&mut ExtendedFloat, i32), |

153 | { |

154 | // Calculate the difference to allow a single calculation |

155 | // rather than a loop, to minimize the number of ops required. |

156 | // This does underflow detection. |

157 | let final_exp = fp.exp + F::DEFAULT_SHIFT; |

158 | if final_exp < F::DENORMAL_EXPONENT { |

159 | // We would end up with a denormal exponent, try to round to more |

160 | // digits. Only shift right if we can avoid zeroing out the value, |

161 | // which requires the exponent diff to be < M::BITS. The value |

162 | // is already normalized, so we shouldn't have any issue zeroing |

163 | // out the value. |

164 | let diff = F::DENORMAL_EXPONENT - fp.exp; |

165 | if diff <= u64::FULL { |

166 | // We can avoid underflow, can get a valid representation. |

167 | algorithm(fp, diff); |

168 | } else { |

169 | // Certain underflow, assign literal 0s. |

170 | fp.mant = 0; |

171 | fp.exp = 0; |

172 | } |

173 | } else { |

174 | algorithm(fp, F::DEFAULT_SHIFT); |

175 | } |

176 | |

177 | if fp.mant & F::CARRY_MASK == F::CARRY_MASK { |

178 | // Roundup carried over to 1 past the hidden bit. |

179 | shr(fp, 1); |

180 | } |

181 | } |

182 | |

183 | // AVOID OVERFLOW/UNDERFLOW |

184 | |

185 | // Avoid overflow for large values, shift left as needed. |

186 | // |

187 | // Shift until a 1-bit is in the hidden bit, if the mantissa is not 0. |

188 | #[inline] |

189 | pub(crate) fn avoid_overflow<F>(fp: &mut ExtendedFloat) |

190 | where |

191 | F: Float, |

192 | { |

193 | // Calculate the difference to allow a single calculation |

194 | // rather than a loop, minimizing the number of ops required. |

195 | if fp.exp >= F::MAX_EXPONENT { |

196 | let diff = fp.exp - F::MAX_EXPONENT; |

197 | if diff <= F::MANTISSA_SIZE { |

198 | // Our overflow mask needs to start at the hidden bit, or at |

199 | // `F::MANTISSA_SIZE+1`, and needs to have `diff+1` bits set, |

200 | // to see if our value overflows. |

201 | let bit = (F::MANTISSA_SIZE + 1) as u64; |

202 | let n = (diff + 1) as u64; |

203 | let mask = internal_n_mask(bit, n); |

204 | if (fp.mant & mask) == 0 { |

205 | // If we have no 1-bit in the hidden-bit position, |

206 | // which is index 0, we need to shift 1. |

207 | let shift = diff + 1; |

208 | shl(fp, shift); |

209 | } |

210 | } |

211 | } |

212 | } |

213 | |

214 | // ROUND TO NATIVE |

215 | |

216 | // Round an extended-precision float to a native float representation. |

217 | #[inline] |

218 | pub(crate) fn round_to_native<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm) |

219 | where |

220 | F: Float, |

221 | Algorithm: FnOnce(&mut ExtendedFloat, i32), |

222 | { |

223 | // Shift all the way left, to ensure a consistent representation. |

224 | // The following right-shifts do not work for a non-normalized number. |

225 | fp.normalize(); |

226 | |

227 | // Round so the fraction is in a native mantissa representation, |

228 | // and avoid overflow/underflow. |

229 | round_to_float::<F, _>(fp, algorithm); |

230 | avoid_overflow::<F>(fp); |

231 | } |

232 |