1//! Rudimentary utility for reading Canonical Huffman Codes.
2//! Based off <https://github.com/webmproject/libwebp/blob/7f8472a610b61ec780ef0a8873cd954ac512a505/src/utils/huffman.c>
3
4use std::io::BufRead;
5
6use crate::decoder::DecodingError;
7
8use super::lossless::BitReader;
9
10const MAX_ALLOWED_CODE_LENGTH: usize = 15;
11const MAX_TABLE_BITS: u8 = 10;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14enum HuffmanTreeNode {
15 Branch(usize), //offset in vector to children
16 Leaf(u16), //symbol stored in leaf
17 Empty,
18}
19
20#[derive(Clone, Debug)]
21enum HuffmanTreeInner {
22 Single(u16),
23 Tree {
24 tree: Vec<HuffmanTreeNode>,
25 table: Vec<u32>,
26 table_mask: u16,
27 },
28}
29
30/// Huffman tree
31#[derive(Clone, Debug)]
32pub(crate) struct HuffmanTree(HuffmanTreeInner);
33
34impl Default for HuffmanTree {
35 fn default() -> Self {
36 Self(HuffmanTreeInner::Single(0))
37 }
38}
39
40impl HuffmanTree {
41 /// Builds a tree implicitly, just from code lengths
42 pub(crate) fn build_implicit(code_lengths: Vec<u16>) -> Result<Self, DecodingError> {
43 // Count symbols and build histogram
44 let mut num_symbols = 0;
45 let mut code_length_hist = [0; MAX_ALLOWED_CODE_LENGTH + 1];
46 for &length in code_lengths.iter().filter(|&&x| x != 0) {
47 code_length_hist[usize::from(length)] += 1;
48 num_symbols += 1;
49 }
50
51 // Handle special cases
52 if num_symbols == 0 {
53 return Err(DecodingError::HuffmanError);
54 } else if num_symbols == 1 {
55 let root_symbol = code_lengths.iter().position(|&x| x != 0).unwrap() as u16;
56 return Ok(Self::build_single_node(root_symbol));
57 };
58
59 // Assign codes
60 let mut curr_code = 0;
61 let mut next_codes = [0; MAX_ALLOWED_CODE_LENGTH + 1];
62 let max_code_length = code_length_hist.iter().rposition(|&x| x != 0).unwrap() as u16;
63 for code_len in 1..usize::from(max_code_length) + 1 {
64 next_codes[code_len] = curr_code;
65 curr_code = (curr_code + code_length_hist[code_len]) << 1;
66 }
67
68 // Confirm that the huffman tree is valid
69 if curr_code != 2 << max_code_length {
70 return Err(DecodingError::HuffmanError);
71 }
72
73 // Calculate table/tree parameters
74 let table_bits = max_code_length.min(u16::from(MAX_TABLE_BITS));
75 let table_size = (1 << table_bits) as usize;
76 let table_mask = table_size as u16 - 1;
77 let tree_size = code_length_hist[table_bits as usize + 1..=max_code_length as usize]
78 .iter()
79 .sum::<u16>() as usize;
80
81 // Populate decoding table
82 let mut tree = Vec::with_capacity(2 * tree_size);
83 let mut table = vec![0; table_size];
84 for (symbol, &length) in code_lengths.iter().enumerate() {
85 if length == 0 {
86 continue;
87 }
88
89 let code = next_codes[length as usize];
90 next_codes[length as usize] += 1;
91
92 if length <= table_bits {
93 let mut j = (u16::reverse_bits(code) >> (16 - length)) as usize;
94 let entry = (u32::from(length) << 16) | symbol as u32;
95 while j < table_size {
96 table[j] = entry;
97 j += 1 << length as usize;
98 }
99 } else {
100 let table_index =
101 ((u16::reverse_bits(code) >> (16 - length)) & table_mask) as usize;
102 let table_value = table[table_index];
103
104 debug_assert_eq!(table_value >> 16, 0);
105
106 let mut node_index = if table_value == 0 {
107 let node_index = tree.len();
108 table[table_index] = (node_index + 1) as u32;
109 tree.push(HuffmanTreeNode::Empty);
110 node_index
111 } else {
112 (table_value - 1) as usize
113 };
114
115 let code = usize::from(code);
116 for depth in (0..length - table_bits).rev() {
117 let node = tree[node_index];
118
119 let offset = match node {
120 HuffmanTreeNode::Empty => {
121 // Turns a node from empty into a branch and assigns its children
122 let offset = tree.len() - node_index;
123 tree[node_index] = HuffmanTreeNode::Branch(offset);
124 tree.push(HuffmanTreeNode::Empty);
125 tree.push(HuffmanTreeNode::Empty);
126 offset
127 }
128 HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
129 HuffmanTreeNode::Branch(offset) => offset,
130 };
131
132 node_index += offset + ((code >> depth) & 1);
133 }
134
135 match tree[node_index] {
136 HuffmanTreeNode::Empty => {
137 tree[node_index] = HuffmanTreeNode::Leaf(symbol as u16);
138 }
139 HuffmanTreeNode::Leaf(_) => return Err(DecodingError::HuffmanError),
140 HuffmanTreeNode::Branch(_offset) => return Err(DecodingError::HuffmanError),
141 }
142 }
143 }
144
145 Ok(Self(HuffmanTreeInner::Tree {
146 tree,
147 table,
148 table_mask,
149 }))
150 }
151
152 pub(crate) const fn build_single_node(symbol: u16) -> Self {
153 Self(HuffmanTreeInner::Single(symbol))
154 }
155
156 pub(crate) fn build_two_node(zero: u16, one: u16) -> Self {
157 Self(HuffmanTreeInner::Tree {
158 tree: vec![
159 HuffmanTreeNode::Leaf(zero),
160 HuffmanTreeNode::Leaf(one),
161 HuffmanTreeNode::Empty,
162 ],
163 table: vec![(1 << 16) | u32::from(zero), (1 << 16) | u32::from(one)],
164 table_mask: 0x1,
165 })
166 }
167
168 pub(crate) const fn is_single_node(&self) -> bool {
169 matches!(self.0, HuffmanTreeInner::Single(_))
170 }
171
172 #[inline(never)]
173 fn read_symbol_slowpath<R: BufRead>(
174 tree: &[HuffmanTreeNode],
175 mut v: usize,
176 start_index: usize,
177 bit_reader: &mut BitReader<R>,
178 ) -> Result<u16, DecodingError> {
179 let mut depth = MAX_TABLE_BITS;
180 let mut index = start_index;
181 loop {
182 match &tree[index] {
183 HuffmanTreeNode::Branch(children_offset) => {
184 index += children_offset + (v & 1);
185 depth += 1;
186 v >>= 1;
187 }
188 HuffmanTreeNode::Leaf(symbol) => {
189 bit_reader.consume(depth)?;
190 return Ok(*symbol);
191 }
192 HuffmanTreeNode::Empty => return Err(DecodingError::HuffmanError),
193 }
194 }
195 }
196
197 /// Reads a symbol using the bit reader.
198 ///
199 /// You must call call `bit_reader.fill()` before calling this function or it may erroroneosly
200 /// detect the end of the stream and return a bitstream error.
201 pub(crate) fn read_symbol<R: BufRead>(
202 &self,
203 bit_reader: &mut BitReader<R>,
204 ) -> Result<u16, DecodingError> {
205 match &self.0 {
206 HuffmanTreeInner::Tree {
207 tree,
208 table,
209 table_mask,
210 } => {
211 let v = bit_reader.peek_full() as u16;
212 let entry = table[(v & table_mask) as usize];
213 if entry >> 16 != 0 {
214 bit_reader.consume((entry >> 16) as u8)?;
215 return Ok(entry as u16);
216 }
217
218 Self::read_symbol_slowpath(
219 tree,
220 (v >> MAX_TABLE_BITS) as usize,
221 ((entry & 0xffff) - 1) as usize,
222 bit_reader,
223 )
224 }
225 HuffmanTreeInner::Single(symbol) => Ok(*symbol),
226 }
227 }
228
229 /// Peek at the next symbol in the bitstream if it can be read with only a primary table lookup.
230 ///
231 /// Returns a tuple of the codelength and symbol value. This function may return wrong
232 /// information if there aren't enough bits in the bit reader to read the next symbol.
233 pub(crate) fn peek_symbol<R: BufRead>(&self, bit_reader: &BitReader<R>) -> Option<(u8, u16)> {
234 match &self.0 {
235 HuffmanTreeInner::Tree {
236 table, table_mask, ..
237 } => {
238 let v = bit_reader.peek_full() as u16;
239 let entry = table[(v & table_mask) as usize];
240 if entry >> 16 != 0 {
241 return Some(((entry >> 16) as u8, entry as u16));
242 }
243 None
244 }
245 HuffmanTreeInner::Single(symbol) => Some((0, *symbol)),
246 }
247 }
248}
249