1use crate::decoder::DecodingError;
2
3use super::vp8::TreeNode;
4
5#[must_use]
6#[repr(transparent)]
7pub(crate) struct BitResult<T> {
8 value_if_not_past_eof: T,
9}
10
11#[must_use]
12pub(crate) struct BitResultAccumulator;
13
14impl<T> BitResult<T> {
15 const fn ok(value: T) -> Self {
16 Self {
17 value_if_not_past_eof: value,
18 }
19 }
20
21 /// Instead of checking this result now, accumulate the burden of checking
22 /// into an accumulator. This accumulator must be checked in the end.
23 #[inline(always)]
24 pub(crate) fn or_accumulate(self, acc: &mut BitResultAccumulator) -> T {
25 let _ = acc;
26 self.value_if_not_past_eof
27 }
28}
29
30impl<T: Default> BitResult<T> {
31 fn err() -> Self {
32 Self {
33 value_if_not_past_eof: T::default(),
34 }
35 }
36}
37
38#[cfg_attr(test, derive(Debug))]
39pub(crate) struct ArithmeticDecoder {
40 chunks: Box<[[u8; 4]]>,
41 state: State,
42 final_bytes: [u8; 3],
43 final_bytes_remaining: i8,
44}
45
46#[cfg_attr(test, derive(Debug))]
47#[derive(Clone, Copy)]
48struct State {
49 chunk_index: usize,
50 value: u64,
51 range: u32,
52 bit_count: i32,
53}
54
55#[cfg_attr(test, derive(Debug))]
56struct FastDecoder<'a> {
57 chunks: &'a [[u8; 4]],
58 uncommitted_state: State,
59 save_state: &'a mut State,
60}
61
62impl ArithmeticDecoder {
63 pub(crate) fn new() -> ArithmeticDecoder {
64 let state = State {
65 chunk_index: 0,
66 value: 0,
67 range: 255,
68 bit_count: -8,
69 };
70 ArithmeticDecoder {
71 chunks: Box::new([]),
72 state,
73 final_bytes: [0; 3],
74 final_bytes_remaining: Self::FINAL_BYTES_REMAINING_EOF,
75 }
76 }
77
78 pub(crate) fn init(&mut self, mut buf: Vec<[u8; 4]>, len: usize) -> Result<(), DecodingError> {
79 let mut final_bytes = [0; 3];
80 let final_bytes_remaining = if len == 4 * buf.len() {
81 0
82 } else {
83 // Pop the last chunk (which is partial), then get length.
84 let Some(last_chunk) = buf.pop() else {
85 return Err(DecodingError::NotEnoughInitData);
86 };
87 let len_rounded_down = 4 * buf.len();
88 let num_bytes_popped = len - len_rounded_down;
89 debug_assert!(num_bytes_popped <= 3);
90 final_bytes[..num_bytes_popped].copy_from_slice(&last_chunk[..num_bytes_popped]);
91 for i in num_bytes_popped..4 {
92 debug_assert_eq!(last_chunk[i], 0, "unexpected {last_chunk:?}");
93 }
94 num_bytes_popped as i8
95 };
96
97 let chunks = buf.into_boxed_slice();
98 let state = State {
99 chunk_index: 0,
100 value: 0,
101 range: 255,
102 bit_count: -8,
103 };
104 *self = Self {
105 chunks,
106 state,
107 final_bytes,
108 final_bytes_remaining,
109 };
110 Ok(())
111 }
112
113 /// Start a span of reading operations from the buffer, without stopping
114 /// when the buffer runs out. For all valid webp images, the buffer will not
115 /// run out prematurely. Conversely if the buffer ends early, the webp image
116 /// cannot be correctly decoded and any intermediate results need to be
117 /// discarded anyway.
118 ///
119 /// Each call to `start_accumulated_result` must be followed by a call to
120 /// `check` on the *same* `ArithmeticDecoder`.
121 #[inline(always)]
122 pub(crate) fn start_accumulated_result(&mut self) -> BitResultAccumulator {
123 BitResultAccumulator
124 }
125
126 /// Check that the read operations done so far were all valid.
127 #[inline(always)]
128 pub(crate) fn check<T>(
129 &self,
130 acc: BitResultAccumulator,
131 value_if_not_past_eof: T,
132 ) -> Result<T, DecodingError> {
133 // The accumulator does not store any state because doing so is
134 // too computationally expensive. Passing it around is a bit of
135 // formality (that is optimized out) to ensure we call `check` .
136 // Instead we check whether we have read past the end of the file.
137 let BitResultAccumulator = acc;
138
139 if self.is_past_eof() {
140 Err(DecodingError::BitStreamError)
141 } else {
142 Ok(value_if_not_past_eof)
143 }
144 }
145
146 fn keep_accumulating<T>(
147 &self,
148 acc: BitResultAccumulator,
149 value_if_not_past_eof: T,
150 ) -> BitResult<T> {
151 // The BitResult will be checked later by a different accumulator.
152 // Because it does not carry state, that is fine.
153 let BitResultAccumulator = acc;
154
155 BitResult::ok(value_if_not_past_eof)
156 }
157
158 // Do not inline this because inlining seems to worsen performance.
159 #[inline(never)]
160 pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult<bool> {
161 if let Some(b) = self.fast().read_bool(probability) {
162 return BitResult::ok(b);
163 }
164
165 self.cold_read_bool(probability)
166 }
167
168 // Do not inline this because inlining seems to worsen performance.
169 #[inline(never)]
170 pub(crate) fn read_flag(&mut self) -> BitResult<bool> {
171 if let Some(b) = self.fast().read_flag() {
172 return BitResult::ok(b);
173 }
174
175 self.cold_read_flag()
176 }
177
178 // Do not inline this because inlining seems to worsen performance.
179 #[inline(never)]
180 pub(crate) fn read_literal(&mut self, n: u8) -> BitResult<u8> {
181 if let Some(v) = self.fast().read_literal(n) {
182 return BitResult::ok(v);
183 }
184
185 self.cold_read_literal(n)
186 }
187
188 // Do not inline this because inlining seems to worsen performance.
189 #[inline(never)]
190 pub(crate) fn read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> {
191 if let Some(v) = self.fast().read_optional_signed_value(n) {
192 return BitResult::ok(v);
193 }
194
195 self.cold_read_optional_signed_value(n)
196 }
197
198 // This is generic and inlined just to skip the first bounds check.
199 #[inline]
200 pub(crate) fn read_with_tree<const N: usize>(&mut self, tree: &[TreeNode; N]) -> BitResult<i8> {
201 let first_node = tree[0];
202 self.read_with_tree_with_first_node(tree, first_node)
203 }
204
205 // Do not inline this because inlining significantly worsens performance.
206 #[inline(never)]
207 pub(crate) fn read_with_tree_with_first_node(
208 &mut self,
209 tree: &[TreeNode],
210 first_node: TreeNode,
211 ) -> BitResult<i8> {
212 if let Some(v) = self.fast().read_with_tree(tree, first_node) {
213 return BitResult::ok(v);
214 }
215
216 self.cold_read_with_tree(tree, usize::from(first_node.index))
217 }
218
219 // As a similar (but different) speedup to BitResult, the FastDecoder reads
220 // bits under an assumption and validates it at the end.
221 //
222 // The idea here is that for normal-sized webp images, the vast majority
223 // of bits are somewhere other than in the last four bytes. Therefore we
224 // can pretend the buffer has infinite size. After we are done reading,
225 // we check if we actually read past the end of `self.chunks`.
226 // If so, we backtrack (or rather we discard `uncommitted_state`)
227 // and try again with the slow approach. This might result in doing double
228 // work for those last few bytes -- in fact we even keep retrying the fast
229 // method to save an if-statement --, but more than make up for that by
230 // speeding up reading from the other thousands or millions of bytes.
231 fn fast(&mut self) -> FastDecoder<'_> {
232 FastDecoder {
233 chunks: &self.chunks,
234 uncommitted_state: self.state,
235 save_state: &mut self.state,
236 }
237 }
238
239 const FINAL_BYTES_REMAINING_EOF: i8 = -0xE;
240
241 fn load_from_final_bytes(&mut self) {
242 match self.final_bytes_remaining {
243 1.. => {
244 self.final_bytes_remaining -= 1;
245 let byte = self.final_bytes[0];
246 self.final_bytes.rotate_left(1);
247 self.state.value <<= 8;
248 self.state.value |= u64::from(byte);
249 self.state.bit_count += 8;
250 }
251 0 => {
252 // libwebp seems to (sometimes?) allow bitstreams that read one byte past the end.
253 // This replicates that logic.
254 self.final_bytes_remaining -= 1;
255 self.state.value <<= 8;
256 self.state.bit_count += 8;
257 }
258 _ => {
259 self.final_bytes_remaining = Self::FINAL_BYTES_REMAINING_EOF;
260 }
261 }
262 }
263
264 fn is_past_eof(&self) -> bool {
265 self.final_bytes_remaining == Self::FINAL_BYTES_REMAINING_EOF
266 }
267
268 fn cold_read_bit(&mut self, probability: u8) -> BitResult<bool> {
269 if self.state.bit_count < 0 {
270 if let Some(chunk) = self.chunks.get(self.state.chunk_index).copied() {
271 let v = u32::from_be_bytes(chunk);
272 self.state.chunk_index += 1;
273 self.state.value <<= 32;
274 self.state.value |= u64::from(v);
275 self.state.bit_count += 32;
276 } else {
277 self.load_from_final_bytes();
278 if self.is_past_eof() {
279 return BitResult::err();
280 }
281 }
282 }
283 debug_assert!(self.state.bit_count >= 0);
284
285 let probability = u32::from(probability);
286 let split = 1 + (((self.state.range - 1) * probability) >> 8);
287 let bigsplit = u64::from(split) << self.state.bit_count;
288
289 let retval = if let Some(new_value) = self.state.value.checked_sub(bigsplit) {
290 self.state.range -= split;
291 self.state.value = new_value;
292 true
293 } else {
294 self.state.range = split;
295 false
296 };
297 debug_assert!(self.state.range > 0);
298
299 // Compute shift required to satisfy `self.state.range >= 128`.
300 // Apply that shift to `self.state.range` and `self.state.bitcount`.
301 //
302 // Subtract 24 because we only care about leading zeros in the
303 // lowest byte of `self.state.range` which is a `u32`.
304 let shift = self.state.range.leading_zeros().saturating_sub(24);
305 self.state.range <<= shift;
306 self.state.bit_count -= shift as i32;
307 debug_assert!(self.state.range >= 128);
308
309 BitResult::ok(retval)
310 }
311
312 #[cold]
313 #[inline(never)]
314 fn cold_read_bool(&mut self, probability: u8) -> BitResult<bool> {
315 self.cold_read_bit(probability)
316 }
317
318 #[cold]
319 #[inline(never)]
320 fn cold_read_flag(&mut self) -> BitResult<bool> {
321 self.cold_read_bit(128)
322 }
323
324 #[cold]
325 #[inline(never)]
326 fn cold_read_literal(&mut self, n: u8) -> BitResult<u8> {
327 let mut v = 0u8;
328 let mut res = self.start_accumulated_result();
329
330 for _ in 0..n {
331 let b = self.cold_read_flag().or_accumulate(&mut res);
332 v = (v << 1) + u8::from(b);
333 }
334
335 self.keep_accumulating(res, v)
336 }
337
338 #[cold]
339 #[inline(never)]
340 fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> {
341 let mut res = self.start_accumulated_result();
342 let flag = self.cold_read_flag().or_accumulate(&mut res);
343 if !flag {
344 // We should not read further bits if the flag is not set.
345 return self.keep_accumulating(res, 0);
346 }
347 let magnitude = self.cold_read_literal(n).or_accumulate(&mut res);
348 let sign = self.cold_read_flag().or_accumulate(&mut res);
349
350 let value = if sign {
351 -i32::from(magnitude)
352 } else {
353 i32::from(magnitude)
354 };
355 self.keep_accumulating(res, value)
356 }
357
358 #[cold]
359 #[inline(never)]
360 fn cold_read_with_tree(&mut self, tree: &[TreeNode], start: usize) -> BitResult<i8> {
361 let mut index = start;
362 let mut res = self.start_accumulated_result();
363
364 loop {
365 let node = tree[index];
366 let prob = node.prob;
367 let b = self.cold_read_bit(prob).or_accumulate(&mut res);
368 let t = if b { node.right } else { node.left };
369 let new_index = usize::from(t);
370 if new_index < tree.len() {
371 index = new_index;
372 } else {
373 let value = TreeNode::value_from_branch(t);
374 return self.keep_accumulating(res, value);
375 }
376 }
377 }
378}
379
380impl FastDecoder<'_> {
381 fn commit_if_valid<T>(self, value_if_not_past_eof: T) -> Option<T> {
382 // If `chunk_index > self.chunks.len()`, it means we used zeroes
383 // instead of an actual chunk and `value_if_not_past_eof` is nonsense.
384 if self.uncommitted_state.chunk_index <= self.chunks.len() {
385 *self.save_state = self.uncommitted_state;
386 Some(value_if_not_past_eof)
387 } else {
388 None
389 }
390 }
391
392 fn read_bool(mut self, probability: u8) -> Option<bool> {
393 let bit = self.fast_read_bit(probability);
394 self.commit_if_valid(bit)
395 }
396
397 fn read_flag(mut self) -> Option<bool> {
398 let value = self.fast_read_flag();
399 self.commit_if_valid(value)
400 }
401
402 fn read_literal(mut self, n: u8) -> Option<u8> {
403 let value = self.fast_read_literal(n);
404 self.commit_if_valid(value)
405 }
406
407 fn read_optional_signed_value(mut self, n: u8) -> Option<i32> {
408 let flag = self.fast_read_flag();
409 if !flag {
410 // We should not read further bits if the flag is not set.
411 return self.commit_if_valid(0);
412 }
413 let magnitude = self.fast_read_literal(n);
414 let sign = self.fast_read_flag();
415 let value = if sign {
416 -i32::from(magnitude)
417 } else {
418 i32::from(magnitude)
419 };
420 self.commit_if_valid(value)
421 }
422
423 fn read_with_tree(mut self, tree: &[TreeNode], first_node: TreeNode) -> Option<i8> {
424 let value = self.fast_read_with_tree(tree, first_node);
425 self.commit_if_valid(value)
426 }
427
428 fn fast_read_bit(&mut self, probability: u8) -> bool {
429 let State {
430 mut chunk_index,
431 mut value,
432 mut range,
433 mut bit_count,
434 } = self.uncommitted_state;
435
436 if bit_count < 0 {
437 let chunk = self.chunks.get(chunk_index).copied();
438 // We ignore invalid data inside the `fast_` functions,
439 // but we increase `chunk_index` below, so we can check
440 // whether we read invalid data in `commit_if_valid`.
441 let chunk = chunk.unwrap_or_default();
442
443 let v = u32::from_be_bytes(chunk);
444 chunk_index += 1;
445 value <<= 32;
446 value |= u64::from(v);
447 bit_count += 32;
448 }
449 debug_assert!(bit_count >= 0);
450
451 let probability = u32::from(probability);
452 let split = 1 + (((range - 1) * probability) >> 8);
453 let bigsplit = u64::from(split) << bit_count;
454
455 let retval = if let Some(new_value) = value.checked_sub(bigsplit) {
456 range -= split;
457 value = new_value;
458 true
459 } else {
460 range = split;
461 false
462 };
463 debug_assert!(range > 0);
464
465 // Compute shift required to satisfy `range >= 128`.
466 // Apply that shift to `range` and `self.bitcount`.
467 //
468 // Subtract 24 because we only care about leading zeros in the
469 // lowest byte of `range` which is a `u32`.
470 let shift = range.leading_zeros().saturating_sub(24);
471 range <<= shift;
472 bit_count -= shift as i32;
473 debug_assert!(range >= 128);
474
475 self.uncommitted_state = State {
476 chunk_index,
477 value,
478 range,
479 bit_count,
480 };
481 retval
482 }
483
484 fn fast_read_flag(&mut self) -> bool {
485 let State {
486 mut chunk_index,
487 mut value,
488 mut range,
489 mut bit_count,
490 } = self.uncommitted_state;
491
492 if bit_count < 0 {
493 let chunk = self.chunks.get(chunk_index).copied();
494 // We ignore invalid data inside the `fast_` functions,
495 // but we increase `chunk_index` below, so we can check
496 // whether we read invalid data in `commit_if_valid`.
497 let chunk = chunk.unwrap_or_default();
498
499 let v = u32::from_be_bytes(chunk);
500 chunk_index += 1;
501 value <<= 32;
502 value |= u64::from(v);
503 bit_count += 32;
504 }
505 debug_assert!(bit_count >= 0);
506
507 let half_range = range / 2;
508 let split = range - half_range;
509 let bigsplit = u64::from(split) << bit_count;
510
511 let retval = if let Some(new_value) = value.checked_sub(bigsplit) {
512 range = half_range;
513 value = new_value;
514 true
515 } else {
516 range = split;
517 false
518 };
519 debug_assert!(range > 0);
520
521 // Compute shift required to satisfy `range >= 128`.
522 // Apply that shift to `range` and `self.bitcount`.
523 //
524 // Subtract 24 because we only care about leading zeros in the
525 // lowest byte of `range` which is a `u32`.
526 let shift = range.leading_zeros().saturating_sub(24);
527 range <<= shift;
528 bit_count -= shift as i32;
529 debug_assert!(range >= 128);
530
531 self.uncommitted_state = State {
532 chunk_index,
533 value,
534 range,
535 bit_count,
536 };
537 retval
538 }
539
540 fn fast_read_literal(&mut self, n: u8) -> u8 {
541 let mut v = 0u8;
542 for _ in 0..n {
543 let b = self.fast_read_flag();
544 v = (v << 1) + u8::from(b);
545 }
546 v
547 }
548
549 fn fast_read_with_tree(&mut self, tree: &[TreeNode], mut node: TreeNode) -> i8 {
550 loop {
551 let prob = node.prob;
552 let b = self.fast_read_bit(prob);
553 let i = if b { node.right } else { node.left };
554 let Some(next_node) = tree.get(usize::from(i)) else {
555 return TreeNode::value_from_branch(i);
556 };
557 node = *next_node;
558 }
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_arithmetic_decoder_hello_short() {
568 let mut decoder = ArithmeticDecoder::new();
569 let data = b"hel";
570 let size = data.len();
571 let mut buf = vec![[0u8; 4]; 1];
572 buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
573 decoder.init(buf, size).unwrap();
574 let mut res = decoder.start_accumulated_result();
575 assert_eq!(false, decoder.read_flag().or_accumulate(&mut res));
576 assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res));
577 assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res));
578 assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res));
579 assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res));
580 assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res));
581 assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res));
582 decoder.check(res, ()).unwrap();
583 }
584
585 #[test]
586 fn test_arithmetic_decoder_hello_long() {
587 let mut decoder = ArithmeticDecoder::new();
588 let data = b"hello world";
589 let size = data.len();
590 let mut buf = vec![[0u8; 4]; (size + 3) / 4];
591 buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
592 decoder.init(buf, size).unwrap();
593 let mut res = decoder.start_accumulated_result();
594 assert_eq!(false, decoder.read_flag().or_accumulate(&mut res));
595 assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res));
596 assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res));
597 assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res));
598 assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res));
599 assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res));
600 assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res));
601 assert_eq!(31, decoder.read_literal(8).or_accumulate(&mut res));
602 decoder.check(res, ()).unwrap();
603 }
604
605 #[test]
606 fn test_arithmetic_decoder_uninit() {
607 let mut decoder = ArithmeticDecoder::new();
608 let mut res = decoder.start_accumulated_result();
609 let _ = decoder.read_flag().or_accumulate(&mut res);
610 let result = decoder.check(res, ());
611 assert!(result.is_err());
612 }
613}
614