1use crate::decoding::bit_reader::BitReader;
2use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
3use alloc::vec::Vec;
4
5/// FSE decoding involves a decoding table that describes the probabilities of
6/// all literals from 0 to the highest present one
7///
8/// <https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description>
9pub struct FSETable {
10 /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1.
11 max_symbol: u8,
12 /// The actual table containing the decoded symbol and the compression data
13 /// connected to that symbol.
14 pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state
15 /// The size of the table is stored in logarithm base 2 format,
16 /// with the **size of the table** being equal to `(1 << accuracy_log)`.
17 /// This value is used so that the decoder knows how many bits to read from the bitstream.
18 pub accuracy_log: u8,
19 /// In this context, probability refers to the likelihood that a symbol occurs in the given data.
20 /// Given this info, the encoder can assign shorter codes to symbols that appear more often,
21 /// and longer codes that appear less often, then the decoder can use the probability
22 /// to determine what code was assigned to what symbol.
23 ///
24 /// The probability of a single symbol is a value representing the proportion of times the symbol
25 /// would fall within the data.
26 ///
27 /// If a symbol probability is set to `-1`, it means that the probability of a symbol
28 /// occurring in the data is less than one.
29 pub symbol_probabilities: Vec<i32>, //used while building the decode Vector
30 /// The number of times each symbol occurs (The first entry being 0x0, the second being 0x1) and so on
31 /// up until the highest possible symbol (255).
32 symbol_counter: Vec<u32>,
33}
34
35#[derive(Debug)]
36#[non_exhaustive]
37pub enum FSETableError {
38 AccLogIsZero,
39 AccLogTooBig {
40 got: u8,
41 max: u8,
42 },
43 GetBitsError(GetBitsError),
44 ProbabilityCounterMismatch {
45 got: u32,
46 expected_sum: u32,
47 symbol_probabilities: Vec<i32>,
48 },
49 TooManySymbols {
50 got: usize,
51 },
52}
53
54#[cfg(feature = "std")]
55impl std::error::Error for FSETableError {
56 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
57 match self {
58 FSETableError::GetBitsError(source: &GetBitsError) => Some(source),
59 _ => None,
60 }
61 }
62}
63
64impl core::fmt::Display for FSETableError {
65 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66 match self {
67 FSETableError::AccLogIsZero => write!(f, "Acclog must be at least 1"),
68 FSETableError::AccLogTooBig { got, max } => {
69 write!(
70 f,
71 "Found FSE acc_log: {0} bigger than allowed maximum in this case: {1}",
72 got, max
73 )
74 }
75 FSETableError::GetBitsError(e) => write!(f, "{:?}", e),
76 FSETableError::ProbabilityCounterMismatch {
77 got,
78 expected_sum,
79 symbol_probabilities,
80 } => {
81 write!(f,
82 "The counter ({}) exceeded the expected sum: {}. This means an error or corrupted data \n {:?}",
83 got,
84 expected_sum,
85 symbol_probabilities,
86 )
87 }
88 FSETableError::TooManySymbols { got } => {
89 write!(
90 f,
91 "There are too many symbols in this distribution: {}. Max: 256",
92 got,
93 )
94 }
95 }
96 }
97}
98
99impl From<GetBitsError> for FSETableError {
100 fn from(val: GetBitsError) -> Self {
101 Self::GetBitsError(val)
102 }
103}
104
105pub struct FSEDecoder<'table> {
106 /// An FSE state value represents an index in the FSE table.
107 pub state: Entry,
108 /// A reference to the table used for decoding.
109 table: &'table FSETable,
110}
111
112#[derive(Debug)]
113#[non_exhaustive]
114pub enum FSEDecoderError {
115 GetBitsError(GetBitsError),
116 TableIsUninitialized,
117}
118
119#[cfg(feature = "std")]
120impl std::error::Error for FSEDecoderError {
121 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 match self {
123 FSEDecoderError::GetBitsError(source: &GetBitsError) => Some(source),
124 _ => None,
125 }
126 }
127}
128
129impl core::fmt::Display for FSEDecoderError {
130 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
131 match self {
132 FSEDecoderError::GetBitsError(e: &GetBitsError) => write!(f, "{:?}", e),
133 FSEDecoderError::TableIsUninitialized => {
134 write!(f, "Tried to use an uninitialized table!")
135 }
136 }
137 }
138}
139
140impl From<GetBitsError> for FSEDecoderError {
141 fn from(val: GetBitsError) -> Self {
142 Self::GetBitsError(val)
143 }
144}
145
146/// A single entry in an FSE table.
147#[derive(Copy, Clone)]
148pub struct Entry {
149 /// This value is used as an offset value, and it is added
150 /// to a value read from the stream to determine the next state value.
151 pub base_line: u32,
152 /// How many bits should be read from the stream when decoding this entry.
153 pub num_bits: u8,
154 /// The byte that should be put in the decode output when encountering this state.
155 pub symbol: u8,
156}
157
158/// This value is added to the first 4 bits of the stream to determine the
159/// `Accuracy_Log`
160const ACC_LOG_OFFSET: u8 = 5;
161
162fn highest_bit_set(x: u32) -> u32 {
163 assert!(x > 0);
164 u32::BITS - x.leading_zeros()
165}
166
167impl<'t> FSEDecoder<'t> {
168 /// Initialize a new Finite State Entropy decoder.
169 pub fn new(table: &'t FSETable) -> FSEDecoder<'t> {
170 FSEDecoder {
171 state: table.decode.first().copied().unwrap_or(Entry {
172 base_line: 0,
173 num_bits: 0,
174 symbol: 0,
175 }),
176 table,
177 }
178 }
179
180 /// Returns the byte associated with the symbol the internal cursor is pointing at.
181 pub fn decode_symbol(&self) -> u8 {
182 self.state.symbol
183 }
184
185 /// Initialize internal state and prepare for decoding. After this, `decode_symbol` can be called
186 /// to read the first symbol and `update_state` can be called to prepare to read the next symbol.
187 pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
188 if self.table.accuracy_log == 0 {
189 return Err(FSEDecoderError::TableIsUninitialized);
190 }
191 self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize];
192
193 Ok(())
194 }
195
196 /// Advance the internal state to decode the next symbol in the bitstream.
197 pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
198 let num_bits = self.state.num_bits;
199 let add = bits.get_bits(num_bits);
200 let base_line = self.state.base_line;
201 let new_state = base_line + add as u32;
202 self.state = self.table.decode[new_state as usize];
203
204 //println!("Update: {}, {} -> {}", base_line, add, self.state);
205 }
206}
207
208impl FSETable {
209 /// Initialize a new empty Finite State Entropy decoding table.
210 pub fn new(max_symbol: u8) -> FSETable {
211 FSETable {
212 max_symbol,
213 symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8
214 symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8
215 decode: Vec::new(), //depending on acc_log.
216 accuracy_log: 0,
217 }
218 }
219
220 /// Reset `self` and update `self`'s state to mirror the provided table.
221 pub fn reinit_from(&mut self, other: &Self) {
222 self.reset();
223 self.symbol_counter.extend_from_slice(&other.symbol_counter);
224 self.symbol_probabilities
225 .extend_from_slice(&other.symbol_probabilities);
226 self.decode.extend_from_slice(&other.decode);
227 self.accuracy_log = other.accuracy_log;
228 }
229
230 /// Empty the table and clear all internal state.
231 pub fn reset(&mut self) {
232 self.symbol_counter.clear();
233 self.symbol_probabilities.clear();
234 self.decode.clear();
235 self.accuracy_log = 0;
236 }
237
238 /// returns how many BYTEs (not bits) were read while building the decoder
239 pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
240 self.accuracy_log = 0;
241
242 let bytes_read = self.read_probabilities(source, max_log)?;
243 self.build_decoding_table()?;
244
245 Ok(bytes_read)
246 }
247
248 /// Given the provided accuracy log, build a decoding table from that log.
249 pub fn build_from_probabilities(
250 &mut self,
251 acc_log: u8,
252 probs: &[i32],
253 ) -> Result<(), FSETableError> {
254 if acc_log == 0 {
255 return Err(FSETableError::AccLogIsZero);
256 }
257 self.symbol_probabilities = probs.to_vec();
258 self.accuracy_log = acc_log;
259 self.build_decoding_table()
260 }
261
262 /// Build the actual decoding table after probabilities have been read into the table.
263 /// After this function is called, the decoding process can begin.
264 fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
265 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
266 return Err(FSETableError::TooManySymbols {
267 got: self.symbol_probabilities.len(),
268 });
269 }
270
271 self.decode.clear();
272
273 let table_size = 1 << self.accuracy_log;
274 if self.decode.len() < table_size {
275 self.decode.reserve(table_size - self.decode.len());
276 }
277 //fill with dummy entries
278 self.decode.resize(
279 table_size,
280 Entry {
281 base_line: 0,
282 num_bits: 0,
283 symbol: 0,
284 },
285 );
286
287 let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol
288
289 //first scan for all -1 probabilities and place them at the top of the table
290 for symbol in 0..self.symbol_probabilities.len() {
291 if self.symbol_probabilities[symbol] == -1 {
292 negative_idx -= 1;
293 let entry = &mut self.decode[negative_idx];
294 entry.symbol = symbol as u8;
295 entry.base_line = 0;
296 entry.num_bits = self.accuracy_log;
297 }
298 }
299
300 //then place in a semi-random order all of the other symbols
301 let mut position = 0;
302 for idx in 0..self.symbol_probabilities.len() {
303 let symbol = idx as u8;
304 if self.symbol_probabilities[idx] <= 0 {
305 continue;
306 }
307
308 //for each probability point the symbol gets on slot
309 let prob = self.symbol_probabilities[idx];
310 for _ in 0..prob {
311 let entry = &mut self.decode[position];
312 entry.symbol = symbol;
313
314 position = next_position(position, table_size);
315 while position >= negative_idx {
316 position = next_position(position, table_size);
317 //everything above negative_idx is already taken
318 }
319 }
320 }
321
322 // baselines and num_bits can only be calculated when all symbols have been spread
323 self.symbol_counter.clear();
324 self.symbol_counter
325 .resize(self.symbol_probabilities.len(), 0);
326 for idx in 0..negative_idx {
327 let entry = &mut self.decode[idx];
328 let symbol = entry.symbol;
329 let prob = self.symbol_probabilities[symbol as usize];
330
331 let symbol_count = self.symbol_counter[symbol as usize];
332 let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
333
334 //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb);
335
336 assert!(nb <= self.accuracy_log);
337 self.symbol_counter[symbol as usize] += 1;
338
339 entry.base_line = bl;
340 entry.num_bits = nb;
341 }
342 Ok(())
343 }
344
345 /// Read the accuracy log and the probability table from the source and return the number of bytes
346 /// read. If the size of the table is larger than the provided `max_log`, return an error.
347 fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
348 self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here
349
350 let mut br = BitReader::new(source);
351 self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
352 if self.accuracy_log > max_log {
353 return Err(FSETableError::AccLogTooBig {
354 got: self.accuracy_log,
355 max: max_log,
356 });
357 }
358 if self.accuracy_log == 0 {
359 return Err(FSETableError::AccLogIsZero);
360 }
361
362 let probability_sum = 1 << self.accuracy_log;
363 let mut probability_counter = 0;
364
365 while probability_counter < probability_sum {
366 let max_remaining_value = probability_sum - probability_counter + 1;
367 let bits_to_read = highest_bit_set(max_remaining_value);
368
369 let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
370
371 let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
372 let mask = (1 << (bits_to_read - 1)) - 1;
373 let small_value = unchecked_value & mask;
374
375 let value = if small_value < low_threshold {
376 br.return_bits(1);
377 small_value
378 } else if unchecked_value > mask {
379 unchecked_value - low_threshold
380 } else {
381 unchecked_value
382 };
383 //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value);
384
385 let prob = (value as i32) - 1;
386
387 self.symbol_probabilities.push(prob);
388 if prob != 0 {
389 if prob > 0 {
390 probability_counter += prob as u32;
391 } else {
392 // probability -1 counts as 1
393 assert!(prob == -1);
394 probability_counter += 1;
395 }
396 } else {
397 //fast skip further zero probabilities
398 loop {
399 let skip_amount = br.get_bits(2)? as usize;
400
401 self.symbol_probabilities
402 .resize(self.symbol_probabilities.len() + skip_amount, 0);
403 if skip_amount != 3 {
404 break;
405 }
406 }
407 }
408 }
409
410 if probability_counter != probability_sum {
411 return Err(FSETableError::ProbabilityCounterMismatch {
412 got: probability_counter,
413 expected_sum: probability_sum,
414 symbol_probabilities: self.symbol_probabilities.clone(),
415 });
416 }
417 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
418 return Err(FSETableError::TooManySymbols {
419 got: self.symbol_probabilities.len(),
420 });
421 }
422
423 let bytes_read = if br.bits_read() % 8 == 0 {
424 br.bits_read() / 8
425 } else {
426 (br.bits_read() / 8) + 1
427 };
428
429 Ok(bytes_read)
430 }
431}
432
433//utility functions for building the decoding table from probabilities
434/// Calculate the position of the next entry of the table given the current
435/// position and size of the table.
436fn next_position(mut p: usize, table_size: usize) -> usize {
437 p += (table_size >> 1) + (table_size >> 3) + 3;
438 p &= table_size - 1;
439 p
440}
441
442fn calc_baseline_and_numbits(
443 num_states_total: u32,
444 num_states_symbol: u32,
445 state_number: u32,
446) -> (u32, u8) {
447 let num_state_slices: u32 = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
448 num_states_symbol
449 } else {
450 1 << (highest_bit_set(num_states_symbol))
451 }; //always power of two
452
453 let num_double_width_state_slices: u32 = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed
454 let num_single_width_state_slices: u32 = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states
455 let slice_width: u32 = num_states_total / num_state_slices; //size of a single width slice of states
456 let num_bits: u32 = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice
457
458 if state_number < num_double_width_state_slices {
459 let baseline: u32 = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
460 (baseline, num_bits as u8 + 1)
461 } else {
462 let index_shifted: u32 = state_number - num_double_width_state_slices;
463 ((index_shifted * slice_width), num_bits as u8)
464 }
465}
466