1 | use std::ptr; |
2 | |
3 | use crate::bytes; |
4 | use crate::error::{Error, Result}; |
5 | use crate::tag; |
6 | use crate::MAX_INPUT_SIZE; |
7 | |
8 | /// A lookup table for quickly computing the various attributes derived from a |
9 | /// tag byte. |
10 | const TAG_LOOKUP_TABLE: TagLookupTable = TagLookupTable(tag::TAG_LOOKUP_TABLE); |
11 | |
12 | /// `WORD_MASK` is a map from the size of an integer in bytes to its |
13 | /// corresponding on a 32 bit integer. This is used when we need to read an |
14 | /// integer and we know there are at least 4 bytes to read from a buffer. In |
15 | /// this case, we can read a 32 bit little endian integer and mask out only the |
16 | /// bits we need. This in particular saves a branch. |
17 | const WORD_MASK: [usize; 5] = [0, 0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF]; |
18 | |
19 | /// Returns the decompressed size (in bytes) of the compressed bytes given. |
20 | /// |
21 | /// `input` must be a sequence of bytes returned by a conforming Snappy |
22 | /// compressor. |
23 | /// |
24 | /// # Errors |
25 | /// |
26 | /// This function returns an error in the following circumstances: |
27 | /// |
28 | /// * An invalid Snappy header was seen. |
29 | /// * The total space required for decompression exceeds `2^32 - 1`. |
30 | pub fn decompress_len(input: &[u8]) -> Result<usize> { |
31 | if input.is_empty() { |
32 | return Ok(0); |
33 | } |
34 | Ok(Header::read(input)?.decompress_len) |
35 | } |
36 | |
37 | /// Decoder is a raw decoder for decompressing bytes in the Snappy format. |
38 | /// |
39 | /// This decoder does not use the Snappy frame format and simply decompresses |
40 | /// the given bytes as if it were returned from `Encoder`. |
41 | /// |
42 | /// Unless you explicitly need the low-level control, you should use |
43 | /// [`read::FrameDecoder`](../read/struct.FrameDecoder.html) |
44 | /// instead, which decompresses the Snappy frame format. |
45 | #[derive (Clone, Debug, Default)] |
46 | pub struct Decoder { |
47 | // Place holder for potential future fields. |
48 | _dummy: (), |
49 | } |
50 | |
51 | impl Decoder { |
52 | /// Return a new decoder that can be used for decompressing bytes. |
53 | pub fn new() -> Decoder { |
54 | Decoder { _dummy: () } |
55 | } |
56 | |
57 | /// Decompresses all bytes in `input` into `output`. |
58 | /// |
59 | /// `input` must be a sequence of bytes returned by a conforming Snappy |
60 | /// compressor. |
61 | /// |
62 | /// The size of `output` must be large enough to hold all decompressed |
63 | /// bytes from the `input`. The size required can be queried with the |
64 | /// `decompress_len` function. |
65 | /// |
66 | /// On success, this returns the number of bytes written to `output`. |
67 | /// |
68 | /// # Errors |
69 | /// |
70 | /// This method returns an error in the following circumstances: |
71 | /// |
72 | /// * Invalid compressed Snappy data was seen. |
73 | /// * The total space required for decompression exceeds `2^32 - 1`. |
74 | /// * `output` has length less than `decompress_len(input)`. |
75 | pub fn decompress( |
76 | &mut self, |
77 | input: &[u8], |
78 | output: &mut [u8], |
79 | ) -> Result<usize> { |
80 | if input.is_empty() { |
81 | return Err(Error::Empty); |
82 | } |
83 | let hdr = Header::read(input)?; |
84 | if hdr.decompress_len > output.len() { |
85 | return Err(Error::BufferTooSmall { |
86 | given: output.len() as u64, |
87 | min: hdr.decompress_len as u64, |
88 | }); |
89 | } |
90 | let dst = &mut output[..hdr.decompress_len]; |
91 | let mut dec = |
92 | Decompress { src: &input[hdr.len..], s: 0, dst: dst, d: 0 }; |
93 | dec.decompress()?; |
94 | Ok(dec.dst.len()) |
95 | } |
96 | |
97 | /// Decompresses all bytes in `input` into a freshly allocated `Vec`. |
98 | /// |
99 | /// This is just like the `decompress` method, except it allocates a `Vec` |
100 | /// with the right size for you. (This is intended to be a convenience |
101 | /// method.) |
102 | /// |
103 | /// This method returns an error under the same circumstances that |
104 | /// `decompress` does. |
105 | pub fn decompress_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> { |
106 | let mut buf = vec![0; decompress_len(input)?]; |
107 | let n = self.decompress(input, &mut buf)?; |
108 | buf.truncate(n); |
109 | Ok(buf) |
110 | } |
111 | } |
112 | |
113 | /// Decompress is the state of the Snappy compressor. |
114 | struct Decompress<'s, 'd> { |
115 | /// The original compressed bytes not including the header. |
116 | src: &'s [u8], |
117 | /// The current position in the compressed bytes. |
118 | s: usize, |
119 | /// The output buffer to write the decompressed bytes. |
120 | dst: &'d mut [u8], |
121 | /// The current position in the decompressed buffer. |
122 | d: usize, |
123 | } |
124 | |
125 | impl<'s, 'd> Decompress<'s, 'd> { |
126 | /// Decompresses snappy compressed bytes in `src` to `dst`. |
127 | /// |
128 | /// This assumes that the header has already been read and that `dst` is |
129 | /// big enough to store all decompressed bytes. |
130 | fn decompress(&mut self) -> Result<()> { |
131 | while self.s < self.src.len() { |
132 | let byte = self.src[self.s]; |
133 | self.s += 1; |
134 | if byte & 0b000000_11 == 0 { |
135 | let len = (byte >> 2) as usize + 1; |
136 | self.read_literal(len)?; |
137 | } else { |
138 | self.read_copy(byte)?; |
139 | } |
140 | } |
141 | if self.d != self.dst.len() { |
142 | return Err(Error::HeaderMismatch { |
143 | expected_len: self.dst.len() as u64, |
144 | got_len: self.d as u64, |
145 | }); |
146 | } |
147 | Ok(()) |
148 | } |
149 | |
150 | /// Decompresses a literal from `src` starting at `s` to `dst` starting at |
151 | /// `d` and returns the updated values of `s` and `d`. `s` should point to |
152 | /// the byte immediately proceding the literal tag byte. |
153 | /// |
154 | /// `len` is the length of the literal if it's <=60. Otherwise, it's the |
155 | /// length tag, indicating the number of bytes needed to read a little |
156 | /// endian integer at `src[s..]`. i.e., `61 => 1 byte`, `62 => 2 bytes`, |
157 | /// `63 => 3 bytes` and `64 => 4 bytes`. |
158 | /// |
159 | /// `len` must be <=64. |
160 | #[inline (always)] |
161 | fn read_literal(&mut self, len: usize) -> Result<()> { |
162 | debug_assert!(len <= 64); |
163 | let mut len = len as u64; |
164 | // As an optimization for the common case, if the literal length is |
165 | // <=16 and we have enough room in both `src` and `dst`, copy the |
166 | // literal using unaligned loads and stores. |
167 | // |
168 | // We pick 16 bytes with the hope that it optimizes down to a 128 bit |
169 | // load/store. |
170 | if len <= 16 |
171 | && self.s + 16 <= self.src.len() |
172 | && self.d + 16 <= self.dst.len() |
173 | { |
174 | unsafe { |
175 | // SAFETY: We know both src and dst have at least 16 bytes of |
176 | // wiggle room after s/d, even if `len` is <16, so the copy is |
177 | // safe. |
178 | let srcp = self.src.as_ptr().add(self.s); |
179 | let dstp = self.dst.as_mut_ptr().add(self.d); |
180 | // Hopefully uses SIMD registers for 128 bit load/store. |
181 | ptr::copy_nonoverlapping(srcp, dstp, 16); |
182 | } |
183 | self.d += len as usize; |
184 | self.s += len as usize; |
185 | return Ok(()); |
186 | } |
187 | // When the length is bigger than 60, it indicates that we need to read |
188 | // an additional 1-4 bytes to get the real length of the literal. |
189 | if len >= 61 { |
190 | // If there aren't at least 4 bytes left to read then we know this |
191 | // is corrupt because the literal must have length >=61. |
192 | if self.s as u64 + 4 > self.src.len() as u64 { |
193 | return Err(Error::Literal { |
194 | len: 4, |
195 | src_len: (self.src.len() - self.s) as u64, |
196 | dst_len: (self.dst.len() - self.d) as u64, |
197 | }); |
198 | } |
199 | // Since we know there are 4 bytes left to read, read a 32 bit LE |
200 | // integer and mask away the bits we don't need. |
201 | let byte_count = len as usize - 60; |
202 | len = bytes::read_u32_le(&self.src[self.s..]) as u64; |
203 | len = (len & (WORD_MASK[byte_count] as u64)) + 1; |
204 | self.s += byte_count; |
205 | } |
206 | // If there's not enough buffer left to load or store this literal, |
207 | // then the input is corrupt. |
208 | // if self.s + len > self.src.len() || self.d + len > self.dst.len() { |
209 | if ((self.src.len() - self.s) as u64) < len |
210 | || ((self.dst.len() - self.d) as u64) < len |
211 | { |
212 | return Err(Error::Literal { |
213 | len: len, |
214 | src_len: (self.src.len() - self.s) as u64, |
215 | dst_len: (self.dst.len() - self.d) as u64, |
216 | }); |
217 | } |
218 | unsafe { |
219 | // SAFETY: We've already checked the bounds, so we know this copy |
220 | // is correct. |
221 | let srcp = self.src.as_ptr().add(self.s); |
222 | let dstp = self.dst.as_mut_ptr().add(self.d); |
223 | ptr::copy_nonoverlapping(srcp, dstp, len as usize); |
224 | } |
225 | self.s += len as usize; |
226 | self.d += len as usize; |
227 | Ok(()) |
228 | } |
229 | |
230 | /// Reads a copy from `src` and writes the decompressed bytes to `dst`. `s` |
231 | /// should point to the byte immediately proceding the copy tag byte. |
232 | #[inline (always)] |
233 | fn read_copy(&mut self, tag_byte: u8) -> Result<()> { |
234 | // Find the copy offset and len, then advance the input past the copy. |
235 | // The rest of this function deals with reading/writing to output only. |
236 | let entry = TAG_LOOKUP_TABLE.entry(tag_byte); |
237 | let offset = entry.offset(self.src, self.s)?; |
238 | let len = entry.len(); |
239 | self.s += entry.num_tag_bytes(); |
240 | |
241 | // What we really care about here is whether `d == 0` or `d < offset`. |
242 | // To save an extra branch, use `d < offset - 1` instead. If `d` is |
243 | // `0`, then `offset.wrapping_sub(1)` will be usize::MAX which is also |
244 | // the max value of `d`. |
245 | if self.d <= offset.wrapping_sub(1) { |
246 | return Err(Error::Offset { |
247 | offset: offset as u64, |
248 | dst_pos: self.d as u64, |
249 | }); |
250 | } |
251 | // When all is said and done, dst is advanced to end. |
252 | let end = self.d + len; |
253 | // When the copy is small and the offset is at least 8 bytes away from |
254 | // `d`, then we can decompress the copy with two 64 bit unaligned |
255 | // loads/stores. |
256 | if offset >= 8 && len <= 16 && self.d + 16 <= self.dst.len() { |
257 | unsafe { |
258 | // SAFETY: We know dstp points to at least 16 bytes of memory |
259 | // from the condition above, and we also know that dstp is |
260 | // preceded by at least `offset` bytes from the `d <= offset` |
261 | // check above. |
262 | // |
263 | // We also know that dstp and dstp-8 do not overlap from the |
264 | // check above, justifying the use of copy_nonoverlapping. |
265 | let dstp = self.dst.as_mut_ptr().add(self.d); |
266 | let srcp = dstp.sub(offset); |
267 | // We can't do a single 16 byte load/store because src/dst may |
268 | // overlap with each other. Namely, the second copy here may |
269 | // copy bytes written in the first copy! |
270 | ptr::copy_nonoverlapping(srcp, dstp, 8); |
271 | ptr::copy_nonoverlapping(srcp.add(8), dstp.add(8), 8); |
272 | } |
273 | // If we have some wiggle room, try to decompress the copy 16 bytes |
274 | // at a time with 128 bit unaligned loads/stores. Remember, we can't |
275 | // just do a memcpy because decompressing copies may require copying |
276 | // overlapping memory. |
277 | // |
278 | // We need the extra wiggle room to make effective use of 128 bit |
279 | // loads/stores. Even if the store ends up copying more data than we |
280 | // need, we're careful to advance `d` by the correct amount at the end. |
281 | } else if end + 24 <= self.dst.len() { |
282 | unsafe { |
283 | // SAFETY: We know that dstp is preceded by at least `offset` |
284 | // bytes from the `d <= offset` check above. |
285 | // |
286 | // We don't know whether dstp overlaps with srcp, so we start |
287 | // by copying from srcp to dstp until they no longer overlap. |
288 | // The worst case is when dstp-src = 3 and copy length = 1. The |
289 | // first loop will issue these copy operations before stopping: |
290 | // |
291 | // [-1, 14] -> [0, 15] |
292 | // [-1, 14] -> [3, 18] |
293 | // [-1, 14] -> [9, 24] |
294 | // |
295 | // But the copy had length 1, so it was only supposed to write |
296 | // to [0, 0]. But the last copy wrote to [9, 24], which is 24 |
297 | // extra bytes in dst *beyond* the end of the copy, which is |
298 | // guaranteed by the conditional above. |
299 | let mut dstp = self.dst.as_mut_ptr().add(self.d); |
300 | let mut srcp = dstp.sub(offset); |
301 | loop { |
302 | debug_assert!(dstp >= srcp); |
303 | let diff = (dstp as usize) - (srcp as usize); |
304 | if diff >= 16 { |
305 | break; |
306 | } |
307 | // srcp and dstp can overlap, so use ptr::copy. |
308 | debug_assert!(self.d + 16 <= self.dst.len()); |
309 | ptr::copy(srcp, dstp, 16); |
310 | self.d += diff as usize; |
311 | dstp = dstp.add(diff); |
312 | } |
313 | while self.d < end { |
314 | ptr::copy_nonoverlapping(srcp, dstp, 16); |
315 | srcp = srcp.add(16); |
316 | dstp = dstp.add(16); |
317 | self.d += 16; |
318 | } |
319 | // At this point, `d` is likely wrong. We correct it before |
320 | // returning. It's correct value is `end`. |
321 | } |
322 | } else { |
323 | if end > self.dst.len() { |
324 | return Err(Error::CopyWrite { |
325 | len: len as u64, |
326 | dst_len: (self.dst.len() - self.d) as u64, |
327 | }); |
328 | } |
329 | // Finally, the slow byte-by-byte case, which should only be used |
330 | // for the last few bytes of decompression. |
331 | while self.d != end { |
332 | self.dst[self.d] = self.dst[self.d - offset]; |
333 | self.d += 1; |
334 | } |
335 | } |
336 | self.d = end; |
337 | Ok(()) |
338 | } |
339 | } |
340 | |
341 | /// Header represents the single varint that starts every Snappy compressed |
342 | /// block. |
343 | #[derive (Debug)] |
344 | struct Header { |
345 | /// The length of the header in bytes (i.e., the varint). |
346 | len: usize, |
347 | /// The length of the original decompressed input in bytes. |
348 | decompress_len: usize, |
349 | } |
350 | |
351 | impl Header { |
352 | /// Reads the varint header from the given input. |
353 | /// |
354 | /// If there was a problem reading the header then an error is returned. |
355 | /// If a header is returned then it is guaranteed to be valid. |
356 | #[inline (always)] |
357 | fn read(input: &[u8]) -> Result<Header> { |
358 | let (decompress_len: u64, header_len: usize) = bytes::read_varu64(data:input); |
359 | if header_len == 0 { |
360 | return Err(Error::Header); |
361 | } |
362 | if decompress_len > MAX_INPUT_SIZE { |
363 | return Err(Error::TooBig { |
364 | given: decompress_len as u64, |
365 | max: MAX_INPUT_SIZE, |
366 | }); |
367 | } |
368 | Ok(Header { len: header_len, decompress_len: decompress_len as usize }) |
369 | } |
370 | } |
371 | |
372 | /// A lookup table for quickly computing the various attributes derived from |
373 | /// a tag byte. The attributes are most useful for the three "copy" tags |
374 | /// and include the length of the copy, part of the offset (for copy 1-byte |
375 | /// only) and the total number of bytes proceding the tag byte that encode |
376 | /// the other part of the offset (1 for copy 1, 2 for copy 2 and 4 for copy 4). |
377 | /// |
378 | /// More specifically, the keys of the table are u8s and the values are u16s. |
379 | /// The bits of the values are laid out as follows: |
380 | /// |
381 | /// xxaa abbb xxcc cccc |
382 | /// |
383 | /// Where `a` is the number of bytes, `b` are the three bits of the offset |
384 | /// for copy 1 (the other 8 bits are in the byte proceding the tag byte; for |
385 | /// copy 2 and copy 4, `b = 0`), and `c` is the length of the copy (max of 64). |
386 | /// |
387 | /// We could pack this in fewer bits, but the position of the three `b` bits |
388 | /// lines up with the most significant three bits in the total offset for copy |
389 | /// 1, which avoids an extra shift instruction. |
390 | /// |
391 | /// In sum, this table is useful because it reduces branches and various |
392 | /// arithmetic operations. |
393 | struct TagLookupTable([u16; 256]); |
394 | |
395 | impl TagLookupTable { |
396 | /// Look up the tag entry given the tag `byte`. |
397 | #[inline (always)] |
398 | fn entry(&self, byte: u8) -> TagEntry { |
399 | TagEntry(self.0[byte as usize] as usize) |
400 | } |
401 | } |
402 | |
403 | /// Represents a single entry in the tag lookup table. |
404 | /// |
405 | /// See the documentation in `TagLookupTable` for the bit layout. |
406 | /// |
407 | /// The type is a `usize` for convenience. |
408 | struct TagEntry(usize); |
409 | |
410 | impl TagEntry { |
411 | /// Return the total number of bytes proceding this tag byte required to |
412 | /// encode the offset. |
413 | fn num_tag_bytes(&self) -> usize { |
414 | self.0 >> 11 |
415 | } |
416 | |
417 | /// Return the total copy length, capped at 255. |
418 | fn len(&self) -> usize { |
419 | self.0 & 0xFF |
420 | } |
421 | |
422 | /// Return the copy offset corresponding to this copy operation. `s` should |
423 | /// point to the position just after the tag byte that this entry was read |
424 | /// from. |
425 | /// |
426 | /// This requires reading from the compressed input since the offset is |
427 | /// encoded in bytes proceding the tag byte. |
428 | fn offset(&self, src: &[u8], s: usize) -> Result<usize> { |
429 | let num_tag_bytes = self.num_tag_bytes(); |
430 | let trailer = |
431 | // It is critical for this case to come first, since it is the |
432 | // fast path. We really hope that this case gets branch |
433 | // predicted. |
434 | if s + 4 <= src.len() { |
435 | unsafe { |
436 | // SAFETY: The conditional above guarantees that |
437 | // src[s..s+4] is valid to read from. |
438 | let p = src.as_ptr().add(s); |
439 | // We use WORD_MASK here to mask out the bits we don't |
440 | // need. While we're guaranteed to read 4 valid bytes, |
441 | // not all of those bytes are necessarily part of the |
442 | // offset. This is the key optimization: we don't need to |
443 | // branch on num_tag_bytes. |
444 | bytes::loadu_u32_le(p) as usize & WORD_MASK[num_tag_bytes] |
445 | } |
446 | } else if num_tag_bytes == 1 { |
447 | if s >= src.len() { |
448 | return Err(Error::CopyRead { |
449 | len: 1, |
450 | src_len: (src.len() - s) as u64, |
451 | }); |
452 | } |
453 | src[s] as usize |
454 | } else if num_tag_bytes == 2 { |
455 | if s + 1 >= src.len() { |
456 | return Err(Error::CopyRead { |
457 | len: 2, |
458 | src_len: (src.len() - s) as u64, |
459 | }); |
460 | } |
461 | bytes::read_u16_le(&src[s..]) as usize |
462 | } else { |
463 | return Err(Error::CopyRead { |
464 | len: num_tag_bytes as u64, |
465 | src_len: (src.len() - s) as u64, |
466 | }); |
467 | }; |
468 | Ok((self.0 & 0b0000_0111_0000_0000) | trailer) |
469 | } |
470 | } |
471 | |