1use crate::io::{Error, Read, Write};
2use alloc::vec::Vec;
3use core::hash::Hasher;
4
5use twox_hash::XxHash64;
6
7use super::ringbuffer::RingBuffer;
8
9pub struct Decodebuffer {
10 buffer: RingBuffer,
11 pub dict_content: Vec<u8>,
12
13 pub window_size: usize,
14 total_output_counter: u64,
15 pub hash: XxHash64,
16}
17
18#[derive(Debug, derive_more::Display)]
19#[cfg_attr(feature = "std", derive(derive_more::Error))]
20#[non_exhaustive]
21pub enum DecodebufferError {
22 #[display(fmt = "Need {need} bytes from the dictionary but it is only {got} bytes long")]
23 NotEnoughBytesInDictionary { got: usize, need: usize },
24 #[display(fmt = "offset: {offset} bigger than buffer: {buf_len}")]
25 OffsetTooBig { offset: usize, buf_len: usize },
26}
27
28impl Read for Decodebuffer {
29 fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
30 let max_amount: usize = self.can_drain_to_window_size().unwrap_or(default:0);
31 let amount: usize = max_amount.min(target.len());
32
33 let mut written: usize = 0;
34 self.drain_to(amount, |buf: &[u8]| {
35 target[written..][..buf.len()].copy_from_slice(src:buf);
36 written += buf.len();
37 (buf.len(), Ok(()))
38 })?;
39 Ok(amount)
40 }
41}
42
43impl Decodebuffer {
44 pub fn new(window_size: usize) -> Decodebuffer {
45 Decodebuffer {
46 buffer: RingBuffer::new(),
47 dict_content: Vec::new(),
48 window_size,
49 total_output_counter: 0,
50 hash: XxHash64::with_seed(0),
51 }
52 }
53
54 pub fn reset(&mut self, window_size: usize) {
55 self.window_size = window_size;
56 self.buffer.clear();
57 self.buffer.reserve(self.window_size);
58 self.dict_content.clear();
59 self.total_output_counter = 0;
60 self.hash = XxHash64::with_seed(0);
61 }
62
63 pub fn len(&self) -> usize {
64 self.buffer.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.buffer.is_empty()
69 }
70
71 pub fn push(&mut self, data: &[u8]) {
72 self.buffer.extend(data);
73 self.total_output_counter += data.len() as u64;
74 }
75
76 pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodebufferError> {
77 if offset > self.buffer.len() {
78 if self.total_output_counter <= self.window_size as u64 {
79 // at least part of that repeat is from the dictionary content
80 let bytes_from_dict = offset - self.buffer.len();
81
82 if bytes_from_dict > self.dict_content.len() {
83 return Err(DecodebufferError::NotEnoughBytesInDictionary {
84 got: self.dict_content.len(),
85 need: bytes_from_dict,
86 });
87 }
88
89 if bytes_from_dict < match_length {
90 let dict_slice =
91 &self.dict_content[self.dict_content.len() - bytes_from_dict..];
92 self.buffer.extend(dict_slice);
93
94 self.total_output_counter += bytes_from_dict as u64;
95 return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
96 } else {
97 let low = self.dict_content.len() - bytes_from_dict;
98 let high = low + match_length;
99 let dict_slice = &self.dict_content[low..high];
100 self.buffer.extend(dict_slice);
101 }
102 } else {
103 return Err(DecodebufferError::OffsetTooBig {
104 offset,
105 buf_len: self.buffer.len(),
106 });
107 }
108 } else {
109 let buf_len = self.buffer.len();
110 let start_idx = buf_len - offset;
111 let end_idx = start_idx + match_length;
112
113 self.buffer.reserve(match_length);
114 if end_idx > buf_len {
115 // We need to copy in chunks.
116 // We have at max offset bytes in one chunk, the last one can be smaller
117 let mut start_idx = start_idx;
118 let mut copied_counter_left = match_length;
119 // TODO this can be optimized further I think.
120 // Each time we copy a chunk we have a repetiton of length 'offset', so we can copy offset * iteration many bytes from start_idx
121 while copied_counter_left > 0 {
122 let chunksize = usize::min(offset, copied_counter_left);
123
124 // SAFETY: Requirements checked:
125 // 1. start_idx + chunksize must be <= self.buffer.len()
126 // We know that:
127 // 1. start_idx starts at buffer.len() - offset
128 // 2. chunksize <= offset (== offset for each iteration but the last, and match_length modulo offset in the last iteration)
129 // 3. the buffer grows by offset many bytes each iteration but the last
130 // 4. start_idx is increased by the same amount as the buffer grows each iteration
131 //
132 // Thus follows: start_idx + chunksize == self.buffer.len() in each iteration but the last, where match_length modulo offset == chunksize < offset
133 // Meaning: start_idx + chunksize <= self.buffer.len()
134 //
135 // 2. explicitly reserved enough memory for the whole match_length
136 unsafe {
137 self.buffer
138 .extend_from_within_unchecked(start_idx, chunksize)
139 };
140 copied_counter_left -= chunksize;
141 start_idx += chunksize;
142 }
143 } else {
144 // can just copy parts of the existing buffer
145 // SAFETY: Requirements checked:
146 // 1. start_idx + match_length must be <= self.buffer.len()
147 // We know that:
148 // 1. start_idx = self.buffer.len() - offset
149 // 2. end_idx = start_idx + match_length
150 // 3. end_idx <= self.buffer.len()
151 // Thus follows: start_idx + match_length <= self.buffer.len()
152 //
153 // 2. explicitly reserved enough memory for the whole match_length
154 unsafe {
155 self.buffer
156 .extend_from_within_unchecked(start_idx, match_length)
157 };
158 }
159
160 self.total_output_counter += match_length as u64;
161 }
162
163 Ok(())
164 }
165
166 // Check if and how many bytes can currently be drawn from the buffer
167 pub fn can_drain_to_window_size(&self) -> Option<usize> {
168 if self.buffer.len() > self.window_size {
169 Some(self.buffer.len() - self.window_size)
170 } else {
171 None
172 }
173 }
174
175 //How many bytes can be drained if the window_size does not have to be maintained
176 pub fn can_drain(&self) -> usize {
177 self.buffer.len()
178 }
179
180 //drain as much as possible while retaining enough so that decoding si still possible with the required window_size
181 //At best call only if can_drain_to_window_size reports a 'high' number of bytes to reduce allocations
182 pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
183 //TODO investigate if it is possible to return the std::vec::Drain iterator directly without collecting here
184 match self.can_drain_to_window_size() {
185 None => None,
186 Some(can_drain) => {
187 let mut vec = Vec::with_capacity(can_drain);
188 self.drain_to(can_drain, |buf| {
189 vec.extend_from_slice(buf);
190 (buf.len(), Ok(()))
191 })
192 .ok()?;
193 Some(vec)
194 }
195 }
196 }
197
198 pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
199 match self.can_drain_to_window_size() {
200 None => Ok(0),
201 Some(can_drain) => {
202 self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf))?;
203 Ok(can_drain)
204 }
205 }
206 }
207
208 //drain the buffer completely
209 pub fn drain(&mut self) -> Vec<u8> {
210 let (slice1, slice2) = self.buffer.as_slices();
211 self.hash.write(slice1);
212 self.hash.write(slice2);
213
214 let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
215 vec.extend_from_slice(slice1);
216 vec.extend_from_slice(slice2);
217 self.buffer.clear();
218 vec
219 }
220
221 pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
222 let len = self.buffer.len();
223 self.drain_to(len, |buf| write_all_bytes(&mut sink, buf))?;
224
225 Ok(len)
226 }
227
228 pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
229 let amount = self.buffer.len().min(target.len());
230
231 let mut written = 0;
232 self.drain_to(amount, |buf| {
233 target[written..][..buf.len()].copy_from_slice(buf);
234 written += buf.len();
235 (buf.len(), Ok(()))
236 })?;
237 Ok(amount)
238 }
239
240 /// Semantics of write_bytes:
241 /// Should dump as many of the provided bytes as possible to whatever sink until no bytes are left or an error is encountered
242 /// Return how many bytes have actually been dumped to the sink.
243 fn drain_to(
244 &mut self,
245 amount: usize,
246 mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
247 ) -> Result<(), Error> {
248 if amount == 0 {
249 return Ok(());
250 }
251
252 struct DrainGuard<'a> {
253 buffer: &'a mut RingBuffer,
254 amount: usize,
255 }
256
257 impl<'a> Drop for DrainGuard<'a> {
258 fn drop(&mut self) {
259 if self.amount != 0 {
260 self.buffer.drop_first_n(self.amount);
261 }
262 }
263 }
264
265 let mut drain_guard = DrainGuard {
266 buffer: &mut self.buffer,
267 amount: 0,
268 };
269
270 let (slice1, slice2) = drain_guard.buffer.as_slices();
271 let n1 = slice1.len().min(amount);
272 let n2 = slice2.len().min(amount - n1);
273
274 if n1 != 0 {
275 let (written1, res1) = write_bytes(&slice1[..n1]);
276 self.hash.write(&slice1[..written1]);
277 drain_guard.amount += written1;
278
279 // Apparently this is what clippy thinks is the best way of expressing this
280 res1?;
281
282 // Only if the first call to write_bytes was not a partial write we can continue with slice2
283 // Partial writes SHOULD never happen without res1 being an error, but lets just protect against it anyways.
284 if written1 == n1 && n2 != 0 {
285 let (written2, res2) = write_bytes(&slice2[..n2]);
286 self.hash.write(&slice2[..written2]);
287 drain_guard.amount += written2;
288
289 // Apparently this is what clippy thinks is the best way of expressing this
290 res2?;
291 }
292 }
293
294 // Make sure we don't accidentally drop `DrainGuard` earlier.
295 drop(drain_guard);
296
297 Ok(())
298 }
299}
300
301/// Like Write::write_all but returns partial write length even on error
302fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
303 let mut written: usize = 0;
304 while written < buf.len() {
305 match sink.write(&buf[written..]) {
306 Ok(w: usize) => written += w,
307 Err(e: Error) => return (written, Err(e)),
308 }
309 }
310 (written, Ok(()))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::Decodebuffer;
316 use crate::io::{Error, ErrorKind, Write};
317
318 extern crate std;
319 use alloc::vec;
320 use alloc::vec::Vec;
321
322 #[test]
323 fn short_writer() {
324 struct ShortWriter {
325 buf: Vec<u8>,
326 write_len: usize,
327 }
328
329 impl Write for ShortWriter {
330 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
331 if buf.len() > self.write_len {
332 self.buf.extend_from_slice(&buf[..self.write_len]);
333 Ok(self.write_len)
334 } else {
335 self.buf.extend_from_slice(buf);
336 Ok(buf.len())
337 }
338 }
339
340 fn flush(&mut self) -> std::result::Result<(), Error> {
341 Ok(())
342 }
343 }
344
345 let mut short_writer = ShortWriter {
346 buf: vec![],
347 write_len: 10,
348 };
349
350 let mut decode_buf = Decodebuffer::new(100);
351 decode_buf.push(b"0123456789");
352 decode_buf.repeat(10, 90).unwrap();
353 let repeats = 1000;
354 for _ in 0..repeats {
355 assert_eq!(decode_buf.len(), 100);
356 decode_buf.repeat(10, 50).unwrap();
357 assert_eq!(decode_buf.len(), 150);
358 decode_buf
359 .drain_to_window_size_writer(&mut short_writer)
360 .unwrap();
361 assert_eq!(decode_buf.len(), 100);
362 }
363
364 assert_eq!(short_writer.buf.len(), repeats * 50);
365 decode_buf.drain_to_writer(&mut short_writer).unwrap();
366 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
367 }
368
369 #[test]
370 fn wouldblock_writer() {
371 struct WouldblockWriter {
372 buf: Vec<u8>,
373 last_blocked: usize,
374 block_every: usize,
375 }
376
377 impl Write for WouldblockWriter {
378 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
379 if self.last_blocked < self.block_every {
380 self.buf.extend_from_slice(buf);
381 self.last_blocked += 1;
382 Ok(buf.len())
383 } else {
384 self.last_blocked = 0;
385 Err(Error::from(ErrorKind::WouldBlock))
386 }
387 }
388
389 fn flush(&mut self) -> std::result::Result<(), Error> {
390 Ok(())
391 }
392 }
393
394 let mut short_writer = WouldblockWriter {
395 buf: vec![],
396 last_blocked: 0,
397 block_every: 5,
398 };
399
400 let mut decode_buf = Decodebuffer::new(100);
401 decode_buf.push(b"0123456789");
402 decode_buf.repeat(10, 90).unwrap();
403 let repeats = 1000;
404 for _ in 0..repeats {
405 assert_eq!(decode_buf.len(), 100);
406 decode_buf.repeat(10, 50).unwrap();
407 assert_eq!(decode_buf.len(), 150);
408 loop {
409 match decode_buf.drain_to_window_size_writer(&mut short_writer) {
410 Ok(written) => {
411 if written == 0 {
412 break;
413 }
414 }
415 Err(e) => {
416 if e.kind() == ErrorKind::WouldBlock {
417 continue;
418 } else {
419 panic!("Unexpected error {:?}", e);
420 }
421 }
422 }
423 }
424 assert_eq!(decode_buf.len(), 100);
425 }
426
427 assert_eq!(short_writer.buf.len(), repeats * 50);
428 loop {
429 match decode_buf.drain_to_writer(&mut short_writer) {
430 Ok(written) => {
431 if written == 0 {
432 break;
433 }
434 }
435 Err(e) => {
436 if e.kind() == ErrorKind::WouldBlock {
437 continue;
438 } else {
439 panic!("Unexpected error {:?}", e);
440 }
441 }
442 }
443 }
444 assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
445 }
446}
447