1use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
2use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
3use super::scratch::HuffmanScratch;
4use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError};
5use alloc::vec::Vec;
6
7#[derive(Debug, derive_more::Display, derive_more::From)]
8#[cfg_attr(feature = "std", derive(derive_more::Error))]
9#[non_exhaustive]
10pub enum DecompressLiteralsError {
11 #[display(
12 fmt = "compressed size was none even though it must be set to something for compressed literals"
13 )]
14 MissingCompressedSize,
15 #[display(
16 fmt = "num_streams was none even though it must be set to something (1 or 4) for compressed literals"
17 )]
18 MissingNumStreams,
19 #[display(fmt = "{_0:?}")]
20 #[from]
21 GetBitsError(GetBitsError),
22 #[display(fmt = "{_0:?}")]
23 #[from]
24 HuffmanTableError(HuffmanTableError),
25 #[display(fmt = "{_0:?}")]
26 #[from]
27 HuffmanDecoderError(HuffmanDecoderError),
28 #[display(fmt = "Tried to reuse huffman table but it was never initialized")]
29 UninitializedHuffmanTable,
30 #[display(fmt = "Need 6 bytes to decode jump header, got {got} bytes")]
31 MissingBytesForJumpHeader { got: usize },
32 #[display(fmt = "Need at least {needed} bytes to decode literals. Have: {got} bytes")]
33 MissingBytesForLiterals { got: usize, needed: usize },
34 #[display(
35 fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption"
36 )]
37 ExtraPadding { skipped_bits: i32 },
38 #[display(fmt = "Bitstream was read till: {read_til}, should have been: {expected}")]
39 BitstreamReadMismatch { read_til: isize, expected: isize },
40 #[display(fmt = "Did not decode enough literals: {decoded}, Should have been: {expected}")]
41 DecodedLiteralCountMismatch { decoded: usize, expected: usize },
42}
43
44pub fn decode_literals(
45 section: &LiteralsSection,
46 scratch: &mut HuffmanScratch,
47 source: &[u8],
48 target: &mut Vec<u8>,
49) -> Result<u32, DecompressLiteralsError> {
50 match section.ls_type {
51 LiteralsSectionType::Raw => {
52 target.extend(&source[0..section.regenerated_size as usize]);
53 Ok(section.regenerated_size)
54 }
55 LiteralsSectionType::RLE => {
56 target.resize(new_len:target.len() + section.regenerated_size as usize, value:source[0]);
57 Ok(1)
58 }
59 LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
60 let bytes_read: u32 = decompress_literals(section, scratch, source, target)?;
61
62 //return sum of used bytes
63 Ok(bytes_read)
64 }
65 }
66}
67
68fn decompress_literals(
69 section: &LiteralsSection,
70 scratch: &mut HuffmanScratch,
71 source: &[u8],
72 target: &mut Vec<u8>,
73) -> Result<u32, DecompressLiteralsError> {
74 use DecompressLiteralsError as err;
75
76 let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
77 let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;
78
79 target.reserve(section.regenerated_size as usize);
80 let source = &source[0..compressed_size];
81 let mut bytes_read = 0;
82
83 match section.ls_type {
84 LiteralsSectionType::Compressed => {
85 //read Huffman tree description
86 bytes_read += scratch.table.build_decoder(source)?;
87 vprintln!("Built huffman table using {} bytes", bytes_read);
88 }
89 LiteralsSectionType::Treeless => {
90 if scratch.table.max_num_bits == 0 {
91 return Err(err::UninitializedHuffmanTable);
92 }
93 }
94 _ => { /* nothing to do, huffman tree has been provided by previous block */ }
95 }
96
97 let source = &source[bytes_read as usize..];
98
99 if num_streams == 4 {
100 //build jumptable
101 if source.len() < 6 {
102 return Err(err::MissingBytesForJumpHeader { got: source.len() });
103 }
104 let jump1 = source[0] as usize + ((source[1] as usize) << 8);
105 let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
106 let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
107 bytes_read += 6;
108 let source = &source[6..];
109
110 if source.len() < jump3 {
111 return Err(err::MissingBytesForLiterals {
112 got: source.len(),
113 needed: jump3,
114 });
115 }
116
117 //decode 4 streams
118 let stream1 = &source[..jump1];
119 let stream2 = &source[jump1..jump2];
120 let stream3 = &source[jump2..jump3];
121 let stream4 = &source[jump3..];
122
123 for stream in &[stream1, stream2, stream3, stream4] {
124 let mut decoder = HuffmanDecoder::new(&scratch.table);
125 let mut br = BitReaderReversed::new(stream);
126 //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
127 let mut skipped_bits = 0;
128 loop {
129 let val = br.get_bits(1)?;
130 skipped_bits += 1;
131 if val == 1 || skipped_bits > 8 {
132 break;
133 }
134 }
135 if skipped_bits > 8 {
136 //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
137 return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
138 }
139 decoder.init_state(&mut br)?;
140
141 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
142 target.push(decoder.decode_symbol());
143 decoder.next_state(&mut br)?;
144 }
145 if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
146 return Err(DecompressLiteralsError::BitstreamReadMismatch {
147 read_til: br.bits_remaining(),
148 expected: -(scratch.table.max_num_bits as isize),
149 });
150 }
151 }
152
153 bytes_read += source.len() as u32;
154 } else {
155 //just decode the one stream
156 assert!(num_streams == 1);
157 let mut decoder = HuffmanDecoder::new(&scratch.table);
158 let mut br = BitReaderReversed::new(source);
159 let mut skipped_bits = 0;
160 loop {
161 let val = br.get_bits(1)?;
162 skipped_bits += 1;
163 if val == 1 || skipped_bits > 8 {
164 break;
165 }
166 }
167 if skipped_bits > 8 {
168 //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
169 return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
170 }
171 decoder.init_state(&mut br)?;
172 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
173 target.push(decoder.decode_symbol());
174 decoder.next_state(&mut br)?;
175 }
176 bytes_read += source.len() as u32;
177 }
178
179 if target.len() != section.regenerated_size as usize {
180 return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
181 decoded: target.len(),
182 expected: section.regenerated_size as usize,
183 });
184 }
185
186 Ok(bytes_read)
187}
188