1 | use crate::{ |
2 | engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, |
3 | DecodeError, DecodeSliceError, PAD_BYTE, |
4 | }; |
5 | |
6 | /// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided |
7 | /// parameters. |
8 | /// |
9 | /// Returns the decode metadata representing the total number of bytes decoded, including the ones |
10 | /// indicated as already written by `output_index`. |
11 | pub(crate) fn decode_suffix( |
12 | input: &[u8], |
13 | input_index: usize, |
14 | output: &mut [u8], |
15 | mut output_index: usize, |
16 | decode_table: &[u8; 256], |
17 | decode_allow_trailing_bits: bool, |
18 | padding_mode: DecodePaddingMode, |
19 | ) -> Result<DecodeMetadata, DecodeSliceError> { |
20 | debug_assert!((input.len() - input_index) <= 4); |
21 | |
22 | // Decode any leftovers that might not be a complete input chunk of 4 bytes. |
23 | // Use a u32 as a stack-resident 4 byte buffer. |
24 | let mut morsels_in_leftover = 0; |
25 | let mut padding_bytes_count = 0; |
26 | // offset from input_index |
27 | let mut first_padding_offset: usize = 0; |
28 | let mut last_symbol = 0_u8; |
29 | let mut morsels = [0_u8; 4]; |
30 | |
31 | for (leftover_index, &b) in input[input_index..].iter().enumerate() { |
32 | // '=' padding |
33 | if b == PAD_BYTE { |
34 | // There can be bad padding bytes in a few ways: |
35 | // 1 - Padding with non-padding characters after it |
36 | // 2 - Padding after zero or one characters in the current quad (should only |
37 | // be after 2 or 3 chars) |
38 | // 3 - More than two characters of padding. If 3 or 4 padding chars |
39 | // are in the same quad, that implies it will be caught by #2. |
40 | // If it spreads from one quad to another, it will be an invalid byte |
41 | // in the first quad. |
42 | // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. |
43 | // Per config, non-canonical but still functional non- or partially-padded base64 |
44 | // may be treated as an error condition. |
45 | |
46 | if leftover_index < 2 { |
47 | // Check for error #2. |
48 | // Either the previous byte was padding, in which case we would have already hit |
49 | // this case, or it wasn't, in which case this is the first such error. |
50 | debug_assert!( |
51 | leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0) |
52 | ); |
53 | let bad_padding_index = input_index + leftover_index; |
54 | return Err(DecodeError::InvalidByte(bad_padding_index, b).into()); |
55 | } |
56 | |
57 | if padding_bytes_count == 0 { |
58 | first_padding_offset = leftover_index; |
59 | } |
60 | |
61 | padding_bytes_count += 1; |
62 | continue; |
63 | } |
64 | |
65 | // Check for case #1. |
66 | // To make '=' handling consistent with the main loop, don't allow |
67 | // non-suffix '=' in trailing chunk either. Report error as first |
68 | // erroneous padding. |
69 | if padding_bytes_count > 0 { |
70 | return Err( |
71 | DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(), |
72 | ); |
73 | } |
74 | |
75 | last_symbol = b; |
76 | |
77 | // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. |
78 | // Pack the leftovers from left to right. |
79 | let morsel = decode_table[b as usize]; |
80 | if morsel == INVALID_VALUE { |
81 | return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into()); |
82 | } |
83 | |
84 | morsels[morsels_in_leftover] = morsel; |
85 | morsels_in_leftover += 1; |
86 | } |
87 | |
88 | // If there was 1 trailing byte, and it was valid, and we got to this point without hitting |
89 | // an invalid byte, now we can report invalid length |
90 | if !input.is_empty() && morsels_in_leftover < 2 { |
91 | return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into()); |
92 | } |
93 | |
94 | match padding_mode { |
95 | DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } |
96 | DecodePaddingMode::RequireCanonical => { |
97 | // allow empty input |
98 | if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { |
99 | return Err(DecodeError::InvalidPadding.into()); |
100 | } |
101 | } |
102 | DecodePaddingMode::RequireNone => { |
103 | if padding_bytes_count > 0 { |
104 | // check at the end to make sure we let the cases of padding that should be InvalidByte |
105 | // get hit |
106 | return Err(DecodeError::InvalidPadding.into()); |
107 | } |
108 | } |
109 | } |
110 | |
111 | // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. |
112 | // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits |
113 | // of bottom 6 bits set). |
114 | // When decoding two symbols back to one trailing byte, any final symbol higher than |
115 | // w would still decode to the original byte because we only care about the top two |
116 | // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a |
117 | // mask based on how many bits are used for just the canonical encoding, and optionally |
118 | // error if any other bits are set. In the example of one encoded byte -> 2 symbols, |
119 | // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and |
120 | // useless since there are no more symbols to provide the necessary 4 additional bits |
121 | // to finish the second original byte. |
122 | |
123 | let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; |
124 | // Put the up to 6 complete bytes as the high bytes. |
125 | // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. |
126 | let mut leftover_num = (u32::from(morsels[0]) << 26) |
127 | | (u32::from(morsels[1]) << 20) |
128 | | (u32::from(morsels[2]) << 14) |
129 | | (u32::from(morsels[3]) << 8); |
130 | |
131 | // if there are bits set outside the bits we care about, last symbol encodes trailing bits that |
132 | // will not be included in the output |
133 | let mask = !0_u32 >> (leftover_bytes_to_append * 8); |
134 | if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { |
135 | // last morsel is at `morsels_in_leftover` - 1 |
136 | return Err(DecodeError::InvalidLastSymbol( |
137 | input_index + morsels_in_leftover - 1, |
138 | last_symbol, |
139 | ) |
140 | .into()); |
141 | } |
142 | |
143 | // Strangely, this approach benchmarks better than writing bytes one at a time, |
144 | // or copy_from_slice into output. |
145 | for _ in 0..leftover_bytes_to_append { |
146 | let hi_byte = (leftover_num >> 24) as u8; |
147 | leftover_num <<= 8; |
148 | *output |
149 | .get_mut(output_index) |
150 | .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte; |
151 | output_index += 1; |
152 | } |
153 | |
154 | Ok(DecodeMetadata::new( |
155 | output_index, |
156 | if padding_bytes_count > 0 { |
157 | Some(input_index + first_padding_offset) |
158 | } else { |
159 | None |
160 | }, |
161 | )) |
162 | } |
163 | |