1 | use core::convert::TryInto; |
2 | |
3 | pub use super::bit_reader::GetBitsError; |
4 | use crate::io::Read; |
5 | |
6 | /// Zstandard encodes some types of data in a way that the data must be read |
7 | /// back to front to decode it properly. `BitReaderReversed` provides a |
8 | /// convenient interface to do that. |
9 | pub struct BitReaderReversed<'s> { |
10 | idx: isize, //index counts bits already read |
11 | source: &'s [u8], |
12 | /// The reader doesn't read directly from the source, |
13 | /// it reads bits from here, and the container is |
14 | /// "refilled" as it's emptied. |
15 | bit_container: u64, |
16 | bits_in_container: u8, |
17 | } |
18 | |
19 | impl<'s> BitReaderReversed<'s> { |
20 | /// How many bits are left to read by the reader. |
21 | pub fn bits_remaining(&self) -> isize { |
22 | self.idx + self.bits_in_container as isize |
23 | } |
24 | |
25 | pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> { |
26 | BitReaderReversed { |
27 | idx: source.len() as isize * 8, |
28 | source, |
29 | bit_container: 0, |
30 | bits_in_container: 0, |
31 | } |
32 | } |
33 | |
34 | /// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data |
35 | #[inline (always)] |
36 | fn refill_container(&mut self) { |
37 | let byte_idx = self.byte_idx() as usize; |
38 | |
39 | let retain_bytes = (self.bits_in_container + 7) / 8; |
40 | let want_to_read_bits = 64 - (retain_bytes * 8); |
41 | |
42 | // if there are >= 8 byte left to read we go a fast path: |
43 | // The slice is looking something like this |U..UCCCCCCCCR..R| Where U are some unread bytes, C are the bytes in the container, and R are already read bytes |
44 | // What we do is, we shift the container by a few bytes to the left by just reading a u64 from the correct position, rereading the portion we did not yet return from the conainer. |
45 | // Technically this would still work for positions lower than 8 but this guarantees that enough bytes are in the source and generally makes for less edge cases |
46 | if byte_idx >= 8 { |
47 | self.refill_fast(byte_idx, retain_bytes, want_to_read_bits) |
48 | } else { |
49 | // In the slow path we just read however many bytes we can |
50 | self.refill_slow(byte_idx, want_to_read_bits) |
51 | } |
52 | } |
53 | |
54 | #[inline (always)] |
55 | fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) { |
56 | let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize; |
57 | let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8]) |
58 | .try_into() |
59 | .unwrap(); |
60 | let refill = u64::from_le_bytes(tmp_bytes); |
61 | self.bit_container = refill; |
62 | self.bits_in_container += want_to_read_bits; |
63 | self.idx -= want_to_read_bits as isize; |
64 | } |
65 | |
66 | #[cold ] |
67 | fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) { |
68 | let can_read_bits = isize::min(want_to_read_bits as isize, self.idx); |
69 | let can_read_bytes = can_read_bits / 8; |
70 | let mut tmp_bytes = [0u8; 8]; |
71 | let offset @ 1..=8 = can_read_bytes as usize else { |
72 | unreachable!() |
73 | }; |
74 | let bits_read = offset * 8; |
75 | |
76 | let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]); |
77 | self.bits_in_container += bits_read as u8; |
78 | self.idx -= bits_read as isize; |
79 | if offset < 8 { |
80 | self.bit_container <<= bits_read; |
81 | self.bit_container |= u64::from_le_bytes(tmp_bytes); |
82 | } else { |
83 | self.bit_container = u64::from_le_bytes(tmp_bytes); |
84 | } |
85 | } |
86 | |
87 | /// Next byte that should be read into the container |
88 | /// Negative values mean that the source buffer as been read into the container completetly. |
89 | fn byte_idx(&self) -> isize { |
90 | (self.idx - 1) / 8 |
91 | } |
92 | |
93 | /// Read `n` number of bits from the source. Will read at most 56 bits. |
94 | /// If there are no more bits to be read from the source zero bits will be returned instead. |
95 | #[inline (always)] |
96 | pub fn get_bits(&mut self, n: u8) -> u64 { |
97 | if n == 0 { |
98 | return 0; |
99 | } |
100 | if self.bits_in_container >= n { |
101 | return self.get_bits_unchecked(n); |
102 | } |
103 | |
104 | self.get_bits_cold(n) |
105 | } |
106 | |
107 | #[cold ] |
108 | fn get_bits_cold(&mut self, n: u8) -> u64 { |
109 | let n = u8::min(n, 56); |
110 | let signed_n = n as isize; |
111 | |
112 | if self.bits_remaining() <= 0 { |
113 | self.idx -= signed_n; |
114 | return 0; |
115 | } |
116 | |
117 | if self.bits_remaining() < signed_n { |
118 | let emulated_read_shift = signed_n - self.bits_remaining(); |
119 | let v = self.get_bits(self.bits_remaining() as u8); |
120 | debug_assert!(self.idx == 0); |
121 | let value = v.wrapping_shl(emulated_read_shift as u32); |
122 | self.idx -= emulated_read_shift; |
123 | return value; |
124 | } |
125 | |
126 | while (self.bits_in_container < n) && self.idx > 0 { |
127 | self.refill_container(); |
128 | } |
129 | |
130 | debug_assert!(self.bits_in_container >= n); |
131 | |
132 | //if we reach this point there are enough bits in the container |
133 | |
134 | self.get_bits_unchecked(n) |
135 | } |
136 | |
137 | /// Same as calling get_bits three times but slightly more performant |
138 | #[inline (always)] |
139 | pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) { |
140 | let sum = n1 as usize + n2 as usize + n3 as usize; |
141 | if sum == 0 { |
142 | return (0, 0, 0); |
143 | } |
144 | if sum > 56 { |
145 | // try and get the values separately |
146 | return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); |
147 | } |
148 | let sum = sum as u8; |
149 | |
150 | if self.bits_in_container >= sum { |
151 | let v1 = if n1 == 0 { |
152 | 0 |
153 | } else { |
154 | self.get_bits_unchecked(n1) |
155 | }; |
156 | let v2 = if n2 == 0 { |
157 | 0 |
158 | } else { |
159 | self.get_bits_unchecked(n2) |
160 | }; |
161 | let v3 = if n3 == 0 { |
162 | 0 |
163 | } else { |
164 | self.get_bits_unchecked(n3) |
165 | }; |
166 | |
167 | return (v1, v2, v3); |
168 | } |
169 | |
170 | self.get_bits_triple_cold(n1, n2, n3, sum) |
171 | } |
172 | |
173 | #[cold ] |
174 | fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) { |
175 | let sum_signed = sum as isize; |
176 | |
177 | if self.bits_remaining() <= 0 { |
178 | self.idx -= sum_signed; |
179 | return (0, 0, 0); |
180 | } |
181 | |
182 | if self.bits_remaining() < sum_signed { |
183 | return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); |
184 | } |
185 | |
186 | while (self.bits_in_container < sum) && self.idx > 0 { |
187 | self.refill_container(); |
188 | } |
189 | |
190 | debug_assert!(self.bits_in_container >= sum); |
191 | |
192 | //if we reach this point there are enough bits in the container |
193 | |
194 | let v1 = if n1 == 0 { |
195 | 0 |
196 | } else { |
197 | self.get_bits_unchecked(n1) |
198 | }; |
199 | let v2 = if n2 == 0 { |
200 | 0 |
201 | } else { |
202 | self.get_bits_unchecked(n2) |
203 | }; |
204 | let v3 = if n3 == 0 { |
205 | 0 |
206 | } else { |
207 | self.get_bits_unchecked(n3) |
208 | }; |
209 | |
210 | (v1, v2, v3) |
211 | } |
212 | |
213 | #[inline (always)] |
214 | fn get_bits_unchecked(&mut self, n: u8) -> u64 { |
215 | let shift_by = self.bits_in_container - n; |
216 | let mask = (1u64 << n) - 1u64; |
217 | |
218 | let value = self.bit_container >> shift_by; |
219 | self.bits_in_container -= n; |
220 | let value_masked = value & mask; |
221 | debug_assert!(value_masked < (1 << n)); |
222 | |
223 | value_masked |
224 | } |
225 | |
226 | pub fn reset(&mut self, new_source: &'s [u8]) { |
227 | self.idx = new_source.len() as isize * 8; |
228 | self.source = new_source; |
229 | self.bit_container = 0; |
230 | self.bits_in_container = 0; |
231 | } |
232 | } |
233 | |