1// Adapted from https://github.com/Alexhuszagh/rust-lexical.
2
3//! Defines rounding schemes for floating-point numbers.
4
5use super::float::ExtendedFloat;
6use super::num::*;
7use super::shift::*;
8use core::mem;
9
10// MASKS
11
12/// Calculate a scalar factor of 2 above the halfway point.
13#[inline]
14pub(crate) fn nth_bit(n: u64) -> u64 {
15 let bits: u64 = mem::size_of::<u64>() as u64 * 8;
16 debug_assert!(n < bits, "nth_bit() overflow in shl.");
17
18 1 << n
19}
20
21/// Generate a bitwise mask for the lower `n` bits.
22#[inline]
23pub(crate) fn lower_n_mask(n: u64) -> u64 {
24 let bits: u64 = mem::size_of::<u64>() as u64 * 8;
25 debug_assert!(n <= bits, "lower_n_mask() overflow in shl.");
26
27 if n == bits {
28 u64::max_value()
29 } else {
30 (1 << n) - 1
31 }
32}
33
34/// Calculate the halfway point for the lower `n` bits.
35#[inline]
36pub(crate) fn lower_n_halfway(n: u64) -> u64 {
37 let bits: u64 = mem::size_of::<u64>() as u64 * 8;
38 debug_assert!(n <= bits, "lower_n_halfway() overflow in shl.");
39
40 if n == 0 {
41 0
42 } else {
43 nth_bit(n - 1)
44 }
45}
46
47/// Calculate a bitwise mask with `n` 1 bits starting at the `bit` position.
48#[inline]
49pub(crate) fn internal_n_mask(bit: u64, n: u64) -> u64 {
50 let bits: u64 = mem::size_of::<u64>() as u64 * 8;
51 debug_assert!(bit <= bits, "internal_n_halfway() overflow in shl.");
52 debug_assert!(n <= bits, "internal_n_halfway() overflow in shl.");
53 debug_assert!(bit >= n, "internal_n_halfway() overflow in sub.");
54
55 lower_n_mask(bit) ^ lower_n_mask(bit - n)
56}
57
58// NEAREST ROUNDING
59
60// Shift right N-bytes and round to the nearest.
61//
62// Return if we are above halfway and if we are halfway.
63#[inline]
64pub(crate) fn round_nearest(fp: &mut ExtendedFloat, shift: i32) -> (bool, bool) {
65 // Extract the truncated bits using mask.
66 // Calculate if the value of the truncated bits are either above
67 // the mid-way point, or equal to it.
68 //
69 // For example, for 4 truncated bytes, the mask would be b1111
70 // and the midway point would be b1000.
71 let mask: u64 = lower_n_mask(shift as u64);
72 let halfway: u64 = lower_n_halfway(shift as u64);
73
74 let truncated_bits = fp.mant & mask;
75 let is_above = truncated_bits > halfway;
76 let is_halfway = truncated_bits == halfway;
77
78 // Bit shift so the leading bit is in the hidden bit.
79 overflowing_shr(fp, shift);
80
81 (is_above, is_halfway)
82}
83
84// Tie rounded floating point to event.
85#[inline]
86pub(crate) fn tie_even(fp: &mut ExtendedFloat, is_above: bool, is_halfway: bool) {
87 // Extract the last bit after shifting (and determine if it is odd).
88 let is_odd = fp.mant & 1 == 1;
89
90 // Calculate if we need to roundup.
91 // We need to roundup if we are above halfway, or if we are odd
92 // and at half-way (need to tie-to-even).
93 if is_above || (is_odd && is_halfway) {
94 fp.mant += 1;
95 }
96}
97
98// Shift right N-bytes and round nearest, tie-to-even.
99//
100// Floating-point arithmetic uses round to nearest, ties to even,
101// which rounds to the nearest value, if the value is halfway in between,
102// round to an even value.
103#[inline]
104pub(crate) fn round_nearest_tie_even(fp: &mut ExtendedFloat, shift: i32) {
105 let (is_above, is_halfway) = round_nearest(fp, shift);
106 tie_even(fp, is_above, is_halfway);
107}
108
109// DIRECTED ROUNDING
110
111// Shift right N-bytes and round towards a direction.
112//
113// Return if we have any truncated bytes.
114#[inline]
115fn round_toward(fp: &mut ExtendedFloat, shift: i32) -> bool {
116 let mask: u64 = lower_n_mask(shift as u64);
117 let truncated_bits = fp.mant & mask;
118
119 // Bit shift so the leading bit is in the hidden bit.
120 overflowing_shr(fp, shift);
121
122 truncated_bits != 0
123}
124
125// Round down.
126#[inline]
127fn downard(_: &mut ExtendedFloat, _: bool) {}
128
129// Shift right N-bytes and round toward zero.
130//
131// Floating-point arithmetic defines round toward zero, which rounds
132// towards positive zero.
133#[inline]
134pub(crate) fn round_downward(fp: &mut ExtendedFloat, shift: i32) {
135 // Bit shift so the leading bit is in the hidden bit.
136 // No rounding schemes, so we just ignore everything else.
137 let is_truncated = round_toward(fp, shift);
138 downard(fp, is_truncated);
139}
140
141// ROUND TO FLOAT
142
143// Shift the ExtendedFloat fraction to the fraction bits in a native float.
144//
145// Floating-point arithmetic uses round to nearest, ties to even,
146// which rounds to the nearest value, if the value is halfway in between,
147// round to an even value.
148#[inline]
149pub(crate) fn round_to_float<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm)
150where
151 F: Float,
152 Algorithm: FnOnce(&mut ExtendedFloat, i32),
153{
154 // Calculate the difference to allow a single calculation
155 // rather than a loop, to minimize the number of ops required.
156 // This does underflow detection.
157 let final_exp = fp.exp + F::DEFAULT_SHIFT;
158 if final_exp < F::DENORMAL_EXPONENT {
159 // We would end up with a denormal exponent, try to round to more
160 // digits. Only shift right if we can avoid zeroing out the value,
161 // which requires the exponent diff to be < M::BITS. The value
162 // is already normalized, so we shouldn't have any issue zeroing
163 // out the value.
164 let diff = F::DENORMAL_EXPONENT - fp.exp;
165 if diff <= u64::FULL {
166 // We can avoid underflow, can get a valid representation.
167 algorithm(fp, diff);
168 } else {
169 // Certain underflow, assign literal 0s.
170 fp.mant = 0;
171 fp.exp = 0;
172 }
173 } else {
174 algorithm(fp, F::DEFAULT_SHIFT);
175 }
176
177 if fp.mant & F::CARRY_MASK == F::CARRY_MASK {
178 // Roundup carried over to 1 past the hidden bit.
179 shr(fp, 1);
180 }
181}
182
183// AVOID OVERFLOW/UNDERFLOW
184
185// Avoid overflow for large values, shift left as needed.
186//
187// Shift until a 1-bit is in the hidden bit, if the mantissa is not 0.
188#[inline]
189pub(crate) fn avoid_overflow<F>(fp: &mut ExtendedFloat)
190where
191 F: Float,
192{
193 // Calculate the difference to allow a single calculation
194 // rather than a loop, minimizing the number of ops required.
195 if fp.exp >= F::MAX_EXPONENT {
196 let diff = fp.exp - F::MAX_EXPONENT;
197 if diff <= F::MANTISSA_SIZE {
198 // Our overflow mask needs to start at the hidden bit, or at
199 // `F::MANTISSA_SIZE+1`, and needs to have `diff+1` bits set,
200 // to see if our value overflows.
201 let bit = (F::MANTISSA_SIZE + 1) as u64;
202 let n = (diff + 1) as u64;
203 let mask = internal_n_mask(bit, n);
204 if (fp.mant & mask) == 0 {
205 // If we have no 1-bit in the hidden-bit position,
206 // which is index 0, we need to shift 1.
207 let shift = diff + 1;
208 shl(fp, shift);
209 }
210 }
211 }
212}
213
214// ROUND TO NATIVE
215
216// Round an extended-precision float to a native float representation.
217#[inline]
218pub(crate) fn round_to_native<F, Algorithm>(fp: &mut ExtendedFloat, algorithm: Algorithm)
219where
220 F: Float,
221 Algorithm: FnOnce(&mut ExtendedFloat, i32),
222{
223 // Shift all the way left, to ensure a consistent representation.
224 // The following right-shifts do not work for a non-normalized number.
225 fp.normalize();
226
227 // Round so the fraction is in a native mantissa representation,
228 // and avoid overflow/underflow.
229 round_to_float::<F, _>(fp, algorithm);
230 avoid_overflow::<F>(fp);
231}
232