1 | use super::frame; |
2 | use crate::decoding::dictionary::Dictionary; |
3 | use crate::decoding::scratch::DecoderScratch; |
4 | use crate::decoding::{self, dictionary}; |
5 | use crate::io::{Error, Read, Write}; |
6 | use alloc::collections::BTreeMap; |
7 | use alloc::vec::Vec; |
8 | use core::convert::TryInto; |
9 | use core::hash::Hasher; |
10 | |
11 | /// This implements a decoder for zstd frames. This decoder is able to decode frames only partially and gives control |
12 | /// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once). |
13 | /// It reads bytes as needed from a provided source and can be read from to collect partial results. |
14 | /// |
15 | /// If you want to just read the whole frame with an io::Read without having to deal with manually calling decode_blocks |
16 | /// you can use the provided StreamingDecoder with wraps this FrameDecoder |
17 | /// |
18 | /// Workflow is as follows: |
19 | /// ``` |
20 | /// use ruzstd::frame_decoder::BlockDecodingStrategy; |
21 | /// |
22 | /// # #[cfg (feature = "std" )] |
23 | /// use std::io::{Read, Write}; |
24 | /// |
25 | /// // no_std environments can use the crate's own Read traits |
26 | /// # #[cfg (not(feature = "std" ))] |
27 | /// use ruzstd::io::{Read, Write}; |
28 | /// |
29 | /// fn decode_this(mut file: impl Read) { |
30 | /// //Create a new decoder |
31 | /// let mut frame_dec = ruzstd::FrameDecoder::new(); |
32 | /// let mut result = Vec::new(); |
33 | /// |
34 | /// // Use reset or init to make the decoder ready to decode the frame from the io::Read |
35 | /// frame_dec.reset(&mut file).unwrap(); |
36 | /// |
37 | /// // Loop until the frame has been decoded completely |
38 | /// while !frame_dec.is_finished() { |
39 | /// // decode (roughly) batch_size many bytes |
40 | /// frame_dec.decode_blocks(&mut file, BlockDecodingStrategy::UptoBytes(1024)).unwrap(); |
41 | /// |
42 | /// // read from the decoder to collect bytes from the internal buffer |
43 | /// let bytes_read = frame_dec.read(result.as_mut_slice()).unwrap(); |
44 | /// |
45 | /// // then do something with it |
46 | /// do_something(&result[0..bytes_read]); |
47 | /// } |
48 | /// |
49 | /// // handle the last chunk of data |
50 | /// while frame_dec.can_collect() > 0 { |
51 | /// let x = frame_dec.read(result.as_mut_slice()).unwrap(); |
52 | /// |
53 | /// do_something(&result[0..x]); |
54 | /// } |
55 | /// } |
56 | /// |
57 | /// fn do_something(data: &[u8]) { |
58 | /// # #[cfg (feature = "std" )] |
59 | /// std::io::stdout().write_all(data).unwrap(); |
60 | /// } |
61 | /// ``` |
62 | pub struct FrameDecoder { |
63 | state: Option<FrameDecoderState>, |
64 | dicts: BTreeMap<u32, Dictionary>, |
65 | } |
66 | |
67 | struct FrameDecoderState { |
68 | pub frame: frame::Frame, |
69 | decoder_scratch: DecoderScratch, |
70 | frame_finished: bool, |
71 | block_counter: usize, |
72 | bytes_read_counter: u64, |
73 | check_sum: Option<u32>, |
74 | using_dict: Option<u32>, |
75 | } |
76 | |
77 | pub enum BlockDecodingStrategy { |
78 | All, |
79 | UptoBlocks(usize), |
80 | UptoBytes(usize), |
81 | } |
82 | |
83 | #[derive (Debug, derive_more::Display, derive_more::From)] |
84 | #[cfg_attr (feature = "std" , derive(derive_more::Error))] |
85 | #[non_exhaustive ] |
86 | pub enum FrameDecoderError { |
87 | #[display(fmt = "{_0:?}" )] |
88 | #[from] |
89 | ReadFrameHeaderError(frame::ReadFrameHeaderError), |
90 | #[display(fmt = "{_0:?}" )] |
91 | #[from] |
92 | FrameHeaderError(frame::FrameHeaderError), |
93 | #[display( |
94 | fmt = "Specified window_size is too big; Requested: {requested}, Max: {MAX_WINDOW_SIZE}" |
95 | )] |
96 | WindowSizeTooBig { requested: u64 }, |
97 | #[display(fmt = "{_0:?}" )] |
98 | #[from] |
99 | DictionaryDecodeError(dictionary::DictionaryDecodeError), |
100 | #[display(fmt = "Failed to parse/decode block body: {_0}" )] |
101 | #[from] |
102 | FailedToReadBlockHeader(decoding::block_decoder::BlockHeaderReadError), |
103 | #[display(fmt = "Failed to parse block header: {_0}" )] |
104 | FailedToReadBlockBody(decoding::block_decoder::DecodeBlockContentError), |
105 | #[display(fmt = "Failed to read checksum: {_0}" )] |
106 | FailedToReadChecksum(Error), |
107 | #[display(fmt = "Decoder must initialized or reset before using it" )] |
108 | NotYetInitialized, |
109 | #[display(fmt = "Decoder encountered error while initializing: {_0}" )] |
110 | FailedToInitialize(frame::FrameHeaderError), |
111 | #[display(fmt = "Decoder encountered error while draining the decodebuffer: {_0}" )] |
112 | FailedToDrainDecodebuffer(Error), |
113 | #[display( |
114 | fmt = "Target must have at least as many bytes as the contentsize of the frame reports" |
115 | )] |
116 | TargetTooSmall, |
117 | #[display( |
118 | fmt = "Frame header specified dictionary id 0x{dict_id:X} that wasnt provided by add_dict() or reset_with_dict()" |
119 | )] |
120 | DictNotProvided { dict_id: u32 }, |
121 | } |
122 | |
123 | const MAX_WINDOW_SIZE: u64 = 1024 * 1024 * 100; |
124 | |
125 | impl FrameDecoderState { |
126 | pub fn new(source: impl Read) -> Result<FrameDecoderState, FrameDecoderError> { |
127 | let (frame, header_size) = frame::read_frame_header(source)?; |
128 | let window_size = frame.header.window_size()?; |
129 | Ok(FrameDecoderState { |
130 | frame, |
131 | frame_finished: false, |
132 | block_counter: 0, |
133 | decoder_scratch: DecoderScratch::new(window_size as usize), |
134 | bytes_read_counter: u64::from(header_size), |
135 | check_sum: None, |
136 | using_dict: None, |
137 | }) |
138 | } |
139 | |
140 | pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> { |
141 | let (frame, header_size) = frame::read_frame_header(source)?; |
142 | let window_size = frame.header.window_size()?; |
143 | |
144 | if window_size > MAX_WINDOW_SIZE { |
145 | return Err(FrameDecoderError::WindowSizeTooBig { |
146 | requested: window_size, |
147 | }); |
148 | } |
149 | |
150 | self.frame = frame; |
151 | self.frame_finished = false; |
152 | self.block_counter = 0; |
153 | self.decoder_scratch.reset(window_size as usize); |
154 | self.bytes_read_counter = u64::from(header_size); |
155 | self.check_sum = None; |
156 | self.using_dict = None; |
157 | Ok(()) |
158 | } |
159 | } |
160 | |
161 | impl Default for FrameDecoder { |
162 | fn default() -> Self { |
163 | Self::new() |
164 | } |
165 | } |
166 | |
167 | impl FrameDecoder { |
168 | /// This will create a new decoder without allocating anything yet. |
169 | /// init()/reset() will allocate all needed buffers if it is the first time this decoder is used |
170 | /// else they just reset these buffers with not further allocations |
171 | pub fn new() -> FrameDecoder { |
172 | FrameDecoder { |
173 | state: None, |
174 | dicts: BTreeMap::new(), |
175 | } |
176 | } |
177 | |
178 | /// init() will allocate all needed buffers if it is the first time this decoder is used |
179 | /// else they just reset these buffers with not further allocations |
180 | /// |
181 | /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer() |
182 | /// |
183 | /// equivalent to reset() |
184 | pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> { |
185 | self.reset(source) |
186 | } |
187 | |
188 | /// reset() will allocate all needed buffers if it is the first time this decoder is used |
189 | /// else they just reset these buffers with not further allocations |
190 | /// |
191 | /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer() |
192 | /// |
193 | /// equivalent to init() |
194 | pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> { |
195 | use FrameDecoderError as err; |
196 | let state = match &mut self.state { |
197 | Some(s) => { |
198 | s.reset(source)?; |
199 | s |
200 | } |
201 | None => { |
202 | self.state = Some(FrameDecoderState::new(source)?); |
203 | self.state.as_mut().unwrap() |
204 | } |
205 | }; |
206 | if let Some(dict_id) = state.frame.header.dictionary_id() { |
207 | let dict = self |
208 | .dicts |
209 | .get(&dict_id) |
210 | .ok_or(err::DictNotProvided { dict_id })?; |
211 | state.decoder_scratch.init_from_dict(dict); |
212 | state.using_dict = Some(dict_id); |
213 | } |
214 | Ok(()) |
215 | } |
216 | |
217 | /// Add a dict to the FrameDecoder that can be used when needed. The FrameDecoder uses the appropriate one dynamically |
218 | pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> { |
219 | self.dicts.insert(dict.id, dict); |
220 | Ok(()) |
221 | } |
222 | |
223 | pub fn force_dict(&mut self, dict_id: u32) -> Result<(), FrameDecoderError> { |
224 | use FrameDecoderError as err; |
225 | let Some(state) = self.state.as_mut() else { |
226 | return Err(err::NotYetInitialized); |
227 | }; |
228 | |
229 | let dict = self |
230 | .dicts |
231 | .get(&dict_id) |
232 | .ok_or(err::DictNotProvided { dict_id })?; |
233 | state.decoder_scratch.init_from_dict(dict); |
234 | state.using_dict = Some(dict_id); |
235 | |
236 | Ok(()) |
237 | } |
238 | |
239 | /// Returns how many bytes the frame contains after decompression |
240 | pub fn content_size(&self) -> u64 { |
241 | match &self.state { |
242 | None => 0, |
243 | Some(s) => s.frame.header.frame_content_size(), |
244 | } |
245 | } |
246 | |
247 | /// Returns the checksum that was read from the data. Only available after all bytes have been read. It is the last 4 bytes of a zstd-frame |
248 | pub fn get_checksum_from_data(&self) -> Option<u32> { |
249 | let state = match &self.state { |
250 | None => return None, |
251 | Some(s) => s, |
252 | }; |
253 | |
254 | state.check_sum |
255 | } |
256 | |
257 | /// Returns the checksum that was calculated while decoding. |
258 | /// Only a sensible value after all decoded bytes have been collected/read from the FrameDecoder |
259 | pub fn get_calculated_checksum(&self) -> Option<u32> { |
260 | let state = match &self.state { |
261 | None => return None, |
262 | Some(s) => s, |
263 | }; |
264 | let cksum_64bit = state.decoder_scratch.buffer.hash.finish(); |
265 | //truncate to lower 32bit because reasons... |
266 | Some(cksum_64bit as u32) |
267 | } |
268 | |
269 | /// Counter for how many bytes have been consumed while decoding the frame |
270 | pub fn bytes_read_from_source(&self) -> u64 { |
271 | let state = match &self.state { |
272 | None => return 0, |
273 | Some(s) => s, |
274 | }; |
275 | state.bytes_read_counter |
276 | } |
277 | |
278 | /// Whether the current frames last block has been decoded yet |
279 | /// If this returns true you can call the drain* functions to get all content |
280 | /// (the read() function will drain automatically if this returns true) |
281 | pub fn is_finished(&self) -> bool { |
282 | let state = match &self.state { |
283 | None => return true, |
284 | Some(s) => s, |
285 | }; |
286 | if state.frame.header.descriptor.content_checksum_flag() { |
287 | state.frame_finished && state.check_sum.is_some() |
288 | } else { |
289 | state.frame_finished |
290 | } |
291 | } |
292 | |
293 | /// Counter for how many blocks have already been decoded |
294 | pub fn blocks_decoded(&self) -> usize { |
295 | let state = match &self.state { |
296 | None => return 0, |
297 | Some(s) => s, |
298 | }; |
299 | state.block_counter |
300 | } |
301 | |
302 | /// Decodes blocks from a reader. It requires that the framedecoder has been initialized first. |
303 | /// The Strategy influences how many blocks will be decoded before the function returns |
304 | /// This is important if you want to manage memory consumption carefully. If you don't care |
305 | /// about that you can just choose the strategy "All" and have all blocks of the frame decoded into the buffer |
306 | pub fn decode_blocks( |
307 | &mut self, |
308 | mut source: impl Read, |
309 | strat: BlockDecodingStrategy, |
310 | ) -> Result<bool, FrameDecoderError> { |
311 | use FrameDecoderError as err; |
312 | let state = self.state.as_mut().ok_or(err::NotYetInitialized)?; |
313 | |
314 | let mut block_dec = decoding::block_decoder::new(); |
315 | |
316 | let buffer_size_before = state.decoder_scratch.buffer.len(); |
317 | let block_counter_before = state.block_counter; |
318 | loop { |
319 | vprintln!("################" ); |
320 | vprintln!("Next Block: {}" , state.block_counter); |
321 | vprintln!("################" ); |
322 | let (block_header, block_header_size) = block_dec |
323 | .read_block_header(&mut source) |
324 | .map_err(err::FailedToReadBlockHeader)?; |
325 | state.bytes_read_counter += u64::from(block_header_size); |
326 | |
327 | vprintln!(); |
328 | vprintln!( |
329 | "Found {} block with size: {}, which will be of size: {}" , |
330 | block_header.block_type, |
331 | block_header.content_size, |
332 | block_header.decompressed_size |
333 | ); |
334 | |
335 | let bytes_read_in_block_body = block_dec |
336 | .decode_block_content(&block_header, &mut state.decoder_scratch, &mut source) |
337 | .map_err(err::FailedToReadBlockBody)?; |
338 | state.bytes_read_counter += bytes_read_in_block_body; |
339 | |
340 | state.block_counter += 1; |
341 | |
342 | vprintln!("Output: {}" , state.decoder_scratch.buffer.len()); |
343 | |
344 | if block_header.last_block { |
345 | state.frame_finished = true; |
346 | if state.frame.header.descriptor.content_checksum_flag() { |
347 | let mut chksum = [0u8; 4]; |
348 | source |
349 | .read_exact(&mut chksum) |
350 | .map_err(err::FailedToReadChecksum)?; |
351 | state.bytes_read_counter += 4; |
352 | let chksum = u32::from_le_bytes(chksum); |
353 | state.check_sum = Some(chksum); |
354 | } |
355 | break; |
356 | } |
357 | |
358 | match strat { |
359 | BlockDecodingStrategy::All => { /* keep going */ } |
360 | BlockDecodingStrategy::UptoBlocks(n) => { |
361 | if state.block_counter - block_counter_before >= n { |
362 | break; |
363 | } |
364 | } |
365 | BlockDecodingStrategy::UptoBytes(n) => { |
366 | if state.decoder_scratch.buffer.len() - buffer_size_before >= n { |
367 | break; |
368 | } |
369 | } |
370 | } |
371 | } |
372 | |
373 | Ok(state.frame_finished) |
374 | } |
375 | |
376 | /// Collect bytes and retain window_size bytes while decoding is still going on. |
377 | /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes |
378 | pub fn collect(&mut self) -> Option<Vec<u8>> { |
379 | let finished = self.is_finished(); |
380 | let state = self.state.as_mut()?; |
381 | if finished { |
382 | Some(state.decoder_scratch.buffer.drain()) |
383 | } else { |
384 | state.decoder_scratch.buffer.drain_to_window_size() |
385 | } |
386 | } |
387 | |
388 | /// Collect bytes and retain window_size bytes while decoding is still going on. |
389 | /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes |
390 | pub fn collect_to_writer(&mut self, w: impl Write) -> Result<usize, Error> { |
391 | let finished = self.is_finished(); |
392 | let state = match &mut self.state { |
393 | None => return Ok(0), |
394 | Some(s) => s, |
395 | }; |
396 | if finished { |
397 | state.decoder_scratch.buffer.drain_to_writer(w) |
398 | } else { |
399 | state.decoder_scratch.buffer.drain_to_window_size_writer(w) |
400 | } |
401 | } |
402 | |
403 | /// How many bytes can currently be collected from the decodebuffer, while decoding is going on this will be lower than the actual decodbuffer size |
404 | /// because window_size bytes need to be retained for decoding. |
405 | /// After decoding of the frame (is_finished() == true) has finished it will report all remaining bytes |
406 | pub fn can_collect(&self) -> usize { |
407 | let finished = self.is_finished(); |
408 | let state = match &self.state { |
409 | None => return 0, |
410 | Some(s) => s, |
411 | }; |
412 | if finished { |
413 | state.decoder_scratch.buffer.can_drain() |
414 | } else { |
415 | state |
416 | .decoder_scratch |
417 | .buffer |
418 | .can_drain_to_window_size() |
419 | .unwrap_or(0) |
420 | } |
421 | } |
422 | |
423 | /// Decodes as many blocks as possible from the source slice and reads from the decodebuffer into the target slice |
424 | /// The source slice may contain only parts of a frame but must contain at least one full block to make progress |
425 | /// |
426 | /// By all means use decode_blocks if you have a io.Reader available. This is just for compatibility with other decompressors |
427 | /// which try to serve an old-style c api |
428 | /// |
429 | /// Returns (read, written), if read == 0 then the source did not contain a full block and further calls with the same |
430 | /// input will not make any progress! |
431 | /// |
432 | /// Note that no kind of block can be bigger than 128kb. |
433 | /// So to be safe use at least 128*1024 (max block content size) + 3 (block_header size) + 18 (max frame_header size) bytes as your source buffer |
434 | /// |
435 | /// You may call this function with an empty source after all bytes have been decoded. This is equivalent to just call decoder.read(&mut target) |
436 | pub fn decode_from_to( |
437 | &mut self, |
438 | source: &[u8], |
439 | target: &mut [u8], |
440 | ) -> Result<(usize, usize), FrameDecoderError> { |
441 | use FrameDecoderError as err; |
442 | let bytes_read_at_start = match &self.state { |
443 | Some(s) => s.bytes_read_counter, |
444 | None => 0, |
445 | }; |
446 | |
447 | if !self.is_finished() || self.state.is_none() { |
448 | let mut mt_source = source; |
449 | |
450 | if self.state.is_none() { |
451 | self.init(&mut mt_source)?; |
452 | } |
453 | |
454 | //pseudo block to scope "state" so we can borrow self again after the block |
455 | { |
456 | let state = match &mut self.state { |
457 | Some(s) => s, |
458 | None => panic!("Bug in library" ), |
459 | }; |
460 | let mut block_dec = decoding::block_decoder::new(); |
461 | |
462 | if state.frame.header.descriptor.content_checksum_flag() |
463 | && state.frame_finished |
464 | && state.check_sum.is_none() |
465 | { |
466 | //this block is needed if the checksum were the only 4 bytes that were not included in the last decode_from_to call for a frame |
467 | if mt_source.len() >= 4 { |
468 | let chksum = mt_source[..4].try_into().expect("optimized away" ); |
469 | state.bytes_read_counter += 4; |
470 | let chksum = u32::from_le_bytes(chksum); |
471 | state.check_sum = Some(chksum); |
472 | } |
473 | return Ok((4, 0)); |
474 | } |
475 | |
476 | loop { |
477 | //check if there are enough bytes for the next header |
478 | if mt_source.len() < 3 { |
479 | break; |
480 | } |
481 | let (block_header, block_header_size) = block_dec |
482 | .read_block_header(&mut mt_source) |
483 | .map_err(err::FailedToReadBlockHeader)?; |
484 | |
485 | // check the needed size for the block before updating counters. |
486 | // If not enough bytes are in the source, the header will have to be read again, so act like we never read it in the first place |
487 | if mt_source.len() < block_header.content_size as usize { |
488 | break; |
489 | } |
490 | state.bytes_read_counter += u64::from(block_header_size); |
491 | |
492 | let bytes_read_in_block_body = block_dec |
493 | .decode_block_content( |
494 | &block_header, |
495 | &mut state.decoder_scratch, |
496 | &mut mt_source, |
497 | ) |
498 | .map_err(err::FailedToReadBlockBody)?; |
499 | state.bytes_read_counter += bytes_read_in_block_body; |
500 | state.block_counter += 1; |
501 | |
502 | if block_header.last_block { |
503 | state.frame_finished = true; |
504 | if state.frame.header.descriptor.content_checksum_flag() { |
505 | //if there are enough bytes handle this here. Else the block at the start of this function will handle it at the next call |
506 | if mt_source.len() >= 4 { |
507 | let chksum = mt_source[..4].try_into().expect("optimized away" ); |
508 | state.bytes_read_counter += 4; |
509 | let chksum = u32::from_le_bytes(chksum); |
510 | state.check_sum = Some(chksum); |
511 | } |
512 | } |
513 | break; |
514 | } |
515 | } |
516 | } |
517 | } |
518 | |
519 | let result_len = self.read(target).map_err(err::FailedToDrainDecodebuffer)?; |
520 | let bytes_read_at_end = match &mut self.state { |
521 | Some(s) => s.bytes_read_counter, |
522 | None => panic!("Bug in library" ), |
523 | }; |
524 | let read_len = bytes_read_at_end - bytes_read_at_start; |
525 | Ok((read_len as usize, result_len)) |
526 | } |
527 | } |
528 | |
529 | /// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished |
530 | /// this will retain window_size bytes, else it will drain it completely |
531 | impl Read for FrameDecoder { |
532 | fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> { |
533 | let state: &mut FrameDecoderState = match &mut self.state { |
534 | None => return Ok(0), |
535 | Some(s: &mut FrameDecoderState) => s, |
536 | }; |
537 | if state.frame_finished { |
538 | state.decoder_scratch.buffer.read_all(target) |
539 | } else { |
540 | state.decoder_scratch.buffer.read(buf:target) |
541 | } |
542 | } |
543 | } |
544 | |