| 1 | //! <https://infra.spec.whatwg.org/#forgiving-base64-decode> |
| 2 | |
| 3 | use alloc::vec::Vec; |
| 4 | use core::fmt; |
| 5 | |
| 6 | #[derive (Debug)] |
| 7 | pub struct InvalidBase64(InvalidBase64Details); |
| 8 | |
| 9 | impl fmt::Display for InvalidBase64 { |
| 10 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 11 | match self.0 { |
| 12 | InvalidBase64Details::UnexpectedSymbol(code_point: u8) => { |
| 13 | write!(f, "symbol with codepoint {} not expected" , code_point) |
| 14 | } |
| 15 | InvalidBase64Details::AlphabetSymbolAfterPadding => { |
| 16 | write!(f, "alphabet symbol present after padding" ) |
| 17 | } |
| 18 | InvalidBase64Details::LoneAlphabetSymbol => write!(f, "lone alphabet symbol present" ), |
| 19 | InvalidBase64Details::Padding => write!(f, "incorrect padding" ), |
| 20 | } |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | #[cfg (feature = "std" )] |
| 25 | impl std::error::Error for InvalidBase64 {} |
| 26 | |
| 27 | #[derive (Debug)] |
| 28 | enum InvalidBase64Details { |
| 29 | UnexpectedSymbol(u8), |
| 30 | AlphabetSymbolAfterPadding, |
| 31 | LoneAlphabetSymbol, |
| 32 | Padding, |
| 33 | } |
| 34 | |
| 35 | #[derive (Debug)] |
| 36 | pub enum DecodeError<E> { |
| 37 | InvalidBase64(InvalidBase64), |
| 38 | WriteError(E), |
| 39 | } |
| 40 | |
| 41 | impl<E: fmt::Display> fmt::Display for DecodeError<E> { |
| 42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 43 | match self { |
| 44 | Self::InvalidBase64(inner: &InvalidBase64) => write!(f, "base64 not valid: {}" , inner), |
| 45 | Self::WriteError(err: &E) => write!(f, "write error: {}" , err), |
| 46 | } |
| 47 | } |
| 48 | } |
| 49 | |
| 50 | #[cfg (feature = "std" )] |
| 51 | impl<E: std::error::Error> std::error::Error for DecodeError<E> {} |
| 52 | |
| 53 | impl<E> From<InvalidBase64Details> for DecodeError<E> { |
| 54 | fn from(e: InvalidBase64Details) -> Self { |
| 55 | DecodeError::InvalidBase64(InvalidBase64(e)) |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | pub(crate) enum Impossible {} |
| 60 | |
| 61 | impl From<DecodeError<Impossible>> for InvalidBase64 { |
| 62 | fn from(e: DecodeError<Impossible>) -> Self { |
| 63 | match e { |
| 64 | DecodeError::InvalidBase64(e: InvalidBase64) => e, |
| 65 | DecodeError::WriteError(e: Impossible) => match e {}, |
| 66 | } |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | /// `input` is assumed to be in an ASCII-compatible encoding |
| 71 | pub fn decode_to_vec(input: &[u8]) -> Result<Vec<u8>, InvalidBase64> { |
| 72 | let mut v: Vec = Vec::new(); |
| 73 | { |
| 74 | let mut decoder: Decoder …, …> = Decoder::new(|bytes: &[u8]| { |
| 75 | v.extend_from_slice(bytes); |
| 76 | Ok(()) |
| 77 | }); |
| 78 | decoder.feed(input)?; |
| 79 | decoder.finish()?; |
| 80 | } |
| 81 | Ok(v) |
| 82 | } |
| 83 | |
| 84 | /// <https://infra.spec.whatwg.org/#forgiving-base64-decode> |
| 85 | pub struct Decoder<F, E> |
| 86 | where |
| 87 | F: FnMut(&[u8]) -> Result<(), E>, |
| 88 | { |
| 89 | write_bytes: F, |
| 90 | bit_buffer: u32, |
| 91 | buffer_bit_length: u8, |
| 92 | padding_symbols: u8, |
| 93 | } |
| 94 | |
| 95 | impl<F, E> Decoder<F, E> |
| 96 | where |
| 97 | F: FnMut(&[u8]) -> Result<(), E>, |
| 98 | { |
| 99 | pub fn new(write_bytes: F) -> Self { |
| 100 | Self { |
| 101 | write_bytes, |
| 102 | bit_buffer: 0, |
| 103 | buffer_bit_length: 0, |
| 104 | padding_symbols: 0, |
| 105 | } |
| 106 | } |
| 107 | |
| 108 | /// Feed to the decoder partial input in an ASCII-compatible encoding |
| 109 | pub fn feed(&mut self, input: &[u8]) -> Result<(), DecodeError<E>> { |
| 110 | for &byte in input.iter() { |
| 111 | let value = BASE64_DECODE_TABLE[byte as usize]; |
| 112 | if value < 0 { |
| 113 | // A character that’s not part of the alphabet |
| 114 | |
| 115 | // Remove ASCII whitespace |
| 116 | if matches!(byte, b' ' | b' \t' | b' \n' | b' \r' | b' \x0C' ) { |
| 117 | continue; |
| 118 | } |
| 119 | |
| 120 | if byte == b'=' { |
| 121 | self.padding_symbols = self.padding_symbols.saturating_add(1); |
| 122 | continue; |
| 123 | } |
| 124 | |
| 125 | return Err(InvalidBase64Details::UnexpectedSymbol(byte).into()); |
| 126 | } |
| 127 | if self.padding_symbols > 0 { |
| 128 | return Err(InvalidBase64Details::AlphabetSymbolAfterPadding.into()); |
| 129 | } |
| 130 | self.bit_buffer <<= 6; |
| 131 | self.bit_buffer |= value as u32; |
| 132 | // 18 before incrementing means we’ve just reached 24 |
| 133 | if self.buffer_bit_length < 18 { |
| 134 | self.buffer_bit_length += 6; |
| 135 | } else { |
| 136 | // We’ve accumulated four times 6 bits, which equals three times 8 bits. |
| 137 | let byte_buffer = [ |
| 138 | (self.bit_buffer >> 16) as u8, |
| 139 | (self.bit_buffer >> 8) as u8, |
| 140 | self.bit_buffer as u8, |
| 141 | ]; |
| 142 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
| 143 | self.buffer_bit_length = 0; |
| 144 | // No need to reset bit_buffer, |
| 145 | // since next time we’re only gonna read relevant bits. |
| 146 | } |
| 147 | } |
| 148 | Ok(()) |
| 149 | } |
| 150 | |
| 151 | /// Call this to signal the end of the input |
| 152 | pub fn finish(mut self) -> Result<(), DecodeError<E>> { |
| 153 | match (self.buffer_bit_length, self.padding_symbols) { |
| 154 | (0, 0) => { |
| 155 | // A multiple of four of alphabet symbols, and nothing else. |
| 156 | } |
| 157 | (12, 2) | (12, 0) => { |
| 158 | // A multiple of four of alphabet symbols, followed by two more symbols, |
| 159 | // optionally followed by two padding characters (which make a total multiple of four). |
| 160 | let byte_buffer = [(self.bit_buffer >> 4) as u8]; |
| 161 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
| 162 | } |
| 163 | (18, 1) | (18, 0) => { |
| 164 | // A multiple of four of alphabet symbols, followed by three more symbols, |
| 165 | // optionally followed by one padding character (which make a total multiple of four). |
| 166 | let byte_buffer = [(self.bit_buffer >> 10) as u8, (self.bit_buffer >> 2) as u8]; |
| 167 | (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?; |
| 168 | } |
| 169 | (6, _) => return Err(InvalidBase64Details::LoneAlphabetSymbol.into()), |
| 170 | _ => return Err(InvalidBase64Details::Padding.into()), |
| 171 | } |
| 172 | Ok(()) |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | /// Generated by `make_base64_decode_table.py` based on "Table 1: The Base 64 Alphabet" |
| 177 | /// at <https://tools.ietf.org/html/rfc4648#section-4> |
| 178 | /// |
| 179 | /// Array indices are the byte value of symbols. |
| 180 | /// Array values are their positions in the base64 alphabet, |
| 181 | /// or -1 for symbols not in the alphabet. |
| 182 | /// The position contributes 6 bits to the decoded bytes. |
| 183 | #[rustfmt::skip] |
| 184 | const BASE64_DECODE_TABLE: [i8; 256] = [ |
| 185 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 186 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 187 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, |
| 188 | 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, |
| 189 | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, |
| 190 | 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, |
| 191 | -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, |
| 192 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, |
| 193 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 194 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 195 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 196 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 197 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 198 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 199 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 200 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
| 201 | ]; |
| 202 | |