1use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
2use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError};
3use alloc::vec::Vec;
4
5pub struct HuffmanTable {
6 decode: Vec<Entry>,
7
8 weights: Vec<u8>,
9 pub max_num_bits: u8,
10 bits: Vec<u8>,
11 bit_ranks: Vec<u32>,
12 rank_indexes: Vec<usize>,
13
14 fse_table: FSETable,
15}
16
17#[derive(Debug, derive_more::Display, derive_more::From)]
18#[cfg_attr(feature = "std", derive(derive_more::Error))]
19#[non_exhaustive]
20pub enum HuffmanTableError {
21 #[display(fmt = "{_0:?}")]
22 #[from]
23 GetBitsError(GetBitsError),
24 #[display(fmt = "{_0:?}")]
25 #[from]
26 FSEDecoderError(FSEDecoderError),
27 #[display(fmt = "{_0:?}")]
28 #[from]
29 FSETableError(FSETableError),
30 #[display(fmt = "Source needs to have at least one byte")]
31 SourceIsEmpty,
32 #[display(
33 fmt = "Header says there should be {expected_bytes} bytes for the weights but there are only {got_bytes} bytes in the stream"
34 )]
35 NotEnoughBytesForWeights {
36 got_bytes: usize,
37 expected_bytes: u8,
38 },
39 #[display(
40 fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption"
41 )]
42 ExtraPadding { skipped_bits: i32 },
43 #[display(
44 fmt = "More than 255 weights decoded (got {got} weights). Stream is probably corrupted"
45 )]
46 TooManyWeights { got: usize },
47 #[display(fmt = "Can't build huffman table without any weights")]
48 MissingWeights,
49 #[display(fmt = "Leftover must be power of two but is: {got}")]
50 LeftoverIsNotAPowerOf2 { got: u32 },
51 #[display(
52 fmt = "Not enough bytes in stream to decompress weights. Is: {have}, Should be: {need}"
53 )]
54 NotEnoughBytesToDecompressWeights { have: usize, need: usize },
55 #[display(
56 fmt = "FSE table used more bytes: {used} than were meant to be used for the whole stream of huffman weights ({available_bytes})"
57 )]
58 FSETableUsedTooManyBytes { used: usize, available_bytes: u8 },
59 #[display(fmt = "Source needs to have at least {need} bytes, got: {got}")]
60 NotEnoughBytesInSource { got: usize, need: usize },
61 #[display(fmt = "Cant have weight: {got} bigger than max_num_bits: {MAX_MAX_NUM_BITS}")]
62 WeightBiggerThanMaxNumBits { got: u8 },
63 #[display(
64 fmt = "max_bits derived from weights is: {got} should be lower than: {MAX_MAX_NUM_BITS}"
65 )]
66 MaxBitsTooHigh { got: u8 },
67}
68
69pub struct HuffmanDecoder<'table> {
70 table: &'table HuffmanTable,
71 pub state: u64,
72}
73
74#[derive(Debug, derive_more::Display, derive_more::From)]
75#[cfg_attr(feature = "std", derive(derive_more::Error))]
76#[non_exhaustive]
77pub enum HuffmanDecoderError {
78 #[display(fmt = "{_0:?}")]
79 #[from]
80 GetBitsError(GetBitsError),
81}
82
83#[derive(Copy, Clone)]
84pub struct Entry {
85 symbol: u8,
86 num_bits: u8,
87}
88
89const MAX_MAX_NUM_BITS: u8 = 11;
90
91fn highest_bit_set(x: u32) -> u32 {
92 assert!(x > 0);
93 u32::BITS - x.leading_zeros()
94}
95
96impl<'t> HuffmanDecoder<'t> {
97 pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
98 HuffmanDecoder { table, state: 0 }
99 }
100
101 pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) {
102 self.state = 0;
103 if let Some(next_table) = new_table {
104 self.table = next_table;
105 }
106 }
107
108 pub fn decode_symbol(&mut self) -> u8 {
109 self.table.decode[self.state as usize].symbol
110 }
111
112 pub fn init_state(
113 &mut self,
114 br: &mut BitReaderReversed<'_>,
115 ) -> Result<u8, HuffmanDecoderError> {
116 let num_bits = self.table.max_num_bits;
117 let new_bits = br.get_bits(num_bits)?;
118 self.state = new_bits;
119 Ok(num_bits)
120 }
121
122 pub fn next_state(
123 &mut self,
124 br: &mut BitReaderReversed<'_>,
125 ) -> Result<u8, HuffmanDecoderError> {
126 let num_bits = self.table.decode[self.state as usize].num_bits;
127 let new_bits = br.get_bits(num_bits)?;
128 self.state <<= num_bits;
129 self.state &= self.table.decode.len() as u64 - 1;
130 self.state |= new_bits;
131 Ok(num_bits)
132 }
133}
134
135impl Default for HuffmanTable {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl HuffmanTable {
142 pub fn new() -> HuffmanTable {
143 HuffmanTable {
144 decode: Vec::new(),
145
146 weights: Vec::with_capacity(256),
147 max_num_bits: 0,
148 bits: Vec::with_capacity(256),
149 bit_ranks: Vec::with_capacity(11),
150 rank_indexes: Vec::with_capacity(11),
151 fse_table: FSETable::new(),
152 }
153 }
154
155 pub fn reinit_from(&mut self, other: &Self) {
156 self.reset();
157 self.decode.extend_from_slice(&other.decode);
158 self.weights.extend_from_slice(&other.weights);
159 self.max_num_bits = other.max_num_bits;
160 self.bits.extend_from_slice(&other.bits);
161 self.rank_indexes.extend_from_slice(&other.rank_indexes);
162 self.fse_table.reinit_from(&other.fse_table);
163 }
164
165 pub fn reset(&mut self) {
166 self.decode.clear();
167 self.weights.clear();
168 self.max_num_bits = 0;
169 self.bits.clear();
170 self.bit_ranks.clear();
171 self.rank_indexes.clear();
172 self.fse_table.reset();
173 }
174
175 pub fn build_decoder(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
176 self.decode.clear();
177
178 let bytes_used = self.read_weights(source)?;
179 self.build_table_from_weights()?;
180 Ok(bytes_used)
181 }
182
183 fn read_weights(&mut self, source: &[u8]) -> Result<u32, HuffmanTableError> {
184 use HuffmanTableError as err;
185
186 if source.is_empty() {
187 return Err(err::SourceIsEmpty);
188 }
189 let header = source[0];
190 let mut bits_read = 8;
191
192 match header {
193 0..=127 => {
194 let fse_stream = &source[1..];
195 if header as usize > fse_stream.len() {
196 return Err(err::NotEnoughBytesForWeights {
197 got_bytes: fse_stream.len(),
198 expected_bytes: header,
199 });
200 }
201 //fse decompress weights
202 let bytes_used_by_fse_header = self
203 .fse_table
204 .build_decoder(fse_stream, /*TODO find actual max*/ 100)?;
205
206 if bytes_used_by_fse_header > header as usize {
207 return Err(err::FSETableUsedTooManyBytes {
208 used: bytes_used_by_fse_header,
209 available_bytes: header,
210 });
211 }
212
213 vprintln!(
214 "Building fse table for huffman weights used: {}",
215 bytes_used_by_fse_header
216 );
217 let mut dec1 = FSEDecoder::new(&self.fse_table);
218 let mut dec2 = FSEDecoder::new(&self.fse_table);
219
220 let compressed_start = bytes_used_by_fse_header;
221 let compressed_length = header as usize - bytes_used_by_fse_header;
222
223 let compressed_weights = &fse_stream[compressed_start..];
224 if compressed_weights.len() < compressed_length {
225 return Err(err::NotEnoughBytesToDecompressWeights {
226 have: compressed_weights.len(),
227 need: compressed_length,
228 });
229 }
230 let compressed_weights = &compressed_weights[..compressed_length];
231 let mut br = BitReaderReversed::new(compressed_weights);
232
233 bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
234
235 //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
236 let mut skipped_bits = 0;
237 loop {
238 let val = br.get_bits(1)?;
239 skipped_bits += 1;
240 if val == 1 || skipped_bits > 8 {
241 break;
242 }
243 }
244 if skipped_bits > 8 {
245 //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
246 return Err(err::ExtraPadding { skipped_bits });
247 }
248
249 dec1.init_state(&mut br)?;
250 dec2.init_state(&mut br)?;
251
252 self.weights.clear();
253
254 loop {
255 let w = dec1.decode_symbol();
256 self.weights.push(w);
257 dec1.update_state(&mut br)?;
258
259 if br.bits_remaining() <= -1 {
260 //collect final states
261 self.weights.push(dec2.decode_symbol());
262 break;
263 }
264
265 let w = dec2.decode_symbol();
266 self.weights.push(w);
267 dec2.update_state(&mut br)?;
268
269 if br.bits_remaining() <= -1 {
270 //collect final states
271 self.weights.push(dec1.decode_symbol());
272 break;
273 }
274 //maximum number of weights is 255 because we use u8 symbols and the last weight is inferred from the sum of all others
275 if self.weights.len() > 255 {
276 return Err(err::TooManyWeights {
277 got: self.weights.len(),
278 });
279 }
280 }
281 }
282 _ => {
283 // weights are directly encoded
284 let weights_raw = &source[1..];
285 let num_weights = header - 127;
286 self.weights.resize(num_weights as usize, 0);
287
288 let bytes_needed = if num_weights % 2 == 0 {
289 num_weights as usize / 2
290 } else {
291 (num_weights as usize / 2) + 1
292 };
293
294 if weights_raw.len() < bytes_needed {
295 return Err(err::NotEnoughBytesInSource {
296 got: weights_raw.len(),
297 need: bytes_needed,
298 });
299 }
300
301 for idx in 0..num_weights {
302 if idx % 2 == 0 {
303 self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
304 } else {
305 self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
306 }
307 bits_read += 4;
308 }
309 }
310 }
311
312 let bytes_read = if bits_read % 8 == 0 {
313 bits_read / 8
314 } else {
315 (bits_read / 8) + 1
316 };
317 Ok(bytes_read as u32)
318 }
319
320 fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> {
321 use HuffmanTableError as err;
322
323 self.bits.clear();
324 self.bits.resize(self.weights.len() + 1, 0);
325
326 let mut weight_sum: u32 = 0;
327 for w in &self.weights {
328 if *w > MAX_MAX_NUM_BITS {
329 return Err(err::WeightBiggerThanMaxNumBits { got: *w });
330 }
331 weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
332 }
333
334 if weight_sum == 0 {
335 return Err(err::MissingWeights);
336 }
337
338 let max_bits = highest_bit_set(weight_sum) as u8;
339 let left_over = (1 << max_bits) - weight_sum;
340
341 //left_over must be power of two
342 if !left_over.is_power_of_two() {
343 return Err(err::LeftoverIsNotAPowerOf2 { got: left_over });
344 }
345
346 let last_weight = highest_bit_set(left_over) as u8;
347
348 for symbol in 0..self.weights.len() {
349 let bits = if self.weights[symbol] > 0 {
350 max_bits + 1 - self.weights[symbol]
351 } else {
352 0
353 };
354 self.bits[symbol] = bits;
355 }
356
357 self.bits[self.weights.len()] = max_bits + 1 - last_weight;
358 self.max_num_bits = max_bits;
359
360 if max_bits > MAX_MAX_NUM_BITS {
361 return Err(err::MaxBitsTooHigh { got: max_bits });
362 }
363
364 self.bit_ranks.clear();
365 self.bit_ranks.resize((max_bits + 1) as usize, 0);
366 for num_bits in &self.bits {
367 self.bit_ranks[(*num_bits) as usize] += 1;
368 }
369
370 //fill with dummy symbols
371 self.decode.resize(
372 1 << self.max_num_bits,
373 Entry {
374 symbol: 0,
375 num_bits: 0,
376 },
377 );
378
379 //starting codes for each rank
380 self.rank_indexes.clear();
381 self.rank_indexes.resize((max_bits + 1) as usize, 0);
382
383 self.rank_indexes[max_bits as usize] = 0;
384 for bits in (1..self.rank_indexes.len() as u8).rev() {
385 self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
386 + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
387 }
388
389 assert!(
390 self.rank_indexes[0] == self.decode.len(),
391 "rank_idx[0]: {} should be: {}",
392 self.rank_indexes[0],
393 self.decode.len()
394 );
395
396 for symbol in 0..self.bits.len() {
397 let bits_for_symbol = self.bits[symbol];
398 if bits_for_symbol != 0 {
399 // allocate code for the symbol and set in the table
400 // a code ignores all max_bits - bits[symbol] bits, so it gets
401 // a range that spans all of those in the decoding table
402 let base_idx = self.rank_indexes[bits_for_symbol as usize];
403 let len = 1 << (max_bits - bits_for_symbol);
404 self.rank_indexes[bits_for_symbol as usize] += len;
405 for idx in 0..len {
406 self.decode[base_idx + idx].symbol = symbol as u8;
407 self.decode[base_idx + idx].num_bits = bits_for_symbol;
408 }
409 }
410 }
411
412 Ok(())
413 }
414}
415