1 | use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
2 | use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError}; |
3 | use alloc::vec::Vec; |
4 | |
5 | pub struct HuffmanTable { |
6 | decode: Vec<Entry>, |
7 | |
8 | weights: Vec<u8>, |
9 | pub max_num_bits: u8, |
10 | bits: Vec<u8>, |
11 | bit_ranks: Vec<u32>, |
12 | rank_indexes: Vec<usize>, |
13 | |
14 | fse_table: FSETable, |
15 | } |
16 | |
17 | #[derive (Debug, derive_more::Display, derive_more::From)] |
18 | #[cfg_attr (feature = "std" , derive(derive_more::Error))] |
19 | #[non_exhaustive ] |
20 | pub enum HuffmanTableError { |
21 | #[display(fmt = "{_0:?}" )] |
22 | #[from] |
23 | GetBitsError(GetBitsError), |
24 | #[display(fmt = "{_0:?}" )] |
25 | #[from] |
26 | FSEDecoderError(FSEDecoderError), |
27 | #[display(fmt = "{_0:?}" )] |
28 | #[from] |
29 | FSETableError(FSETableError), |
30 | #[display(fmt = "Source needs to have at least one byte" )] |
31 | SourceIsEmpty, |
32 | #[display( |
33 | fmt = "Header says there should be {expected_bytes} bytes for the weights but there are only {got_bytes} bytes in the stream" |
34 | )] |
35 | NotEnoughBytesForWeights { |
36 | got_bytes: usize, |
37 | expected_bytes: u8, |
38 | }, |
39 | #[display( |
40 | fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption" |
41 | )] |
42 | ExtraPadding { skipped_bits: i32 }, |
43 | #[display( |
44 | fmt = "More than 255 weights decoded (got {got} weights). Stream is probably corrupted" |
45 | )] |
46 | TooManyWeights { got: usize }, |
47 | #[display(fmt = "Can't build huffman table without any weights" )] |
48 | MissingWeights, |
49 | #[display(fmt = "Leftover must be power of two but is: {got}" )] |
50 | LeftoverIsNotAPowerOf2 { got: u32 }, |
51 | #[display( |
52 | fmt = "Not enough bytes in stream to decompress weights. Is: {have}, Should be: {need}" |
53 | )] |
54 | NotEnoughBytesToDecompressWeights { have: usize, need: usize }, |
55 | #[display( |
56 | fmt = "FSE table used more bytes: {used} than were meant to be used for the whole stream of huffman weights ({available_bytes})" |
57 | )] |
58 | FSETableUsedTooManyBytes { used: usize, available_bytes: u8 }, |
59 | #[display(fmt = "Source needs to have at least {need} bytes, got: {got}" )] |
60 | NotEnoughBytesInSource { got: usize, need: usize }, |
61 | #[display(fmt = "Cant have weight: {got} bigger than max_num_bits: {MAX_MAX_NUM_BITS}" )] |
62 | WeightBiggerThanMaxNumBits { got: u8 }, |
63 | #[display( |
64 | fmt = "max_bits derived from weights is: {got} should be lower than: {MAX_MAX_NUM_BITS}" |
65 | )] |
66 | MaxBitsTooHigh { got: u8 }, |
67 | } |
68 | |
69 | pub struct HuffmanDecoder<'table> { |
70 | table: &'table HuffmanTable, |
71 | pub state: u64, |
72 | } |
73 | |
74 | #[derive (Debug, derive_more::Display, derive_more::From)] |
75 | #[cfg_attr (feature = "std" , derive(derive_more::Error))] |
76 | #[non_exhaustive ] |
77 | pub enum HuffmanDecoderError { |
78 | #[display(fmt = "{_0:?}" )] |
79 | #[from] |
80 | GetBitsError(GetBitsError), |
81 | } |
82 | |
83 | #[derive (Copy, Clone)] |
84 | pub struct Entry { |
85 | symbol: u8, |
86 | num_bits: u8, |
87 | } |
88 | |
89 | const MAX_MAX_NUM_BITS: u8 = 11; |
90 | |
91 | fn highest_bit_set(x: u32) -> u32 { |
92 | assert!(x > 0); |
93 | u32::BITS - x.leading_zeros() |
94 | } |
95 | |
96 | impl<'t> HuffmanDecoder<'t> { |
97 | pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> { |
98 | HuffmanDecoder { table, state: 0 } |
99 | } |
100 | |
101 | pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) { |
102 | self.state = 0; |
103 | if let Some(next_table) = new_table { |
104 | self.table = next_table; |
105 | } |
106 | } |
107 | |
108 | pub fn decode_symbol(&mut self) -> u8 { |
109 | self.table.decode[self.state as usize].symbol |
110 | } |
111 | |
112 | pub fn init_state( |
113 | &mut self, |
114 | br: &mut BitReaderReversed<'_>, |
115 | ) -> Result<u8, HuffmanDecoderError> { |
116 | let num_bits = self.table.max_num_bits; |
117 | let new_bits = br.get_bits(num_bits)?; |
118 | self.state = new_bits; |
119 | Ok(num_bits) |
120 | } |
121 | |
122 | pub fn next_state( |
123 | &mut self, |
124 | br: &mut BitReaderReversed<'_>, |
125 | ) -> Result<u8, HuffmanDecoderError> { |
126 | let num_bits = self.table.decode[self.state as usize].num_bits; |
127 | let new_bits = br.get_bits(num_bits)?; |
128 | self.state <<= num_bits; |
129 | self.state &= self.table.decode.len() as u64 - 1; |
130 | self.state |= new_bits; |
131 | Ok(num_bits) |
132 | } |
133 | } |
134 | |
135 | impl Default for HuffmanTable { |
136 | fn default() -> Self { |
137 | Self::new() |
138 | } |
139 | } |
140 | |
141 | impl HuffmanTable { |
142 | pub fn new() -> HuffmanTable { |
143 | HuffmanTable { |
144 | decode: Vec::new(), |
145 | |
146 | weights: Vec::with_capacity(256), |
147 | max_num_bits: 0, |
148 | bits: Vec::with_capacity(256), |
149 | bit_ranks: Vec::with_capacity(11), |
150 | rank_indexes: Vec::with_capacity(11), |
151 | fse_table: FSETable::new(), |
152 | } |
153 | } |
154 | |
155 | pub fn reinit_from(&mut self, other: &Self) { |
156 | self.reset(); |
157 | self.decode.extend_from_slice(&other.decode); |
158 | self.weights.extend_from_slice(&other.weights); |
159 | self.max_num_bits = other.max_num_bits; |
160 | self.bits.extend_from_slice(&other.bits); |
161 | self.rank_indexes.extend_from_slice(&other.rank_indexes); |
162 | self.fse_table.reinit_from(&other.fse_table); |
163 | } |
164 | |
165 | pub fn reset(&mut self) { |
166 | self.decode.clear(); |
167 | self.weights.clear(); |
168 | self.max_num_bits = 0; |
169 | self.bits.clear(); |
170 | self.bit_ranks.clear(); |
171 | self.rank_indexes.clear(); |
172 | self.fse_table.reset(); |
173 | } |
174 | |
175 | pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> { |
176 | self.decode.clear(); |
177 | |
178 | let bytes_used = self.read_weights(source)?; |
179 | self.build_table_from_weights()?; |
180 | Ok(bytes_used) |
181 | } |
182 | |
183 | fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> { |
184 | use HuffmanTableError as err; |
185 | |
186 | if source.is_empty() { |
187 | return Err(err::SourceIsEmpty); |
188 | } |
189 | let header = source[0]; |
190 | let mut bits_read = 8; |
191 | |
192 | match header { |
193 | 0..=127 => { |
194 | let fse_stream = &source[1..]; |
195 | if header as usize > fse_stream.len() { |
196 | return Err(err::NotEnoughBytesForWeights { |
197 | got_bytes: fse_stream.len(), |
198 | expected_bytes: header, |
199 | }); |
200 | } |
201 | //fse decompress weights |
202 | let bytes_used_by_fse_header = self |
203 | .fse_table |
204 | .build_decoder(fse_stream, /*TODO find actual max*/ 100)?; |
205 | |
206 | if bytes_used_by_fse_header > header as usize { |
207 | return Err(err::FSETableUsedTooManyBytes { |
208 | used: bytes_used_by_fse_header, |
209 | available_bytes: header, |
210 | }); |
211 | } |
212 | |
213 | vprintln!( |
214 | "Building fse table for huffman weights used: {}" , |
215 | bytes_used_by_fse_header |
216 | ); |
217 | let mut dec1 = FSEDecoder::new(&self.fse_table); |
218 | let mut dec2 = FSEDecoder::new(&self.fse_table); |
219 | |
220 | let compressed_start = bytes_used_by_fse_header; |
221 | let compressed_length = header as usize - bytes_used_by_fse_header; |
222 | |
223 | let compressed_weights = &fse_stream[compressed_start..]; |
224 | if compressed_weights.len() < compressed_length { |
225 | return Err(err::NotEnoughBytesToDecompressWeights { |
226 | have: compressed_weights.len(), |
227 | need: compressed_length, |
228 | }); |
229 | } |
230 | let compressed_weights = &compressed_weights[..compressed_length]; |
231 | let mut br = BitReaderReversed::new(compressed_weights); |
232 | |
233 | bits_read += (bytes_used_by_fse_header + compressed_length) * 8; |
234 | |
235 | //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found |
236 | let mut skipped_bits = 0; |
237 | loop { |
238 | let val = br.get_bits(1)?; |
239 | skipped_bits += 1; |
240 | if val == 1 || skipped_bits > 8 { |
241 | break; |
242 | } |
243 | } |
244 | if skipped_bits > 8 { |
245 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
246 | return Err(err::ExtraPadding { skipped_bits }); |
247 | } |
248 | |
249 | dec1.init_state(&mut br)?; |
250 | dec2.init_state(&mut br)?; |
251 | |
252 | self.weights.clear(); |
253 | |
254 | loop { |
255 | let w = dec1.decode_symbol(); |
256 | self.weights.push(w); |
257 | dec1.update_state(&mut br)?; |
258 | |
259 | if br.bits_remaining() <= -1 { |
260 | //collect final states |
261 | self.weights.push(dec2.decode_symbol()); |
262 | break; |
263 | } |
264 | |
265 | let w = dec2.decode_symbol(); |
266 | self.weights.push(w); |
267 | dec2.update_state(&mut br)?; |
268 | |
269 | if br.bits_remaining() <= -1 { |
270 | //collect final states |
271 | self.weights.push(dec1.decode_symbol()); |
272 | break; |
273 | } |
274 | //maximum number of weights is 255 because we use u8 symbols and the last weight is inferred from the sum of all others |
275 | if self.weights.len() > 255 { |
276 | return Err(err::TooManyWeights { |
277 | got: self.weights.len(), |
278 | }); |
279 | } |
280 | } |
281 | } |
282 | _ => { |
283 | // weights are directly encoded |
284 | let weights_raw = &source[1..]; |
285 | let num_weights = header - 127; |
286 | self.weights.resize(num_weights as usize, 0); |
287 | |
288 | let bytes_needed = if num_weights % 2 == 0 { |
289 | num_weights as usize / 2 |
290 | } else { |
291 | (num_weights as usize / 2) + 1 |
292 | }; |
293 | |
294 | if weights_raw.len() < bytes_needed { |
295 | return Err(err::NotEnoughBytesInSource { |
296 | got: weights_raw.len(), |
297 | need: bytes_needed, |
298 | }); |
299 | } |
300 | |
301 | for idx in 0..num_weights { |
302 | if idx % 2 == 0 { |
303 | self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4; |
304 | } else { |
305 | self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF; |
306 | } |
307 | bits_read += 4; |
308 | } |
309 | } |
310 | } |
311 | |
312 | let bytes_read = if bits_read % 8 == 0 { |
313 | bits_read / 8 |
314 | } else { |
315 | (bits_read / 8) + 1 |
316 | }; |
317 | Ok(bytes_read as u32) |
318 | } |
319 | |
320 | fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> { |
321 | use HuffmanTableError as err; |
322 | |
323 | self.bits.clear(); |
324 | self.bits.resize(self.weights.len() + 1, 0); |
325 | |
326 | let mut weight_sum: u32 = 0; |
327 | for w in &self.weights { |
328 | if *w > MAX_MAX_NUM_BITS { |
329 | return Err(err::WeightBiggerThanMaxNumBits { got: *w }); |
330 | } |
331 | weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 }; |
332 | } |
333 | |
334 | if weight_sum == 0 { |
335 | return Err(err::MissingWeights); |
336 | } |
337 | |
338 | let max_bits = highest_bit_set(weight_sum) as u8; |
339 | let left_over = (1 << max_bits) - weight_sum; |
340 | |
341 | //left_over must be power of two |
342 | if !left_over.is_power_of_two() { |
343 | return Err(err::LeftoverIsNotAPowerOf2 { got: left_over }); |
344 | } |
345 | |
346 | let last_weight = highest_bit_set(left_over) as u8; |
347 | |
348 | for symbol in 0..self.weights.len() { |
349 | let bits = if self.weights[symbol] > 0 { |
350 | max_bits + 1 - self.weights[symbol] |
351 | } else { |
352 | 0 |
353 | }; |
354 | self.bits[symbol] = bits; |
355 | } |
356 | |
357 | self.bits[self.weights.len()] = max_bits + 1 - last_weight; |
358 | self.max_num_bits = max_bits; |
359 | |
360 | if max_bits > MAX_MAX_NUM_BITS { |
361 | return Err(err::MaxBitsTooHigh { got: max_bits }); |
362 | } |
363 | |
364 | self.bit_ranks.clear(); |
365 | self.bit_ranks.resize((max_bits + 1) as usize, 0); |
366 | for num_bits in &self.bits { |
367 | self.bit_ranks[(*num_bits) as usize] += 1; |
368 | } |
369 | |
370 | //fill with dummy symbols |
371 | self.decode.resize( |
372 | 1 << self.max_num_bits, |
373 | Entry { |
374 | symbol: 0, |
375 | num_bits: 0, |
376 | }, |
377 | ); |
378 | |
379 | //starting codes for each rank |
380 | self.rank_indexes.clear(); |
381 | self.rank_indexes.resize((max_bits + 1) as usize, 0); |
382 | |
383 | self.rank_indexes[max_bits as usize] = 0; |
384 | for bits in (1..self.rank_indexes.len() as u8).rev() { |
385 | self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize] |
386 | + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits)); |
387 | } |
388 | |
389 | assert!( |
390 | self.rank_indexes[0] == self.decode.len(), |
391 | "rank_idx[0]: {} should be: {}" , |
392 | self.rank_indexes[0], |
393 | self.decode.len() |
394 | ); |
395 | |
396 | for symbol in 0..self.bits.len() { |
397 | let bits_for_symbol = self.bits[symbol]; |
398 | if bits_for_symbol != 0 { |
399 | // allocate code for the symbol and set in the table |
400 | // a code ignores all max_bits - bits[symbol] bits, so it gets |
401 | // a range that spans all of those in the decoding table |
402 | let base_idx = self.rank_indexes[bits_for_symbol as usize]; |
403 | let len = 1 << (max_bits - bits_for_symbol); |
404 | self.rank_indexes[bits_for_symbol as usize] += len; |
405 | for idx in 0..len { |
406 | self.decode[base_idx + idx].symbol = symbol as u8; |
407 | self.decode[base_idx + idx].num_bits = bits_for_symbol; |
408 | } |
409 | } |
410 | } |
411 | |
412 | Ok(()) |
413 | } |
414 | } |
415 | |