1 | //! <https://infra.spec.whatwg.org/#forgiving-base64-decode> |
2 | |
3 | use alloc::vec::Vec; |
4 | use core::fmt; |
5 | |
6 | #[derive (Debug)] |
7 | pub struct InvalidBase64(InvalidBase64Details); |
8 | |
9 | impl fmt::Display for InvalidBase64 { |
10 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
11 | match self.0 { |
12 | InvalidBase64Details::UnexpectedSymbol(code_point: u8) => { |
13 | write!(f, "symbol with codepoint {} not expected" , code_point) |
14 | } |
15 | InvalidBase64Details::AlphabetSymbolAfterPadding => { |
16 | write!(f, "alphabet symbol present after padding" ) |
17 | } |
18 | InvalidBase64Details::LoneAlphabetSymbol => write!(f, "lone alphabet symbol present" ), |
19 | InvalidBase64Details::Padding => write!(f, "incorrect padding" ), |
20 | } |
21 | } |
22 | } |
23 | |
24 | #[cfg (feature = "std" )] |
25 | impl std::error::Error for InvalidBase64 {} |
26 | |
27 | #[derive (Debug)] |
28 | enum InvalidBase64Details { |
29 | UnexpectedSymbol(u8), |
30 | AlphabetSymbolAfterPadding, |
31 | LoneAlphabetSymbol, |
32 | Padding, |
33 | } |
34 | |
35 | #[derive (Debug)] |
36 | pub enum DecodeError<E> { |
37 | InvalidBase64(InvalidBase64), |
38 | WriteError(E), |
39 | } |
40 | |
41 | impl<E: fmt::Display> fmt::Display for DecodeError<E> { |
42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
43 | match self { |
44 | Self::InvalidBase64(inner: &InvalidBase64) => write!(f, "base64 not valid: {}" , inner), |
45 | Self::WriteError(err: &E) => write!(f, "write error: {}" , err), |
46 | } |
47 | } |
48 | } |
49 | |
50 | #[cfg (feature = "std" )] |
51 | impl<E: std::error::Error> std::error::Error for DecodeError<E> {} |
52 | |
53 | impl<E> From<InvalidBase64Details> for DecodeError<E> { |
54 | fn from(e: InvalidBase64Details) -> Self { |
55 | DecodeError::InvalidBase64(InvalidBase64(e)) |
56 | } |
57 | } |
58 | |
59 | pub(crate) enum Impossible {} |
60 | |
61 | impl From<DecodeError<Impossible>> for InvalidBase64 { |
62 | fn from(e: DecodeError<Impossible>) -> Self { |
63 | match e { |
64 | DecodeError::InvalidBase64(e: InvalidBase64) => e, |
65 | DecodeError::WriteError(e: Impossible) => match e {}, |
66 | } |
67 | } |
68 | } |
69 | |
70 | /// `input` is assumed to be in an ASCII-compatible encoding |
71 | pub fn decode_to_vec(input: &[u8]) -> Result<Vec<u8>, InvalidBase64> { |
72 | let mut v: Vec = Vec::new(); |
73 | { |
74 | let mut decoder: Decoder …, …> = Decoder::new(|bytes: &[u8]| { |
75 | v.extend_from_slice(bytes); |
76 | Ok(()) |
77 | }); |
78 | decoder.feed(input)?; |
79 | decoder.finish()?; |
80 | } |
81 | Ok(v) |
82 | } |
83 | |
84 | /// <https://infra.spec.whatwg.org/#forgiving-base64-decode> |
85 | pub struct Decoder<F, E> |
86 | where |
87 | F: FnMut(&[u8]) -> Result<(), E>, |
88 | { |
89 | write_bytes: F, |
90 | bit_buffer: u32, |
91 | buffer_bit_length: u8, |
92 | padding_symbols: u8, |
93 | } |
94 | |
95 | impl<F, E> Decoder<F, E> |
96 | where |
97 | F: FnMut(&[u8]) -> Result<(), E>, |
98 | { |
99 | pub fn new(write_bytes: F) -> Self { |
100 | Self { |
101 | write_bytes, |
102 | bit_buffer: 0, |
103 | buffer_bit_length: 0, |
104 | padding_symbols: 0, |
105 | } |
106 | } |
107 | |
108 | /// Feed to the decoder partial input in an ASCII-compatible encoding |
109 | pub fn feed(&mut self, input: &[u8]) -> Result<(), DecodeError<E>> { |
110 | for &byte in input.iter() { |
111 | let value = BASE64_DECODE_TABLE[byte as usize]; |
112 | if value < 0 { |
113 | // A character that’s not part of the alphabet |
114 | |
115 | // Remove ASCII whitespace |
116 | if matches!(byte, b' ' | b' \t' | b' \n' | b' \r' | b' \x0C' ) { |
117 | continue; |
118 | } |
119 | |
120 | if byte == b'=' { |
121 | self.padding_symbols = self.padding_symbols.saturating_add(1); |
122 | continue; |
123 | } |
124 | |
125 | return Err(InvalidBase64Details::UnexpectedSymbol(byte).into()); |
126 | } |
127 | if self.padding_symbols > 0 { |
128 | return Err(InvalidBase64Details::AlphabetSymbolAfterPadding.into()); |
129 | } |
130 | self.bit_buffer <<= 6; |
131 | self.bit_buffer |= value as u32; |
132 | // 18 before incrementing means we’ve just reached 24 |
133 | if self.buffer_bit_length < 18 { |
134 | self.buffer_bit_length += 6; |
135 | } else { |
136 | // We’ve accumulated four times 6 bits, which equals three times 8 bits. |
137 | let byte_buffer = [ |
138 | (self.bit_buffer >> 16) as u8, |
139 | (self.bit_buffer >> 8) as u8, |
140 | self.bit_buffer as u8, |
141 | ]; |
142 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
143 | self.buffer_bit_length = 0; |
144 | // No need to reset bit_buffer, |
145 | // since next time we’re only gonna read relevant bits. |
146 | } |
147 | } |
148 | Ok(()) |
149 | } |
150 | |
151 | /// Call this to signal the end of the input |
152 | pub fn finish(mut self) -> Result<(), DecodeError<E>> { |
153 | match (self.buffer_bit_length, self.padding_symbols) { |
154 | (0, 0) => { |
155 | // A multiple of four of alphabet symbols, and nothing else. |
156 | } |
157 | (12, 2) | (12, 0) => { |
158 | // A multiple of four of alphabet symbols, followed by two more symbols, |
159 | // optionally followed by two padding characters (which make a total multiple of four). |
160 | let byte_buffer = [(self.bit_buffer >> 4) as u8]; |
161 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
162 | } |
163 | (18, 1) | (18, 0) => { |
164 | // A multiple of four of alphabet symbols, followed by three more symbols, |
165 | // optionally followed by one padding character (which make a total multiple of four). |
166 | let byte_buffer = [(self.bit_buffer >> 10) as u8, (self.bit_buffer >> 2) as u8]; |
167 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
168 | } |
169 | (6, _) => return Err(InvalidBase64Details::LoneAlphabetSymbol.into()), |
170 | _ => return Err(InvalidBase64Details::Padding.into()), |
171 | } |
172 | Ok(()) |
173 | } |
174 | } |
175 | |
176 | /// Generated by `make_base64_decode_table.py` based on "Table 1: The Base 64 Alphabet" |
177 | /// at <https://tools.ietf.org/html/rfc4648#section-4> |
178 | /// |
179 | /// Array indices are the byte value of symbols. |
180 | /// Array values are their positions in the base64 alphabet, |
181 | /// or -1 for symbols not in the alphabet. |
182 | /// The position contributes 6 bits to the decoded bytes. |
183 | #[rustfmt::skip] |
184 | const BASE64_DECODE_TABLE: [i8; 256] = [ |
185 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
186 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
187 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, |
188 | 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, |
189 | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, |
190 | 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, |
191 | -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, |
192 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, |
193 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
194 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
195 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
196 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
197 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
198 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
199 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
200 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
201 | ]; |
202 | |