| 1 | use crate::{ |
| 2 | engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, |
| 3 | DecodeError, PAD_BYTE, |
| 4 | }; |
| 5 | |
| 6 | // decode logic operates on chunks of 8 input bytes without padding |
| 7 | const INPUT_CHUNK_LEN: usize = 8; |
| 8 | const DECODED_CHUNK_LEN: usize = 6; |
| 9 | |
| 10 | // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last |
| 11 | // 2 bytes of any output u64 should not be counted as written to (but must be available in a |
| 12 | // slice). |
| 13 | const DECODED_CHUNK_SUFFIX: usize = 2; |
| 14 | |
| 15 | // how many u64's of input to handle at a time |
| 16 | const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; |
| 17 | |
| 18 | const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; |
| 19 | |
| 20 | // includes the trailing 2 bytes for the final u64 write |
| 21 | const DECODED_BLOCK_LEN: usize = |
| 22 | CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; |
| 23 | |
| 24 | #[doc (hidden)] |
| 25 | pub struct GeneralPurposeEstimate { |
| 26 | /// Total number of decode chunks, including a possibly partial last chunk |
| 27 | num_chunks: usize, |
| 28 | decoded_len_estimate: usize, |
| 29 | } |
| 30 | |
| 31 | impl GeneralPurposeEstimate { |
| 32 | pub(crate) fn new(encoded_len: usize) -> Self { |
| 33 | // Formulas that won't overflow |
| 34 | Self { |
| 35 | num_chunks: encoded_len / INPUT_CHUNK_LEN |
| 36 | + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, |
| 37 | decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, |
| 38 | } |
| 39 | } |
| 40 | } |
| 41 | |
| 42 | impl DecodeEstimate for GeneralPurposeEstimate { |
| 43 | fn decoded_len_estimate(&self) -> usize { |
| 44 | self.decoded_len_estimate |
| 45 | } |
| 46 | } |
| 47 | |
| 48 | /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. |
| 49 | /// Returns the decode metadata, or an error. |
| 50 | // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is |
| 51 | // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, |
| 52 | // but this is fragile and the best setting changes with only minor code modifications. |
| 53 | #[inline ] |
| 54 | pub(crate) fn decode_helper( |
| 55 | input: &[u8], |
| 56 | estimate: GeneralPurposeEstimate, |
| 57 | output: &mut [u8], |
| 58 | decode_table: &[u8; 256], |
| 59 | decode_allow_trailing_bits: bool, |
| 60 | padding_mode: DecodePaddingMode, |
| 61 | ) -> Result<DecodeMetadata, DecodeError> { |
| 62 | let remainder_len = input.len() % INPUT_CHUNK_LEN; |
| 63 | |
| 64 | // Because the fast decode loop writes in groups of 8 bytes (unrolled to |
| 65 | // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of |
| 66 | // which only 6 are valid data), we need to be sure that we stop using the fast decode loop |
| 67 | // soon enough that there will always be 2 more bytes of valid data written after that loop. |
| 68 | let trailing_bytes_to_skip = match remainder_len { |
| 69 | // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, |
| 70 | // and the fast decode logic cannot handle padding |
| 71 | 0 => INPUT_CHUNK_LEN, |
| 72 | // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte |
| 73 | 1 | 5 => { |
| 74 | // trailing whitespace is so common that it's worth it to check the last byte to |
| 75 | // possibly return a better error message |
| 76 | if let Some(b) = input.last() { |
| 77 | if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { |
| 78 | return Err(DecodeError::InvalidByte(input.len() - 1, *b)); |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | return Err(DecodeError::InvalidLength); |
| 83 | } |
| 84 | // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes |
| 85 | // written by the fast decode loop. So, we have to ignore both these 2 bytes and the |
| 86 | // previous chunk. |
| 87 | 2 => INPUT_CHUNK_LEN + 2, |
| 88 | // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this |
| 89 | // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail |
| 90 | // with an error, not panic from going past the bounds of the output slice, so we let it |
| 91 | // use stage 3 + 4. |
| 92 | 3 => INPUT_CHUNK_LEN + 3, |
| 93 | // This can also decode to one output byte because it may be 2 input chars + 2 padding |
| 94 | // chars, which would decode to 1 byte. |
| 95 | 4 => INPUT_CHUNK_LEN + 4, |
| 96 | // Everything else is a legal decode len (given that we don't require padding), and will |
| 97 | // decode to at least 2 bytes of output. |
| 98 | _ => remainder_len, |
| 99 | }; |
| 100 | |
| 101 | // rounded up to include partial chunks |
| 102 | let mut remaining_chunks = estimate.num_chunks; |
| 103 | |
| 104 | let mut input_index = 0; |
| 105 | let mut output_index = 0; |
| 106 | |
| 107 | { |
| 108 | let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); |
| 109 | |
| 110 | // Fast loop, stage 1 |
| 111 | // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks |
| 112 | if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { |
| 113 | while input_index <= max_start_index { |
| 114 | let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; |
| 115 | let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; |
| 116 | |
| 117 | decode_chunk( |
| 118 | &input_slice[0..], |
| 119 | input_index, |
| 120 | decode_table, |
| 121 | &mut output_slice[0..], |
| 122 | )?; |
| 123 | decode_chunk( |
| 124 | &input_slice[8..], |
| 125 | input_index + 8, |
| 126 | decode_table, |
| 127 | &mut output_slice[6..], |
| 128 | )?; |
| 129 | decode_chunk( |
| 130 | &input_slice[16..], |
| 131 | input_index + 16, |
| 132 | decode_table, |
| 133 | &mut output_slice[12..], |
| 134 | )?; |
| 135 | decode_chunk( |
| 136 | &input_slice[24..], |
| 137 | input_index + 24, |
| 138 | decode_table, |
| 139 | &mut output_slice[18..], |
| 140 | )?; |
| 141 | |
| 142 | input_index += INPUT_BLOCK_LEN; |
| 143 | output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; |
| 144 | remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | // Fast loop, stage 2 (aka still pretty fast loop) |
| 149 | // 8 bytes at a time for whatever we didn't do in stage 1. |
| 150 | if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { |
| 151 | while input_index < max_start_index { |
| 152 | decode_chunk( |
| 153 | &input[input_index..(input_index + INPUT_CHUNK_LEN)], |
| 154 | input_index, |
| 155 | decode_table, |
| 156 | &mut output |
| 157 | [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], |
| 158 | )?; |
| 159 | |
| 160 | output_index += DECODED_CHUNK_LEN; |
| 161 | input_index += INPUT_CHUNK_LEN; |
| 162 | remaining_chunks -= 1; |
| 163 | } |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | // Stage 3 |
| 168 | // If input length was such that a chunk had to be deferred until after the fast loop |
| 169 | // because decoding it would have produced 2 trailing bytes that wouldn't then be |
| 170 | // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 |
| 171 | // trailing bytes. |
| 172 | // However, we still need to avoid the last chunk (partial or complete) because it could |
| 173 | // have padding, so we always do 1 fewer to avoid the last chunk. |
| 174 | for _ in 1..remaining_chunks { |
| 175 | decode_chunk_precise( |
| 176 | &input[input_index..], |
| 177 | input_index, |
| 178 | decode_table, |
| 179 | &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], |
| 180 | )?; |
| 181 | |
| 182 | input_index += INPUT_CHUNK_LEN; |
| 183 | output_index += DECODED_CHUNK_LEN; |
| 184 | } |
| 185 | |
| 186 | // always have one more (possibly partial) block of 8 input |
| 187 | debug_assert!(input.len() - input_index > 1 || input.is_empty()); |
| 188 | debug_assert!(input.len() - input_index <= 8); |
| 189 | |
| 190 | super::decode_suffix::decode_suffix( |
| 191 | input, |
| 192 | input_index, |
| 193 | output, |
| 194 | output_index, |
| 195 | decode_table, |
| 196 | decode_allow_trailing_bits, |
| 197 | padding_mode, |
| 198 | ) |
| 199 | } |
| 200 | |
| 201 | /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the |
| 202 | /// first 6 of those contain meaningful data. |
| 203 | /// |
| 204 | /// `input` is the bytes to decode, of which the first 8 bytes will be processed. |
| 205 | /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors |
| 206 | /// accurately) |
| 207 | /// `decode_table` is the lookup table for the particular base64 alphabet. |
| 208 | /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded |
| 209 | /// data. |
| 210 | // yes, really inline (worth 30-50% speedup) |
| 211 | #[inline (always)] |
| 212 | fn decode_chunk( |
| 213 | input: &[u8], |
| 214 | index_at_start_of_input: usize, |
| 215 | decode_table: &[u8; 256], |
| 216 | output: &mut [u8], |
| 217 | ) -> Result<(), DecodeError> { |
| 218 | let morsel = decode_table[input[0] as usize]; |
| 219 | if morsel == INVALID_VALUE { |
| 220 | return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
| 221 | } |
| 222 | let mut accum = (morsel as u64) << 58; |
| 223 | |
| 224 | let morsel = decode_table[input[1] as usize]; |
| 225 | if morsel == INVALID_VALUE { |
| 226 | return Err(DecodeError::InvalidByte( |
| 227 | index_at_start_of_input + 1, |
| 228 | input[1], |
| 229 | )); |
| 230 | } |
| 231 | accum |= (morsel as u64) << 52; |
| 232 | |
| 233 | let morsel = decode_table[input[2] as usize]; |
| 234 | if morsel == INVALID_VALUE { |
| 235 | return Err(DecodeError::InvalidByte( |
| 236 | index_at_start_of_input + 2, |
| 237 | input[2], |
| 238 | )); |
| 239 | } |
| 240 | accum |= (morsel as u64) << 46; |
| 241 | |
| 242 | let morsel = decode_table[input[3] as usize]; |
| 243 | if morsel == INVALID_VALUE { |
| 244 | return Err(DecodeError::InvalidByte( |
| 245 | index_at_start_of_input + 3, |
| 246 | input[3], |
| 247 | )); |
| 248 | } |
| 249 | accum |= (morsel as u64) << 40; |
| 250 | |
| 251 | let morsel = decode_table[input[4] as usize]; |
| 252 | if morsel == INVALID_VALUE { |
| 253 | return Err(DecodeError::InvalidByte( |
| 254 | index_at_start_of_input + 4, |
| 255 | input[4], |
| 256 | )); |
| 257 | } |
| 258 | accum |= (morsel as u64) << 34; |
| 259 | |
| 260 | let morsel = decode_table[input[5] as usize]; |
| 261 | if morsel == INVALID_VALUE { |
| 262 | return Err(DecodeError::InvalidByte( |
| 263 | index_at_start_of_input + 5, |
| 264 | input[5], |
| 265 | )); |
| 266 | } |
| 267 | accum |= (morsel as u64) << 28; |
| 268 | |
| 269 | let morsel = decode_table[input[6] as usize]; |
| 270 | if morsel == INVALID_VALUE { |
| 271 | return Err(DecodeError::InvalidByte( |
| 272 | index_at_start_of_input + 6, |
| 273 | input[6], |
| 274 | )); |
| 275 | } |
| 276 | accum |= (morsel as u64) << 22; |
| 277 | |
| 278 | let morsel = decode_table[input[7] as usize]; |
| 279 | if morsel == INVALID_VALUE { |
| 280 | return Err(DecodeError::InvalidByte( |
| 281 | index_at_start_of_input + 7, |
| 282 | input[7], |
| 283 | )); |
| 284 | } |
| 285 | accum |= (morsel as u64) << 16; |
| 286 | |
| 287 | write_u64(output, accum); |
| 288 | |
| 289 | Ok(()) |
| 290 | } |
| 291 | |
| 292 | /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 |
| 293 | /// trailing garbage bytes. |
| 294 | #[inline ] |
| 295 | fn decode_chunk_precise( |
| 296 | input: &[u8], |
| 297 | index_at_start_of_input: usize, |
| 298 | decode_table: &[u8; 256], |
| 299 | output: &mut [u8], |
| 300 | ) -> Result<(), DecodeError> { |
| 301 | let mut tmp_buf: [u8; 8] = [0_u8; 8]; |
| 302 | |
| 303 | decode_chunk( |
| 304 | input, |
| 305 | index_at_start_of_input, |
| 306 | decode_table, |
| 307 | &mut tmp_buf[..], |
| 308 | )?; |
| 309 | |
| 310 | output[0..6].copy_from_slice(&tmp_buf[0..6]); |
| 311 | |
| 312 | Ok(()) |
| 313 | } |
| 314 | |
| 315 | #[inline ] |
| 316 | fn write_u64(output: &mut [u8], value: u64) { |
| 317 | output[..8].copy_from_slice(&value.to_be_bytes()); |
| 318 | } |
| 319 | |
| 320 | #[cfg (test)] |
| 321 | mod tests { |
| 322 | use super::*; |
| 323 | |
| 324 | use crate::engine::general_purpose::STANDARD; |
| 325 | |
| 326 | #[test ] |
| 327 | fn decode_chunk_precise_writes_only_6_bytes() { |
| 328 | let input = b"Zm9vYmFy" ; // "foobar" |
| 329 | let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| 330 | |
| 331 | decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| 332 | assert_eq!(&vec![b'f' , b'o' , b'o' , b'b' , b'a' , b'r' , 6, 7], &output); |
| 333 | } |
| 334 | |
| 335 | #[test ] |
| 336 | fn decode_chunk_writes_8_bytes() { |
| 337 | let input = b"Zm9vYmFy" ; // "foobar" |
| 338 | let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| 339 | |
| 340 | decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| 341 | assert_eq!(&vec![b'f' , b'o' , b'o' , b'b' , b'a' , b'r' , 0, 0], &output); |
| 342 | } |
| 343 | |
| 344 | #[test ] |
| 345 | fn estimate_short_lengths() { |
| 346 | for (range, (num_chunks, decoded_len_estimate)) in [ |
| 347 | (0..=0, (0, 0)), |
| 348 | (1..=4, (1, 3)), |
| 349 | (5..=8, (1, 6)), |
| 350 | (9..=12, (2, 9)), |
| 351 | (13..=16, (2, 12)), |
| 352 | (17..=20, (3, 15)), |
| 353 | ] { |
| 354 | for encoded_len in range { |
| 355 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
| 356 | assert_eq!(num_chunks, estimate.num_chunks); |
| 357 | assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); |
| 358 | } |
| 359 | } |
| 360 | } |
| 361 | |
| 362 | #[test ] |
| 363 | fn estimate_via_u128_inflation() { |
| 364 | // cover both ends of usize |
| 365 | (0..1000) |
| 366 | .chain(usize::MAX - 1000..=usize::MAX) |
| 367 | .for_each(|encoded_len| { |
| 368 | // inflate to 128 bit type to be able to safely use the easy formulas |
| 369 | let len_128 = encoded_len as u128; |
| 370 | |
| 371 | let estimate = GeneralPurposeEstimate::new(encoded_len); |
| 372 | assert_eq!( |
| 373 | ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) |
| 374 | as usize, |
| 375 | estimate.num_chunks |
| 376 | ); |
| 377 | assert_eq!( |
| 378 | ((len_128 + 3) / 4 * 3) as usize, |
| 379 | estimate.decoded_len_estimate |
| 380 | ); |
| 381 | }) |
| 382 | } |
| 383 | } |
| 384 | |