1 | use std::fmt; |
2 | use std::io::{self, Read}; |
3 | use std::str::{self, FromStr}; |
4 | |
5 | #[derive (Debug)] |
6 | pub enum CharReadError { |
7 | UnexpectedEof, |
8 | Utf8(str::Utf8Error), |
9 | Io(io::Error), |
10 | } |
11 | |
12 | impl From<str::Utf8Error> for CharReadError { |
13 | #[cold ] |
14 | fn from(e: str::Utf8Error) -> Self { |
15 | Self::Utf8(e) |
16 | } |
17 | } |
18 | |
19 | impl From<io::Error> for CharReadError { |
20 | #[cold ] |
21 | fn from(e: io::Error) -> Self { |
22 | Self::Io(e) |
23 | } |
24 | } |
25 | |
26 | impl fmt::Display for CharReadError { |
27 | #[cold ] |
28 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
29 | use self::CharReadError::{Io, UnexpectedEof, Utf8}; |
30 | match *self { |
31 | UnexpectedEof => write!(f, "unexpected end of stream" ), |
32 | Utf8(ref e: &Utf8Error) => write!(f, "UTF-8 decoding error: {e}" ), |
33 | Io(ref e: &Error) => write!(f, "I/O error: {e}" ), |
34 | } |
35 | } |
36 | } |
37 | |
38 | /// Character encoding used for parsing |
39 | #[derive (Debug, Copy, Clone, Eq, PartialEq)] |
40 | #[non_exhaustive ] |
41 | pub enum Encoding { |
42 | /// Explicitly UTF-8 only |
43 | Utf8, |
44 | /// UTF-8 fallback, but can be any 8-bit encoding |
45 | Default, |
46 | /// ISO-8859-1 |
47 | Latin1, |
48 | /// US-ASCII |
49 | Ascii, |
50 | /// Big-Endian |
51 | Utf16Be, |
52 | /// Little-Endian |
53 | Utf16Le, |
54 | /// Unknown endianness yet, will be sniffed |
55 | Utf16, |
56 | /// Not determined yet, may be sniffed to be anything |
57 | Unknown, |
58 | } |
59 | |
60 | // Rustc inlines eq_ignore_ascii_case and creates kilobytes of code! |
61 | #[inline (never)] |
62 | fn icmp(lower: &str, varcase: &str) -> bool { |
63 | lower.bytes().zip(varcase.bytes()).all(|(l: u8, v: u8)| l == v.to_ascii_lowercase()) |
64 | } |
65 | |
66 | impl FromStr for Encoding { |
67 | type Err = &'static str; |
68 | |
69 | fn from_str(val: &str) -> Result<Self, Self::Err> { |
70 | if ["utf-8" , "utf8" ].into_iter().any(move |label: &'static str| icmp(lower:label, varcase:val)) { |
71 | Ok(Self::Utf8) |
72 | } else if ["iso-8859-1" , "latin1" ].into_iter().any(move |label: &'static str| icmp(lower:label, varcase:val)) { |
73 | Ok(Self::Latin1) |
74 | } else if ["utf-16" , "utf16" ].into_iter().any(move |label: &'static str| icmp(lower:label, varcase:val)) { |
75 | Ok(Self::Utf16) |
76 | } else if ["ascii" , "us-ascii" ].into_iter().any(move |label: &'static str| icmp(lower:label, varcase:val)) { |
77 | Ok(Self::Ascii) |
78 | } else { |
79 | Err("unknown encoding name" ) |
80 | } |
81 | } |
82 | } |
83 | |
84 | impl fmt::Display for Encoding { |
85 | #[cold ] |
86 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
87 | f.write_str(data:match self { |
88 | Self::Utf8 | |
89 | Self::Default => "UTF-8" , |
90 | Self::Latin1 => "ISO-8859-1" , |
91 | Self::Ascii => "US-ASCII" , |
92 | Self::Utf16Be | |
93 | Self::Utf16Le | |
94 | Self::Utf16 => "UTF-16" , |
95 | Self::Unknown => "(unknown)" , |
96 | }) |
97 | } |
98 | } |
99 | |
100 | pub(crate) struct CharReader { |
101 | pub encoding: Encoding, |
102 | } |
103 | |
104 | impl CharReader { |
105 | pub const fn new() -> Self { |
106 | Self { encoding: Encoding::Unknown } |
107 | } |
108 | |
109 | pub fn next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError> { |
110 | let mut bytes = source.bytes(); |
111 | const MAX_CODEPOINT_LEN: usize = 4; |
112 | |
113 | let mut buf = [0u8; MAX_CODEPOINT_LEN]; |
114 | let mut pos = 0; |
115 | while pos < MAX_CODEPOINT_LEN { |
116 | let next = match bytes.next() { |
117 | Some(Ok(b)) => b, |
118 | Some(Err(e)) => return Err(e.into()), |
119 | None if pos == 0 => return Ok(None), |
120 | None => return Err(CharReadError::UnexpectedEof), |
121 | }; |
122 | |
123 | match self.encoding { |
124 | Encoding::Utf8 | Encoding::Default => { |
125 | // fast path for ASCII subset |
126 | if pos == 0 && next.is_ascii() { |
127 | return Ok(Some(next.into())); |
128 | } |
129 | |
130 | buf[pos] = next; |
131 | pos += 1; |
132 | |
133 | match str::from_utf8(&buf[..pos]) { |
134 | Ok(s) => return Ok(s.chars().next()), // always Some(..) |
135 | Err(_) if pos < MAX_CODEPOINT_LEN => continue, |
136 | Err(e) => return Err(e.into()), |
137 | } |
138 | }, |
139 | Encoding::Latin1 => { |
140 | return Ok(Some(next.into())); |
141 | }, |
142 | Encoding::Ascii => { |
143 | return if next.is_ascii() { |
144 | Ok(Some(next.into())) |
145 | } else { |
146 | Err(CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, "char is not ASCII" ))) |
147 | }; |
148 | }, |
149 | Encoding::Unknown | Encoding::Utf16 => { |
150 | buf[pos] = next; |
151 | pos += 1; |
152 | if let Some(value) = self.sniff_bom(&buf[..pos], &mut pos) { |
153 | return value; |
154 | } |
155 | }, |
156 | Encoding::Utf16Be => { |
157 | buf[pos] = next; |
158 | pos += 1; |
159 | if pos == 2 { |
160 | if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() { |
161 | return Ok(Some(c)); |
162 | } |
163 | } else if pos == 4 { |
164 | return Self::surrogate([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())]); |
165 | } |
166 | }, |
167 | Encoding::Utf16Le => { |
168 | buf[pos] = next; |
169 | pos += 1; |
170 | if pos == 2 { |
171 | if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() { |
172 | return Ok(Some(c)); |
173 | } |
174 | } else if pos == 4 { |
175 | return Self::surrogate([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())]); |
176 | } |
177 | }, |
178 | } |
179 | } |
180 | Err(CharReadError::Io(io::ErrorKind::InvalidData.into())) |
181 | } |
182 | |
183 | #[cold ] |
184 | fn sniff_bom(&mut self, buf: &[u8], pos: &mut usize) -> Option<Result<Option<char>, CharReadError>> { |
185 | // sniff BOM |
186 | if buf.len() <= 3 && [0xEF, 0xBB, 0xBF].starts_with(buf) { |
187 | if buf.len() == 3 && self.encoding != Encoding::Utf16 { |
188 | *pos = 0; |
189 | self.encoding = Encoding::Utf8; |
190 | } |
191 | } else if buf.len() <= 2 && [0xFE, 0xFF].starts_with(buf) { |
192 | if buf.len() == 2 { |
193 | *pos = 0; |
194 | self.encoding = Encoding::Utf16Be; |
195 | } |
196 | } else if buf.len() <= 2 && [0xFF, 0xFE].starts_with(buf) { |
197 | if buf.len() == 2 { |
198 | *pos = 0; |
199 | self.encoding = Encoding::Utf16Le; |
200 | } |
201 | } else if buf.len() == 1 && self.encoding == Encoding::Utf16 { |
202 | // sniff ASCII char in UTF-16 |
203 | self.encoding = if buf[0] == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le }; |
204 | } else { |
205 | // UTF-8 is the default, but XML decl can change it to other 8-bit encoding |
206 | self.encoding = Encoding::Default; |
207 | if buf.len() == 1 && buf[0].is_ascii() { |
208 | return Some(Ok(Some(buf[0].into()))); |
209 | } |
210 | } |
211 | None |
212 | } |
213 | |
214 | fn surrogate(buf: [u16; 2]) -> Result<Option<char>, CharReadError> { |
215 | char::decode_utf16(buf).next().transpose() |
216 | .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e))) |
217 | } |
218 | } |
219 | |
220 | #[cfg (test)] |
221 | mod tests { |
222 | use super::{CharReadError, CharReader, Encoding}; |
223 | |
224 | #[test ] |
225 | fn test_next_char_from() { |
226 | use std::io; |
227 | |
228 | let mut bytes: &[u8] = b"correct" ; // correct ASCII |
229 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c' )); |
230 | |
231 | let mut bytes: &[u8] = b" \xEF\xBB\xBF\xE2\x80\xA2!" ; // BOM |
232 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•' )); |
233 | |
234 | let mut bytes: &[u8] = b" \xEF\xBB\xBFx123" ; // BOM |
235 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x' )); |
236 | |
237 | let mut bytes: &[u8] = b" \xEF\xBB\xBF" ; // Nothing after BOM |
238 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
239 | |
240 | let mut bytes: &[u8] = b" \xEF\xBB" ; // Nothing after BO |
241 | assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof))); |
242 | |
243 | let mut bytes: &[u8] = b" \xEF\xBB\x42" ; // Nothing after BO |
244 | assert!(CharReader::new().next_char_from(&mut bytes).is_err()); |
245 | |
246 | let mut bytes: &[u8] = b" \xFE\xFF\x00\x42" ; // UTF-16 |
247 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B' )); |
248 | |
249 | let mut bytes: &[u8] = b" \xFF\xFE\x42\x00" ; // UTF-16 |
250 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B' )); |
251 | |
252 | let mut bytes: &[u8] = b" \xFF\xFE" ; // UTF-16 |
253 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
254 | |
255 | let mut bytes: &[u8] = b" \xFF\xFE\x00" ; // UTF-16 |
256 | assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof))); |
257 | |
258 | let mut bytes: &[u8] = "правильно" .as_bytes(); // correct BMP |
259 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п' )); |
260 | |
261 | let mut bytes: &[u8] = "правильно" .as_bytes(); |
262 | assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿' )); |
263 | |
264 | let mut bytes: &[u8] = "правильно" .as_bytes(); |
265 | assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐' )); |
266 | |
267 | let mut bytes: &[u8] = b" \xD8\xD8\x80" ; |
268 | assert!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).is_err()); |
269 | |
270 | let mut bytes: &[u8] = b" \x00\x42" ; |
271 | assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B' )); |
272 | |
273 | let mut bytes: &[u8] = b" \x42\x00" ; |
274 | assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B' )); |
275 | |
276 | let mut bytes: &[u8] = &[0xEF, 0xBB, 0xBF, 0xFF, 0xFF]; |
277 | assert!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).is_err()); |
278 | |
279 | let mut bytes: &[u8] = b" \x00" ; |
280 | assert!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).is_err()); |
281 | |
282 | let mut bytes: &[u8] = "😊" .as_bytes(); // correct non-BMP |
283 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('😊' )); |
284 | |
285 | let mut bytes: &[u8] = b"" ; // empty |
286 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
287 | |
288 | let mut bytes: &[u8] = b" \xf0\x9f\x98" ; // incomplete code point |
289 | match CharReader::new().next_char_from(&mut bytes).unwrap_err() { |
290 | super::CharReadError::UnexpectedEof => {}, |
291 | e => panic!("Unexpected result: {e:?}" ) |
292 | }; |
293 | |
294 | let mut bytes: &[u8] = b" \xff\x9f\x98\x32" ; // invalid code point |
295 | match CharReader::new().next_char_from(&mut bytes).unwrap_err() { |
296 | super::CharReadError::Utf8(_) => {}, |
297 | e => panic!("Unexpected result: {e:?}" ) |
298 | }; |
299 | |
300 | // error during read |
301 | struct ErrorReader; |
302 | impl io::Read for ErrorReader { |
303 | fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { |
304 | Err(io::Error::new(io::ErrorKind::Other, "test error" )) |
305 | } |
306 | } |
307 | |
308 | let mut r = ErrorReader; |
309 | match CharReader::new().next_char_from(&mut r).unwrap_err() { |
310 | super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other && |
311 | e.to_string().contains("test error" ) => {}, |
312 | e => panic!("Unexpected result: {e:?}" ) |
313 | } |
314 | } |
315 | } |
316 | |