1 | use crate::decoding::bit_reader::BitReader; |
2 | use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
3 | use alloc::vec::Vec; |
4 | |
5 | /// FSE decoding involves a decoding table that describes the probabilities of |
6 | /// all literals from 0 to the highest present one |
7 | /// |
8 | /// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description> |
9 | pub struct FSETable { |
10 | /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1. |
11 | max_symbol: u8, |
12 | /// The actual table containing the decoded symbol and the compression data |
13 | /// connected to that symbol. |
14 | pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state |
15 | /// The size of the table is stored in logarithm base 2 format, |
16 | /// with the **size of the table** being equal to `(1 << accuracy_log)`. |
17 | /// This value is used so that the decoder knows how many bits to read from the bitstream. |
18 | pub accuracy_log: u8, |
19 | /// In this context, probability refers to the likelihood that a symbol occurs in the given data. |
20 | /// Given this info, the encoder can assign shorter codes to symbols that appear more often, |
21 | /// and longer codes that appear less often, then the decoder can use the probability |
22 | /// to determine what code was assigned to what symbol. |
23 | /// |
24 | /// The probability of a single symbol is a value representing the proportion of times the symbol |
25 | /// would fall within the data. |
26 | /// |
27 | /// If a symbol probability is set to `-1`, it means that the probability of a symbol |
28 | /// occurring in the data is less than one. |
29 | pub symbol_probabilities: Vec<i32>, //used while building the decode Vector |
30 | /// The number of times each symbol occurs (The first entry being 0x0, the second being 0x1) and so on |
31 | /// up until the highest possible symbol (255). |
32 | symbol_counter: Vec<u32>, |
33 | } |
34 | |
35 | #[derive (Debug)] |
36 | #[non_exhaustive ] |
37 | pub enum FSETableError { |
38 | AccLogIsZero, |
39 | AccLogTooBig { |
40 | got: u8, |
41 | max: u8, |
42 | }, |
43 | GetBitsError(GetBitsError), |
44 | ProbabilityCounterMismatch { |
45 | got: u32, |
46 | expected_sum: u32, |
47 | symbol_probabilities: Vec<i32>, |
48 | }, |
49 | TooManySymbols { |
50 | got: usize, |
51 | }, |
52 | } |
53 | |
54 | #[cfg (feature = "std" )] |
55 | impl std::error::Error for FSETableError { |
56 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |
57 | match self { |
58 | FSETableError::GetBitsError(source: &GetBitsError) => Some(source), |
59 | _ => None, |
60 | } |
61 | } |
62 | } |
63 | |
64 | impl core::fmt::Display for FSETableError { |
65 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
66 | match self { |
67 | FSETableError::AccLogIsZero => write!(f, "Acclog must be at least 1" ), |
68 | FSETableError::AccLogTooBig { got, max } => { |
69 | write!( |
70 | f, |
71 | "Found FSE acc_log: {0} bigger than allowed maximum in this case: {1}" , |
72 | got, max |
73 | ) |
74 | } |
75 | FSETableError::GetBitsError(e) => write!(f, " {:?}" , e), |
76 | FSETableError::ProbabilityCounterMismatch { |
77 | got, |
78 | expected_sum, |
79 | symbol_probabilities, |
80 | } => { |
81 | write!(f, |
82 | "The counter ( {}) exceeded the expected sum: {}. This means an error or corrupted data \n {:?}" , |
83 | got, |
84 | expected_sum, |
85 | symbol_probabilities, |
86 | ) |
87 | } |
88 | FSETableError::TooManySymbols { got } => { |
89 | write!( |
90 | f, |
91 | "There are too many symbols in this distribution: {}. Max: 256" , |
92 | got, |
93 | ) |
94 | } |
95 | } |
96 | } |
97 | } |
98 | |
99 | impl From<GetBitsError> for FSETableError { |
100 | fn from(val: GetBitsError) -> Self { |
101 | Self::GetBitsError(val) |
102 | } |
103 | } |
104 | |
105 | pub struct FSEDecoder<'table> { |
106 | /// An FSE state value represents an index in the FSE table. |
107 | pub state: Entry, |
108 | /// A reference to the table used for decoding. |
109 | table: &'table FSETable, |
110 | } |
111 | |
112 | #[derive (Debug)] |
113 | #[non_exhaustive ] |
114 | pub enum FSEDecoderError { |
115 | GetBitsError(GetBitsError), |
116 | TableIsUninitialized, |
117 | } |
118 | |
119 | #[cfg (feature = "std" )] |
120 | impl std::error::Error for FSEDecoderError { |
121 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |
122 | match self { |
123 | FSEDecoderError::GetBitsError(source: &GetBitsError) => Some(source), |
124 | _ => None, |
125 | } |
126 | } |
127 | } |
128 | |
129 | impl core::fmt::Display for FSEDecoderError { |
130 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
131 | match self { |
132 | FSEDecoderError::GetBitsError(e: &GetBitsError) => write!(f, " {:?}" , e), |
133 | FSEDecoderError::TableIsUninitialized => { |
134 | write!(f, "Tried to use an uninitialized table!" ) |
135 | } |
136 | } |
137 | } |
138 | } |
139 | |
140 | impl From<GetBitsError> for FSEDecoderError { |
141 | fn from(val: GetBitsError) -> Self { |
142 | Self::GetBitsError(val) |
143 | } |
144 | } |
145 | |
146 | /// A single entry in an FSE table. |
147 | #[derive (Copy, Clone)] |
148 | pub struct Entry { |
149 | /// This value is used as an offset value, and it is added |
150 | /// to a value read from the stream to determine the next state value. |
151 | pub base_line: u32, |
152 | /// How many bits should be read from the stream when decoding this entry. |
153 | pub num_bits: u8, |
154 | /// The byte that should be put in the decode output when encountering this state. |
155 | pub symbol: u8, |
156 | } |
157 | |
158 | /// This value is added to the first 4 bits of the stream to determine the |
159 | /// `Accuracy_Log` |
160 | const ACC_LOG_OFFSET: u8 = 5; |
161 | |
162 | fn highest_bit_set(x: u32) -> u32 { |
163 | assert!(x > 0); |
164 | u32::BITS - x.leading_zeros() |
165 | } |
166 | |
167 | impl<'t> FSEDecoder<'t> { |
168 | /// Initialize a new Finite State Entropy decoder. |
169 | pub fn new(table: &'t FSETable) -> FSEDecoder<'t> { |
170 | FSEDecoder { |
171 | state: table.decode.first().copied().unwrap_or(Entry { |
172 | base_line: 0, |
173 | num_bits: 0, |
174 | symbol: 0, |
175 | }), |
176 | table, |
177 | } |
178 | } |
179 | |
180 | /// Returns the byte associated with the symbol the internal cursor is pointing at. |
181 | pub fn decode_symbol(&self) -> u8 { |
182 | self.state.symbol |
183 | } |
184 | |
185 | /// Initialize internal state and prepare for decoding. After this, `decode_symbol` can be called |
186 | /// to read the first symbol and `update_state` can be called to prepare to read the next symbol. |
187 | pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> { |
188 | if self.table.accuracy_log == 0 { |
189 | return Err(FSEDecoderError::TableIsUninitialized); |
190 | } |
191 | self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize]; |
192 | |
193 | Ok(()) |
194 | } |
195 | |
196 | /// Advance the internal state to decode the next symbol in the bitstream. |
197 | pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { |
198 | let num_bits = self.state.num_bits; |
199 | let add = bits.get_bits(num_bits); |
200 | let base_line = self.state.base_line; |
201 | let new_state = base_line + add as u32; |
202 | self.state = self.table.decode[new_state as usize]; |
203 | |
204 | //println!("Update: {}, {} -> {}", base_line, add, self.state); |
205 | } |
206 | } |
207 | |
208 | impl FSETable { |
209 | /// Initialize a new empty Finite State Entropy decoding table. |
210 | pub fn new(max_symbol: u8) -> FSETable { |
211 | FSETable { |
212 | max_symbol, |
213 | symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8 |
214 | symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8 |
215 | decode: Vec::new(), //depending on acc_log. |
216 | accuracy_log: 0, |
217 | } |
218 | } |
219 | |
220 | /// Reset `self` and update `self`'s state to mirror the provided table. |
221 | pub fn reinit_from(&mut self, other: &Self) { |
222 | self.reset(); |
223 | self.symbol_counter.extend_from_slice(&other.symbol_counter); |
224 | self.symbol_probabilities |
225 | .extend_from_slice(&other.symbol_probabilities); |
226 | self.decode.extend_from_slice(&other.decode); |
227 | self.accuracy_log = other.accuracy_log; |
228 | } |
229 | |
230 | /// Empty the table and clear all internal state. |
231 | pub fn reset(&mut self) { |
232 | self.symbol_counter.clear(); |
233 | self.symbol_probabilities.clear(); |
234 | self.decode.clear(); |
235 | self.accuracy_log = 0; |
236 | } |
237 | |
238 | /// returns how many BYTEs (not bits) were read while building the decoder |
239 | pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { |
240 | self.accuracy_log = 0; |
241 | |
242 | let bytes_read = self.read_probabilities(source, max_log)?; |
243 | self.build_decoding_table()?; |
244 | |
245 | Ok(bytes_read) |
246 | } |
247 | |
248 | /// Given the provided accuracy log, build a decoding table from that log. |
249 | pub fn build_from_probabilities( |
250 | &mut self, |
251 | acc_log: u8, |
252 | probs: &[i32], |
253 | ) -> Result<(), FSETableError> { |
254 | if acc_log == 0 { |
255 | return Err(FSETableError::AccLogIsZero); |
256 | } |
257 | self.symbol_probabilities = probs.to_vec(); |
258 | self.accuracy_log = acc_log; |
259 | self.build_decoding_table() |
260 | } |
261 | |
262 | /// Build the actual decoding table after probabilities have been read into the table. |
263 | /// After this function is called, the decoding process can begin. |
264 | fn build_decoding_table(&mut self) -> Result<(), FSETableError> { |
265 | if self.symbol_probabilities.len() > self.max_symbol as usize + 1 { |
266 | return Err(FSETableError::TooManySymbols { |
267 | got: self.symbol_probabilities.len(), |
268 | }); |
269 | } |
270 | |
271 | self.decode.clear(); |
272 | |
273 | let table_size = 1 << self.accuracy_log; |
274 | if self.decode.len() < table_size { |
275 | self.decode.reserve(table_size - self.decode.len()); |
276 | } |
277 | //fill with dummy entries |
278 | self.decode.resize( |
279 | table_size, |
280 | Entry { |
281 | base_line: 0, |
282 | num_bits: 0, |
283 | symbol: 0, |
284 | }, |
285 | ); |
286 | |
287 | let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol |
288 | |
289 | //first scan for all -1 probabilities and place them at the top of the table |
290 | for symbol in 0..self.symbol_probabilities.len() { |
291 | if self.symbol_probabilities[symbol] == -1 { |
292 | negative_idx -= 1; |
293 | let entry = &mut self.decode[negative_idx]; |
294 | entry.symbol = symbol as u8; |
295 | entry.base_line = 0; |
296 | entry.num_bits = self.accuracy_log; |
297 | } |
298 | } |
299 | |
300 | //then place in a semi-random order all of the other symbols |
301 | let mut position = 0; |
302 | for idx in 0..self.symbol_probabilities.len() { |
303 | let symbol = idx as u8; |
304 | if self.symbol_probabilities[idx] <= 0 { |
305 | continue; |
306 | } |
307 | |
308 | //for each probability point the symbol gets on slot |
309 | let prob = self.symbol_probabilities[idx]; |
310 | for _ in 0..prob { |
311 | let entry = &mut self.decode[position]; |
312 | entry.symbol = symbol; |
313 | |
314 | position = next_position(position, table_size); |
315 | while position >= negative_idx { |
316 | position = next_position(position, table_size); |
317 | //everything above negative_idx is already taken |
318 | } |
319 | } |
320 | } |
321 | |
322 | // baselines and num_bits can only be calculated when all symbols have been spread |
323 | self.symbol_counter.clear(); |
324 | self.symbol_counter |
325 | .resize(self.symbol_probabilities.len(), 0); |
326 | for idx in 0..negative_idx { |
327 | let entry = &mut self.decode[idx]; |
328 | let symbol = entry.symbol; |
329 | let prob = self.symbol_probabilities[symbol as usize]; |
330 | |
331 | let symbol_count = self.symbol_counter[symbol as usize]; |
332 | let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count); |
333 | |
334 | //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb); |
335 | |
336 | assert!(nb <= self.accuracy_log); |
337 | self.symbol_counter[symbol as usize] += 1; |
338 | |
339 | entry.base_line = bl; |
340 | entry.num_bits = nb; |
341 | } |
342 | Ok(()) |
343 | } |
344 | |
345 | /// Read the accuracy log and the probability table from the source and return the number of bytes |
346 | /// read. If the size of the table is larger than the provided `max_log`, return an error. |
347 | fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { |
348 | self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here |
349 | |
350 | let mut br = BitReader::new(source); |
351 | self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8); |
352 | if self.accuracy_log > max_log { |
353 | return Err(FSETableError::AccLogTooBig { |
354 | got: self.accuracy_log, |
355 | max: max_log, |
356 | }); |
357 | } |
358 | if self.accuracy_log == 0 { |
359 | return Err(FSETableError::AccLogIsZero); |
360 | } |
361 | |
362 | let probability_sum = 1 << self.accuracy_log; |
363 | let mut probability_counter = 0; |
364 | |
365 | while probability_counter < probability_sum { |
366 | let max_remaining_value = probability_sum - probability_counter + 1; |
367 | let bits_to_read = highest_bit_set(max_remaining_value); |
368 | |
369 | let unchecked_value = br.get_bits(bits_to_read as usize)? as u32; |
370 | |
371 | let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value); |
372 | let mask = (1 << (bits_to_read - 1)) - 1; |
373 | let small_value = unchecked_value & mask; |
374 | |
375 | let value = if small_value < low_threshold { |
376 | br.return_bits(1); |
377 | small_value |
378 | } else if unchecked_value > mask { |
379 | unchecked_value - low_threshold |
380 | } else { |
381 | unchecked_value |
382 | }; |
383 | //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value); |
384 | |
385 | let prob = (value as i32) - 1; |
386 | |
387 | self.symbol_probabilities.push(prob); |
388 | if prob != 0 { |
389 | if prob > 0 { |
390 | probability_counter += prob as u32; |
391 | } else { |
392 | // probability -1 counts as 1 |
393 | assert!(prob == -1); |
394 | probability_counter += 1; |
395 | } |
396 | } else { |
397 | //fast skip further zero probabilities |
398 | loop { |
399 | let skip_amount = br.get_bits(2)? as usize; |
400 | |
401 | self.symbol_probabilities |
402 | .resize(self.symbol_probabilities.len() + skip_amount, 0); |
403 | if skip_amount != 3 { |
404 | break; |
405 | } |
406 | } |
407 | } |
408 | } |
409 | |
410 | if probability_counter != probability_sum { |
411 | return Err(FSETableError::ProbabilityCounterMismatch { |
412 | got: probability_counter, |
413 | expected_sum: probability_sum, |
414 | symbol_probabilities: self.symbol_probabilities.clone(), |
415 | }); |
416 | } |
417 | if self.symbol_probabilities.len() > self.max_symbol as usize + 1 { |
418 | return Err(FSETableError::TooManySymbols { |
419 | got: self.symbol_probabilities.len(), |
420 | }); |
421 | } |
422 | |
423 | let bytes_read = if br.bits_read() % 8 == 0 { |
424 | br.bits_read() / 8 |
425 | } else { |
426 | (br.bits_read() / 8) + 1 |
427 | }; |
428 | |
429 | Ok(bytes_read) |
430 | } |
431 | } |
432 | |
433 | //utility functions for building the decoding table from probabilities |
434 | /// Calculate the position of the next entry of the table given the current |
435 | /// position and size of the table. |
436 | fn next_position(mut p: usize, table_size: usize) -> usize { |
437 | p += (table_size >> 1) + (table_size >> 3) + 3; |
438 | p &= table_size - 1; |
439 | p |
440 | } |
441 | |
442 | fn calc_baseline_and_numbits( |
443 | num_states_total: u32, |
444 | num_states_symbol: u32, |
445 | state_number: u32, |
446 | ) -> (u32, u8) { |
447 | let num_state_slices: u32 = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol { |
448 | num_states_symbol |
449 | } else { |
450 | 1 << (highest_bit_set(num_states_symbol)) |
451 | }; //always power of two |
452 | |
453 | let num_double_width_state_slices: u32 = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed |
454 | let num_single_width_state_slices: u32 = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states |
455 | let slice_width: u32 = num_states_total / num_state_slices; //size of a single width slice of states |
456 | let num_bits: u32 = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice |
457 | |
458 | if state_number < num_double_width_state_slices { |
459 | let baseline: u32 = num_single_width_state_slices * slice_width + state_number * slice_width * 2; |
460 | (baseline, num_bits as u8 + 1) |
461 | } else { |
462 | let index_shifted: u32 = state_number - num_double_width_state_slices; |
463 | ((index_shifted * slice_width), num_bits as u8) |
464 | } |
465 | } |
466 | |