| 1 | use crate::{ |
| 2 | engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, |
| 3 | DecodeError, DecodeSliceError, PAD_BYTE, |
| 4 | }; |
| 5 | |
| 6 | /// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided |
| 7 | /// parameters. |
| 8 | /// |
| 9 | /// Returns the decode metadata representing the total number of bytes decoded, including the ones |
| 10 | /// indicated as already written by `output_index`. |
| 11 | pub(crate) fn decode_suffix( |
| 12 | input: &[u8], |
| 13 | input_index: usize, |
| 14 | output: &mut [u8], |
| 15 | mut output_index: usize, |
| 16 | decode_table: &[u8; 256], |
| 17 | decode_allow_trailing_bits: bool, |
| 18 | padding_mode: DecodePaddingMode, |
| 19 | ) -> Result<DecodeMetadata, DecodeSliceError> { |
| 20 | debug_assert!((input.len() - input_index) <= 4); |
| 21 | |
| 22 | // Decode any leftovers that might not be a complete input chunk of 4 bytes. |
| 23 | // Use a u32 as a stack-resident 4 byte buffer. |
| 24 | let mut morsels_in_leftover = 0; |
| 25 | let mut padding_bytes_count = 0; |
| 26 | // offset from input_index |
| 27 | let mut first_padding_offset: usize = 0; |
| 28 | let mut last_symbol = 0_u8; |
| 29 | let mut morsels = [0_u8; 4]; |
| 30 | |
| 31 | for (leftover_index, &b) in input[input_index..].iter().enumerate() { |
| 32 | // '=' padding |
| 33 | if b == PAD_BYTE { |
| 34 | // There can be bad padding bytes in a few ways: |
| 35 | // 1 - Padding with non-padding characters after it |
| 36 | // 2 - Padding after zero or one characters in the current quad (should only |
| 37 | // be after 2 or 3 chars) |
| 38 | // 3 - More than two characters of padding. If 3 or 4 padding chars |
| 39 | // are in the same quad, that implies it will be caught by #2. |
| 40 | // If it spreads from one quad to another, it will be an invalid byte |
| 41 | // in the first quad. |
| 42 | // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. |
| 43 | // Per config, non-canonical but still functional non- or partially-padded base64 |
| 44 | // may be treated as an error condition. |
| 45 | |
| 46 | if leftover_index < 2 { |
| 47 | // Check for error #2. |
| 48 | // Either the previous byte was padding, in which case we would have already hit |
| 49 | // this case, or it wasn't, in which case this is the first such error. |
| 50 | debug_assert!( |
| 51 | leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0) |
| 52 | ); |
| 53 | let bad_padding_index = input_index + leftover_index; |
| 54 | return Err(DecodeError::InvalidByte(bad_padding_index, b).into()); |
| 55 | } |
| 56 | |
| 57 | if padding_bytes_count == 0 { |
| 58 | first_padding_offset = leftover_index; |
| 59 | } |
| 60 | |
| 61 | padding_bytes_count += 1; |
| 62 | continue; |
| 63 | } |
| 64 | |
| 65 | // Check for case #1. |
| 66 | // To make '=' handling consistent with the main loop, don't allow |
| 67 | // non-suffix '=' in trailing chunk either. Report error as first |
| 68 | // erroneous padding. |
| 69 | if padding_bytes_count > 0 { |
| 70 | return Err( |
| 71 | DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(), |
| 72 | ); |
| 73 | } |
| 74 | |
| 75 | last_symbol = b; |
| 76 | |
| 77 | // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. |
| 78 | // Pack the leftovers from left to right. |
| 79 | let morsel = decode_table[b as usize]; |
| 80 | if morsel == INVALID_VALUE { |
| 81 | return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into()); |
| 82 | } |
| 83 | |
| 84 | morsels[morsels_in_leftover] = morsel; |
| 85 | morsels_in_leftover += 1; |
| 86 | } |
| 87 | |
| 88 | // If there was 1 trailing byte, and it was valid, and we got to this point without hitting |
| 89 | // an invalid byte, now we can report invalid length |
| 90 | if !input.is_empty() && morsels_in_leftover < 2 { |
| 91 | return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into()); |
| 92 | } |
| 93 | |
| 94 | match padding_mode { |
| 95 | DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } |
| 96 | DecodePaddingMode::RequireCanonical => { |
| 97 | // allow empty input |
| 98 | if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { |
| 99 | return Err(DecodeError::InvalidPadding.into()); |
| 100 | } |
| 101 | } |
| 102 | DecodePaddingMode::RequireNone => { |
| 103 | if padding_bytes_count > 0 { |
| 104 | // check at the end to make sure we let the cases of padding that should be InvalidByte |
| 105 | // get hit |
| 106 | return Err(DecodeError::InvalidPadding.into()); |
| 107 | } |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. |
| 112 | // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits |
| 113 | // of bottom 6 bits set). |
| 114 | // When decoding two symbols back to one trailing byte, any final symbol higher than |
| 115 | // w would still decode to the original byte because we only care about the top two |
| 116 | // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a |
| 117 | // mask based on how many bits are used for just the canonical encoding, and optionally |
| 118 | // error if any other bits are set. In the example of one encoded byte -> 2 symbols, |
| 119 | // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and |
| 120 | // useless since there are no more symbols to provide the necessary 4 additional bits |
| 121 | // to finish the second original byte. |
| 122 | |
| 123 | let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; |
| 124 | // Put the up to 6 complete bytes as the high bytes. |
| 125 | // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. |
| 126 | let mut leftover_num = (u32::from(morsels[0]) << 26) |
| 127 | | (u32::from(morsels[1]) << 20) |
| 128 | | (u32::from(morsels[2]) << 14) |
| 129 | | (u32::from(morsels[3]) << 8); |
| 130 | |
| 131 | // if there are bits set outside the bits we care about, last symbol encodes trailing bits that |
| 132 | // will not be included in the output |
| 133 | let mask = !0_u32 >> (leftover_bytes_to_append * 8); |
| 134 | if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { |
| 135 | // last morsel is at `morsels_in_leftover` - 1 |
| 136 | return Err(DecodeError::InvalidLastSymbol( |
| 137 | input_index + morsels_in_leftover - 1, |
| 138 | last_symbol, |
| 139 | ) |
| 140 | .into()); |
| 141 | } |
| 142 | |
| 143 | // Strangely, this approach benchmarks better than writing bytes one at a time, |
| 144 | // or copy_from_slice into output. |
| 145 | for _ in 0..leftover_bytes_to_append { |
| 146 | let hi_byte = (leftover_num >> 24) as u8; |
| 147 | leftover_num <<= 8; |
| 148 | *output |
| 149 | .get_mut(output_index) |
| 150 | .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte; |
| 151 | output_index += 1; |
| 152 | } |
| 153 | |
| 154 | Ok(DecodeMetadata::new( |
| 155 | output_index, |
| 156 | if padding_bytes_count > 0 { |
| 157 | Some(input_index + first_padding_offset) |
| 158 | } else { |
| 159 | None |
| 160 | }, |
| 161 | )) |
| 162 | } |
| 163 | |