1 | //! Utilities for decoding Huff0 encoded huffman data. |
2 | |
3 | use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
4 | use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError}; |
5 | use alloc::vec::Vec; |
6 | #[cfg (feature = "std" )] |
7 | use std::error::Error as StdError; |
8 | |
9 | pub struct HuffmanTable { |
10 | decode: Vec<Entry>, |
11 | /// The weight of a symbol is the number of occurences in a table. |
12 | /// This value is used in constructing a binary tree referred to as |
13 | /// a huffman tree. |
14 | weights: Vec<u8>, |
15 | /// The maximum size in bits a prefix code in the encoded data can be. |
16 | /// This value is used so that the decoder knows how many bits |
17 | /// to read from the bitstream before checking the table. This |
18 | /// value must be 11 or lower. |
19 | pub max_num_bits: u8, |
20 | bits: Vec<u8>, |
21 | bit_ranks: Vec<u32>, |
22 | rank_indexes: Vec<usize>, |
23 | /// In some cases, the list of weights is compressed using FSE compression. |
24 | fse_table: FSETable, |
25 | } |
26 | |
27 | #[derive (Debug)] |
28 | #[non_exhaustive ] |
29 | pub enum HuffmanTableError { |
30 | GetBitsError(GetBitsError), |
31 | FSEDecoderError(FSEDecoderError), |
32 | FSETableError(FSETableError), |
33 | SourceIsEmpty, |
34 | NotEnoughBytesForWeights { |
35 | got_bytes: usize, |
36 | expected_bytes: u8, |
37 | }, |
38 | ExtraPadding { |
39 | skipped_bits: i32, |
40 | }, |
41 | TooManyWeights { |
42 | got: usize, |
43 | }, |
44 | MissingWeights, |
45 | LeftoverIsNotAPowerOf2 { |
46 | got: u32, |
47 | }, |
48 | NotEnoughBytesToDecompressWeights { |
49 | have: usize, |
50 | need: usize, |
51 | }, |
52 | FSETableUsedTooManyBytes { |
53 | used: usize, |
54 | available_bytes: u8, |
55 | }, |
56 | NotEnoughBytesInSource { |
57 | got: usize, |
58 | need: usize, |
59 | }, |
60 | WeightBiggerThanMaxNumBits { |
61 | got: u8, |
62 | }, |
63 | MaxBitsTooHigh { |
64 | got: u8, |
65 | }, |
66 | } |
67 | |
68 | #[cfg (feature = "std" )] |
69 | impl StdError for HuffmanTableError { |
70 | fn source(&self) -> Option<&(dyn StdError + 'static)> { |
71 | match self { |
72 | HuffmanTableError::GetBitsError(source: &GetBitsError) => Some(source), |
73 | HuffmanTableError::FSEDecoderError(source: &FSEDecoderError) => Some(source), |
74 | HuffmanTableError::FSETableError(source: &FSETableError) => Some(source), |
75 | _ => None, |
76 | } |
77 | } |
78 | } |
79 | |
80 | impl core::fmt::Display for HuffmanTableError { |
81 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result { |
82 | match self { |
83 | HuffmanTableError::GetBitsError(e) => write!(f, " {:?}" , e), |
84 | HuffmanTableError::FSEDecoderError(e) => write!(f, " {:?}" , e), |
85 | HuffmanTableError::FSETableError(e) => write!(f, " {:?}" , e), |
86 | HuffmanTableError::SourceIsEmpty => write!(f, "Source needs to have at least one byte" ), |
87 | HuffmanTableError::NotEnoughBytesForWeights { |
88 | got_bytes, |
89 | expected_bytes, |
90 | } => { |
91 | write!(f, "Header says there should be {} bytes for the weights but there are only {} bytes in the stream" , |
92 | expected_bytes, |
93 | got_bytes) |
94 | } |
95 | HuffmanTableError::ExtraPadding { skipped_bits } => { |
96 | write!(f, |
97 | "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption" , |
98 | skipped_bits, |
99 | ) |
100 | } |
101 | HuffmanTableError::TooManyWeights { got } => { |
102 | write!( |
103 | f, |
104 | "More than 255 weights decoded (got {} weights). Stream is probably corrupted" , |
105 | got, |
106 | ) |
107 | } |
108 | HuffmanTableError::MissingWeights => { |
109 | write!(f, "Can \'t build huffman table without any weights" ) |
110 | } |
111 | HuffmanTableError::LeftoverIsNotAPowerOf2 { got } => { |
112 | write!(f, "Leftover must be power of two but is: {}" , got) |
113 | } |
114 | HuffmanTableError::NotEnoughBytesToDecompressWeights { have, need } => { |
115 | write!( |
116 | f, |
117 | "Not enough bytes in stream to decompress weights. Is: {}, Should be: {}" , |
118 | have, need, |
119 | ) |
120 | } |
121 | HuffmanTableError::FSETableUsedTooManyBytes { |
122 | used, |
123 | available_bytes, |
124 | } => { |
125 | write!(f, |
126 | "FSE table used more bytes: {} than were meant to be used for the whole stream of huffman weights ( {})" , |
127 | used, |
128 | available_bytes, |
129 | ) |
130 | } |
131 | HuffmanTableError::NotEnoughBytesInSource { got, need } => { |
132 | write!( |
133 | f, |
134 | "Source needs to have at least {} bytes, got: {}" , |
135 | need, got, |
136 | ) |
137 | } |
138 | HuffmanTableError::WeightBiggerThanMaxNumBits { got } => { |
139 | write!( |
140 | f, |
141 | "Cant have weight: {} bigger than max_num_bits: {}" , |
142 | got, MAX_MAX_NUM_BITS, |
143 | ) |
144 | } |
145 | HuffmanTableError::MaxBitsTooHigh { got } => { |
146 | write!( |
147 | f, |
148 | "max_bits derived from weights is: {} should be lower than: {}" , |
149 | got, MAX_MAX_NUM_BITS, |
150 | ) |
151 | } |
152 | } |
153 | } |
154 | } |
155 | |
156 | impl From<GetBitsError> for HuffmanTableError { |
157 | fn from(val: GetBitsError) -> Self { |
158 | Self::GetBitsError(val) |
159 | } |
160 | } |
161 | |
162 | impl From<FSEDecoderError> for HuffmanTableError { |
163 | fn from(val: FSEDecoderError) -> Self { |
164 | Self::FSEDecoderError(val) |
165 | } |
166 | } |
167 | |
168 | impl From<FSETableError> for HuffmanTableError { |
169 | fn from(val: FSETableError) -> Self { |
170 | Self::FSETableError(val) |
171 | } |
172 | } |
173 | |
174 | /// An interface around a huffman table used to decode data. |
175 | pub struct HuffmanDecoder<'table> { |
176 | table: &'table HuffmanTable, |
177 | /// State is used to index into the table. |
178 | pub state: u64, |
179 | } |
180 | |
181 | #[derive (Debug)] |
182 | #[non_exhaustive ] |
183 | pub enum HuffmanDecoderError { |
184 | GetBitsError(GetBitsError), |
185 | } |
186 | |
187 | impl core::fmt::Display for HuffmanDecoderError { |
188 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
189 | match self { |
190 | HuffmanDecoderError::GetBitsError(e: &GetBitsError) => write!(f, " {:?}" , e), |
191 | } |
192 | } |
193 | } |
194 | |
195 | #[cfg (feature = "std" )] |
196 | impl StdError for HuffmanDecoderError { |
197 | fn source(&self) -> Option<&(dyn StdError + 'static)> { |
198 | match self { |
199 | HuffmanDecoderError::GetBitsError(source: &GetBitsError) => Some(source), |
200 | } |
201 | } |
202 | } |
203 | |
204 | impl From<GetBitsError> for HuffmanDecoderError { |
205 | fn from(val: GetBitsError) -> Self { |
206 | Self::GetBitsError(val) |
207 | } |
208 | } |
209 | |
210 | /// A single entry in the table contains the decoded symbol/literal and the |
211 | /// size of the prefix code. |
212 | #[derive (Copy, Clone)] |
213 | pub struct Entry { |
214 | /// The byte that the prefix code replaces during encoding. |
215 | symbol: u8, |
216 | /// The number of bits the prefix code occupies. |
217 | num_bits: u8, |
218 | } |
219 | |
220 | /// The Zstandard specification limits the maximum length of a code to 11 bits. |
221 | const MAX_MAX_NUM_BITS: u8 = 11; |
222 | |
223 | /// Assert that the provided value is greater than zero, and returns the |
224 | /// 32 - the number of leading zeros |
225 | fn highest_bit_set(x: u32) -> u32 { |
226 | assert!(x > 0); |
227 | u32::BITS - x.leading_zeros() |
228 | } |
229 | |
230 | impl<'t> HuffmanDecoder<'t> { |
231 | /// Create a new decoder with the provided table |
232 | pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> { |
233 | HuffmanDecoder { table, state: 0 } |
234 | } |
235 | |
236 | /// Re-initialize the decoder, using the new table if one is provided. |
237 | /// This might used for treeless blocks, because they re-use the table from old |
238 | /// data. |
239 | pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) { |
240 | self.state = 0; |
241 | if let Some(next_table) = new_table { |
242 | self.table = next_table; |
243 | } |
244 | } |
245 | |
246 | /// Decode the symbol the internal state (cursor) is pointed at and return the |
247 | /// decoded literal. |
248 | pub fn decode_symbol(&mut self) -> u8 { |
249 | self.table.decode[self.state as usize].symbol |
250 | } |
251 | |
252 | /// Initialize internal state and prepare to decode data. Then, `decode_symbol` can be called |
253 | /// to read the byte the internal cursor is pointing at, and `next_state` can be called to advance |
254 | /// the cursor until the max number of bits has been read. |
255 | pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { |
256 | let num_bits = self.table.max_num_bits; |
257 | let new_bits = br.get_bits(num_bits); |
258 | self.state = new_bits; |
259 | num_bits |
260 | } |
261 | |
262 | /// Advance the internal cursor to the next symbol. After this, you can call `decode_symbol` |
263 | /// to read from the new position. |
264 | pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { |
265 | // self.state stores a small section, or a window of the bit stream. The table can be indexed via this state, |
266 | // telling you how many bits identify the current symbol. |
267 | let num_bits = self.table.decode[self.state as usize].num_bits; |
268 | // New bits are read from the stream |
269 | let new_bits = br.get_bits(num_bits); |
270 | // Shift and mask out the bits that identify the current symbol |
271 | self.state <<= num_bits; |
272 | self.state &= self.table.decode.len() as u64 - 1; |
273 | // The new bits are appended at the end of the current state. |
274 | self.state |= new_bits; |
275 | num_bits |
276 | } |
277 | } |
278 | |
279 | impl Default for HuffmanTable { |
280 | fn default() -> Self { |
281 | Self::new() |
282 | } |
283 | } |
284 | |
285 | impl HuffmanTable { |
286 | /// Create a new, empty table. |
287 | pub fn new() -> HuffmanTable { |
288 | HuffmanTable { |
289 | decode: Vec::new(), |
290 | |
291 | weights: Vec::with_capacity(256), |
292 | max_num_bits: 0, |
293 | bits: Vec::with_capacity(256), |
294 | bit_ranks: Vec::with_capacity(11), |
295 | rank_indexes: Vec::with_capacity(11), |
296 | fse_table: FSETable::new(100), |
297 | } |
298 | } |
299 | |
300 | /// Completely empty the table then repopulate as a replica |
301 | /// of `other`. |
302 | pub fn reinit_from(&mut self, other: &Self) { |
303 | self.reset(); |
304 | self.decode.extend_from_slice(&other.decode); |
305 | self.weights.extend_from_slice(&other.weights); |
306 | self.max_num_bits = other.max_num_bits; |
307 | self.bits.extend_from_slice(&other.bits); |
308 | self.rank_indexes.extend_from_slice(&other.rank_indexes); |
309 | self.fse_table.reinit_from(&other.fse_table); |
310 | } |
311 | |
312 | /// Completely empty the table of all data. |
313 | pub fn reset(&mut self) { |
314 | self.decode.clear(); |
315 | self.weights.clear(); |
316 | self.max_num_bits = 0; |
317 | self.bits.clear(); |
318 | self.bit_ranks.clear(); |
319 | self.rank_indexes.clear(); |
320 | self.fse_table.reset(); |
321 | } |
322 | |
323 | /// Read from `source` and parse it into a huffman table. |
324 | /// |
325 | /// Returns the number of bytes read. |
326 | pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> { |
327 | self.decode.clear(); |
328 | |
329 | let bytes_used = self.read_weights(source)?; |
330 | self.build_table_from_weights()?; |
331 | Ok(bytes_used) |
332 | } |
333 | |
334 | /// Read weights from the provided source. |
335 | /// |
336 | /// The huffman table is represented in the encoded data as a list of weights |
337 | /// at the most basic level. After the header, weights are read, then the table |
338 | /// can be built using that list of weights. |
339 | /// |
340 | /// Returns the number of bytes read. |
341 | fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> { |
342 | use HuffmanTableError as err; |
343 | |
344 | if source.is_empty() { |
345 | return Err(err::SourceIsEmpty); |
346 | } |
347 | let header = source[0]; |
348 | let mut bits_read = 8; |
349 | |
350 | match header { |
351 | // If the header byte is less than 128, the series of weights |
352 | // is compressed using two interleaved FSE streams that share |
353 | // a distribution table. |
354 | 0..=127 => { |
355 | let fse_stream = &source[1..]; |
356 | if header as usize > fse_stream.len() { |
357 | return Err(err::NotEnoughBytesForWeights { |
358 | got_bytes: fse_stream.len(), |
359 | expected_bytes: header, |
360 | }); |
361 | } |
362 | //fse decompress weights |
363 | let bytes_used_by_fse_header = self |
364 | .fse_table |
365 | .build_decoder(fse_stream, /*TODO find actual max*/ 100)?; |
366 | |
367 | if bytes_used_by_fse_header > header as usize { |
368 | return Err(err::FSETableUsedTooManyBytes { |
369 | used: bytes_used_by_fse_header, |
370 | available_bytes: header, |
371 | }); |
372 | } |
373 | |
374 | vprintln!( |
375 | "Building fse table for huffman weights used: {}" , |
376 | bytes_used_by_fse_header |
377 | ); |
378 | // Huffman headers are compressed using two interleaved |
379 | // FSE bitstreams, where the first state (decoder) handles |
380 | // even symbols, and the second handles odd symbols. |
381 | let mut dec1 = FSEDecoder::new(&self.fse_table); |
382 | let mut dec2 = FSEDecoder::new(&self.fse_table); |
383 | |
384 | let compressed_start = bytes_used_by_fse_header; |
385 | let compressed_length = header as usize - bytes_used_by_fse_header; |
386 | |
387 | let compressed_weights = &fse_stream[compressed_start..]; |
388 | if compressed_weights.len() < compressed_length { |
389 | return Err(err::NotEnoughBytesToDecompressWeights { |
390 | have: compressed_weights.len(), |
391 | need: compressed_length, |
392 | }); |
393 | } |
394 | let compressed_weights = &compressed_weights[..compressed_length]; |
395 | let mut br = BitReaderReversed::new(compressed_weights); |
396 | |
397 | bits_read += (bytes_used_by_fse_header + compressed_length) * 8; |
398 | |
399 | //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found |
400 | let mut skipped_bits = 0; |
401 | loop { |
402 | let val = br.get_bits(1); |
403 | skipped_bits += 1; |
404 | if val == 1 || skipped_bits > 8 { |
405 | break; |
406 | } |
407 | } |
408 | if skipped_bits > 8 { |
409 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
410 | return Err(err::ExtraPadding { skipped_bits }); |
411 | } |
412 | |
413 | dec1.init_state(&mut br)?; |
414 | dec2.init_state(&mut br)?; |
415 | |
416 | self.weights.clear(); |
417 | |
418 | // The two decoders take turns decoding a single symbol and updating their state. |
419 | loop { |
420 | let w = dec1.decode_symbol(); |
421 | self.weights.push(w); |
422 | dec1.update_state(&mut br); |
423 | |
424 | if br.bits_remaining() <= -1 { |
425 | //collect final states |
426 | self.weights.push(dec2.decode_symbol()); |
427 | break; |
428 | } |
429 | |
430 | let w = dec2.decode_symbol(); |
431 | self.weights.push(w); |
432 | dec2.update_state(&mut br); |
433 | |
434 | if br.bits_remaining() <= -1 { |
435 | //collect final states |
436 | self.weights.push(dec1.decode_symbol()); |
437 | break; |
438 | } |
439 | //maximum number of weights is 255 because we use u8 symbols and the last weight is inferred from the sum of all others |
440 | if self.weights.len() > 255 { |
441 | return Err(err::TooManyWeights { |
442 | got: self.weights.len(), |
443 | }); |
444 | } |
445 | } |
446 | } |
447 | // If the header byte is greater than or equal to 128, |
448 | // weights are directly represented, where each weight is |
449 | // encoded directly as a 4 bit field. The weights will |
450 | // always be encoded with full bytes, meaning if there's |
451 | // an odd number of weights, the last weight will still |
452 | // occupy a full byte. |
453 | _ => { |
454 | // weights are directly encoded |
455 | let weights_raw = &source[1..]; |
456 | let num_weights = header - 127; |
457 | self.weights.resize(num_weights as usize, 0); |
458 | |
459 | let bytes_needed = if num_weights % 2 == 0 { |
460 | num_weights as usize / 2 |
461 | } else { |
462 | (num_weights as usize / 2) + 1 |
463 | }; |
464 | |
465 | if weights_raw.len() < bytes_needed { |
466 | return Err(err::NotEnoughBytesInSource { |
467 | got: weights_raw.len(), |
468 | need: bytes_needed, |
469 | }); |
470 | } |
471 | |
472 | for idx in 0..num_weights { |
473 | if idx % 2 == 0 { |
474 | self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4; |
475 | } else { |
476 | self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF; |
477 | } |
478 | bits_read += 4; |
479 | } |
480 | } |
481 | } |
482 | |
483 | let bytes_read = if bits_read % 8 == 0 { |
484 | bits_read / 8 |
485 | } else { |
486 | (bits_read / 8) + 1 |
487 | }; |
488 | Ok(bytes_read as u32) |
489 | } |
490 | |
491 | /// Once the weights have been read from the data, you can decode the weights |
492 | /// into a table, and use that table to decode the actual compressed data. |
493 | /// |
494 | /// This function populates the rest of the table from the series of weights. |
495 | fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> { |
496 | use HuffmanTableError as err; |
497 | |
498 | self.bits.clear(); |
499 | self.bits.resize(self.weights.len() + 1, 0); |
500 | |
501 | let mut weight_sum: u32 = 0; |
502 | for w in &self.weights { |
503 | if *w > MAX_MAX_NUM_BITS { |
504 | return Err(err::WeightBiggerThanMaxNumBits { got: *w }); |
505 | } |
506 | weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 }; |
507 | } |
508 | |
509 | if weight_sum == 0 { |
510 | return Err(err::MissingWeights); |
511 | } |
512 | |
513 | let max_bits = highest_bit_set(weight_sum) as u8; |
514 | let left_over = (1 << max_bits) - weight_sum; |
515 | |
516 | //left_over must be power of two |
517 | if !left_over.is_power_of_two() { |
518 | return Err(err::LeftoverIsNotAPowerOf2 { got: left_over }); |
519 | } |
520 | |
521 | let last_weight = highest_bit_set(left_over) as u8; |
522 | |
523 | for symbol in 0..self.weights.len() { |
524 | let bits = if self.weights[symbol] > 0 { |
525 | max_bits + 1 - self.weights[symbol] |
526 | } else { |
527 | 0 |
528 | }; |
529 | self.bits[symbol] = bits; |
530 | } |
531 | |
532 | self.bits[self.weights.len()] = max_bits + 1 - last_weight; |
533 | self.max_num_bits = max_bits; |
534 | |
535 | if max_bits > MAX_MAX_NUM_BITS { |
536 | return Err(err::MaxBitsTooHigh { got: max_bits }); |
537 | } |
538 | |
539 | self.bit_ranks.clear(); |
540 | self.bit_ranks.resize((max_bits + 1) as usize, 0); |
541 | for num_bits in &self.bits { |
542 | self.bit_ranks[(*num_bits) as usize] += 1; |
543 | } |
544 | |
545 | //fill with dummy symbols |
546 | self.decode.resize( |
547 | 1 << self.max_num_bits, |
548 | Entry { |
549 | symbol: 0, |
550 | num_bits: 0, |
551 | }, |
552 | ); |
553 | |
554 | //starting codes for each rank |
555 | self.rank_indexes.clear(); |
556 | self.rank_indexes.resize((max_bits + 1) as usize, 0); |
557 | |
558 | self.rank_indexes[max_bits as usize] = 0; |
559 | for bits in (1..self.rank_indexes.len() as u8).rev() { |
560 | self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize] |
561 | + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits)); |
562 | } |
563 | |
564 | assert!( |
565 | self.rank_indexes[0] == self.decode.len(), |
566 | "rank_idx[0]: {} should be: {}" , |
567 | self.rank_indexes[0], |
568 | self.decode.len() |
569 | ); |
570 | |
571 | for symbol in 0..self.bits.len() { |
572 | let bits_for_symbol = self.bits[symbol]; |
573 | if bits_for_symbol != 0 { |
574 | // allocate code for the symbol and set in the table |
575 | // a code ignores all max_bits - bits[symbol] bits, so it gets |
576 | // a range that spans all of those in the decoding table |
577 | let base_idx = self.rank_indexes[bits_for_symbol as usize]; |
578 | let len = 1 << (max_bits - bits_for_symbol); |
579 | self.rank_indexes[bits_for_symbol as usize] += len; |
580 | for idx in 0..len { |
581 | self.decode[base_idx + idx].symbol = symbol as u8; |
582 | self.decode[base_idx + idx].num_bits = bits_for_symbol; |
583 | } |
584 | } |
585 | } |
586 | |
587 | Ok(()) |
588 | } |
589 | } |
590 | |