| 1 | use crate::{ |
| 2 | engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, |
| 3 | DecodeError, PAD_BYTE, |
| 4 | }; |
| 5 | |
| 6 | /// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided |
| 7 | /// parameters. |
| 8 | /// |
| 9 | /// Returns the decode metadata representing the total number of bytes decoded, including the ones |
| 10 | /// indicated as already written by `output_index`. |
| 11 | pub(crate) fn decode_suffix( |
| 12 | input: &[u8], |
| 13 | input_index: usize, |
| 14 | output: &mut [u8], |
| 15 | mut output_index: usize, |
| 16 | decode_table: &[u8; 256], |
| 17 | decode_allow_trailing_bits: bool, |
| 18 | padding_mode: DecodePaddingMode, |
| 19 | ) -> Result<DecodeMetadata, DecodeError> { |
| 20 | // Decode any leftovers that aren't a complete input block of 8 bytes. |
| 21 | // Use a u64 as a stack-resident 8 byte buffer. |
| 22 | let mut leftover_bits: u64 = 0; |
| 23 | let mut morsels_in_leftover = 0; |
| 24 | let mut padding_bytes = 0; |
| 25 | let mut first_padding_index: usize = 0; |
| 26 | let mut last_symbol = 0_u8; |
| 27 | let start_of_leftovers = input_index; |
| 28 | |
| 29 | for (i, &b) in input[start_of_leftovers..].iter().enumerate() { |
| 30 | // '=' padding |
| 31 | if b == PAD_BYTE { |
| 32 | // There can be bad padding bytes in a few ways: |
| 33 | // 1 - Padding with non-padding characters after it |
| 34 | // 2 - Padding after zero or one characters in the current quad (should only |
| 35 | // be after 2 or 3 chars) |
| 36 | // 3 - More than two characters of padding. If 3 or 4 padding chars |
| 37 | // are in the same quad, that implies it will be caught by #2. |
| 38 | // If it spreads from one quad to another, it will be an invalid byte |
| 39 | // in the first quad. |
| 40 | // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. |
| 41 | // Per config, non-canonical but still functional non- or partially-padded base64 |
| 42 | // may be treated as an error condition. |
| 43 | |
| 44 | if i % 4 < 2 { |
| 45 | // Check for case #2. |
| 46 | let bad_padding_index = start_of_leftovers |
| 47 | + if padding_bytes > 0 { |
| 48 | // If we've already seen padding, report the first padding index. |
| 49 | // This is to be consistent with the normal decode logic: it will report an |
| 50 | // error on the first padding character (since it doesn't expect to see |
| 51 | // anything but actual encoded data). |
| 52 | // This could only happen if the padding started in the previous quad since |
| 53 | // otherwise this case would have been hit at i % 4 == 0 if it was the same |
| 54 | // quad. |
| 55 | first_padding_index |
| 56 | } else { |
| 57 | // haven't seen padding before, just use where we are now |
| 58 | i |
| 59 | }; |
| 60 | return Err(DecodeError::InvalidByte(bad_padding_index, b)); |
| 61 | } |
| 62 | |
| 63 | if padding_bytes == 0 { |
| 64 | first_padding_index = i; |
| 65 | } |
| 66 | |
| 67 | padding_bytes += 1; |
| 68 | continue; |
| 69 | } |
| 70 | |
| 71 | // Check for case #1. |
| 72 | // To make '=' handling consistent with the main loop, don't allow |
| 73 | // non-suffix '=' in trailing chunk either. Report error as first |
| 74 | // erroneous padding. |
| 75 | if padding_bytes > 0 { |
| 76 | return Err(DecodeError::InvalidByte( |
| 77 | start_of_leftovers + first_padding_index, |
| 78 | PAD_BYTE, |
| 79 | )); |
| 80 | } |
| 81 | |
| 82 | last_symbol = b; |
| 83 | |
| 84 | // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. |
| 85 | // Pack the leftovers from left to right. |
| 86 | let shift = 64 - (morsels_in_leftover + 1) * 6; |
| 87 | let morsel = decode_table[b as usize]; |
| 88 | if morsel == INVALID_VALUE { |
| 89 | return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); |
| 90 | } |
| 91 | |
| 92 | leftover_bits |= (morsel as u64) << shift; |
| 93 | morsels_in_leftover += 1; |
| 94 | } |
| 95 | |
| 96 | match padding_mode { |
| 97 | DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } |
| 98 | DecodePaddingMode::RequireCanonical => { |
| 99 | if (padding_bytes + morsels_in_leftover) % 4 != 0 { |
| 100 | return Err(DecodeError::InvalidPadding); |
| 101 | } |
| 102 | } |
| 103 | DecodePaddingMode::RequireNone => { |
| 104 | if padding_bytes > 0 { |
| 105 | // check at the end to make sure we let the cases of padding that should be InvalidByte |
| 106 | // get hit |
| 107 | return Err(DecodeError::InvalidPadding); |
| 108 | } |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. |
| 113 | // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits |
| 114 | // of bottom 6 bits set). |
| 115 | // When decoding two symbols back to one trailing byte, any final symbol higher than |
| 116 | // w would still decode to the original byte because we only care about the top two |
| 117 | // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a |
| 118 | // mask based on how many bits are used for just the canonical encoding, and optionally |
| 119 | // error if any other bits are set. In the example of one encoded byte -> 2 symbols, |
| 120 | // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and |
| 121 | // useless since there are no more symbols to provide the necessary 4 additional bits |
| 122 | // to finish the second original byte. |
| 123 | |
| 124 | let leftover_bits_ready_to_append = match morsels_in_leftover { |
| 125 | 0 => 0, |
| 126 | 2 => 8, |
| 127 | 3 => 16, |
| 128 | 4 => 24, |
| 129 | 6 => 32, |
| 130 | 7 => 40, |
| 131 | 8 => 48, |
| 132 | // can also be detected as case #2 bad padding above |
| 133 | _ => unreachable!( |
| 134 | "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" |
| 135 | ), |
| 136 | }; |
| 137 | |
| 138 | // if there are bits set outside the bits we care about, last symbol encodes trailing bits that |
| 139 | // will not be included in the output |
| 140 | let mask = !0 >> leftover_bits_ready_to_append; |
| 141 | if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { |
| 142 | // last morsel is at `morsels_in_leftover` - 1 |
| 143 | return Err(DecodeError::InvalidLastSymbol( |
| 144 | start_of_leftovers + morsels_in_leftover - 1, |
| 145 | last_symbol, |
| 146 | )); |
| 147 | } |
| 148 | |
| 149 | // TODO benchmark simply converting to big endian bytes |
| 150 | let mut leftover_bits_appended_to_buf = 0; |
| 151 | while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { |
| 152 | // `as` simply truncates the higher bits, which is what we want here |
| 153 | let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; |
| 154 | output[output_index] = selected_bits; |
| 155 | output_index += 1; |
| 156 | |
| 157 | leftover_bits_appended_to_buf += 8; |
| 158 | } |
| 159 | |
| 160 | Ok(DecodeMetadata::new( |
| 161 | output_index, |
| 162 | if padding_bytes > 0 { |
| 163 | Some(input_index + first_padding_index) |
| 164 | } else { |
| 165 | None |
| 166 | }, |
| 167 | )) |
| 168 | } |
| 169 | |