1use std::convert::TryInto;
2
3use simd_adler32::Adler32;
4
5use crate::tables::{
6 self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FDEFLATE_DIST_DECODE_TABLE,
7 FDEFLATE_LITLEN_DECODE_TABLE, FIXED_CODE_LENGTHS, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA,
8};
9
10/// An error encountered while decompressing a deflate stream.
11#[derive(Debug, PartialEq)]
12pub enum DecompressionError {
13 /// The zlib header is corrupt.
14 BadZlibHeader,
15 /// All input was consumed, but the end of the stream hasn't been reached.
16 InsufficientInput,
17 /// A block header specifies an invalid block type.
18 InvalidBlockType,
19 /// An uncompressed block's NLEN value is invalid.
20 InvalidUncompressedBlockLength,
21 /// Too many literals were specified.
22 InvalidHlit,
23 /// Too many distance codes were specified.
24 InvalidHdist,
25 /// Attempted to repeat a previous code before reading any codes, or past the end of the code
26 /// lengths.
27 InvalidCodeLengthRepeat,
28 /// The stream doesn't specify a valid huffman tree.
29 BadCodeLengthHuffmanTree,
30 /// The stream doesn't specify a valid huffman tree.
31 BadLiteralLengthHuffmanTree,
32 /// The stream doesn't specify a valid huffman tree.
33 BadDistanceHuffmanTree,
34 /// The stream contains a literal/length code that was not allowed by the header.
35 InvalidLiteralLengthCode,
36 /// The stream contains a distance code that was not allowed by the header.
37 InvalidDistanceCode,
38 /// The stream contains contains back-reference as the first symbol.
39 InputStartsWithRun,
40 /// The stream contains a back-reference that is too far back.
41 DistanceTooFarBack,
42 /// The deflate stream checksum is incorrect.
43 WrongChecksum,
44 /// Extra input data.
45 ExtraInput,
46}
47
48struct BlockHeader {
49 hlit: usize,
50 hdist: usize,
51 hclen: usize,
52 num_lengths_read: usize,
53
54 /// Low 3-bits are code length code length, high 5-bits are code length code.
55 table: [u8; 128],
56 code_lengths: [u8; 320],
57}
58
59const LITERAL_ENTRY: u32 = 0x8000;
60const EXCEPTIONAL_ENTRY: u32 = 0x4000;
61const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
62
63/// The Decompressor state for a compressed block.
64///
65/// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is
66/// split into 4 sections:
67///
68/// aaaaaaaa_bbbbbbbb_1000yyyy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal)
69/// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length)
70/// 00000000_00000000_01000000_0000xxxx x = input_advance_bits (EOF)
71/// 0000xxxx_xxxxxxxx_01100000_00000000 x = secondary_table_index
72/// 00000000_00000000_01000000_00000000 invalid code
73///
74/// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning.
75///
76/// 00000000_00000000_00000000_00000000 symbol is more than 9 bits
77/// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base
78#[repr(align(64))]
79#[derive(Eq, PartialEq, Debug)]
80struct CompressedBlock {
81 litlen_table: [u32; 4096],
82 dist_table: [u32; 512],
83
84 dist_symbol_lengths: [u8; 30],
85 dist_symbol_masks: [u16; 30],
86 dist_symbol_codes: [u16; 30],
87
88 secondary_table: Vec<u16>,
89 eof_code: u16,
90 eof_mask: u16,
91 eof_bits: u8,
92}
93
94const FDEFLATE_COMPRESSED_BLOCK: CompressedBlock = CompressedBlock {
95 litlen_table: FDEFLATE_LITLEN_DECODE_TABLE,
96 dist_table: FDEFLATE_DIST_DECODE_TABLE,
97 dist_symbol_lengths: [
98 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
99 ],
100 dist_symbol_masks: [
101 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
102 ],
103 dist_symbol_codes: [
104 0, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
105 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
106 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
107 ],
108 secondary_table: Vec::new(),
109 eof_code: 0x8ff,
110 eof_mask: 0xfff,
111 eof_bits: 0xc,
112};
113
114#[derive(Debug, Copy, Clone, Eq, PartialEq)]
115enum State {
116 ZlibHeader,
117 BlockHeader,
118 CodeLengthCodes,
119 CodeLengths,
120 CompressedData,
121 UncompressedData,
122 Checksum,
123 Done,
124}
125
126/// Decompressor for arbitrary zlib streams.
127pub struct Decompressor {
128 /// State for decoding a compressed block.
129 compression: CompressedBlock,
130 // State for decoding a block header.
131 header: BlockHeader,
132 // Number of bytes left for uncompressed block.
133 uncompressed_bytes_left: u16,
134
135 buffer: u64,
136 nbits: u8,
137
138 queued_rle: Option<(u8, usize)>,
139 queued_backref: Option<(usize, usize)>,
140 last_block: bool,
141
142 state: State,
143 checksum: Adler32,
144 ignore_adler32: bool,
145}
146
147impl Default for Decompressor {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl Decompressor {
154 /// Create a new decompressor.
155 pub fn new() -> Self {
156 Self {
157 buffer: 0,
158 nbits: 0,
159 compression: CompressedBlock {
160 litlen_table: [0; 4096],
161 dist_table: [0; 512],
162 secondary_table: Vec::new(),
163 dist_symbol_lengths: [0; 30],
164 dist_symbol_masks: [0; 30],
165 dist_symbol_codes: [0xffff; 30],
166 eof_code: 0,
167 eof_mask: 0,
168 eof_bits: 0,
169 },
170 header: BlockHeader {
171 hlit: 0,
172 hdist: 0,
173 hclen: 0,
174 table: [0; 128],
175 num_lengths_read: 0,
176 code_lengths: [0; 320],
177 },
178 uncompressed_bytes_left: 0,
179 queued_rle: None,
180 queued_backref: None,
181 checksum: Adler32::new(),
182 state: State::ZlibHeader,
183 last_block: false,
184 ignore_adler32: false,
185 }
186 }
187
188 /// Ignore the checksum at the end of the stream.
189 pub fn ignore_adler32(&mut self) {
190 self.ignore_adler32 = true;
191 }
192
193 fn fill_buffer(&mut self, input: &mut &[u8]) {
194 if input.len() >= 8 {
195 self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits;
196 *input = &mut &input[(63 - self.nbits as usize) / 8..];
197 self.nbits |= 56;
198 } else {
199 let nbytes = input.len().min((63 - self.nbits as usize) / 8);
200 let mut input_data = [0; 8];
201 input_data[..nbytes].copy_from_slice(&input[..nbytes]);
202 self.buffer |= u64::from_le_bytes(input_data)
203 .checked_shl(self.nbits as u32)
204 .unwrap_or(0);
205 self.nbits += nbytes as u8 * 8;
206 *input = &mut &input[nbytes..];
207 }
208 }
209
210 fn peak_bits(&mut self, nbits: u8) -> u64 {
211 debug_assert!(nbits <= 56 && nbits <= self.nbits);
212 self.buffer & ((1u64 << nbits) - 1)
213 }
214 fn consume_bits(&mut self, nbits: u8) {
215 debug_assert!(self.nbits >= nbits);
216 self.buffer >>= nbits;
217 self.nbits -= nbits;
218 }
219
220 fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
221 self.fill_buffer(remaining_input);
222 if self.nbits < 3 {
223 return Ok(());
224 }
225
226 let start = self.peak_bits(3);
227 self.last_block = start & 1 != 0;
228 match start >> 1 {
229 0b00 => {
230 let align_bits = (self.nbits - 3) % 8;
231 let header_bits = 3 + 32 + align_bits;
232 if self.nbits < header_bits {
233 return Ok(());
234 }
235
236 let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16;
237 let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16;
238 if nlen != !len {
239 return Err(DecompressionError::InvalidUncompressedBlockLength);
240 }
241
242 self.state = State::UncompressedData;
243 self.uncompressed_bytes_left = len;
244 self.consume_bits(header_bits);
245 Ok(())
246 }
247 0b01 => {
248 self.consume_bits(3);
249 // TODO: Do this statically rather than every time.
250 Self::build_tables(288, &FIXED_CODE_LENGTHS, &mut self.compression, 6)?;
251 self.state = State::CompressedData;
252 Ok(())
253 }
254 0b10 => {
255 if self.nbits < 17 {
256 return Ok(());
257 }
258
259 self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257;
260 self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1;
261 self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4;
262 if self.header.hlit > 286 {
263 return Err(DecompressionError::InvalidHlit);
264 }
265 if self.header.hdist > 30 {
266 return Err(DecompressionError::InvalidHdist);
267 }
268
269 self.consume_bits(17);
270 self.state = State::CodeLengthCodes;
271 Ok(())
272 }
273 0b11 => Err(DecompressionError::InvalidBlockType),
274 _ => unreachable!(),
275 }
276 }
277
278 fn read_code_length_codes(
279 &mut self,
280 remaining_input: &mut &[u8],
281 ) -> Result<(), DecompressionError> {
282 self.fill_buffer(remaining_input);
283 if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
284 return Ok(());
285 }
286
287 let mut code_length_lengths = [0; 19];
288 for i in 0..self.header.hclen {
289 code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8;
290 self.consume_bits(3);
291
292 // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds
293 // between 56 and 63 bits total.
294 if i == 17 {
295 self.fill_buffer(remaining_input);
296 }
297 }
298 let code_length_codes: [u16; 19] = crate::compute_codes(&code_length_lengths)
299 .ok_or(DecompressionError::BadCodeLengthHuffmanTree)?;
300
301 self.header.table = [255; 128];
302 for i in 0..19 {
303 let length = code_length_lengths[i];
304 if length > 0 {
305 let mut j = code_length_codes[i];
306 while j < 128 {
307 self.header.table[j as usize] = ((i as u8) << 3) | length;
308 j += 1 << length;
309 }
310 }
311 }
312
313 self.state = State::CodeLengths;
314 self.header.num_lengths_read = 0;
315 Ok(())
316 }
317
318 fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
319 let total_lengths = self.header.hlit + self.header.hdist;
320 while self.header.num_lengths_read < total_lengths {
321 self.fill_buffer(remaining_input);
322 if self.nbits < 7 {
323 return Ok(());
324 }
325
326 let code = self.peak_bits(7);
327 let entry = self.header.table[code as usize];
328 let length = entry & 0x7;
329 let symbol = entry >> 3;
330
331 debug_assert!(length != 0);
332 match symbol {
333 0..=15 => {
334 self.header.code_lengths[self.header.num_lengths_read] = symbol;
335 self.header.num_lengths_read += 1;
336 self.consume_bits(length);
337 }
338 16..=18 => {
339 let (base_repeat, extra_bits) = match symbol {
340 16 => (3, 2),
341 17 => (3, 3),
342 18 => (11, 7),
343 _ => unreachable!(),
344 };
345
346 if self.nbits < length + extra_bits {
347 return Ok(());
348 }
349
350 let value = match symbol {
351 16 => {
352 self.header.code_lengths[self
353 .header
354 .num_lengths_read
355 .checked_sub(1)
356 .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
357 // TODO: is this right?
358 }
359 17 => 0,
360 18 => 0,
361 _ => unreachable!(),
362 };
363
364 let repeat =
365 (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat;
366 if self.header.num_lengths_read + repeat > total_lengths {
367 return Err(DecompressionError::InvalidCodeLengthRepeat);
368 }
369
370 for i in 0..repeat {
371 self.header.code_lengths[self.header.num_lengths_read + i] = value;
372 }
373 self.header.num_lengths_read += repeat;
374 self.consume_bits(length + extra_bits);
375 }
376 _ => unreachable!(),
377 }
378 }
379
380 self.header
381 .code_lengths
382 .copy_within(self.header.hlit..total_lengths, 288);
383 for i in self.header.hlit..288 {
384 self.header.code_lengths[i] = 0;
385 }
386 for i in 288 + self.header.hdist..320 {
387 self.header.code_lengths[i] = 0;
388 }
389
390 if self.header.hdist == 1
391 && self.header.code_lengths[..286] == tables::HUFFMAN_LENGTHS
392 && self.header.code_lengths[288] == 1
393 {
394 self.compression = FDEFLATE_COMPRESSED_BLOCK;
395 } else {
396 Self::build_tables(
397 self.header.hlit,
398 &self.header.code_lengths,
399 &mut self.compression,
400 6,
401 )?;
402 }
403 self.state = State::CompressedData;
404 Ok(())
405 }
406
407 fn build_tables(
408 hlit: usize,
409 code_lengths: &[u8],
410 compression: &mut CompressedBlock,
411 max_search_bits: u8,
412 ) -> Result<(), DecompressionError> {
413 // Build the literal/length code table.
414 let lengths = &code_lengths[..288];
415 let codes: [u16; 288] = crate::compute_codes(&lengths.try_into().unwrap())
416 .ok_or(DecompressionError::BadLiteralLengthHuffmanTree)?;
417
418 let table_bits = lengths.iter().cloned().max().unwrap().min(12).max(6);
419 let table_size = 1 << table_bits;
420
421 for i in 0..256 {
422 let code = codes[i];
423 let length = lengths[i];
424 let mut j = code;
425
426 while j < table_size && length != 0 && length <= 12 {
427 compression.litlen_table[j as usize] =
428 ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8) | length as u32;
429 j += 1 << length;
430 }
431
432 if length > 0 && length <= max_search_bits {
433 for ii in 0..256 {
434 let code2 = codes[ii];
435 let length2 = lengths[ii];
436 if length2 != 0 && length + length2 <= table_bits {
437 let mut j = code | (code2 << length);
438
439 while j < table_size {
440 compression.litlen_table[j as usize] = (ii as u32) << 24
441 | (i as u32) << 16
442 | LITERAL_ENTRY
443 | (2 << 8)
444 | ((length + length2) as u32);
445 j += 1 << (length + length2);
446 }
447 }
448 }
449 }
450 }
451
452 if lengths[256] != 0 && lengths[256] <= 12 {
453 let mut j = codes[256];
454 while j < table_size {
455 compression.litlen_table[j as usize] = EXCEPTIONAL_ENTRY | lengths[256] as u32;
456 j += 1 << lengths[256];
457 }
458 }
459
460 let table_size = table_size as usize;
461 for i in (table_size..4096).step_by(table_size) {
462 compression.litlen_table.copy_within(0..table_size, i);
463 }
464
465 compression.eof_code = codes[256];
466 compression.eof_mask = (1 << lengths[256]) - 1;
467 compression.eof_bits = lengths[256];
468
469 for i in 257..hlit {
470 let code = codes[i];
471 let length = lengths[i];
472 if length != 0 && length <= 12 {
473 let mut j = code;
474 while j < 4096 {
475 compression.litlen_table[j as usize] = if i < 286 {
476 (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16
477 | (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8
478 | length as u32
479 } else {
480 EXCEPTIONAL_ENTRY
481 };
482 j += 1 << length;
483 }
484 }
485 }
486
487 for i in 0..hlit {
488 if lengths[i] > 12 {
489 compression.litlen_table[(codes[i] & 0xfff) as usize] = u32::MAX;
490 }
491 }
492
493 let mut secondary_table_len = 0;
494 for i in 0..hlit {
495 if lengths[i] > 12 {
496 let j = (codes[i] & 0xfff) as usize;
497 if compression.litlen_table[j] == u32::MAX {
498 compression.litlen_table[j] =
499 (secondary_table_len << 16) | EXCEPTIONAL_ENTRY | SECONDARY_TABLE_ENTRY;
500 secondary_table_len += 8;
501 }
502 }
503 }
504 assert!(secondary_table_len <= 0x7ff);
505 compression.secondary_table = vec![0; secondary_table_len as usize];
506 for i in 0..hlit {
507 let code = codes[i];
508 let length = lengths[i];
509 if length > 12 {
510 let j = (codes[i] & 0xfff) as usize;
511 let k = (compression.litlen_table[j] >> 16) as usize;
512
513 let mut s = code >> 12;
514 while s < 8 {
515 debug_assert_eq!(compression.secondary_table[k + s as usize], 0);
516 compression.secondary_table[k + s as usize] =
517 ((i as u16) << 4) | (length as u16);
518 s += 1 << (length - 12);
519 }
520 }
521 }
522 debug_assert!(compression
523 .secondary_table
524 .iter()
525 .all(|&x| x != 0 && (x & 0xf) > 12));
526
527 // Build the distance code table.
528 let lengths = &code_lengths[288..320];
529 if lengths == [0; 32] {
530 compression.dist_symbol_masks = [0; 30];
531 compression.dist_symbol_codes = [0xffff; 30];
532 compression.dist_table.fill(0);
533 } else {
534 let codes: [u16; 32] = match crate::compute_codes(&lengths.try_into().unwrap()) {
535 Some(codes) => codes,
536 None => {
537 if lengths.iter().filter(|&&l| l != 0).count() != 1 {
538 return Err(DecompressionError::BadDistanceHuffmanTree);
539 }
540 [0; 32]
541 }
542 };
543
544 compression.dist_symbol_codes.copy_from_slice(&codes[..30]);
545 compression
546 .dist_symbol_lengths
547 .copy_from_slice(&lengths[..30]);
548 compression.dist_table.fill(0);
549 for i in 0..30 {
550 let length = lengths[i];
551 let code = codes[i];
552 if length == 0 {
553 compression.dist_symbol_masks[i] = 0;
554 compression.dist_symbol_codes[i] = 0xffff;
555 } else {
556 compression.dist_symbol_masks[i] = (1 << lengths[i]) - 1;
557 if lengths[i] <= 9 {
558 let mut j = code;
559 while j < 512 {
560 compression.dist_table[j as usize] = (DIST_SYM_TO_DIST_BASE[i] as u32)
561 << 16
562 | (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8
563 | length as u32;
564 j += 1 << lengths[i];
565 }
566 }
567 }
568 }
569 }
570
571 Ok(())
572 }
573
574 fn read_compressed(
575 &mut self,
576 remaining_input: &mut &[u8],
577 output: &mut [u8],
578 mut output_index: usize,
579 ) -> Result<usize, DecompressionError> {
580 while let State::CompressedData = self.state {
581 self.fill_buffer(remaining_input);
582 if output_index == output.len() {
583 break;
584 }
585
586 let mut bits = self.buffer;
587 let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize];
588 let litlen_code_bits = litlen_entry as u8;
589
590 if litlen_entry & LITERAL_ENTRY != 0 {
591 // Ultra-fast path: do 3 more consecutive table lookups and bail if any of them need the slow path.
592 if self.nbits >= 48 {
593 let litlen_entry2 =
594 self.compression.litlen_table[(bits >> litlen_code_bits & 0xfff) as usize];
595 let litlen_code_bits2 = litlen_entry2 as u8;
596 let litlen_entry3 = self.compression.litlen_table
597 [(bits >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize];
598 let litlen_code_bits3 = litlen_entry3 as u8;
599 let litlen_entry4 = self.compression.litlen_table[(bits
600 >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
601 & 0xfff)
602 as usize];
603 let litlen_code_bits4 = litlen_entry4 as u8;
604 if litlen_entry2 & litlen_entry3 & litlen_entry4 & LITERAL_ENTRY != 0 {
605 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
606 let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
607 let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
608 let advance_output_bytes4 = ((litlen_entry4 & 0xf00) >> 8) as usize;
609 if output_index
610 + advance_output_bytes
611 + advance_output_bytes2
612 + advance_output_bytes3
613 + advance_output_bytes4
614 < output.len()
615 {
616 self.consume_bits(
617 litlen_code_bits
618 + litlen_code_bits2
619 + litlen_code_bits3
620 + litlen_code_bits4,
621 );
622
623 output[output_index] = (litlen_entry >> 16) as u8;
624 output[output_index + 1] = (litlen_entry >> 24) as u8;
625 output_index += advance_output_bytes;
626 output[output_index] = (litlen_entry2 >> 16) as u8;
627 output[output_index + 1] = (litlen_entry2 >> 24) as u8;
628 output_index += advance_output_bytes2;
629 output[output_index] = (litlen_entry3 >> 16) as u8;
630 output[output_index + 1] = (litlen_entry3 >> 24) as u8;
631 output_index += advance_output_bytes3;
632 output[output_index] = (litlen_entry4 >> 16) as u8;
633 output[output_index + 1] = (litlen_entry4 >> 24) as u8;
634 output_index += advance_output_bytes4;
635 continue;
636 }
637 }
638 }
639
640 // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the
641 // output bytes and we can directly write them to the output buffer.
642 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
643
644 // match advance_output_bytes {
645 // 1 => println!("[{output_index}] LIT1 {}", litlen_entry >> 16),
646 // 2 => println!(
647 // "[{output_index}] LIT2 {} {} {}",
648 // (litlen_entry >> 16) as u8,
649 // litlen_entry >> 24,
650 // bits & 0xfff
651 // ),
652 // n => println!(
653 // "[{output_index}] LIT{n} {} {}",
654 // (litlen_entry >> 16) as u8,
655 // litlen_entry >> 24,
656 // ),
657 // }
658
659 if self.nbits < litlen_code_bits {
660 break;
661 } else if output_index + 1 < output.len() {
662 output[output_index] = (litlen_entry >> 16) as u8;
663 output[output_index + 1] = (litlen_entry >> 24) as u8;
664 output_index += advance_output_bytes;
665 self.consume_bits(litlen_code_bits);
666 continue;
667 } else if output_index + advance_output_bytes == output.len() {
668 debug_assert_eq!(advance_output_bytes, 1);
669 output[output_index] = (litlen_entry >> 16) as u8;
670 output_index += 1;
671 self.consume_bits(litlen_code_bits);
672 break;
673 } else {
674 debug_assert_eq!(advance_output_bytes, 2);
675 output[output_index] = (litlen_entry >> 16) as u8;
676 self.queued_rle = Some(((litlen_entry >> 24) as u8, 1));
677 output_index += 1;
678 self.consume_bits(litlen_code_bits);
679 break;
680 }
681 }
682
683 let (length_base, length_extra_bits, litlen_code_bits) =
684 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
685 (
686 litlen_entry >> 16,
687 (litlen_entry >> 8) as u8,
688 litlen_code_bits,
689 )
690 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
691 let secondary_index = litlen_entry >> 16;
692 let secondary_entry = self.compression.secondary_table
693 [secondary_index as usize + ((bits >> 12) & 0x7) as usize];
694 let litlen_symbol = secondary_entry >> 4;
695 let litlen_code_bits = (secondary_entry & 0xf) as u8;
696
697 if self.nbits < litlen_code_bits {
698 break;
699 } else if litlen_symbol < 256 {
700 // println!("[{output_index}] LIT1b {} (val={:04x})", litlen_symbol, self.peak_bits(15));
701
702 self.consume_bits(litlen_code_bits);
703 output[output_index] = litlen_symbol as u8;
704 output_index += 1;
705 continue;
706 } else if litlen_symbol == 256 {
707 // println!("[{output_index}] EOF");
708 self.consume_bits(litlen_code_bits);
709 self.state = match self.last_block {
710 true => State::Checksum,
711 false => State::BlockHeader,
712 };
713 break;
714 }
715
716 (
717 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
718 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
719 litlen_code_bits,
720 )
721 } else if litlen_code_bits == 0 {
722 return Err(DecompressionError::InvalidLiteralLengthCode);
723 } else {
724 if self.nbits < litlen_code_bits {
725 break;
726 }
727 // println!("[{output_index}] EOF");
728 self.consume_bits(litlen_code_bits);
729 self.state = match self.last_block {
730 true => State::Checksum,
731 false => State::BlockHeader,
732 };
733 break;
734 };
735 bits >>= litlen_code_bits;
736
737 let length_extra_mask = (1 << length_extra_bits) - 1;
738 let length = length_base as usize + (bits & length_extra_mask) as usize;
739 bits >>= length_extra_bits;
740
741 let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
742 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry != 0 {
743 (
744 (dist_entry >> 16) as u16,
745 (dist_entry >> 8) as u8,
746 dist_entry as u8,
747 )
748 } else {
749 let mut dist_extra_bits = 0;
750 let mut dist_base = 0;
751 let mut dist_advance_bits = 0;
752 for i in 0..self.compression.dist_symbol_lengths.len() {
753 if bits as u16 & self.compression.dist_symbol_masks[i]
754 == self.compression.dist_symbol_codes[i]
755 {
756 dist_extra_bits = DIST_SYM_TO_DIST_EXTRA[i];
757 dist_base = DIST_SYM_TO_DIST_BASE[i];
758 dist_advance_bits = self.compression.dist_symbol_lengths[i];
759 break;
760 }
761 }
762 if dist_advance_bits == 0 {
763 return Err(DecompressionError::InvalidDistanceCode);
764 }
765 (dist_base, dist_extra_bits, dist_advance_bits)
766 };
767 bits >>= dist_code_bits;
768
769 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
770 let total_bits =
771 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
772
773 if self.nbits < total_bits {
774 break;
775 } else if dist > output_index {
776 return Err(DecompressionError::DistanceTooFarBack);
777 }
778
779 // println!("[{output_index}] BACKREF len={} dist={} {:x}", length, dist, dist_entry);
780 self.consume_bits(total_bits);
781
782 let copy_length = length.min(output.len() - output_index);
783 if dist == 1 {
784 let last = output[output_index - 1];
785 output[output_index..][..copy_length].fill(last);
786
787 if copy_length < length {
788 self.queued_rle = Some((last, length - copy_length));
789 output_index = output.len();
790 break;
791 }
792 } else if output_index + length + 15 <= output.len() {
793 let start = output_index - dist;
794 output.copy_within(start..start + 16, output_index);
795
796 if length > 16 || dist < 16 {
797 for i in (0..length).step_by(dist.min(16)).skip(1) {
798 output.copy_within(start + i..start + i + 16, output_index + i);
799 }
800 }
801 } else {
802 if dist < copy_length {
803 for i in 0..copy_length {
804 output[output_index + i] = output[output_index + i - dist];
805 }
806 } else {
807 output.copy_within(
808 output_index - dist..output_index + copy_length - dist,
809 output_index,
810 )
811 }
812
813 if copy_length < length {
814 self.queued_backref = Some((dist, length - copy_length));
815 output_index = output.len();
816 break;
817 }
818 }
819 output_index += copy_length;
820 }
821
822 if self.state == State::CompressedData
823 && self.queued_backref.is_none()
824 && self.queued_rle.is_none()
825 && self.nbits >= 15
826 && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code
827 {
828 self.consume_bits(self.compression.eof_bits);
829 self.state = match self.last_block {
830 true => State::Checksum,
831 false => State::BlockHeader,
832 };
833 }
834
835 Ok(output_index)
836 }
837
838 /// Decompresses a chunk of data.
839 ///
840 /// Returns the number of bytes read from `input` and the number of bytes written to `output`,
841 /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is
842 /// the buffer to write the decompressed data to, starting at index `output_position`.
843 /// `end_of_input` indicates whether more data may be available in the future.
844 ///
845 /// The contents of `output` after `output_position` are ignored. However, this function may
846 /// write additional data to `output` past what is indicated by the return value.
847 ///
848 /// When this function returns `Ok`, at least one of the following is true:
849 /// - The input is fully consumed.
850 /// - The output is full but there are more bytes to output.
851 /// - The deflate stream is complete (and `is_done` will return true).
852 ///
853 /// # Panics
854 ///
855 /// This function will panic if `output_position` is out of bounds.
856 pub fn read(
857 &mut self,
858 input: &[u8],
859 output: &mut [u8],
860 output_position: usize,
861 end_of_input: bool,
862 ) -> Result<(usize, usize), DecompressionError> {
863 if let State::Done = self.state {
864 return Ok((0, 0));
865 }
866
867 assert!(output_position <= output.len());
868
869 let mut remaining_input = input;
870 let mut output_index = output_position;
871
872 if let Some((data, len)) = self.queued_rle.take() {
873 let n = len.min(output.len() - output_index);
874 output[output_index..][..n].fill(data);
875 output_index += n;
876 if n < len {
877 self.queued_rle = Some((data, len - n));
878 return Ok((0, n));
879 }
880 }
881 if let Some((dist, len)) = self.queued_backref.take() {
882 let n = len.min(output.len() - output_index);
883 for i in 0..n {
884 output[output_index + i] = output[output_index + i - dist];
885 }
886 output_index += n;
887 if n < len {
888 self.queued_backref = Some((dist, len - n));
889 return Ok((0, n));
890 }
891 }
892
893 // Main decoding state machine.
894 let mut last_state = None;
895 while last_state != Some(self.state) {
896 last_state = Some(self.state);
897 match self.state {
898 State::ZlibHeader => {
899 self.fill_buffer(&mut remaining_input);
900 if self.nbits < 16 {
901 break;
902 }
903
904 let input0 = self.peak_bits(8);
905 let input1 = self.peak_bits(16) >> 8 & 0xff;
906 if input0 & 0x0f != 0x08
907 || (input0 & 0xf0) > 0x70
908 || input1 & 0x20 != 0
909 || (input0 << 8 | input1) % 31 != 0
910 {
911 return Err(DecompressionError::BadZlibHeader);
912 }
913
914 self.consume_bits(16);
915 self.state = State::BlockHeader;
916 }
917 State::BlockHeader => {
918 self.read_block_header(&mut remaining_input)?;
919 }
920 State::CodeLengthCodes => {
921 self.read_code_length_codes(&mut remaining_input)?;
922 }
923 State::CodeLengths => {
924 self.read_code_lengths(&mut remaining_input)?;
925 }
926 State::CompressedData => {
927 output_index =
928 self.read_compressed(&mut remaining_input, output, output_index)?
929 }
930 State::UncompressedData => {
931 // Drain any bytes from our buffer.
932 debug_assert_eq!(self.nbits % 8, 0);
933 while self.nbits > 0
934 && self.uncompressed_bytes_left > 0
935 && output_index < output.len()
936 {
937 output[output_index] = self.peak_bits(8) as u8;
938 self.consume_bits(8);
939 output_index += 1;
940 self.uncompressed_bytes_left -= 1;
941 }
942 // Buffer may contain one additional byte. Clear it to avoid confusion.
943 if self.nbits == 0 {
944 self.buffer = 0;
945 }
946
947 // Copy subsequent bytes directly from the input.
948 let copy_bytes = (self.uncompressed_bytes_left as usize)
949 .min(remaining_input.len())
950 .min(output.len() - output_index);
951 output[output_index..][..copy_bytes]
952 .copy_from_slice(&remaining_input[..copy_bytes]);
953 remaining_input = &remaining_input[copy_bytes..];
954 output_index += copy_bytes;
955 self.uncompressed_bytes_left -= copy_bytes as u16;
956
957 if self.uncompressed_bytes_left == 0 {
958 self.state = if self.last_block {
959 State::Checksum
960 } else {
961 State::BlockHeader
962 };
963 }
964 }
965 State::Checksum => {
966 self.fill_buffer(&mut remaining_input);
967
968 let align_bits = self.nbits % 8;
969 if self.nbits >= 32 + align_bits {
970 self.checksum.write(&output[output_position..output_index]);
971 if align_bits != 0 {
972 self.consume_bits(align_bits);
973 }
974 #[cfg(not(fuzzing))]
975 if !self.ignore_adler32
976 && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish()
977 {
978 return Err(DecompressionError::WrongChecksum);
979 }
980 self.state = State::Done;
981 self.consume_bits(32);
982 break;
983 }
984 }
985 State::Done => unreachable!(),
986 }
987 }
988
989 if !self.ignore_adler32 && self.state != State::Done {
990 self.checksum.write(&output[output_position..output_index]);
991 }
992
993 if self.state == State::Done || !end_of_input || output_index == output.len() {
994 let input_left = remaining_input.len();
995 Ok((input.len() - input_left, output_index - output_position))
996 } else {
997 Err(DecompressionError::InsufficientInput)
998 }
999 }
1000
1001 /// Returns true if the decompressor has finished decompressing the input.
1002 pub fn is_done(&self) -> bool {
1003 self.state == State::Done
1004 }
1005}
1006
1007/// Decompress the given data.
1008pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1009 match decompress_to_vec_bounded(input, maxlen:usize::MAX) {
1010 Ok(output: Vec) => Ok(output),
1011 Err(BoundedDecompressionError::DecompressionError { inner: DecompressionError }) => Err(inner),
1012 Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1013 unreachable!("Impossible to allocate more than isize::MAX bytes")
1014 }
1015 }
1016}
1017
1018/// An error encountered while decompressing a deflate stream given a bounded maximum output.
1019pub enum BoundedDecompressionError {
1020 /// The input is not a valid deflate stream.
1021 DecompressionError {
1022 /// The underlying error.
1023 inner: DecompressionError,
1024 },
1025
1026 /// The output is too large.
1027 OutputTooLarge {
1028 /// The output decoded so far.
1029 partial_output: Vec<u8>,
1030 },
1031}
1032impl From<DecompressionError> for BoundedDecompressionError {
1033 fn from(inner: DecompressionError) -> Self {
1034 BoundedDecompressionError::DecompressionError { inner }
1035 }
1036}
1037
1038/// Decompress the given data, returning an error if the output is larger than
1039/// `maxlen` bytes.
1040pub fn decompress_to_vec_bounded(
1041 input: &[u8],
1042 maxlen: usize,
1043) -> Result<Vec<u8>, BoundedDecompressionError> {
1044 let mut decoder = Decompressor::new();
1045 let mut output = vec![0; 1024.min(maxlen)];
1046 let mut input_index = 0;
1047 let mut output_index = 0;
1048 loop {
1049 let (consumed, produced) =
1050 decoder.read(&input[input_index..], &mut output, output_index, true)?;
1051 input_index += consumed;
1052 output_index += produced;
1053 if decoder.is_done() || output_index == maxlen {
1054 break;
1055 }
1056 output.resize((output_index + 32 * 1024).min(maxlen), 0);
1057 }
1058 output.resize(output_index, 0);
1059
1060 if decoder.is_done() {
1061 Ok(output)
1062 } else {
1063 Err(BoundedDecompressionError::OutputTooLarge {
1064 partial_output: output,
1065 })
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use crate::tables::{self, LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1072
1073 use super::*;
1074 use rand::Rng;
1075
1076 fn roundtrip(data: &[u8]) {
1077 let compressed = crate::compress_to_vec(data);
1078 let decompressed = decompress_to_vec(&compressed).unwrap();
1079 assert_eq!(&decompressed, data);
1080 }
1081
1082 fn roundtrip_miniz_oxide(data: &[u8]) {
1083 let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1084 let decompressed = decompress_to_vec(&compressed).unwrap();
1085 assert_eq!(decompressed.len(), data.len());
1086 for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1087 assert_eq!(a, b, "chunk {}..{}", i * 1, i * 1 + 1);
1088 }
1089 assert_eq!(&decompressed, data);
1090 }
1091
1092 #[allow(unused)]
1093 fn compare_decompression(data: &[u8]) {
1094 // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data))
1095 // .bytes()
1096 // .collect::<Result<Vec<_>, _>>()
1097 // .unwrap();
1098 let decompressed = decompress_to_vec(&data).unwrap();
1099 let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(&data).unwrap();
1100 for i in 0..decompressed.len().min(decompressed2.len()) {
1101 if decompressed[i] != decompressed2[i] {
1102 panic!(
1103 "mismatch at index {} {:?} {:?}",
1104 i,
1105 &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1106 &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1107 );
1108 }
1109 }
1110 if decompressed != decompressed2 {
1111 panic!(
1112 "length mismatch {} {} {:x?}",
1113 decompressed.len(),
1114 decompressed2.len(),
1115 &decompressed2[decompressed.len()..][..16]
1116 );
1117 }
1118 //assert_eq!(decompressed, decompressed2);
1119 }
1120
1121 #[test]
1122 fn tables() {
1123 for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1124 let len_base = LEN_SYM_TO_LEN_BASE[i];
1125 for j in 0..(1 << bits) {
1126 if i == 27 && j == 31 {
1127 continue;
1128 }
1129 assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1130 assert_eq!(
1131 LENGTH_TO_SYMBOL[len_base + j - 3],
1132 i as u16 + 257,
1133 "{} {}",
1134 i,
1135 j
1136 );
1137 }
1138 }
1139 }
1140
1141 #[test]
1142 fn fdeflate_table() {
1143 let mut compression = CompressedBlock {
1144 litlen_table: [0; 4096],
1145 dist_table: [0; 512],
1146 dist_symbol_lengths: [0; 30],
1147 dist_symbol_masks: [0; 30],
1148 dist_symbol_codes: [0; 30],
1149 secondary_table: Vec::new(),
1150 eof_code: 0,
1151 eof_mask: 0,
1152 eof_bits: 0,
1153 };
1154 let mut lengths = tables::HUFFMAN_LENGTHS.to_vec();
1155 lengths.resize(288, 0);
1156 lengths.push(1);
1157 lengths.resize(320, 0);
1158 Decompressor::build_tables(286, &lengths, &mut compression, 11).unwrap();
1159
1160 assert_eq!(
1161 compression, FDEFLATE_COMPRESSED_BLOCK,
1162 "{:#x?}",
1163 compression
1164 );
1165 }
1166
1167 #[test]
1168 fn it_works() {
1169 roundtrip(b"Hello world!");
1170 }
1171
1172 #[test]
1173 fn constant() {
1174 roundtrip_miniz_oxide(&vec![0; 50]);
1175 roundtrip_miniz_oxide(&vec![5; 2048]);
1176 roundtrip_miniz_oxide(&vec![128; 2048]);
1177 roundtrip_miniz_oxide(&vec![254; 2048]);
1178 }
1179
1180 #[test]
1181 fn random() {
1182 let mut rng = rand::thread_rng();
1183 let mut data = vec![0; 50000];
1184 for _ in 0..10 {
1185 for byte in &mut data {
1186 *byte = rng.gen::<u8>() % 5;
1187 }
1188 println!("Random data: {:?}", data);
1189 roundtrip_miniz_oxide(&data);
1190 }
1191 }
1192
1193 #[test]
1194 fn ignore_adler32() {
1195 let mut compressed = crate::compress_to_vec(b"Hello world!");
1196 let last_byte = compressed.len() - 1;
1197 compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1198
1199 match decompress_to_vec(&compressed) {
1200 Err(DecompressionError::WrongChecksum) => {}
1201 r => panic!("expected WrongChecksum, got {:?}", r),
1202 }
1203
1204 let mut decompressor = Decompressor::new();
1205 decompressor.ignore_adler32();
1206 let mut decompressed = vec![0; 1024];
1207 let decompressed_len = decompressor
1208 .read(&compressed, &mut decompressed, 0, true)
1209 .unwrap()
1210 .1;
1211 assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1212 }
1213
1214 #[test]
1215 fn checksum_after_eof() {
1216 let input = b"Hello world!";
1217 let compressed = crate::compress_to_vec(input);
1218
1219 let mut decompressor = Decompressor::new();
1220 let mut decompressed = vec![0; 1024];
1221 let (input_consumed, output_written) = decompressor
1222 .read(
1223 &compressed[..compressed.len() - 1],
1224 &mut decompressed,
1225 0,
1226 false,
1227 )
1228 .unwrap();
1229 assert_eq!(output_written, input.len());
1230 assert_eq!(input_consumed, compressed.len() - 1);
1231
1232 let (input_consumed, output_written) = decompressor
1233 .read(
1234 &compressed[input_consumed..],
1235 &mut decompressed[..output_written],
1236 output_written,
1237 true,
1238 )
1239 .unwrap();
1240 assert!(decompressor.is_done());
1241 assert_eq!(input_consumed, 1);
1242 assert_eq!(output_written, 0);
1243
1244 assert_eq!(&decompressed[..input.len()], input);
1245 }
1246
1247 #[test]
1248 fn zero_length() {
1249 let mut compressed = crate::compress_to_vec(b"").to_vec();
1250
1251 // Splice in zero-length non-compressed blocks.
1252 for _ in 0..10 {
1253 println!("compressed len: {}", compressed.len());
1254 compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1255 }
1256
1257 // Ensure that the full input is decompressed, regardless of whether
1258 // `end_of_input` is set.
1259 for end_of_input in [true, false] {
1260 let mut decompressor = Decompressor::new();
1261 let (input_consumed, output_written) = decompressor
1262 .read(&compressed, &mut [], 0, end_of_input)
1263 .unwrap();
1264
1265 assert!(decompressor.is_done());
1266 assert_eq!(input_consumed, compressed.len());
1267 assert_eq!(output_written, 0);
1268 }
1269 }
1270}
1271