1 | use crate::decoder::DecodingError; |
2 | |
3 | use super::vp8::TreeNode; |
4 | |
5 | #[must_use ] |
6 | #[repr (transparent)] |
7 | pub(crate) struct BitResult<T> { |
8 | value_if_not_past_eof: T, |
9 | } |
10 | |
11 | #[must_use ] |
12 | pub(crate) struct BitResultAccumulator; |
13 | |
14 | impl<T> BitResult<T> { |
15 | const fn ok(value: T) -> Self { |
16 | Self { |
17 | value_if_not_past_eof: value, |
18 | } |
19 | } |
20 | |
21 | /// Instead of checking this result now, accumulate the burden of checking |
22 | /// into an accumulator. This accumulator must be checked in the end. |
23 | #[inline (always)] |
24 | pub(crate) fn or_accumulate(self, acc: &mut BitResultAccumulator) -> T { |
25 | let _ = acc; |
26 | self.value_if_not_past_eof |
27 | } |
28 | } |
29 | |
30 | impl<T: Default> BitResult<T> { |
31 | fn err() -> Self { |
32 | Self { |
33 | value_if_not_past_eof: T::default(), |
34 | } |
35 | } |
36 | } |
37 | |
38 | #[cfg_attr (test, derive(Debug))] |
39 | pub(crate) struct ArithmeticDecoder { |
40 | chunks: Box<[[u8; 4]]>, |
41 | state: State, |
42 | final_bytes: [u8; 3], |
43 | final_bytes_remaining: i8, |
44 | } |
45 | |
46 | #[cfg_attr (test, derive(Debug))] |
47 | #[derive (Clone, Copy)] |
48 | struct State { |
49 | chunk_index: usize, |
50 | value: u64, |
51 | range: u32, |
52 | bit_count: i32, |
53 | } |
54 | |
55 | #[cfg_attr (test, derive(Debug))] |
56 | struct FastDecoder<'a> { |
57 | chunks: &'a [[u8; 4]], |
58 | uncommitted_state: State, |
59 | save_state: &'a mut State, |
60 | } |
61 | |
62 | impl ArithmeticDecoder { |
63 | pub(crate) fn new() -> ArithmeticDecoder { |
64 | let state = State { |
65 | chunk_index: 0, |
66 | value: 0, |
67 | range: 255, |
68 | bit_count: -8, |
69 | }; |
70 | ArithmeticDecoder { |
71 | chunks: Box::new([]), |
72 | state, |
73 | final_bytes: [0; 3], |
74 | final_bytes_remaining: Self::FINAL_BYTES_REMAINING_EOF, |
75 | } |
76 | } |
77 | |
78 | pub(crate) fn init(&mut self, mut buf: Vec<[u8; 4]>, len: usize) -> Result<(), DecodingError> { |
79 | let mut final_bytes = [0; 3]; |
80 | let final_bytes_remaining = if len == 4 * buf.len() { |
81 | 0 |
82 | } else { |
83 | // Pop the last chunk (which is partial), then get length. |
84 | let Some(last_chunk) = buf.pop() else { |
85 | return Err(DecodingError::NotEnoughInitData); |
86 | }; |
87 | let len_rounded_down = 4 * buf.len(); |
88 | let num_bytes_popped = len - len_rounded_down; |
89 | debug_assert!(num_bytes_popped <= 3); |
90 | final_bytes[..num_bytes_popped].copy_from_slice(&last_chunk[..num_bytes_popped]); |
91 | for i in num_bytes_popped..4 { |
92 | debug_assert_eq!(last_chunk[i], 0, "unexpected {last_chunk:?}" ); |
93 | } |
94 | num_bytes_popped as i8 |
95 | }; |
96 | |
97 | let chunks = buf.into_boxed_slice(); |
98 | let state = State { |
99 | chunk_index: 0, |
100 | value: 0, |
101 | range: 255, |
102 | bit_count: -8, |
103 | }; |
104 | *self = Self { |
105 | chunks, |
106 | state, |
107 | final_bytes, |
108 | final_bytes_remaining, |
109 | }; |
110 | Ok(()) |
111 | } |
112 | |
113 | /// Start a span of reading operations from the buffer, without stopping |
114 | /// when the buffer runs out. For all valid webp images, the buffer will not |
115 | /// run out prematurely. Conversely if the buffer ends early, the webp image |
116 | /// cannot be correctly decoded and any intermediate results need to be |
117 | /// discarded anyway. |
118 | /// |
119 | /// Each call to `start_accumulated_result` must be followed by a call to |
120 | /// `check` on the *same* `ArithmeticDecoder`. |
121 | #[inline (always)] |
122 | pub(crate) fn start_accumulated_result(&mut self) -> BitResultAccumulator { |
123 | BitResultAccumulator |
124 | } |
125 | |
126 | /// Check that the read operations done so far were all valid. |
127 | #[inline (always)] |
128 | pub(crate) fn check<T>( |
129 | &self, |
130 | acc: BitResultAccumulator, |
131 | value_if_not_past_eof: T, |
132 | ) -> Result<T, DecodingError> { |
133 | // The accumulator does not store any state because doing so is |
134 | // too computationally expensive. Passing it around is a bit of |
135 | // formality (that is optimized out) to ensure we call `check` . |
136 | // Instead we check whether we have read past the end of the file. |
137 | let BitResultAccumulator = acc; |
138 | |
139 | if self.is_past_eof() { |
140 | Err(DecodingError::BitStreamError) |
141 | } else { |
142 | Ok(value_if_not_past_eof) |
143 | } |
144 | } |
145 | |
146 | fn keep_accumulating<T>( |
147 | &self, |
148 | acc: BitResultAccumulator, |
149 | value_if_not_past_eof: T, |
150 | ) -> BitResult<T> { |
151 | // The BitResult will be checked later by a different accumulator. |
152 | // Because it does not carry state, that is fine. |
153 | let BitResultAccumulator = acc; |
154 | |
155 | BitResult::ok(value_if_not_past_eof) |
156 | } |
157 | |
158 | // Do not inline this because inlining seems to worsen performance. |
159 | #[inline (never)] |
160 | pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult<bool> { |
161 | if let Some(b) = self.fast().read_bool(probability) { |
162 | return BitResult::ok(b); |
163 | } |
164 | |
165 | self.cold_read_bool(probability) |
166 | } |
167 | |
168 | // Do not inline this because inlining seems to worsen performance. |
169 | #[inline (never)] |
170 | pub(crate) fn read_flag(&mut self) -> BitResult<bool> { |
171 | if let Some(b) = self.fast().read_flag() { |
172 | return BitResult::ok(b); |
173 | } |
174 | |
175 | self.cold_read_flag() |
176 | } |
177 | |
178 | // Do not inline this because inlining seems to worsen performance. |
179 | #[inline (never)] |
180 | pub(crate) fn read_literal(&mut self, n: u8) -> BitResult<u8> { |
181 | if let Some(v) = self.fast().read_literal(n) { |
182 | return BitResult::ok(v); |
183 | } |
184 | |
185 | self.cold_read_literal(n) |
186 | } |
187 | |
188 | // Do not inline this because inlining seems to worsen performance. |
189 | #[inline (never)] |
190 | pub(crate) fn read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> { |
191 | if let Some(v) = self.fast().read_optional_signed_value(n) { |
192 | return BitResult::ok(v); |
193 | } |
194 | |
195 | self.cold_read_optional_signed_value(n) |
196 | } |
197 | |
198 | // This is generic and inlined just to skip the first bounds check. |
199 | #[inline ] |
200 | pub(crate) fn read_with_tree<const N: usize>(&mut self, tree: &[TreeNode; N]) -> BitResult<i8> { |
201 | let first_node = tree[0]; |
202 | self.read_with_tree_with_first_node(tree, first_node) |
203 | } |
204 | |
205 | // Do not inline this because inlining significantly worsens performance. |
206 | #[inline (never)] |
207 | pub(crate) fn read_with_tree_with_first_node( |
208 | &mut self, |
209 | tree: &[TreeNode], |
210 | first_node: TreeNode, |
211 | ) -> BitResult<i8> { |
212 | if let Some(v) = self.fast().read_with_tree(tree, first_node) { |
213 | return BitResult::ok(v); |
214 | } |
215 | |
216 | self.cold_read_with_tree(tree, usize::from(first_node.index)) |
217 | } |
218 | |
219 | // As a similar (but different) speedup to BitResult, the FastDecoder reads |
220 | // bits under an assumption and validates it at the end. |
221 | // |
222 | // The idea here is that for normal-sized webp images, the vast majority |
223 | // of bits are somewhere other than in the last four bytes. Therefore we |
224 | // can pretend the buffer has infinite size. After we are done reading, |
225 | // we check if we actually read past the end of `self.chunks`. |
226 | // If so, we backtrack (or rather we discard `uncommitted_state`) |
227 | // and try again with the slow approach. This might result in doing double |
228 | // work for those last few bytes -- in fact we even keep retrying the fast |
229 | // method to save an if-statement --, but more than make up for that by |
230 | // speeding up reading from the other thousands or millions of bytes. |
231 | fn fast(&mut self) -> FastDecoder<'_> { |
232 | FastDecoder { |
233 | chunks: &self.chunks, |
234 | uncommitted_state: self.state, |
235 | save_state: &mut self.state, |
236 | } |
237 | } |
238 | |
239 | const FINAL_BYTES_REMAINING_EOF: i8 = -0xE; |
240 | |
241 | fn load_from_final_bytes(&mut self) { |
242 | match self.final_bytes_remaining { |
243 | 1.. => { |
244 | self.final_bytes_remaining -= 1; |
245 | let byte = self.final_bytes[0]; |
246 | self.final_bytes.rotate_left(1); |
247 | self.state.value <<= 8; |
248 | self.state.value |= u64::from(byte); |
249 | self.state.bit_count += 8; |
250 | } |
251 | 0 => { |
252 | // libwebp seems to (sometimes?) allow bitstreams that read one byte past the end. |
253 | // This replicates that logic. |
254 | self.final_bytes_remaining -= 1; |
255 | self.state.value <<= 8; |
256 | self.state.bit_count += 8; |
257 | } |
258 | _ => { |
259 | self.final_bytes_remaining = Self::FINAL_BYTES_REMAINING_EOF; |
260 | } |
261 | } |
262 | } |
263 | |
264 | fn is_past_eof(&self) -> bool { |
265 | self.final_bytes_remaining == Self::FINAL_BYTES_REMAINING_EOF |
266 | } |
267 | |
268 | fn cold_read_bit(&mut self, probability: u8) -> BitResult<bool> { |
269 | if self.state.bit_count < 0 { |
270 | if let Some(chunk) = self.chunks.get(self.state.chunk_index).copied() { |
271 | let v = u32::from_be_bytes(chunk); |
272 | self.state.chunk_index += 1; |
273 | self.state.value <<= 32; |
274 | self.state.value |= u64::from(v); |
275 | self.state.bit_count += 32; |
276 | } else { |
277 | self.load_from_final_bytes(); |
278 | if self.is_past_eof() { |
279 | return BitResult::err(); |
280 | } |
281 | } |
282 | } |
283 | debug_assert!(self.state.bit_count >= 0); |
284 | |
285 | let probability = u32::from(probability); |
286 | let split = 1 + (((self.state.range - 1) * probability) >> 8); |
287 | let bigsplit = u64::from(split) << self.state.bit_count; |
288 | |
289 | let retval = if let Some(new_value) = self.state.value.checked_sub(bigsplit) { |
290 | self.state.range -= split; |
291 | self.state.value = new_value; |
292 | true |
293 | } else { |
294 | self.state.range = split; |
295 | false |
296 | }; |
297 | debug_assert!(self.state.range > 0); |
298 | |
299 | // Compute shift required to satisfy `self.state.range >= 128`. |
300 | // Apply that shift to `self.state.range` and `self.state.bitcount`. |
301 | // |
302 | // Subtract 24 because we only care about leading zeros in the |
303 | // lowest byte of `self.state.range` which is a `u32`. |
304 | let shift = self.state.range.leading_zeros().saturating_sub(24); |
305 | self.state.range <<= shift; |
306 | self.state.bit_count -= shift as i32; |
307 | debug_assert!(self.state.range >= 128); |
308 | |
309 | BitResult::ok(retval) |
310 | } |
311 | |
312 | #[cold ] |
313 | #[inline (never)] |
314 | fn cold_read_bool(&mut self, probability: u8) -> BitResult<bool> { |
315 | self.cold_read_bit(probability) |
316 | } |
317 | |
318 | #[cold ] |
319 | #[inline (never)] |
320 | fn cold_read_flag(&mut self) -> BitResult<bool> { |
321 | self.cold_read_bit(128) |
322 | } |
323 | |
324 | #[cold ] |
325 | #[inline (never)] |
326 | fn cold_read_literal(&mut self, n: u8) -> BitResult<u8> { |
327 | let mut v = 0u8; |
328 | let mut res = self.start_accumulated_result(); |
329 | |
330 | for _ in 0..n { |
331 | let b = self.cold_read_flag().or_accumulate(&mut res); |
332 | v = (v << 1) + u8::from(b); |
333 | } |
334 | |
335 | self.keep_accumulating(res, v) |
336 | } |
337 | |
338 | #[cold ] |
339 | #[inline (never)] |
340 | fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> { |
341 | let mut res = self.start_accumulated_result(); |
342 | let flag = self.cold_read_flag().or_accumulate(&mut res); |
343 | if !flag { |
344 | // We should not read further bits if the flag is not set. |
345 | return self.keep_accumulating(res, 0); |
346 | } |
347 | let magnitude = self.cold_read_literal(n).or_accumulate(&mut res); |
348 | let sign = self.cold_read_flag().or_accumulate(&mut res); |
349 | |
350 | let value = if sign { |
351 | -i32::from(magnitude) |
352 | } else { |
353 | i32::from(magnitude) |
354 | }; |
355 | self.keep_accumulating(res, value) |
356 | } |
357 | |
358 | #[cold ] |
359 | #[inline (never)] |
360 | fn cold_read_with_tree(&mut self, tree: &[TreeNode], start: usize) -> BitResult<i8> { |
361 | let mut index = start; |
362 | let mut res = self.start_accumulated_result(); |
363 | |
364 | loop { |
365 | let node = tree[index]; |
366 | let prob = node.prob; |
367 | let b = self.cold_read_bit(prob).or_accumulate(&mut res); |
368 | let t = if b { node.right } else { node.left }; |
369 | let new_index = usize::from(t); |
370 | if new_index < tree.len() { |
371 | index = new_index; |
372 | } else { |
373 | let value = TreeNode::value_from_branch(t); |
374 | return self.keep_accumulating(res, value); |
375 | } |
376 | } |
377 | } |
378 | } |
379 | |
380 | impl FastDecoder<'_> { |
381 | fn commit_if_valid<T>(self, value_if_not_past_eof: T) -> Option<T> { |
382 | // If `chunk_index > self.chunks.len()`, it means we used zeroes |
383 | // instead of an actual chunk and `value_if_not_past_eof` is nonsense. |
384 | if self.uncommitted_state.chunk_index <= self.chunks.len() { |
385 | *self.save_state = self.uncommitted_state; |
386 | Some(value_if_not_past_eof) |
387 | } else { |
388 | None |
389 | } |
390 | } |
391 | |
392 | fn read_bool(mut self, probability: u8) -> Option<bool> { |
393 | let bit = self.fast_read_bit(probability); |
394 | self.commit_if_valid(bit) |
395 | } |
396 | |
397 | fn read_flag(mut self) -> Option<bool> { |
398 | let value = self.fast_read_flag(); |
399 | self.commit_if_valid(value) |
400 | } |
401 | |
402 | fn read_literal(mut self, n: u8) -> Option<u8> { |
403 | let value = self.fast_read_literal(n); |
404 | self.commit_if_valid(value) |
405 | } |
406 | |
407 | fn read_optional_signed_value(mut self, n: u8) -> Option<i32> { |
408 | let flag = self.fast_read_flag(); |
409 | if !flag { |
410 | // We should not read further bits if the flag is not set. |
411 | return self.commit_if_valid(0); |
412 | } |
413 | let magnitude = self.fast_read_literal(n); |
414 | let sign = self.fast_read_flag(); |
415 | let value = if sign { |
416 | -i32::from(magnitude) |
417 | } else { |
418 | i32::from(magnitude) |
419 | }; |
420 | self.commit_if_valid(value) |
421 | } |
422 | |
423 | fn read_with_tree(mut self, tree: &[TreeNode], first_node: TreeNode) -> Option<i8> { |
424 | let value = self.fast_read_with_tree(tree, first_node); |
425 | self.commit_if_valid(value) |
426 | } |
427 | |
428 | fn fast_read_bit(&mut self, probability: u8) -> bool { |
429 | let State { |
430 | mut chunk_index, |
431 | mut value, |
432 | mut range, |
433 | mut bit_count, |
434 | } = self.uncommitted_state; |
435 | |
436 | if bit_count < 0 { |
437 | let chunk = self.chunks.get(chunk_index).copied(); |
438 | // We ignore invalid data inside the `fast_` functions, |
439 | // but we increase `chunk_index` below, so we can check |
440 | // whether we read invalid data in `commit_if_valid`. |
441 | let chunk = chunk.unwrap_or_default(); |
442 | |
443 | let v = u32::from_be_bytes(chunk); |
444 | chunk_index += 1; |
445 | value <<= 32; |
446 | value |= u64::from(v); |
447 | bit_count += 32; |
448 | } |
449 | debug_assert!(bit_count >= 0); |
450 | |
451 | let probability = u32::from(probability); |
452 | let split = 1 + (((range - 1) * probability) >> 8); |
453 | let bigsplit = u64::from(split) << bit_count; |
454 | |
455 | let retval = if let Some(new_value) = value.checked_sub(bigsplit) { |
456 | range -= split; |
457 | value = new_value; |
458 | true |
459 | } else { |
460 | range = split; |
461 | false |
462 | }; |
463 | debug_assert!(range > 0); |
464 | |
465 | // Compute shift required to satisfy `range >= 128`. |
466 | // Apply that shift to `range` and `self.bitcount`. |
467 | // |
468 | // Subtract 24 because we only care about leading zeros in the |
469 | // lowest byte of `range` which is a `u32`. |
470 | let shift = range.leading_zeros().saturating_sub(24); |
471 | range <<= shift; |
472 | bit_count -= shift as i32; |
473 | debug_assert!(range >= 128); |
474 | |
475 | self.uncommitted_state = State { |
476 | chunk_index, |
477 | value, |
478 | range, |
479 | bit_count, |
480 | }; |
481 | retval |
482 | } |
483 | |
484 | fn fast_read_flag(&mut self) -> bool { |
485 | let State { |
486 | mut chunk_index, |
487 | mut value, |
488 | mut range, |
489 | mut bit_count, |
490 | } = self.uncommitted_state; |
491 | |
492 | if bit_count < 0 { |
493 | let chunk = self.chunks.get(chunk_index).copied(); |
494 | // We ignore invalid data inside the `fast_` functions, |
495 | // but we increase `chunk_index` below, so we can check |
496 | // whether we read invalid data in `commit_if_valid`. |
497 | let chunk = chunk.unwrap_or_default(); |
498 | |
499 | let v = u32::from_be_bytes(chunk); |
500 | chunk_index += 1; |
501 | value <<= 32; |
502 | value |= u64::from(v); |
503 | bit_count += 32; |
504 | } |
505 | debug_assert!(bit_count >= 0); |
506 | |
507 | let half_range = range / 2; |
508 | let split = range - half_range; |
509 | let bigsplit = u64::from(split) << bit_count; |
510 | |
511 | let retval = if let Some(new_value) = value.checked_sub(bigsplit) { |
512 | range = half_range; |
513 | value = new_value; |
514 | true |
515 | } else { |
516 | range = split; |
517 | false |
518 | }; |
519 | debug_assert!(range > 0); |
520 | |
521 | // Compute shift required to satisfy `range >= 128`. |
522 | // Apply that shift to `range` and `self.bitcount`. |
523 | // |
524 | // Subtract 24 because we only care about leading zeros in the |
525 | // lowest byte of `range` which is a `u32`. |
526 | let shift = range.leading_zeros().saturating_sub(24); |
527 | range <<= shift; |
528 | bit_count -= shift as i32; |
529 | debug_assert!(range >= 128); |
530 | |
531 | self.uncommitted_state = State { |
532 | chunk_index, |
533 | value, |
534 | range, |
535 | bit_count, |
536 | }; |
537 | retval |
538 | } |
539 | |
540 | fn fast_read_literal(&mut self, n: u8) -> u8 { |
541 | let mut v = 0u8; |
542 | for _ in 0..n { |
543 | let b = self.fast_read_flag(); |
544 | v = (v << 1) + u8::from(b); |
545 | } |
546 | v |
547 | } |
548 | |
549 | fn fast_read_with_tree(&mut self, tree: &[TreeNode], mut node: TreeNode) -> i8 { |
550 | loop { |
551 | let prob = node.prob; |
552 | let b = self.fast_read_bit(prob); |
553 | let i = if b { node.right } else { node.left }; |
554 | let Some(next_node) = tree.get(usize::from(i)) else { |
555 | return TreeNode::value_from_branch(i); |
556 | }; |
557 | node = *next_node; |
558 | } |
559 | } |
560 | } |
561 | |
562 | #[cfg (test)] |
563 | mod tests { |
564 | use super::*; |
565 | |
566 | #[test ] |
567 | fn test_arithmetic_decoder_hello_short() { |
568 | let mut decoder = ArithmeticDecoder::new(); |
569 | let data = b"hel" ; |
570 | let size = data.len(); |
571 | let mut buf = vec![[0u8; 4]; 1]; |
572 | buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); |
573 | decoder.init(buf, size).unwrap(); |
574 | let mut res = decoder.start_accumulated_result(); |
575 | assert_eq!(false, decoder.read_flag().or_accumulate(&mut res)); |
576 | assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res)); |
577 | assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res)); |
578 | assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res)); |
579 | assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res)); |
580 | assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res)); |
581 | assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res)); |
582 | decoder.check(res, ()).unwrap(); |
583 | } |
584 | |
585 | #[test ] |
586 | fn test_arithmetic_decoder_hello_long() { |
587 | let mut decoder = ArithmeticDecoder::new(); |
588 | let data = b"hello world" ; |
589 | let size = data.len(); |
590 | let mut buf = vec![[0u8; 4]; (size + 3) / 4]; |
591 | buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]); |
592 | decoder.init(buf, size).unwrap(); |
593 | let mut res = decoder.start_accumulated_result(); |
594 | assert_eq!(false, decoder.read_flag().or_accumulate(&mut res)); |
595 | assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res)); |
596 | assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res)); |
597 | assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res)); |
598 | assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res)); |
599 | assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res)); |
600 | assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res)); |
601 | assert_eq!(31, decoder.read_literal(8).or_accumulate(&mut res)); |
602 | decoder.check(res, ()).unwrap(); |
603 | } |
604 | |
605 | #[test ] |
606 | fn test_arithmetic_decoder_uninit() { |
607 | let mut decoder = ArithmeticDecoder::new(); |
608 | let mut res = decoder.start_accumulated_result(); |
609 | let _ = decoder.read_flag().or_accumulate(&mut res); |
610 | let result = decoder.check(res, ()); |
611 | assert!(result.is_err()); |
612 | } |
613 | } |
614 | |