1pub struct BitReader<'s> {
2 idx: usize, //index counts bits already read
3 source: &'s [u8],
4}
5
6#[derive(Debug, derive_more::Display)]
7#[cfg_attr(feature = "std", derive(derive_more::Error))]
8#[non_exhaustive]
9pub enum GetBitsError {
10 #[display(
11 fmt = "Cant serve this request. The reader is limited to {limit} bits, requested {num_requested_bits} bits"
12 )]
13 TooManyBits {
14 num_requested_bits: usize,
15 limit: u8,
16 },
17 #[display(fmt = "Can't read {requested} bits, only have {remaining} bits left")]
18 NotEnoughRemainingBits { requested: usize, remaining: usize },
19}
20
21impl<'s> BitReader<'s> {
22 pub fn new(source: &'s [u8]) -> BitReader<'_> {
23 BitReader { idx: 0, source }
24 }
25
26 pub fn bits_left(&self) -> usize {
27 self.source.len() * 8 - self.idx
28 }
29
30 pub fn bits_read(&self) -> usize {
31 self.idx
32 }
33
34 pub fn return_bits(&mut self, n: usize) {
35 if n > self.idx {
36 panic!("Cant return this many bits");
37 }
38 self.idx -= n;
39 }
40
41 pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
42 if n > 64 {
43 return Err(GetBitsError::TooManyBits {
44 num_requested_bits: n,
45 limit: 64,
46 });
47 }
48 if self.bits_left() < n {
49 return Err(GetBitsError::NotEnoughRemainingBits {
50 requested: n,
51 remaining: self.bits_left(),
52 });
53 }
54
55 let old_idx = self.idx;
56
57 let bits_left_in_current_byte = 8 - (self.idx % 8);
58 let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
59
60 //collect bits from the currently pointed to byte
61 let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
62
63 if bits_left_in_current_byte >= n {
64 //no need for fancy stuff
65
66 //just mask all but the needed n bit
67 value &= (1 << n) - 1;
68 self.idx += n;
69 } else {
70 self.idx += bits_left_in_current_byte;
71
72 //n spans over multiple bytes
73 let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
74 let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
75
76 assert!(
77 bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
78 );
79
80 let mut bit_shift = bits_left_in_current_byte; //this many bits are already set in value
81
82 assert!(self.idx % 8 == 0);
83
84 //collect full bytes
85 for _ in 0..full_bytes_needed {
86 value |= u64::from(self.source[self.idx / 8]) << bit_shift;
87 self.idx += 8;
88 bit_shift += 8;
89 }
90
91 assert!(n - bit_shift == bits_in_last_byte_needed);
92
93 if bits_in_last_byte_needed > 0 {
94 let val_las_byte =
95 u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
96 value |= val_las_byte << bit_shift;
97 self.idx += bits_in_last_byte_needed;
98 }
99 }
100
101 assert!(self.idx == old_idx + n);
102
103 Ok(value)
104 }
105
106 pub fn reset(&mut self, new_source: &'s [u8]) {
107 self.idx = 0;
108 self.source = new_source;
109 }
110}
111