1 | //! This module contains the [decompress_literals] function, used to take a |
2 | //! parsed literals header and a source and decompress it. |
3 | |
4 | use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType}; |
5 | use super::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
6 | use super::scratch::HuffmanScratch; |
7 | use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError}; |
8 | use alloc::vec::Vec; |
9 | |
10 | #[derive (Debug)] |
11 | #[non_exhaustive ] |
12 | pub enum DecompressLiteralsError { |
13 | MissingCompressedSize, |
14 | MissingNumStreams, |
15 | GetBitsError(GetBitsError), |
16 | HuffmanTableError(HuffmanTableError), |
17 | HuffmanDecoderError(HuffmanDecoderError), |
18 | UninitializedHuffmanTable, |
19 | MissingBytesForJumpHeader { got: usize }, |
20 | MissingBytesForLiterals { got: usize, needed: usize }, |
21 | ExtraPadding { skipped_bits: i32 }, |
22 | BitstreamReadMismatch { read_til: isize, expected: isize }, |
23 | DecodedLiteralCountMismatch { decoded: usize, expected: usize }, |
24 | } |
25 | |
26 | #[cfg (feature = "std" )] |
27 | impl std::error::Error for DecompressLiteralsError { |
28 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |
29 | match self { |
30 | DecompressLiteralsError::GetBitsError(source: &GetBitsError) => Some(source), |
31 | DecompressLiteralsError::HuffmanTableError(source: &HuffmanTableError) => Some(source), |
32 | DecompressLiteralsError::HuffmanDecoderError(source: &HuffmanDecoderError) => Some(source), |
33 | _ => None, |
34 | } |
35 | } |
36 | } |
37 | impl core::fmt::Display for DecompressLiteralsError { |
38 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
39 | match self { |
40 | DecompressLiteralsError::MissingCompressedSize => { |
41 | write!(f, |
42 | "compressed size was none even though it must be set to something for compressed literals" , |
43 | ) |
44 | } |
45 | DecompressLiteralsError::MissingNumStreams => { |
46 | write!(f, |
47 | "num_streams was none even though it must be set to something (1 or 4) for compressed literals" , |
48 | ) |
49 | } |
50 | DecompressLiteralsError::GetBitsError(e) => write!(f, " {:?}" , e), |
51 | DecompressLiteralsError::HuffmanTableError(e) => write!(f, " {:?}" , e), |
52 | DecompressLiteralsError::HuffmanDecoderError(e) => write!(f, " {:?}" , e), |
53 | DecompressLiteralsError::UninitializedHuffmanTable => { |
54 | write!( |
55 | f, |
56 | "Tried to reuse huffman table but it was never initialized" , |
57 | ) |
58 | } |
59 | DecompressLiteralsError::MissingBytesForJumpHeader { got } => { |
60 | write!(f, "Need 6 bytes to decode jump header, got {} bytes" , got,) |
61 | } |
62 | DecompressLiteralsError::MissingBytesForLiterals { got, needed } => { |
63 | write!( |
64 | f, |
65 | "Need at least {} bytes to decode literals. Have: {} bytes" , |
66 | needed, got, |
67 | ) |
68 | } |
69 | DecompressLiteralsError::ExtraPadding { skipped_bits } => { |
70 | write!(f, |
71 | "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption" , |
72 | skipped_bits, |
73 | ) |
74 | } |
75 | DecompressLiteralsError::BitstreamReadMismatch { read_til, expected } => { |
76 | write!( |
77 | f, |
78 | "Bitstream was read till: {}, should have been: {}" , |
79 | read_til, expected, |
80 | ) |
81 | } |
82 | DecompressLiteralsError::DecodedLiteralCountMismatch { decoded, expected } => { |
83 | write!( |
84 | f, |
85 | "Did not decode enough literals: {}, Should have been: {}" , |
86 | decoded, expected, |
87 | ) |
88 | } |
89 | } |
90 | } |
91 | } |
92 | |
93 | impl From<HuffmanDecoderError> for DecompressLiteralsError { |
94 | fn from(val: HuffmanDecoderError) -> Self { |
95 | Self::HuffmanDecoderError(val) |
96 | } |
97 | } |
98 | |
99 | impl From<GetBitsError> for DecompressLiteralsError { |
100 | fn from(val: GetBitsError) -> Self { |
101 | Self::GetBitsError(val) |
102 | } |
103 | } |
104 | |
105 | impl From<HuffmanTableError> for DecompressLiteralsError { |
106 | fn from(val: HuffmanTableError) -> Self { |
107 | Self::HuffmanTableError(val) |
108 | } |
109 | } |
110 | |
111 | /// Decode and decompress the provided literals section into `target`, returning the number of bytes read. |
112 | pub fn decode_literals( |
113 | section: &LiteralsSection, |
114 | scratch: &mut HuffmanScratch, |
115 | source: &[u8], |
116 | target: &mut Vec<u8>, |
117 | ) -> Result<u32, DecompressLiteralsError> { |
118 | match section.ls_type { |
119 | LiteralsSectionType::Raw => { |
120 | target.extend(&source[0..section.regenerated_size as usize]); |
121 | Ok(section.regenerated_size) |
122 | } |
123 | LiteralsSectionType::RLE => { |
124 | target.resize(new_len:target.len() + section.regenerated_size as usize, value:source[0]); |
125 | Ok(1) |
126 | } |
127 | LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => { |
128 | let bytes_read: u32 = decompress_literals(section, scratch, source, target)?; |
129 | |
130 | //return sum of used bytes |
131 | Ok(bytes_read) |
132 | } |
133 | } |
134 | } |
135 | |
136 | /// Decompress the provided literals section and source into the provided `target`. |
137 | /// This function is used when the literals section is `Compressed` or `Treeless` |
138 | /// |
139 | /// Returns the number of bytes read. |
140 | fn decompress_literals( |
141 | section: &LiteralsSection, |
142 | scratch: &mut HuffmanScratch, |
143 | source: &[u8], |
144 | target: &mut Vec<u8>, |
145 | ) -> Result<u32, DecompressLiteralsError> { |
146 | use DecompressLiteralsError as err; |
147 | |
148 | let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize; |
149 | let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?; |
150 | |
151 | target.reserve(section.regenerated_size as usize); |
152 | let source = &source[0..compressed_size]; |
153 | let mut bytes_read = 0; |
154 | |
155 | match section.ls_type { |
156 | LiteralsSectionType::Compressed => { |
157 | //read Huffman tree description |
158 | bytes_read += scratch.table.build_decoder(source)?; |
159 | vprintln!("Built huffman table using {} bytes" , bytes_read); |
160 | } |
161 | LiteralsSectionType::Treeless => { |
162 | if scratch.table.max_num_bits == 0 { |
163 | return Err(err::UninitializedHuffmanTable); |
164 | } |
165 | } |
166 | _ => { /* nothing to do, huffman tree has been provided by previous block */ } |
167 | } |
168 | |
169 | let source = &source[bytes_read as usize..]; |
170 | |
171 | if num_streams == 4 { |
172 | //build jumptable |
173 | if source.len() < 6 { |
174 | return Err(err::MissingBytesForJumpHeader { got: source.len() }); |
175 | } |
176 | let jump1 = source[0] as usize + ((source[1] as usize) << 8); |
177 | let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8); |
178 | let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8); |
179 | bytes_read += 6; |
180 | let source = &source[6..]; |
181 | |
182 | if source.len() < jump3 { |
183 | return Err(err::MissingBytesForLiterals { |
184 | got: source.len(), |
185 | needed: jump3, |
186 | }); |
187 | } |
188 | |
189 | //decode 4 streams |
190 | let stream1 = &source[..jump1]; |
191 | let stream2 = &source[jump1..jump2]; |
192 | let stream3 = &source[jump2..jump3]; |
193 | let stream4 = &source[jump3..]; |
194 | |
195 | for stream in &[stream1, stream2, stream3, stream4] { |
196 | let mut decoder = HuffmanDecoder::new(&scratch.table); |
197 | let mut br = BitReaderReversed::new(stream); |
198 | //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found |
199 | let mut skipped_bits = 0; |
200 | loop { |
201 | let val = br.get_bits(1); |
202 | skipped_bits += 1; |
203 | if val == 1 || skipped_bits > 8 { |
204 | break; |
205 | } |
206 | } |
207 | if skipped_bits > 8 { |
208 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
209 | return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); |
210 | } |
211 | decoder.init_state(&mut br); |
212 | |
213 | while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { |
214 | target.push(decoder.decode_symbol()); |
215 | decoder.next_state(&mut br); |
216 | } |
217 | if br.bits_remaining() != -(scratch.table.max_num_bits as isize) { |
218 | return Err(DecompressLiteralsError::BitstreamReadMismatch { |
219 | read_til: br.bits_remaining(), |
220 | expected: -(scratch.table.max_num_bits as isize), |
221 | }); |
222 | } |
223 | } |
224 | |
225 | bytes_read += source.len() as u32; |
226 | } else { |
227 | //just decode the one stream |
228 | assert!(num_streams == 1); |
229 | let mut decoder = HuffmanDecoder::new(&scratch.table); |
230 | let mut br = BitReaderReversed::new(source); |
231 | let mut skipped_bits = 0; |
232 | loop { |
233 | let val = br.get_bits(1); |
234 | skipped_bits += 1; |
235 | if val == 1 || skipped_bits > 8 { |
236 | break; |
237 | } |
238 | } |
239 | if skipped_bits > 8 { |
240 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
241 | return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); |
242 | } |
243 | decoder.init_state(&mut br); |
244 | while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { |
245 | target.push(decoder.decode_symbol()); |
246 | decoder.next_state(&mut br); |
247 | } |
248 | bytes_read += source.len() as u32; |
249 | } |
250 | |
251 | if target.len() != section.regenerated_size as usize { |
252 | return Err(DecompressLiteralsError::DecodedLiteralCountMismatch { |
253 | decoded: target.len(), |
254 | expected: section.regenerated_size as usize, |
255 | }); |
256 | } |
257 | |
258 | Ok(bytes_read) |
259 | } |
260 | |