1 | use crate::decoding::bit_reader::BitReader; |
2 | use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
3 | use alloc::vec::Vec; |
4 | |
5 | pub struct FSETable { |
6 | pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state |
7 | |
8 | pub accuracy_log: u8, |
9 | pub symbol_probabilities: Vec<i32>, //used while building the decode Vector |
10 | symbol_counter: Vec<u32>, |
11 | } |
12 | |
13 | impl Default for FSETable { |
14 | fn default() -> Self { |
15 | Self::new() |
16 | } |
17 | } |
18 | |
19 | #[derive (Debug, derive_more::Display, derive_more::From)] |
20 | #[cfg_attr (feature = "std" , derive(derive_more::Error))] |
21 | #[non_exhaustive ] |
22 | pub enum FSETableError { |
23 | #[display(fmt = "Acclog must be at least 1" )] |
24 | AccLogIsZero, |
25 | #[display(fmt = "Found FSE acc_log: {got} bigger than allowed maximum in this case: {max}" )] |
26 | AccLogTooBig { got: u8, max: u8 }, |
27 | #[display(fmt = "{_0:?}" )] |
28 | #[from] |
29 | GetBitsError(GetBitsError), |
30 | #[display( |
31 | fmt = "The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}" |
32 | )] |
33 | ProbabilityCounterMismatch { |
34 | got: u32, |
35 | expected_sum: u32, |
36 | symbol_probabilities: Vec<i32>, |
37 | }, |
38 | #[display(fmt = "There are too many symbols in this distribution: {got}. Max: 256" )] |
39 | TooManySymbols { got: usize }, |
40 | } |
41 | |
42 | pub struct FSEDecoder<'table> { |
43 | pub state: Entry, |
44 | table: &'table FSETable, |
45 | } |
46 | |
47 | #[derive (Debug, derive_more::Display, derive_more::From)] |
48 | #[cfg_attr (feature = "std" , derive(derive_more::Error))] |
49 | #[non_exhaustive ] |
50 | pub enum FSEDecoderError { |
51 | #[display(fmt = "{_0:?}" )] |
52 | #[from] |
53 | GetBitsError(GetBitsError), |
54 | #[display(fmt = "Tried to use an uninitialized table!" )] |
55 | TableIsUninitialized, |
56 | } |
57 | |
58 | #[derive (Copy, Clone)] |
59 | pub struct Entry { |
60 | pub base_line: u32, |
61 | pub num_bits: u8, |
62 | pub symbol: u8, |
63 | } |
64 | |
65 | const ACC_LOG_OFFSET: u8 = 5; |
66 | |
67 | fn highest_bit_set(x: u32) -> u32 { |
68 | assert!(x > 0); |
69 | u32::BITS - x.leading_zeros() |
70 | } |
71 | |
72 | impl<'t> FSEDecoder<'t> { |
73 | pub fn new(table: &'t FSETable) -> FSEDecoder<'_> { |
74 | FSEDecoder { |
75 | state: table.decode.first().copied().unwrap_or(Entry { |
76 | base_line: 0, |
77 | num_bits: 0, |
78 | symbol: 0, |
79 | }), |
80 | table, |
81 | } |
82 | } |
83 | |
84 | pub fn decode_symbol(&self) -> u8 { |
85 | self.state.symbol |
86 | } |
87 | |
88 | pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> { |
89 | if self.table.accuracy_log == 0 { |
90 | return Err(FSEDecoderError::TableIsUninitialized); |
91 | } |
92 | self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize]; |
93 | |
94 | Ok(()) |
95 | } |
96 | |
97 | pub fn update_state( |
98 | &mut self, |
99 | bits: &mut BitReaderReversed<'_>, |
100 | ) -> Result<(), FSEDecoderError> { |
101 | let num_bits = self.state.num_bits; |
102 | let add = bits.get_bits(num_bits)?; |
103 | let base_line = self.state.base_line; |
104 | let new_state = base_line + add as u32; |
105 | self.state = self.table.decode[new_state as usize]; |
106 | |
107 | //println!("Update: {}, {} -> {}", base_line, add, self.state); |
108 | Ok(()) |
109 | } |
110 | } |
111 | |
112 | impl FSETable { |
113 | pub fn new() -> FSETable { |
114 | FSETable { |
115 | symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8 |
116 | symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8 |
117 | decode: Vec::new(), //depending on acc_log. |
118 | accuracy_log: 0, |
119 | } |
120 | } |
121 | |
122 | pub fn reinit_from(&mut self, other: &Self) { |
123 | self.reset(); |
124 | self.symbol_counter.extend_from_slice(&other.symbol_counter); |
125 | self.symbol_probabilities |
126 | .extend_from_slice(&other.symbol_probabilities); |
127 | self.decode.extend_from_slice(&other.decode); |
128 | self.accuracy_log = other.accuracy_log; |
129 | } |
130 | |
131 | pub fn reset(&mut self) { |
132 | self.symbol_counter.clear(); |
133 | self.symbol_probabilities.clear(); |
134 | self.decode.clear(); |
135 | self.accuracy_log = 0; |
136 | } |
137 | |
138 | //returns how many BYTEs (not bits) were read while building the decoder |
139 | pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { |
140 | self.accuracy_log = 0; |
141 | |
142 | let bytes_read = self.read_probabilities(source, max_log)?; |
143 | self.build_decoding_table(); |
144 | |
145 | Ok(bytes_read) |
146 | } |
147 | |
148 | pub fn build_from_probabilities( |
149 | &mut self, |
150 | acc_log: u8, |
151 | probs: &[i32], |
152 | ) -> Result<(), FSETableError> { |
153 | if acc_log == 0 { |
154 | return Err(FSETableError::AccLogIsZero); |
155 | } |
156 | self.symbol_probabilities = probs.to_vec(); |
157 | self.accuracy_log = acc_log; |
158 | self.build_decoding_table(); |
159 | Ok(()) |
160 | } |
161 | |
162 | fn build_decoding_table(&mut self) { |
163 | self.decode.clear(); |
164 | |
165 | let table_size = 1 << self.accuracy_log; |
166 | if self.decode.len() < table_size { |
167 | self.decode.reserve(table_size - self.decode.len()); |
168 | } |
169 | //fill with dummy entries |
170 | self.decode.resize( |
171 | table_size, |
172 | Entry { |
173 | base_line: 0, |
174 | num_bits: 0, |
175 | symbol: 0, |
176 | }, |
177 | ); |
178 | |
179 | let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol |
180 | |
181 | //first scan for all -1 probabilities and place them at the top of the table |
182 | for symbol in 0..self.symbol_probabilities.len() { |
183 | if self.symbol_probabilities[symbol] == -1 { |
184 | negative_idx -= 1; |
185 | let entry = &mut self.decode[negative_idx]; |
186 | entry.symbol = symbol as u8; |
187 | entry.base_line = 0; |
188 | entry.num_bits = self.accuracy_log; |
189 | } |
190 | } |
191 | |
192 | //then place in a semi-random order all of the other symbols |
193 | let mut position = 0; |
194 | for idx in 0..self.symbol_probabilities.len() { |
195 | let symbol = idx as u8; |
196 | if self.symbol_probabilities[idx] <= 0 { |
197 | continue; |
198 | } |
199 | |
200 | //for each probability point the symbol gets on slot |
201 | let prob = self.symbol_probabilities[idx]; |
202 | for _ in 0..prob { |
203 | let entry = &mut self.decode[position]; |
204 | entry.symbol = symbol; |
205 | |
206 | position = next_position(position, table_size); |
207 | while position >= negative_idx { |
208 | position = next_position(position, table_size); |
209 | //everything above negative_idx is already taken |
210 | } |
211 | } |
212 | } |
213 | |
214 | // baselines and num_bits can only be calculated when all symbols have been spread |
215 | self.symbol_counter.clear(); |
216 | self.symbol_counter |
217 | .resize(self.symbol_probabilities.len(), 0); |
218 | for idx in 0..negative_idx { |
219 | let entry = &mut self.decode[idx]; |
220 | let symbol = entry.symbol; |
221 | let prob = self.symbol_probabilities[symbol as usize]; |
222 | |
223 | let symbol_count = self.symbol_counter[symbol as usize]; |
224 | let (bl, nb) = calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count); |
225 | |
226 | //println!("symbol: {:2}, table: {}, prob: {:3}, count: {:3}, bl: {:3}, nb: {:2}", symbol, table_size, prob, symbol_count, bl, nb); |
227 | |
228 | assert!(nb <= self.accuracy_log); |
229 | self.symbol_counter[symbol as usize] += 1; |
230 | |
231 | entry.base_line = bl; |
232 | entry.num_bits = nb; |
233 | } |
234 | } |
235 | |
236 | fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> { |
237 | self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here |
238 | |
239 | let mut br = BitReader::new(source); |
240 | self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8); |
241 | if self.accuracy_log > max_log { |
242 | return Err(FSETableError::AccLogTooBig { |
243 | got: self.accuracy_log, |
244 | max: max_log, |
245 | }); |
246 | } |
247 | if self.accuracy_log == 0 { |
248 | return Err(FSETableError::AccLogIsZero); |
249 | } |
250 | |
251 | let probablility_sum = 1 << self.accuracy_log; |
252 | let mut probability_counter = 0; |
253 | |
254 | while probability_counter < probablility_sum { |
255 | let max_remaining_value = probablility_sum - probability_counter + 1; |
256 | let bits_to_read = highest_bit_set(max_remaining_value); |
257 | |
258 | let unchecked_value = br.get_bits(bits_to_read as usize)? as u32; |
259 | |
260 | let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value); |
261 | let mask = (1 << (bits_to_read - 1)) - 1; |
262 | let small_value = unchecked_value & mask; |
263 | |
264 | let value = if small_value < low_threshold { |
265 | br.return_bits(1); |
266 | small_value |
267 | } else if unchecked_value > mask { |
268 | unchecked_value - low_threshold |
269 | } else { |
270 | unchecked_value |
271 | }; |
272 | //println!("{}, {}, {}", self.symbol_probablilities.len(), unchecked_value, value); |
273 | |
274 | let prob = (value as i32) - 1; |
275 | |
276 | self.symbol_probabilities.push(prob); |
277 | if prob != 0 { |
278 | if prob > 0 { |
279 | probability_counter += prob as u32; |
280 | } else { |
281 | // probability -1 counts as 1 |
282 | assert!(prob == -1); |
283 | probability_counter += 1; |
284 | } |
285 | } else { |
286 | //fast skip further zero probabilities |
287 | loop { |
288 | let skip_amount = br.get_bits(2)? as usize; |
289 | |
290 | self.symbol_probabilities |
291 | .resize(self.symbol_probabilities.len() + skip_amount, 0); |
292 | if skip_amount != 3 { |
293 | break; |
294 | } |
295 | } |
296 | } |
297 | } |
298 | |
299 | if probability_counter != probablility_sum { |
300 | return Err(FSETableError::ProbabilityCounterMismatch { |
301 | got: probability_counter, |
302 | expected_sum: probablility_sum, |
303 | symbol_probabilities: self.symbol_probabilities.clone(), |
304 | }); |
305 | } |
306 | if self.symbol_probabilities.len() > 256 { |
307 | return Err(FSETableError::TooManySymbols { |
308 | got: self.symbol_probabilities.len(), |
309 | }); |
310 | } |
311 | |
312 | let bytes_read = if br.bits_read() % 8 == 0 { |
313 | br.bits_read() / 8 |
314 | } else { |
315 | (br.bits_read() / 8) + 1 |
316 | }; |
317 | Ok(bytes_read) |
318 | } |
319 | } |
320 | |
321 | //utility functions for building the decoding table from probabilities |
322 | fn next_position(mut p: usize, table_size: usize) -> usize { |
323 | p += (table_size >> 1) + (table_size >> 3) + 3; |
324 | p &= table_size - 1; |
325 | p |
326 | } |
327 | |
328 | fn calc_baseline_and_numbits( |
329 | num_states_total: u32, |
330 | num_states_symbol: u32, |
331 | state_number: u32, |
332 | ) -> (u32, u8) { |
333 | let num_state_slices: u32 = if 1 << (highest_bit_set(num_states_symbol) - 1) == num_states_symbol { |
334 | num_states_symbol |
335 | } else { |
336 | 1 << (highest_bit_set(num_states_symbol)) |
337 | }; //always power of two |
338 | |
339 | let num_double_width_state_slices: u32 = num_state_slices - num_states_symbol; //leftovers to the power of two need to be distributed |
340 | let num_single_width_state_slices: u32 = num_states_symbol - num_double_width_state_slices; //these will not receive a double width slice of states |
341 | let slice_width: u32 = num_states_total / num_state_slices; //size of a single width slice of states |
342 | let num_bits: u32 = highest_bit_set(slice_width) - 1; //number of bits needed to read for one slice |
343 | |
344 | if state_number < num_double_width_state_slices { |
345 | let baseline: u32 = num_single_width_state_slices * slice_width + state_number * slice_width * 2; |
346 | (baseline, num_bits as u8 + 1) |
347 | } else { |
348 | let index_shifted: u32 = state_number - num_double_width_state_slices; |
349 | ((index_shifted * slice_width), num_bits as u8) |
350 | } |
351 | } |
352 | |