| 1 | use crate::{ |
| 2 | engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, |
| 3 | DecodeError, DecodeSliceError, PAD_BYTE, |
| 4 | }; |
| 5 | |
| 6 | #[doc (hidden)] |
| 7 | pub struct GeneralPurposeEstimate { |
| 8 | /// input len % 4 |
| 9 | rem: usize, |
| 10 | conservative_decoded_len: usize, |
| 11 | } |
| 12 | |
| 13 | impl GeneralPurposeEstimate { |
| 14 | pub(crate) fn new(encoded_len: usize) -> Self { |
| 15 | let rem: usize = encoded_len % 4; |
| 16 | Self { |
| 17 | rem, |
| 18 | conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3, |
| 19 | } |
| 20 | } |
| 21 | } |
| 22 | |
| 23 | impl DecodeEstimate for GeneralPurposeEstimate { |
| 24 | fn decoded_len_estimate(&self) -> usize { |
| 25 | self.conservative_decoded_len |
| 26 | } |
| 27 | } |
| 28 | |
| 29 | /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. |
| 30 | /// Returns the decode metadata, or an error. |
| 31 | // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is |
| 32 | // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, |
| 33 | // but this is fragile and the best setting changes with only minor code modifications. |
| 34 | #[inline ] |
| 35 | pub(crate) fn decode_helper( |
| 36 | input: &[u8], |
| 37 | estimate: GeneralPurposeEstimate, |
| 38 | output: &mut [u8], |
| 39 | decode_table: &[u8; 256], |
| 40 | decode_allow_trailing_bits: bool, |
| 41 | padding_mode: DecodePaddingMode, |
| 42 | ) -> Result<DecodeMetadata, DecodeSliceError> { |
| 43 | let input_complete_nonterminal_quads_len = |
| 44 | complete_quads_len(input, estimate.rem, output.len(), decode_table)?; |
| 45 | |
| 46 | const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; |
| 47 | const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; |
| 48 | |
| 49 | let input_complete_quads_after_unrolled_chunks_len = |
| 50 | input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; |
| 51 | |
| 52 | let input_unrolled_loop_len = |
| 53 | input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; |
| 54 | |
| 55 | // chunks of 32 bytes |
| 56 | for (chunk_index, chunk) in input[..input_unrolled_loop_len] |
| 57 | .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) |
| 58 | .enumerate() |
| 59 | { |
| 60 | let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; |
| 61 | let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE |
| 62 | ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; |
| 63 | |
| 64 | decode_chunk_8( |
| 65 | &chunk[0..8], |
| 66 | input_index, |
| 67 | decode_table, |
| 68 | &mut chunk_output[0..6], |
| 69 | )?; |
| 70 | decode_chunk_8( |
| 71 | &chunk[8..16], |
| 72 | input_index + 8, |
| 73 | decode_table, |
| 74 | &mut chunk_output[6..12], |
| 75 | )?; |
| 76 | decode_chunk_8( |
| 77 | &chunk[16..24], |
| 78 | input_index + 16, |
| 79 | decode_table, |
| 80 | &mut chunk_output[12..18], |
| 81 | )?; |
| 82 | decode_chunk_8( |
| 83 | &chunk[24..32], |
| 84 | input_index + 24, |
| 85 | decode_table, |
| 86 | &mut chunk_output[18..24], |
| 87 | )?; |
| 88 | } |
| 89 | |
| 90 | // remaining quads, except for the last possibly partial one, as it may have padding |
| 91 | let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; |
| 92 | let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; |
| 93 | { |
| 94 | let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; |
| 95 | |
| 96 | for (chunk_index, chunk) in input |
| 97 | [input_unrolled_loop_len..input_complete_nonterminal_quads_len] |
| 98 | .chunks_exact(4) |
| 99 | .enumerate() |
| 100 | { |
| 101 | let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; |
| 102 | |
| 103 | decode_chunk_4( |
| 104 | chunk, |
| 105 | input_unrolled_loop_len + chunk_index * 4, |
| 106 | decode_table, |
| 107 | chunk_output, |
| 108 | )?; |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | super::decode_suffix::decode_suffix( |
| 113 | input, |
| 114 | input_complete_nonterminal_quads_len, |
| 115 | output, |
| 116 | output_complete_quad_len, |
| 117 | decode_table, |
| 118 | decode_allow_trailing_bits, |
| 119 | padding_mode, |
| 120 | ) |
| 121 | } |
| 122 | |
| 123 | /// Returns the length of complete quads, except for the last one, even if it is complete. |
| 124 | /// |
| 125 | /// Returns an error if the output len is not big enough for decoding those complete quads, or if |
| 126 | /// the input % 4 == 1, and that last byte is an invalid value other than a pad byte. |
| 127 | /// |
| 128 | /// - `input` is the base64 input |
| 129 | /// - `input_len_rem` is input len % 4 |
| 130 | /// - `output_len` is the length of the output slice |
| 131 | pub(crate) fn complete_quads_len( |
| 132 | input: &[u8], |
| 133 | input_len_rem: usize, |
| 134 | output_len: usize, |
| 135 | decode_table: &[u8; 256], |
| 136 | ) -> Result<usize, DecodeSliceError> { |
| 137 | debug_assert!(input.len() % 4 == input_len_rem); |
| 138 | |
| 139 | // detect a trailing invalid byte, like a newline, as a user convenience |
| 140 | if input_len_rem == 1 { |
| 141 | let last_byte = input[input.len() - 1]; |
| 142 | // exclude pad bytes; might be part of padding that extends from earlier in the input |
| 143 | if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { |
| 144 | return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into()); |
| 145 | } |
| 146 | }; |
| 147 | |
| 148 | // skip last quad, even if it's complete, as it may have padding |
| 149 | let input_complete_nonterminal_quads_len = input |
| 150 | .len() |
| 151 | .saturating_sub(input_len_rem) |
| 152 | // if rem was 0, subtract 4 to avoid padding |
| 153 | .saturating_sub((input_len_rem == 0) as usize * 4); |
| 154 | debug_assert!( |
| 155 | input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) |
| 156 | ); |
| 157 | |
| 158 | // check that everything except the last quad handled by decode_suffix will fit |
| 159 | if output_len < input_complete_nonterminal_quads_len / 4 * 3 { |
| 160 | return Err(DecodeSliceError::OutputSliceTooSmall); |
| 161 | }; |
| 162 | Ok(input_complete_nonterminal_quads_len) |
| 163 | } |
| 164 | |
| 165 | /// Decode 8 bytes of input into 6 bytes of output. |
| 166 | /// |
| 167 | /// `input` is the 8 bytes to decode. |
| 168 | /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors |
| 169 | /// accurately) |
| 170 | /// `decode_table` is the lookup table for the particular base64 alphabet. |
| 171 | /// `output` will have its first 6 bytes overwritten |
| 172 | // yes, really inline (worth 30-50% speedup) |
| 173 | #[inline (always)] |
| 174 | fn decode_chunk_8( |
| 175 | input: &[u8], |
| 176 | index_at_start_of_input: usize, |
| 177 | decode_table: &[u8; 256], |
| 178 | output: &mut [u8], |
| 179 | ) -> Result<(), DecodeError> { |
| 180 | let morsel = decode_table[usize::from(input[0])]; |
| 181 | if morsel == INVALID_VALUE { |
| 182 | return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
| 183 | } |
| 184 | let mut accum = u64::from(morsel) << 58; |
| 185 | |
| 186 | let morsel = decode_table[usize::from(input[1])]; |
| 187 | if morsel == INVALID_VALUE { |
| 188 | return Err(DecodeError::InvalidByte( |
| 189 | index_at_start_of_input + 1, |
| 190 | input[1], |
| 191 | )); |
| 192 | } |
| 193 | accum |= u64::from(morsel) << 52; |
| 194 | |
| 195 | let morsel = decode_table[usize::from(input[2])]; |
| 196 | if morsel == INVALID_VALUE { |
| 197 | return Err(DecodeError::InvalidByte( |
| 198 | index_at_start_of_input + 2, |
| 199 | input[2], |
| 200 | )); |
| 201 | } |
| 202 | accum |= u64::from(morsel) << 46; |
| 203 | |
| 204 | let morsel = decode_table[usize::from(input[3])]; |
| 205 | if morsel == INVALID_VALUE { |
| 206 | return Err(DecodeError::InvalidByte( |
| 207 | index_at_start_of_input + 3, |
| 208 | input[3], |
| 209 | )); |
| 210 | } |
| 211 | accum |= u64::from(morsel) << 40; |
| 212 | |
| 213 | let morsel = decode_table[usize::from(input[4])]; |
| 214 | if morsel == INVALID_VALUE { |
| 215 | return Err(DecodeError::InvalidByte( |
| 216 | index_at_start_of_input + 4, |
| 217 | input[4], |
| 218 | )); |
| 219 | } |
| 220 | accum |= u64::from(morsel) << 34; |
| 221 | |
| 222 | let morsel = decode_table[usize::from(input[5])]; |
| 223 | if morsel == INVALID_VALUE { |
| 224 | return Err(DecodeError::InvalidByte( |
| 225 | index_at_start_of_input + 5, |
| 226 | input[5], |
| 227 | )); |
| 228 | } |
| 229 | accum |= u64::from(morsel) << 28; |
| 230 | |
| 231 | let morsel = decode_table[usize::from(input[6])]; |
| 232 | if morsel == INVALID_VALUE { |
| 233 | return Err(DecodeError::InvalidByte( |
| 234 | index_at_start_of_input + 6, |
| 235 | input[6], |
| 236 | )); |
| 237 | } |
| 238 | accum |= u64::from(morsel) << 22; |
| 239 | |
| 240 | let morsel = decode_table[usize::from(input[7])]; |
| 241 | if morsel == INVALID_VALUE { |
| 242 | return Err(DecodeError::InvalidByte( |
| 243 | index_at_start_of_input + 7, |
| 244 | input[7], |
| 245 | )); |
| 246 | } |
| 247 | accum |= u64::from(morsel) << 16; |
| 248 | |
| 249 | output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); |
| 250 | |
| 251 | Ok(()) |
| 252 | } |
| 253 | |
| 254 | /// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. |
| 255 | #[inline (always)] |
| 256 | fn decode_chunk_4( |
| 257 | input: &[u8], |
| 258 | index_at_start_of_input: usize, |
| 259 | decode_table: &[u8; 256], |
| 260 | output: &mut [u8], |
| 261 | ) -> Result<(), DecodeError> { |
| 262 | let morsel = decode_table[usize::from(input[0])]; |
| 263 | if morsel == INVALID_VALUE { |
| 264 | return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
| 265 | } |
| 266 | let mut accum = u32::from(morsel) << 26; |
| 267 | |
| 268 | let morsel = decode_table[usize::from(input[1])]; |
| 269 | if morsel == INVALID_VALUE { |
| 270 | return Err(DecodeError::InvalidByte( |
| 271 | index_at_start_of_input + 1, |
| 272 | input[1], |
| 273 | )); |
| 274 | } |
| 275 | accum |= u32::from(morsel) << 20; |
| 276 | |
| 277 | let morsel = decode_table[usize::from(input[2])]; |
| 278 | if morsel == INVALID_VALUE { |
| 279 | return Err(DecodeError::InvalidByte( |
| 280 | index_at_start_of_input + 2, |
| 281 | input[2], |
| 282 | )); |
| 283 | } |
| 284 | accum |= u32::from(morsel) << 14; |
| 285 | |
| 286 | let morsel = decode_table[usize::from(input[3])]; |
| 287 | if morsel == INVALID_VALUE { |
| 288 | return Err(DecodeError::InvalidByte( |
| 289 | index_at_start_of_input + 3, |
| 290 | input[3], |
| 291 | )); |
| 292 | } |
| 293 | accum |= u32::from(morsel) << 8; |
| 294 | |
| 295 | output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); |
| 296 | |
| 297 | Ok(()) |
| 298 | } |
| 299 | |
| 300 | #[cfg (test)] |
| 301 | mod tests { |
| 302 | use super::*; |
| 303 | |
| 304 | use crate::engine::general_purpose::STANDARD; |
| 305 | |
| 306 | #[test ] |
| 307 | fn decode_chunk_8_writes_only_6_bytes() { |
| 308 | let input = b"Zm9vYmFy" ; // "foobar" |
| 309 | let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| 310 | |
| 311 | decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| 312 | assert_eq!(&vec![b'f' , b'o' , b'o' , b'b' , b'a' , b'r' , 6, 7], &output); |
| 313 | } |
| 314 | |
| 315 | #[test ] |
| 316 | fn decode_chunk_4_writes_only_3_bytes() { |
| 317 | let input = b"Zm9v" ; // "foobar" |
| 318 | let mut output = [0_u8, 1, 2, 3]; |
| 319 | |
| 320 | decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| 321 | assert_eq!(&vec![b'f' , b'o' , b'o' , 3], &output); |
| 322 | } |
| 323 | |
| 324 | #[test ] |
| 325 | fn estimate_short_lengths() { |
| 326 | for (range, decoded_len_estimate) in [ |
| 327 | (0..=0, 0), |
| 328 | (1..=4, 3), |
| 329 | (5..=8, 6), |
| 330 | (9..=12, 9), |
| 331 | (13..=16, 12), |
| 332 | (17..=20, 15), |
| 333 | ] { |
| 334 | for encoded_len in range { |
| 335 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
| 336 | assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); |
| 337 | } |
| 338 | } |
| 339 | } |
| 340 | |
| 341 | #[test ] |
| 342 | fn estimate_via_u128_inflation() { |
| 343 | // cover both ends of usize |
| 344 | (0..1000) |
| 345 | .chain(usize::MAX - 1000..=usize::MAX) |
| 346 | .for_each(|encoded_len| { |
| 347 | // inflate to 128 bit type to be able to safely use the easy formulas |
| 348 | let len_128 = encoded_len as u128; |
| 349 | |
| 350 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
| 351 | assert_eq!( |
| 352 | (len_128 + 3) / 4 * 3, |
| 353 | estimate.conservative_decoded_len as u128 |
| 354 | ); |
| 355 | }) |
| 356 | } |
| 357 | } |
| 358 | |