1 | // Adapted from https://github.com/Alexhuszagh/rust-lexical. |
2 | |
3 | //! Defines rounding schemes for floating-point numbers. |
4 | |
5 | use super::float::ExtendedFloat; |
6 | use super::num::*; |
7 | use super::shift::*; |
8 | use core::mem; |
9 | |
10 | // MASKS |
11 | |
12 | /// Calculate a scalar factor of 2 above the halfway point. |
13 | #[inline ] |
14 | pub(crate) fn nth_bit(n: u64) -> u64 { |
15 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |
16 | debug_assert!(n < bits, "nth_bit() overflow in shl." ); |
17 | |
18 | 1 << n |
19 | } |
20 | |
21 | /// Generate a bitwise mask for the lower `n` bits. |
22 | #[inline ] |
23 | pub(crate) fn lower_n_mask(n: u64) -> u64 { |
24 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |
25 | debug_assert!(n <= bits, "lower_n_mask() overflow in shl." ); |
26 | |
27 | if n == bits { |
28 | u64::max_value() |
29 | } else { |
30 | (1 << n) - 1 |
31 | } |
32 | } |
33 | |
34 | /// Calculate the halfway point for the lower `n` bits. |
35 | #[inline ] |
36 | pub(crate) fn lower_n_halfway(n: u64) -> u64 { |
37 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |
38 | debug_assert!(n <= bits, "lower_n_halfway() overflow in shl." ); |
39 | |
40 | if n == 0 { |
41 | 0 |
42 | } else { |
43 | nth_bit(n - 1) |
44 | } |
45 | } |
46 | |
47 | /// Calculate a bitwise mask with `n` 1 bits starting at the `bit` position. |
48 | #[inline ] |
49 | pub(crate) fn internal_n_mask(bit: u64, n: u64) -> u64 { |
50 | let bits: u64 = mem::size_of::<u64>() as u64 * 8; |
51 | debug_assert!(bit <= bits, "internal_n_halfway() overflow in shl." ); |
52 | debug_assert!(n <= bits, "internal_n_halfway() overflow in shl." ); |
53 | debug_assert!(bit >= n, "internal_n_halfway() overflow in sub." ); |
54 | |
55 | lower_n_mask(bit) ^ lower_n_mask(bit - n) |
56 | } |
57 | |
58 | // NEAREST ROUNDING |
59 | |
60 | // Shift right N-bytes and round to the nearest. |
61 | // |
62 | // Return if we are above halfway and if we are halfway. |
63 | #[inline ] |
64 | pub(crate) fn round_nearest(fp: &mut ExtendedFloat, shift: i32) -> (bool, bool) { |
65 | // Extract the truncated bits using mask. |
66 | // Calculate if the value of the truncated bits are either above |
67 | // the mid-way point, or equal to it. |
68 | // |
69 | // For example, for 4 truncated bytes, the mask would be b1111 |
70 | // and the midway point would be b1000. |
71 | let mask: u64 = lower_n_mask(shift as u64); |
72 | let halfway: u64 = lower_n_halfway(shift as u64); |
73 | |
74 | let truncated_bits = fp.mant & mask; |
75 | let is_above = truncated_bits > halfway; |
76 | let is_halfway = truncated_bits == halfway; |
77 | |
78 | // Bit shift so the leading bit is in the hidden bit. |
79 | overflowing_shr(fp, shift); |
80 | |
81 | (is_above, is_halfway) |
82 | } |
83 | |
84 | // Tie rounded floating point to event. |
85 | #[inline ] |
86 | pub(crate) fn tie_even(fp: &mut ExtendedFloat, is_above: bool, is_halfway: bool) { |
87 | // Extract the last bit after shifting (and determine if it is odd). |
88 | let is_odd = fp.mant & 1 == 1; |
89 | |
90 | // Calculate if we need to roundup. |
91 | // We need to roundup if we are above halfway, or if we are odd |
92 | // and at half-way (need to tie-to-even). |
93 | if is_above || (is_odd && is_halfway) { |
94 | fp.mant += 1; |
95 | } |
96 | } |
97 | |
98 | // Shift right N-bytes and round nearest, tie-to-even. |
99 | // |
100 | // Floating-point arithmetic uses round to nearest, ties to even, |
101 | // which rounds to the nearest value, if the value is halfway in between, |
102 | // round to an even value. |
103 | #[inline ] |
104 | pub(crate) fn round_nearest_tie_even(fp: &mut ExtendedFloat, shift: i32) { |
105 | let (is_above, is_halfway) = round_nearest(fp, shift); |
106 | tie_even(fp, is_above, is_halfway); |
107 | } |
108 | |
109 | // DIRECTED ROUNDING |
110 | |
111 | // Shift right N-bytes and round towards a direction. |
112 | // |
113 | // Return if we have any truncated bytes. |
114 | #[inline ] |
115 | fn round_toward(fp: &mut ExtendedFloat, shift: i32) -> bool { |
116 | let mask: u64 = lower_n_mask(shift as u64); |
117 | let truncated_bits = fp.mant & mask; |
118 | |
119 | // Bit shift so the leading bit is in the hidden bit. |
120 | overflowing_shr(fp, shift); |
121 | |
122 | truncated_bits != 0 |
123 | } |
124 | |
125 | // Round down. |
126 | #[inline ] |
127 | fn downard(_: &mut ExtendedFloat, _: bool) {} |
128 | |
129 | // Shift right N-bytes and round toward zero. |
130 | // |
131 | // Floating-point arithmetic defines round toward zero, which rounds |
132 | // towards positive zero. |
133 | #[inline ] |
134 | pub(crate) fn round_downward(fp: &mut ExtendedFloat, shift: i32) { |
135 | // Bit shift so the leading bit is in the hidden bit. |
136 | // No rounding schemes, so we just ignore everything else. |
137 | let is_truncated = round_toward(fp, shift); |
138 | downard(fp, is_truncated); |
139 | } |
140 | |
141 | // ROUND TO FLOAT |
142 | |
143 | // Shift the ExtendedFloat fraction to the fraction bits in a native float. |
144 | // |
145 | // Floating-point arithmetic uses round to nearest, ties to even, |
146 | // which rounds to the nearest value, if the value is halfway in between, |
147 | // round to an even value. |
148 | #[inline ] |
149 | pub(crate) fn round_to_float<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm) |
150 | where |
151 | F: Float, |
152 | Algorithm: FnOnce(&mut ExtendedFloat, i32), |
153 | { |
154 | // Calculate the difference to allow a single calculation |
155 | // rather than a loop, to minimize the number of ops required. |
156 | // This does underflow detection. |
157 | let final_exp = fp.exp + F::DEFAULT_SHIFT; |
158 | if final_exp < F::DENORMAL_EXPONENT { |
159 | // We would end up with a denormal exponent, try to round to more |
160 | // digits. Only shift right if we can avoid zeroing out the value, |
161 | // which requires the exponent diff to be < M::BITS. The value |
162 | // is already normalized, so we shouldn't have any issue zeroing |
163 | // out the value. |
164 | let diff = F::DENORMAL_EXPONENT - fp.exp; |
165 | if diff <= u64::FULL { |
166 | // We can avoid underflow, can get a valid representation. |
167 | algorithm(fp, diff); |
168 | } else { |
169 | // Certain underflow, assign literal 0s. |
170 | fp.mant = 0; |
171 | fp.exp = 0; |
172 | } |
173 | } else { |
174 | algorithm(fp, F::DEFAULT_SHIFT); |
175 | } |
176 | |
177 | if fp.mant & F::CARRY_MASK == F::CARRY_MASK { |
178 | // Roundup carried over to 1 past the hidden bit. |
179 | shr(fp, 1); |
180 | } |
181 | } |
182 | |
183 | // AVOID OVERFLOW/UNDERFLOW |
184 | |
185 | // Avoid overflow for large values, shift left as needed. |
186 | // |
187 | // Shift until a 1-bit is in the hidden bit, if the mantissa is not 0. |
188 | #[inline ] |
189 | pub(crate) fn avoid_overflow<F>(fp: &mut ExtendedFloat) |
190 | where |
191 | F: Float, |
192 | { |
193 | // Calculate the difference to allow a single calculation |
194 | // rather than a loop, minimizing the number of ops required. |
195 | if fp.exp >= F::MAX_EXPONENT { |
196 | let diff = fp.exp - F::MAX_EXPONENT; |
197 | if diff <= F::MANTISSA_SIZE { |
198 | // Our overflow mask needs to start at the hidden bit, or at |
199 | // `F::MANTISSA_SIZE+1`, and needs to have `diff+1` bits set, |
200 | // to see if our value overflows. |
201 | let bit = (F::MANTISSA_SIZE + 1) as u64; |
202 | let n = (diff + 1) as u64; |
203 | let mask = internal_n_mask(bit, n); |
204 | if (fp.mant & mask) == 0 { |
205 | // If we have no 1-bit in the hidden-bit position, |
206 | // which is index 0, we need to shift 1. |
207 | let shift = diff + 1; |
208 | shl(fp, shift); |
209 | } |
210 | } |
211 | } |
212 | } |
213 | |
214 | // ROUND TO NATIVE |
215 | |
216 | // Round an extended-precision float to a native float representation. |
217 | #[inline ] |
218 | pub(crate) fn round_to_native<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm) |
219 | where |
220 | F: Float, |
221 | Algorithm: FnOnce(&mut ExtendedFloat, i32), |
222 | { |
223 | // Shift all the way left, to ensure a consistent representation. |
224 | // The following right-shifts do not work for a non-normalized number. |
225 | fp.normalize(); |
226 | |
227 | // Round so the fraction is in a native mantissa representation, |
228 | // and avoid overflow/underflow. |
229 | round_to_float::<F, _>(fp, algorithm); |
230 | avoid_overflow::<F>(fp); |
231 | } |
232 | |