1 | use crate::{ |
2 | engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, |
3 | DecodeError, DecodeSliceError, PAD_BYTE, |
4 | }; |
5 | |
6 | #[doc (hidden)] |
7 | pub struct GeneralPurposeEstimate { |
8 | /// input len % 4 |
9 | rem: usize, |
10 | conservative_decoded_len: usize, |
11 | } |
12 | |
13 | impl GeneralPurposeEstimate { |
14 | pub(crate) fn new(encoded_len: usize) -> Self { |
15 | let rem: usize = encoded_len % 4; |
16 | Self { |
17 | rem, |
18 | conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3, |
19 | } |
20 | } |
21 | } |
22 | |
23 | impl DecodeEstimate for GeneralPurposeEstimate { |
24 | fn decoded_len_estimate(&self) -> usize { |
25 | self.conservative_decoded_len |
26 | } |
27 | } |
28 | |
29 | /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. |
30 | /// Returns the decode metadata, or an error. |
31 | // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is |
32 | // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, |
33 | // but this is fragile and the best setting changes with only minor code modifications. |
34 | #[inline ] |
35 | pub(crate) fn decode_helper( |
36 | input: &[u8], |
37 | estimate: GeneralPurposeEstimate, |
38 | output: &mut [u8], |
39 | decode_table: &[u8; 256], |
40 | decode_allow_trailing_bits: bool, |
41 | padding_mode: DecodePaddingMode, |
42 | ) -> Result<DecodeMetadata, DecodeSliceError> { |
43 | let input_complete_nonterminal_quads_len = |
44 | complete_quads_len(input, estimate.rem, output.len(), decode_table)?; |
45 | |
46 | const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; |
47 | const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; |
48 | |
49 | let input_complete_quads_after_unrolled_chunks_len = |
50 | input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; |
51 | |
52 | let input_unrolled_loop_len = |
53 | input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; |
54 | |
55 | // chunks of 32 bytes |
56 | for (chunk_index, chunk) in input[..input_unrolled_loop_len] |
57 | .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) |
58 | .enumerate() |
59 | { |
60 | let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; |
61 | let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE |
62 | ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; |
63 | |
64 | decode_chunk_8( |
65 | &chunk[0..8], |
66 | input_index, |
67 | decode_table, |
68 | &mut chunk_output[0..6], |
69 | )?; |
70 | decode_chunk_8( |
71 | &chunk[8..16], |
72 | input_index + 8, |
73 | decode_table, |
74 | &mut chunk_output[6..12], |
75 | )?; |
76 | decode_chunk_8( |
77 | &chunk[16..24], |
78 | input_index + 16, |
79 | decode_table, |
80 | &mut chunk_output[12..18], |
81 | )?; |
82 | decode_chunk_8( |
83 | &chunk[24..32], |
84 | input_index + 24, |
85 | decode_table, |
86 | &mut chunk_output[18..24], |
87 | )?; |
88 | } |
89 | |
90 | // remaining quads, except for the last possibly partial one, as it may have padding |
91 | let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; |
92 | let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; |
93 | { |
94 | let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; |
95 | |
96 | for (chunk_index, chunk) in input |
97 | [input_unrolled_loop_len..input_complete_nonterminal_quads_len] |
98 | .chunks_exact(4) |
99 | .enumerate() |
100 | { |
101 | let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; |
102 | |
103 | decode_chunk_4( |
104 | chunk, |
105 | input_unrolled_loop_len + chunk_index * 4, |
106 | decode_table, |
107 | chunk_output, |
108 | )?; |
109 | } |
110 | } |
111 | |
112 | super::decode_suffix::decode_suffix( |
113 | input, |
114 | input_complete_nonterminal_quads_len, |
115 | output, |
116 | output_complete_quad_len, |
117 | decode_table, |
118 | decode_allow_trailing_bits, |
119 | padding_mode, |
120 | ) |
121 | } |
122 | |
123 | /// Returns the length of complete quads, except for the last one, even if it is complete. |
124 | /// |
125 | /// Returns an error if the output len is not big enough for decoding those complete quads, or if |
126 | /// the input % 4 == 1, and that last byte is an invalid value other than a pad byte. |
127 | /// |
128 | /// - `input` is the base64 input |
129 | /// - `input_len_rem` is input len % 4 |
130 | /// - `output_len` is the length of the output slice |
131 | pub(crate) fn complete_quads_len( |
132 | input: &[u8], |
133 | input_len_rem: usize, |
134 | output_len: usize, |
135 | decode_table: &[u8; 256], |
136 | ) -> Result<usize, DecodeSliceError> { |
137 | debug_assert!(input.len() % 4 == input_len_rem); |
138 | |
139 | // detect a trailing invalid byte, like a newline, as a user convenience |
140 | if input_len_rem == 1 { |
141 | let last_byte = input[input.len() - 1]; |
142 | // exclude pad bytes; might be part of padding that extends from earlier in the input |
143 | if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { |
144 | return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into()); |
145 | } |
146 | }; |
147 | |
148 | // skip last quad, even if it's complete, as it may have padding |
149 | let input_complete_nonterminal_quads_len = input |
150 | .len() |
151 | .saturating_sub(input_len_rem) |
152 | // if rem was 0, subtract 4 to avoid padding |
153 | .saturating_sub((input_len_rem == 0) as usize * 4); |
154 | debug_assert!( |
155 | input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) |
156 | ); |
157 | |
158 | // check that everything except the last quad handled by decode_suffix will fit |
159 | if output_len < input_complete_nonterminal_quads_len / 4 * 3 { |
160 | return Err(DecodeSliceError::OutputSliceTooSmall); |
161 | }; |
162 | Ok(input_complete_nonterminal_quads_len) |
163 | } |
164 | |
165 | /// Decode 8 bytes of input into 6 bytes of output. |
166 | /// |
167 | /// `input` is the 8 bytes to decode. |
168 | /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors |
169 | /// accurately) |
170 | /// `decode_table` is the lookup table for the particular base64 alphabet. |
171 | /// `output` will have its first 6 bytes overwritten |
172 | // yes, really inline (worth 30-50% speedup) |
173 | #[inline (always)] |
174 | fn decode_chunk_8( |
175 | input: &[u8], |
176 | index_at_start_of_input: usize, |
177 | decode_table: &[u8; 256], |
178 | output: &mut [u8], |
179 | ) -> Result<(), DecodeError> { |
180 | let morsel = decode_table[usize::from(input[0])]; |
181 | if morsel == INVALID_VALUE { |
182 | return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
183 | } |
184 | let mut accum = u64::from(morsel) << 58; |
185 | |
186 | let morsel = decode_table[usize::from(input[1])]; |
187 | if morsel == INVALID_VALUE { |
188 | return Err(DecodeError::InvalidByte( |
189 | index_at_start_of_input + 1, |
190 | input[1], |
191 | )); |
192 | } |
193 | accum |= u64::from(morsel) << 52; |
194 | |
195 | let morsel = decode_table[usize::from(input[2])]; |
196 | if morsel == INVALID_VALUE { |
197 | return Err(DecodeError::InvalidByte( |
198 | index_at_start_of_input + 2, |
199 | input[2], |
200 | )); |
201 | } |
202 | accum |= u64::from(morsel) << 46; |
203 | |
204 | let morsel = decode_table[usize::from(input[3])]; |
205 | if morsel == INVALID_VALUE { |
206 | return Err(DecodeError::InvalidByte( |
207 | index_at_start_of_input + 3, |
208 | input[3], |
209 | )); |
210 | } |
211 | accum |= u64::from(morsel) << 40; |
212 | |
213 | let morsel = decode_table[usize::from(input[4])]; |
214 | if morsel == INVALID_VALUE { |
215 | return Err(DecodeError::InvalidByte( |
216 | index_at_start_of_input + 4, |
217 | input[4], |
218 | )); |
219 | } |
220 | accum |= u64::from(morsel) << 34; |
221 | |
222 | let morsel = decode_table[usize::from(input[5])]; |
223 | if morsel == INVALID_VALUE { |
224 | return Err(DecodeError::InvalidByte( |
225 | index_at_start_of_input + 5, |
226 | input[5], |
227 | )); |
228 | } |
229 | accum |= u64::from(morsel) << 28; |
230 | |
231 | let morsel = decode_table[usize::from(input[6])]; |
232 | if morsel == INVALID_VALUE { |
233 | return Err(DecodeError::InvalidByte( |
234 | index_at_start_of_input + 6, |
235 | input[6], |
236 | )); |
237 | } |
238 | accum |= u64::from(morsel) << 22; |
239 | |
240 | let morsel = decode_table[usize::from(input[7])]; |
241 | if morsel == INVALID_VALUE { |
242 | return Err(DecodeError::InvalidByte( |
243 | index_at_start_of_input + 7, |
244 | input[7], |
245 | )); |
246 | } |
247 | accum |= u64::from(morsel) << 16; |
248 | |
249 | output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); |
250 | |
251 | Ok(()) |
252 | } |
253 | |
254 | /// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. |
255 | #[inline (always)] |
256 | fn decode_chunk_4( |
257 | input: &[u8], |
258 | index_at_start_of_input: usize, |
259 | decode_table: &[u8; 256], |
260 | output: &mut [u8], |
261 | ) -> Result<(), DecodeError> { |
262 | let morsel = decode_table[usize::from(input[0])]; |
263 | if morsel == INVALID_VALUE { |
264 | return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
265 | } |
266 | let mut accum = u32::from(morsel) << 26; |
267 | |
268 | let morsel = decode_table[usize::from(input[1])]; |
269 | if morsel == INVALID_VALUE { |
270 | return Err(DecodeError::InvalidByte( |
271 | index_at_start_of_input + 1, |
272 | input[1], |
273 | )); |
274 | } |
275 | accum |= u32::from(morsel) << 20; |
276 | |
277 | let morsel = decode_table[usize::from(input[2])]; |
278 | if morsel == INVALID_VALUE { |
279 | return Err(DecodeError::InvalidByte( |
280 | index_at_start_of_input + 2, |
281 | input[2], |
282 | )); |
283 | } |
284 | accum |= u32::from(morsel) << 14; |
285 | |
286 | let morsel = decode_table[usize::from(input[3])]; |
287 | if morsel == INVALID_VALUE { |
288 | return Err(DecodeError::InvalidByte( |
289 | index_at_start_of_input + 3, |
290 | input[3], |
291 | )); |
292 | } |
293 | accum |= u32::from(morsel) << 8; |
294 | |
295 | output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); |
296 | |
297 | Ok(()) |
298 | } |
299 | |
300 | #[cfg (test)] |
301 | mod tests { |
302 | use super::*; |
303 | |
304 | use crate::engine::general_purpose::STANDARD; |
305 | |
306 | #[test ] |
307 | fn decode_chunk_8_writes_only_6_bytes() { |
308 | let input = b"Zm9vYmFy" ; // "foobar" |
309 | let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
310 | |
311 | decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
312 | assert_eq!(&vec![b'f' , b'o' , b'o' , b'b' , b'a' , b'r' , 6, 7], &output); |
313 | } |
314 | |
315 | #[test ] |
316 | fn decode_chunk_4_writes_only_3_bytes() { |
317 | let input = b"Zm9v" ; // "foobar" |
318 | let mut output = [0_u8, 1, 2, 3]; |
319 | |
320 | decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
321 | assert_eq!(&vec![b'f' , b'o' , b'o' , 3], &output); |
322 | } |
323 | |
324 | #[test ] |
325 | fn estimate_short_lengths() { |
326 | for (range, decoded_len_estimate) in [ |
327 | (0..=0, 0), |
328 | (1..=4, 3), |
329 | (5..=8, 6), |
330 | (9..=12, 9), |
331 | (13..=16, 12), |
332 | (17..=20, 15), |
333 | ] { |
334 | for encoded_len in range { |
335 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
336 | assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); |
337 | } |
338 | } |
339 | } |
340 | |
341 | #[test ] |
342 | fn estimate_via_u128_inflation() { |
343 | // cover both ends of usize |
344 | (0..1000) |
345 | .chain(usize::MAX - 1000..=usize::MAX) |
346 | .for_each(|encoded_len| { |
347 | // inflate to 128 bit type to be able to safely use the easy formulas |
348 | let len_128 = encoded_len as u128; |
349 | |
350 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
351 | assert_eq!( |
352 | (len_128 + 3) / 4 * 3, |
353 | estimate.conservative_decoded_len as u128 |
354 | ); |
355 | }) |
356 | } |
357 | } |
358 | |