1use crate::decoding::bit_reader::BitReader;
2use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
3use alloc::vec::Vec;
4
5pub struct FSETable {
6 pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state
7
8 pub accuracy_log: u8,
9 pub symbol_probabilities: Vec<i32>, //used while building the decode Vector
10 symbol_counter: Vec<u32>,
11}
12
13impl Default for FSETable {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19#[derive(Debug, derive_more::Display, derive_more::From)]
20#[cfg_attr(feature = "std", derive(derive_more::Error))]
21#[non_exhaustive]
22pub enum FSETableError {
23 #[display(fmt = "Acclog must be at least 1")]
24 AccLogIsZero,
25 #[display(fmt = "Found FSE acc_log: {got} bigger than allowed maximum in this case: {max}")]
26 AccLogTooBig { got: u8, max: u8 },
27 #[display(fmt = "{_0:?}")]
28 #[from]
29 GetBitsError(GetBitsError),
30 #[display(
31 fmt = "The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}"
32 )]
33 ProbabilityCounterMismatch {
34 got: u32,
35 expected_sum: u32,
36 symbol_probabilities: Vec<i32>,
37 },
38 #[display(fmt = "There are too many symbols in this distribution: {got}. Max: 256")]
39 TooManySymbols { got: usize },
40}
41
42pub struct FSEDecoder<'table> {
43 pub state: Entry,
44 table: &'table FSETable,
45}
46
47#[derive(Debug, derive_more::Display, derive_more::From)]
48#[cfg_attr(feature = "std", derive(derive_more::Error))]
49#[non_exhaustive]
50pub enum FSEDecoderError {
51 #[display(fmt = "{_0:?}")]
52 #[from]
53 GetBitsError(GetBitsError),
54 #[display(fmt = "Tried to use an uninitialized table!")]
55 TableIsUninitialized,
56}
57
58#[derive(Copy, Clone)]
59pub struct Entry {
60 pub base_line: u32,
61 pub num_bits: u8,
62 pub symbol: u8,
63}
64
65const ACC_LOG_OFFSET: u8 = 5;
66
67fn highest_bit_set(x: u32) -> u32 {
68 assert!(x > 0);
69 u32::BITS - x.leading_zeros()
70}
71
72impl<'t> FSEDecoder<'t> {
73 pub fn new(table: &'t FSETable) -> FSEDecoder<'_> {
74 FSEDecoder {
75 state: table.decode.first().copied().unwrap_or(Entry {
76 base_line: 0,
77 num_bits: 0,
78 symbol: 0,
79 }),
80 table,
81 }
82 }
83
84 pub fn decode_symbol(&self) -> u8 {
85 self.state.symbol
86 }
87
88 pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> {
89 if self.table.accuracy_log == 0 {
90 return Err(FSEDecoderError::TableIsUninitialized);
91 }
92 self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize];
93
94 Ok(())
95 }
96
97 pub fn update_state(
98 &mut self,
99 bits: &mut BitReaderReversed<'_>,
100 ) -> Result<(), FSEDecoderError> {
101 let num_bits = self.state.num_bits;
102 let add = bits.get_bits(num_bits)?;
103 let base_line = self.state.base_line;
104 let new_state = base_line + add as u32;
105 self.state = self.table.decode[new_state as usize];
106
107 //println!("Update: {}, {} -> {}", base_line, add, self.state);
108 Ok(())
109 }
110}
111
112impl FSETable {
113 pub fn new() -> FSETable {
114 FSETable {
115 symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8
116 symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8
117 decode: Vec::new(), //depending on acc_log.
118 accuracy_log: 0,
119 }
120 }
121
122 pub fn reinit_from(&mut self, other: &Self) {
123 self.reset();
124 self.symbol_counter.extend_from_slice(&other.symbol_counter);
125 self.symbol_probabilities
126 .extend_from_slice(&other.symbol_probabilities);
127 self.decode.extend_from_slice(&other.decode);
128 self.accuracy_log = other.accuracy_log;
129 }
130
131 pub fn reset(&mut self) {
132 self.symbol_counter.clear();
133 self.symbol_probabilities.clear();
134 self.decode.clear();
135 self.accuracy_log = 0;
136 }
137
138 //returns how many BYTEs (not bits) were read while building the decoder
139 pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
140 self.accuracy_log = 0;
141
142 let bytes_read = self.read_probabilities(source, max_log)?;
143 self.build_decoding_table();
144
145 Ok(bytes_read)
146 }
147
148 pub fn build_from_probabilities(
149 &mut self,
150 acc_log: u8,
151 probs: &[i32],
152 ) -> Result<(), FSETableError> {
153 if acc_log == 0 {
154 return Err(FSETableError::AccLogIsZero);
155 }
156 self.symbol_probabilities = probs.to_vec();
157 self.accuracy_log = acc_log;
158 self.build_decoding_table();
159 Ok(())
160 }
161
162 fn build_decoding_table(&mut self) {
163 self.decode.clear();
164
165 let table_size = 1 << self.accuracy_log;
166 if self.decode.len() < table_size {
167 self.decode.reserve(table_size - self.decode.len());
168 }
169 //fill with dummy entries
170 self.decode.resize(
171 table_size,
172 Entry {
173 base_line: 0,
174 num_bits: 0,
175 symbol: 0,
176 },
177 );
178
179 let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol
180
181 //first scan for all -1 probabilities and place them at the top of the table
182 for symbol in 0..self.symbol_probabilities.len() {
183 if self.symbol_probabilities[symbol] == -1 {
184 negative_idx -= 1;
185 let entry = &mut self.decode[negative_idx];
186 entry.symbol = symbol as u8;
187 entry.base_line = 0;
188 entry.num_bits = self.accuracy_log;
189 }
190 }
191
192 //then place in a semi-random order all of the other symbols
193 let mut position = 0;
194 for idx in 0..self.symbol_probabilities.len() {
195 let symbol = idx as u8;
196 if self.symbol_probabilities[idx] <= 0 {
197 continue;
198 }
199
200 //for each probability point the symbol gets on slot
201 let prob = self.symbol_probabilities[idx];
202 for _ in 0..prob {
203 let entry = &mut self.decode[position];
204 entry.symbol = symbol;
205
206 position = next_position(position, table_size);
207 while position >= negative_idx {
208 position = next_position(position, table_size);
209 //everything above negative_idx is already taken
210 }
211 }
212 }
213
214 // baselines and num_bits can only be calculated when all symbols have been spread
215 self.symbol_counter.clear();
216 self.symbol_counter
217 .resize(self.symbol_probabilities.len(), 0);
218 for idx in 0..negative_idx {
219 let entry = &mut self.decode[idx];
220 let symbol = entry.symbol;
221 let prob = self.symbol_probabilities[symbol as usize];
222
223 let symbol_count = self.symbol_counter[symbol as usize];
224 let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
225
226 //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb);
227
228 assert!(nb <= self.accuracy_log);
229 self.symbol_counter[symbol as usize] += 1;
230
231 entry.base_line = bl;
232 entry.num_bits = nb;
233 }
234 }
235
236 fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
237 self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here
238
239 let mut br = BitReader::new(source);
240 self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
241 if self.accuracy_log > max_log {
242 return Err(FSETableError::AccLogTooBig {
243 got: self.accuracy_log,
244 max: max_log,
245 });
246 }
247 if self.accuracy_log == 0 {
248 return Err(FSETableError::AccLogIsZero);
249 }
250
251 let probablility_sum = 1 << self.accuracy_log;
252 let mut probability_counter = 0;
253
254 while probability_counter < probablility_sum {
255 let max_remaining_value = probablility_sum - probability_counter + 1;
256 let bits_to_read = highest_bit_set(max_remaining_value);
257
258 let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
259
260 let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
261 let mask = (1 << (bits_to_read - 1)) - 1;
262 let small_value = unchecked_value & mask;
263
264 let value = if small_value < low_threshold {
265 br.return_bits(1);
266 small_value
267 } else if unchecked_value > mask {
268 unchecked_value - low_threshold
269 } else {
270 unchecked_value
271 };
272 //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value);
273
274 let prob = (value as i32) - 1;
275
276 self.symbol_probabilities.push(prob);
277 if prob != 0 {
278 if prob > 0 {
279 probability_counter += prob as u32;
280 } else {
281 // probability -1 counts as 1
282 assert!(prob == -1);
283 probability_counter += 1;
284 }
285 } else {
286 //fast skip further zero probabilities
287 loop {
288 let skip_amount = br.get_bits(2)? as usize;
289
290 self.symbol_probabilities
291 .resize(self.symbol_probabilities.len() + skip_amount, 0);
292 if skip_amount != 3 {
293 break;
294 }
295 }
296 }
297 }
298
299 if probability_counter != probablility_sum {
300 return Err(FSETableError::ProbabilityCounterMismatch {
301 got: probability_counter,
302 expected_sum: probablility_sum,
303 symbol_probabilities: self.symbol_probabilities.clone(),
304 });
305 }
306 if self.symbol_probabilities.len() > 256 {
307 return Err(FSETableError::TooManySymbols {
308 got: self.symbol_probabilities.len(),
309 });
310 }
311
312 let bytes_read = if br.bits_read() % 8 == 0 {
313 br.bits_read() / 8
314 } else {
315 (br.bits_read() / 8) + 1
316 };
317 Ok(bytes_read)
318 }
319}
320
321//utility functions for building the decoding table from probabilities
322fn next_position(mut p: usize, table_size: usize) -> usize {
323 p += (table_size >> 1) + (table_size >> 3) + 3;
324 p &= table_size - 1;
325 p
326}
327
328fn calc_baseline_and_numbits(
329 num_states_total: u32,
330 num_states_symbol: u32,
331 state_number: u32,
332) -> (u32, u8) {
333 let num_state_slices: u32 = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol {
334 num_states_symbol
335 } else {
336 1 << (highest_bit_set(num_states_symbol))
337 }; //always power of two
338
339 let num_double_width_state_slices: u32 = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed
340 let num_single_width_state_slices: u32 = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states
341 let slice_width: u32 = num_states_total / num_state_slices; //size of a single width slice of states
342 let num_bits: u32 = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice
343
344 if state_number < num_double_width_state_slices {
345 let baseline: u32 = num_single_width_state_slices * slice_width + state_number * slice_width * 2;
346 (baseline, num_bits as u8 + 1)
347 } else {
348 let index_shifted: u32 = state_number - num_double_width_state_slices;
349 ((index_shifted * slice_width), num_bits as u8)
350 }
351}
352