| 1 | //! This module contains the [decompress_literals] function, used to take a |
| 2 | //! parsed literals header and a source and decompress it. |
| 3 | |
| 4 | use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType}; |
| 5 | use super::bit_reader_reverse::{BitReaderReversed, GetBitsError}; |
| 6 | use super::scratch::HuffmanScratch; |
| 7 | use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError}; |
| 8 | use alloc::vec::Vec; |
| 9 | |
| 10 | #[derive (Debug)] |
| 11 | #[non_exhaustive ] |
| 12 | pub enum DecompressLiteralsError { |
| 13 | MissingCompressedSize, |
| 14 | MissingNumStreams, |
| 15 | GetBitsError(GetBitsError), |
| 16 | HuffmanTableError(HuffmanTableError), |
| 17 | HuffmanDecoderError(HuffmanDecoderError), |
| 18 | UninitializedHuffmanTable, |
| 19 | MissingBytesForJumpHeader { got: usize }, |
| 20 | MissingBytesForLiterals { got: usize, needed: usize }, |
| 21 | ExtraPadding { skipped_bits: i32 }, |
| 22 | BitstreamReadMismatch { read_til: isize, expected: isize }, |
| 23 | DecodedLiteralCountMismatch { decoded: usize, expected: usize }, |
| 24 | } |
| 25 | |
| 26 | #[cfg (feature = "std" )] |
| 27 | impl std::error::Error for DecompressLiteralsError { |
| 28 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |
| 29 | match self { |
| 30 | DecompressLiteralsError::GetBitsError(source: &GetBitsError) => Some(source), |
| 31 | DecompressLiteralsError::HuffmanTableError(source: &HuffmanTableError) => Some(source), |
| 32 | DecompressLiteralsError::HuffmanDecoderError(source: &HuffmanDecoderError) => Some(source), |
| 33 | _ => None, |
| 34 | } |
| 35 | } |
| 36 | } |
| 37 | impl core::fmt::Display for DecompressLiteralsError { |
| 38 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| 39 | match self { |
| 40 | DecompressLiteralsError::MissingCompressedSize => { |
| 41 | write!(f, |
| 42 | "compressed size was none even though it must be set to something for compressed literals" , |
| 43 | ) |
| 44 | } |
| 45 | DecompressLiteralsError::MissingNumStreams => { |
| 46 | write!(f, |
| 47 | "num_streams was none even though it must be set to something (1 or 4) for compressed literals" , |
| 48 | ) |
| 49 | } |
| 50 | DecompressLiteralsError::GetBitsError(e) => write!(f, " {:?}" , e), |
| 51 | DecompressLiteralsError::HuffmanTableError(e) => write!(f, " {:?}" , e), |
| 52 | DecompressLiteralsError::HuffmanDecoderError(e) => write!(f, " {:?}" , e), |
| 53 | DecompressLiteralsError::UninitializedHuffmanTable => { |
| 54 | write!( |
| 55 | f, |
| 56 | "Tried to reuse huffman table but it was never initialized" , |
| 57 | ) |
| 58 | } |
| 59 | DecompressLiteralsError::MissingBytesForJumpHeader { got } => { |
| 60 | write!(f, "Need 6 bytes to decode jump header, got {} bytes" , got,) |
| 61 | } |
| 62 | DecompressLiteralsError::MissingBytesForLiterals { got, needed } => { |
| 63 | write!( |
| 64 | f, |
| 65 | "Need at least {} bytes to decode literals. Have: {} bytes" , |
| 66 | needed, got, |
| 67 | ) |
| 68 | } |
| 69 | DecompressLiteralsError::ExtraPadding { skipped_bits } => { |
| 70 | write!(f, |
| 71 | "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption" , |
| 72 | skipped_bits, |
| 73 | ) |
| 74 | } |
| 75 | DecompressLiteralsError::BitstreamReadMismatch { read_til, expected } => { |
| 76 | write!( |
| 77 | f, |
| 78 | "Bitstream was read till: {}, should have been: {}" , |
| 79 | read_til, expected, |
| 80 | ) |
| 81 | } |
| 82 | DecompressLiteralsError::DecodedLiteralCountMismatch { decoded, expected } => { |
| 83 | write!( |
| 84 | f, |
| 85 | "Did not decode enough literals: {}, Should have been: {}" , |
| 86 | decoded, expected, |
| 87 | ) |
| 88 | } |
| 89 | } |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | impl From<HuffmanDecoderError> for DecompressLiteralsError { |
| 94 | fn from(val: HuffmanDecoderError) -> Self { |
| 95 | Self::HuffmanDecoderError(val) |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | impl From<GetBitsError> for DecompressLiteralsError { |
| 100 | fn from(val: GetBitsError) -> Self { |
| 101 | Self::GetBitsError(val) |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | impl From<HuffmanTableError> for DecompressLiteralsError { |
| 106 | fn from(val: HuffmanTableError) -> Self { |
| 107 | Self::HuffmanTableError(val) |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | /// Decode and decompress the provided literals section into `target`, returning the number of bytes read. |
| 112 | pub fn decode_literals( |
| 113 | section: &LiteralsSection, |
| 114 | scratch: &mut HuffmanScratch, |
| 115 | source: &[u8], |
| 116 | target: &mut Vec<u8>, |
| 117 | ) -> Result<u32, DecompressLiteralsError> { |
| 118 | match section.ls_type { |
| 119 | LiteralsSectionType::Raw => { |
| 120 | target.extend(&source[0..section.regenerated_size as usize]); |
| 121 | Ok(section.regenerated_size) |
| 122 | } |
| 123 | LiteralsSectionType::RLE => { |
| 124 | target.resize(new_len:target.len() + section.regenerated_size as usize, value:source[0]); |
| 125 | Ok(1) |
| 126 | } |
| 127 | LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => { |
| 128 | let bytes_read: u32 = decompress_literals(section, scratch, source, target)?; |
| 129 | |
| 130 | //return sum of used bytes |
| 131 | Ok(bytes_read) |
| 132 | } |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | /// Decompress the provided literals section and source into the provided `target`. |
| 137 | /// This function is used when the literals section is `Compressed` or `Treeless` |
| 138 | /// |
| 139 | /// Returns the number of bytes read. |
| 140 | fn decompress_literals( |
| 141 | section: &LiteralsSection, |
| 142 | scratch: &mut HuffmanScratch, |
| 143 | source: &[u8], |
| 144 | target: &mut Vec<u8>, |
| 145 | ) -> Result<u32, DecompressLiteralsError> { |
| 146 | use DecompressLiteralsError as err; |
| 147 | |
| 148 | let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize; |
| 149 | let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?; |
| 150 | |
| 151 | target.reserve(section.regenerated_size as usize); |
| 152 | let source = &source[0..compressed_size]; |
| 153 | let mut bytes_read = 0; |
| 154 | |
| 155 | match section.ls_type { |
| 156 | LiteralsSectionType::Compressed => { |
| 157 | //read Huffman tree description |
| 158 | bytes_read += scratch.table.build_decoder(source)?; |
| 159 | vprintln!("Built huffman table using {} bytes" , bytes_read); |
| 160 | } |
| 161 | LiteralsSectionType::Treeless => { |
| 162 | if scratch.table.max_num_bits == 0 { |
| 163 | return Err(err::UninitializedHuffmanTable); |
| 164 | } |
| 165 | } |
| 166 | _ => { /* nothing to do, huffman tree has been provided by previous block */ } |
| 167 | } |
| 168 | |
| 169 | let source = &source[bytes_read as usize..]; |
| 170 | |
| 171 | if num_streams == 4 { |
| 172 | //build jumptable |
| 173 | if source.len() < 6 { |
| 174 | return Err(err::MissingBytesForJumpHeader { got: source.len() }); |
| 175 | } |
| 176 | let jump1 = source[0] as usize + ((source[1] as usize) << 8); |
| 177 | let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8); |
| 178 | let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8); |
| 179 | bytes_read += 6; |
| 180 | let source = &source[6..]; |
| 181 | |
| 182 | if source.len() < jump3 { |
| 183 | return Err(err::MissingBytesForLiterals { |
| 184 | got: source.len(), |
| 185 | needed: jump3, |
| 186 | }); |
| 187 | } |
| 188 | |
| 189 | //decode 4 streams |
| 190 | let stream1 = &source[..jump1]; |
| 191 | let stream2 = &source[jump1..jump2]; |
| 192 | let stream3 = &source[jump2..jump3]; |
| 193 | let stream4 = &source[jump3..]; |
| 194 | |
| 195 | for stream in &[stream1, stream2, stream3, stream4] { |
| 196 | let mut decoder = HuffmanDecoder::new(&scratch.table); |
| 197 | let mut br = BitReaderReversed::new(stream); |
| 198 | //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found |
| 199 | let mut skipped_bits = 0; |
| 200 | loop { |
| 201 | let val = br.get_bits(1); |
| 202 | skipped_bits += 1; |
| 203 | if val == 1 || skipped_bits > 8 { |
| 204 | break; |
| 205 | } |
| 206 | } |
| 207 | if skipped_bits > 8 { |
| 208 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
| 209 | return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); |
| 210 | } |
| 211 | decoder.init_state(&mut br); |
| 212 | |
| 213 | while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { |
| 214 | target.push(decoder.decode_symbol()); |
| 215 | decoder.next_state(&mut br); |
| 216 | } |
| 217 | if br.bits_remaining() != -(scratch.table.max_num_bits as isize) { |
| 218 | return Err(DecompressLiteralsError::BitstreamReadMismatch { |
| 219 | read_til: br.bits_remaining(), |
| 220 | expected: -(scratch.table.max_num_bits as isize), |
| 221 | }); |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | bytes_read += source.len() as u32; |
| 226 | } else { |
| 227 | //just decode the one stream |
| 228 | assert!(num_streams == 1); |
| 229 | let mut decoder = HuffmanDecoder::new(&scratch.table); |
| 230 | let mut br = BitReaderReversed::new(source); |
| 231 | let mut skipped_bits = 0; |
| 232 | loop { |
| 233 | let val = br.get_bits(1); |
| 234 | skipped_bits += 1; |
| 235 | if val == 1 || skipped_bits > 8 { |
| 236 | break; |
| 237 | } |
| 238 | } |
| 239 | if skipped_bits > 8 { |
| 240 | //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data |
| 241 | return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); |
| 242 | } |
| 243 | decoder.init_state(&mut br); |
| 244 | while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { |
| 245 | target.push(decoder.decode_symbol()); |
| 246 | decoder.next_state(&mut br); |
| 247 | } |
| 248 | bytes_read += source.len() as u32; |
| 249 | } |
| 250 | |
| 251 | if target.len() != section.regenerated_size as usize { |
| 252 | return Err(DecompressLiteralsError::DecodedLiteralCountMismatch { |
| 253 | decoded: target.len(), |
| 254 | expected: section.regenerated_size as usize, |
| 255 | }); |
| 256 | } |
| 257 | |
| 258 | Ok(bytes_read) |
| 259 | } |
| 260 | |