1 | use crate::{ |
2 | engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, |
3 | DecodeError, PAD_BYTE, |
4 | }; |
5 | |
6 | /// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided |
7 | /// parameters. |
8 | /// |
9 | /// Returns the decode metadata representing the total number of bytes decoded, including the ones |
10 | /// indicated as already written by `output_index`. |
11 | pub(crate) fn decode_suffix( |
12 | input: &[u8], |
13 | input_index: usize, |
14 | output: &mut [u8], |
15 | mut output_index: usize, |
16 | decode_table: &[u8; 256], |
17 | decode_allow_trailing_bits: bool, |
18 | padding_mode: DecodePaddingMode, |
19 | ) -> Result<DecodeMetadata, DecodeError> { |
20 | // Decode any leftovers that aren't a complete input block of 8 bytes. |
21 | // Use a u64 as a stack-resident 8 byte buffer. |
22 | let mut leftover_bits: u64 = 0; |
23 | let mut morsels_in_leftover = 0; |
24 | let mut padding_bytes = 0; |
25 | let mut first_padding_index: usize = 0; |
26 | let mut last_symbol = 0_u8; |
27 | let start_of_leftovers = input_index; |
28 | |
29 | for (i, &b) in input[start_of_leftovers..].iter().enumerate() { |
30 | // '=' padding |
31 | if b == PAD_BYTE { |
32 | // There can be bad padding bytes in a few ways: |
33 | // 1 - Padding with non-padding characters after it |
34 | // 2 - Padding after zero or one characters in the current quad (should only |
35 | // be after 2 or 3 chars) |
36 | // 3 - More than two characters of padding. If 3 or 4 padding chars |
37 | // are in the same quad, that implies it will be caught by #2. |
38 | // If it spreads from one quad to another, it will be an invalid byte |
39 | // in the first quad. |
40 | // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. |
41 | // Per config, non-canonical but still functional non- or partially-padded base64 |
42 | // may be treated as an error condition. |
43 | |
44 | if i % 4 < 2 { |
45 | // Check for case #2. |
46 | let bad_padding_index = start_of_leftovers |
47 | + if padding_bytes > 0 { |
48 | // If we've already seen padding, report the first padding index. |
49 | // This is to be consistent with the normal decode logic: it will report an |
50 | // error on the first padding character (since it doesn't expect to see |
51 | // anything but actual encoded data). |
52 | // This could only happen if the padding started in the previous quad since |
53 | // otherwise this case would have been hit at i % 4 == 0 if it was the same |
54 | // quad. |
55 | first_padding_index |
56 | } else { |
57 | // haven't seen padding before, just use where we are now |
58 | i |
59 | }; |
60 | return Err(DecodeError::InvalidByte(bad_padding_index, b)); |
61 | } |
62 | |
63 | if padding_bytes == 0 { |
64 | first_padding_index = i; |
65 | } |
66 | |
67 | padding_bytes += 1; |
68 | continue; |
69 | } |
70 | |
71 | // Check for case #1. |
72 | // To make '=' handling consistent with the main loop, don't allow |
73 | // non-suffix '=' in trailing chunk either. Report error as first |
74 | // erroneous padding. |
75 | if padding_bytes > 0 { |
76 | return Err(DecodeError::InvalidByte( |
77 | start_of_leftovers + first_padding_index, |
78 | PAD_BYTE, |
79 | )); |
80 | } |
81 | |
82 | last_symbol = b; |
83 | |
84 | // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. |
85 | // Pack the leftovers from left to right. |
86 | let shift = 64 - (morsels_in_leftover + 1) * 6; |
87 | let morsel = decode_table[b as usize]; |
88 | if morsel == INVALID_VALUE { |
89 | return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); |
90 | } |
91 | |
92 | leftover_bits |= (morsel as u64) << shift; |
93 | morsels_in_leftover += 1; |
94 | } |
95 | |
96 | match padding_mode { |
97 | DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } |
98 | DecodePaddingMode::RequireCanonical => { |
99 | if (padding_bytes + morsels_in_leftover) % 4 != 0 { |
100 | return Err(DecodeError::InvalidPadding); |
101 | } |
102 | } |
103 | DecodePaddingMode::RequireNone => { |
104 | if padding_bytes > 0 { |
105 | // check at the end to make sure we let the cases of padding that should be InvalidByte |
106 | // get hit |
107 | return Err(DecodeError::InvalidPadding); |
108 | } |
109 | } |
110 | } |
111 | |
112 | // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. |
113 | // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits |
114 | // of bottom 6 bits set). |
115 | // When decoding two symbols back to one trailing byte, any final symbol higher than |
116 | // w would still decode to the original byte because we only care about the top two |
117 | // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a |
118 | // mask based on how many bits are used for just the canonical encoding, and optionally |
119 | // error if any other bits are set. In the example of one encoded byte -> 2 symbols, |
120 | // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and |
121 | // useless since there are no more symbols to provide the necessary 4 additional bits |
122 | // to finish the second original byte. |
123 | |
124 | let leftover_bits_ready_to_append = match morsels_in_leftover { |
125 | 0 => 0, |
126 | 2 => 8, |
127 | 3 => 16, |
128 | 4 => 24, |
129 | 6 => 32, |
130 | 7 => 40, |
131 | 8 => 48, |
132 | // can also be detected as case #2 bad padding above |
133 | _ => unreachable!( |
134 | "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" |
135 | ), |
136 | }; |
137 | |
138 | // if there are bits set outside the bits we care about, last symbol encodes trailing bits that |
139 | // will not be included in the output |
140 | let mask = !0 >> leftover_bits_ready_to_append; |
141 | if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { |
142 | // last morsel is at `morsels_in_leftover` - 1 |
143 | return Err(DecodeError::InvalidLastSymbol( |
144 | start_of_leftovers + morsels_in_leftover - 1, |
145 | last_symbol, |
146 | )); |
147 | } |
148 | |
149 | // TODO benchmark simply converting to big endian bytes |
150 | let mut leftover_bits_appended_to_buf = 0; |
151 | while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { |
152 | // `as` simply truncates the higher bits, which is what we want here |
153 | let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; |
154 | output[output_index] = selected_bits; |
155 | output_index += 1; |
156 | |
157 | leftover_bits_appended_to_buf += 8; |
158 | } |
159 | |
160 | Ok(DecodeMetadata::new( |
161 | output_index, |
162 | if padding_bytes > 0 { |
163 | Some(input_index + first_padding_index) |
164 | } else { |
165 | None |
166 | }, |
167 | )) |
168 | } |
169 | |