1 | use std::fmt; |
2 | use std::io::{self, Read}; |
3 | use std::str::{self, FromStr}; |
4 | |
5 | #[derive (Debug)] |
6 | pub enum CharReadError { |
7 | UnexpectedEof, |
8 | Utf8(str::Utf8Error), |
9 | Io(io::Error), |
10 | } |
11 | |
12 | impl From<str::Utf8Error> for CharReadError { |
13 | #[cold ] |
14 | fn from(e: str::Utf8Error) -> CharReadError { |
15 | CharReadError::Utf8(e) |
16 | } |
17 | } |
18 | |
19 | impl From<io::Error> for CharReadError { |
20 | #[cold ] |
21 | fn from(e: io::Error) -> CharReadError { |
22 | CharReadError::Io(e) |
23 | } |
24 | } |
25 | |
26 | impl fmt::Display for CharReadError { |
27 | #[cold ] |
28 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
29 | use self::CharReadError::{Io, UnexpectedEof, Utf8}; |
30 | match *self { |
31 | UnexpectedEof => write!(f, "unexpected end of stream" ), |
32 | Utf8(ref e: &Utf8Error) => write!(f, "UTF-8 decoding error: {e}" ), |
33 | Io(ref e: &Error) => write!(f, "I/O error: {e}" ), |
34 | } |
35 | } |
36 | } |
37 | |
38 | /// Character encoding used for parsing |
39 | #[derive (Debug, Copy, Clone, Eq, PartialEq)] |
40 | #[non_exhaustive ] |
41 | pub enum Encoding { |
42 | /// Explicitly UTF-8 only |
43 | Utf8, |
44 | /// UTF-8 fallback, but can be any 8-bit encoding |
45 | Default, |
46 | /// ISO-8859-1 |
47 | Latin1, |
48 | /// US-ASCII |
49 | Ascii, |
50 | /// Big-Endian |
51 | Utf16Be, |
52 | /// Little-Endian |
53 | Utf16Le, |
54 | /// Unknown endianness yet, will be sniffed |
55 | Utf16, |
56 | /// Not determined yet, may be sniffed to be anything |
57 | Unknown, |
58 | } |
59 | |
60 | // Rustc inlines eq_ignore_ascii_case and creates kilobytes of code! |
61 | #[inline (never)] |
62 | fn icmp(lower: &str, varcase: &str) -> bool { |
63 | lower.bytes().zip(varcase.bytes()).all(|(l: u8, v: u8)| l == v.to_ascii_lowercase()) |
64 | } |
65 | |
66 | impl FromStr for Encoding { |
67 | type Err = &'static str; |
68 | |
69 | fn from_str(val: &str) -> Result<Self, Self::Err> { |
70 | if ["utf-8" , "utf8" ].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) { |
71 | Ok(Encoding::Utf8) |
72 | } else if ["iso-8859-1" , "latin1" ].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) { |
73 | Ok(Encoding::Latin1) |
74 | } else if ["utf-16" , "utf16" ].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) { |
75 | Ok(Encoding::Utf16) |
76 | } else if ["ascii" , "us-ascii" ].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) { |
77 | Ok(Encoding::Ascii) |
78 | } else { |
79 | Err("unknown encoding name" ) |
80 | } |
81 | } |
82 | } |
83 | |
84 | impl fmt::Display for Encoding { |
85 | #[cold ] |
86 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
87 | f.write_str(data:match self { |
88 | Encoding::Utf8 => "UTF-8" , |
89 | Encoding::Default => "UTF-8" , |
90 | Encoding::Latin1 => "ISO-8859-1" , |
91 | Encoding::Ascii => "US-ASCII" , |
92 | Encoding::Utf16Be => "UTF-16" , |
93 | Encoding::Utf16Le => "UTF-16" , |
94 | Encoding::Utf16 => "UTF-16" , |
95 | Encoding::Unknown => "(unknown)" , |
96 | }) |
97 | } |
98 | } |
99 | |
100 | pub(crate) struct CharReader { |
101 | pub encoding: Encoding, |
102 | } |
103 | |
104 | impl CharReader { |
105 | pub fn new() -> Self { |
106 | Self { |
107 | encoding: Encoding::Unknown, |
108 | } |
109 | } |
110 | |
111 | pub fn next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError> { |
112 | let mut bytes = source.bytes(); |
113 | const MAX_CODEPOINT_LEN: usize = 4; |
114 | |
115 | let mut buf = [0u8; MAX_CODEPOINT_LEN]; |
116 | let mut pos = 0; |
117 | loop { |
118 | let next = match bytes.next() { |
119 | Some(Ok(b)) => b, |
120 | Some(Err(e)) => return Err(e.into()), |
121 | None if pos == 0 => return Ok(None), |
122 | None => return Err(CharReadError::UnexpectedEof), |
123 | }; |
124 | |
125 | match self.encoding { |
126 | Encoding::Utf8 | Encoding::Default => { |
127 | // fast path for ASCII subset |
128 | if pos == 0 && next.is_ascii() { |
129 | return Ok(Some(next.into())); |
130 | } |
131 | |
132 | buf[pos] = next; |
133 | pos += 1; |
134 | |
135 | match str::from_utf8(&buf[..pos]) { |
136 | Ok(s) => return Ok(s.chars().next()), // always Some(..) |
137 | Err(_) if pos < MAX_CODEPOINT_LEN => continue, |
138 | Err(e) => return Err(e.into()), |
139 | } |
140 | }, |
141 | Encoding::Latin1 => { |
142 | return Ok(Some(next.into())); |
143 | }, |
144 | Encoding::Ascii => { |
145 | if next.is_ascii() { |
146 | return Ok(Some(next.into())); |
147 | } else { |
148 | return Err(CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, "char is not ASCII" ))); |
149 | } |
150 | }, |
151 | Encoding::Unknown | Encoding::Utf16 => { |
152 | buf[pos] = next; |
153 | pos += 1; |
154 | |
155 | // sniff BOM |
156 | if pos <= 3 && buf[..pos] == [0xEF, 0xBB, 0xBF][..pos] { |
157 | if pos == 3 && self.encoding != Encoding::Utf16 { |
158 | pos = 0; |
159 | self.encoding = Encoding::Utf8; |
160 | } |
161 | } else if pos <= 2 && buf[..pos] == [0xFE, 0xFF][..pos] { |
162 | if pos == 2 { |
163 | pos = 0; |
164 | self.encoding = Encoding::Utf16Be; |
165 | } |
166 | } else if pos <= 2 && buf[..pos] == [0xFF, 0xFE][..pos] { |
167 | if pos == 2 { |
168 | pos = 0; |
169 | self.encoding = Encoding::Utf16Le; |
170 | } |
171 | } else if pos == 1 && self.encoding == Encoding::Utf16 { |
172 | // sniff ASCII char in UTF-16 |
173 | self.encoding = if next == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le }; |
174 | } else { |
175 | // UTF-8 is the default, but XML decl can change it to other 8-bit encoding |
176 | self.encoding = Encoding::Default; |
177 | if pos == 1 && next.is_ascii() { |
178 | return Ok(Some(next.into())); |
179 | } |
180 | } |
181 | }, |
182 | Encoding::Utf16Be => { |
183 | buf[pos] = next; |
184 | pos += 1; |
185 | if pos == 2 { |
186 | if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() { |
187 | return Ok(Some(c)); |
188 | } |
189 | } else if pos == 4 { // surrogate |
190 | return char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())]) |
191 | .next().transpose() |
192 | .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); |
193 | } |
194 | }, |
195 | Encoding::Utf16Le => { |
196 | buf[pos] = next; |
197 | pos += 1; |
198 | if pos == 2 { |
199 | if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() { |
200 | return Ok(Some(c)); |
201 | } |
202 | } else if pos == 4 { // surrogate |
203 | return char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())]) |
204 | .next().transpose() |
205 | .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); |
206 | } |
207 | }, |
208 | } |
209 | } |
210 | } |
211 | } |
212 | |
213 | #[cfg (test)] |
214 | mod tests { |
215 | use super::{CharReadError, CharReader, Encoding}; |
216 | |
217 | #[test ] |
218 | fn test_next_char_from() { |
219 | use std::io; |
220 | |
221 | let mut bytes: &[u8] = "correct" .as_bytes(); // correct ASCII |
222 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c' )); |
223 | |
224 | let mut bytes: &[u8] = b" \xEF\xBB\xBF\xE2\x80\xA2!" ; // BOM |
225 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•' )); |
226 | |
227 | let mut bytes: &[u8] = b" \xEF\xBB\xBFx123" ; // BOM |
228 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x' )); |
229 | |
230 | let mut bytes: &[u8] = b" \xEF\xBB\xBF" ; // Nothing after BOM |
231 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
232 | |
233 | let mut bytes: &[u8] = b" \xEF\xBB" ; // Nothing after BO |
234 | assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof))); |
235 | |
236 | let mut bytes: &[u8] = b" \xEF\xBB\x42" ; // Nothing after BO |
237 | assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(_))); |
238 | |
239 | let mut bytes: &[u8] = b" \xFE\xFF\x00\x42" ; // UTF-16 |
240 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B' )); |
241 | |
242 | let mut bytes: &[u8] = b" \xFF\xFE\x42\x00" ; // UTF-16 |
243 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B' )); |
244 | |
245 | let mut bytes: &[u8] = b" \xFF\xFE" ; // UTF-16 |
246 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
247 | |
248 | let mut bytes: &[u8] = b" \xFF\xFE\x00" ; // UTF-16 |
249 | assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof))); |
250 | |
251 | let mut bytes: &[u8] = "правильно" .as_bytes(); // correct BMP |
252 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п' )); |
253 | |
254 | let mut bytes: &[u8] = "правильно" .as_bytes(); |
255 | assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿' )); |
256 | |
257 | let mut bytes: &[u8] = "правильно" .as_bytes(); |
258 | assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐' )); |
259 | |
260 | let mut bytes: &[u8] = b" \xD8\xD8\x80" ; |
261 | assert!(matches!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes), Err(_))); |
262 | |
263 | let mut bytes: &[u8] = b" \x00\x42" ; |
264 | assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B' )); |
265 | |
266 | let mut bytes: &[u8] = b" \x42\x00" ; |
267 | assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B' )); |
268 | |
269 | let mut bytes: &[u8] = b" \x00" ; |
270 | assert!(matches!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes), Err(_))); |
271 | |
272 | let mut bytes: &[u8] = "😊" .as_bytes(); // correct non-BMP |
273 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('😊' )); |
274 | |
275 | let mut bytes: &[u8] = b"" ; // empty |
276 | assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None); |
277 | |
278 | let mut bytes: &[u8] = b" \xf0\x9f\x98" ; // incomplete code point |
279 | match CharReader::new().next_char_from(&mut bytes).unwrap_err() { |
280 | super::CharReadError::UnexpectedEof => {}, |
281 | e => panic!("Unexpected result: {e:?}" ) |
282 | }; |
283 | |
284 | let mut bytes: &[u8] = b" \xff\x9f\x98\x32" ; // invalid code point |
285 | match CharReader::new().next_char_from(&mut bytes).unwrap_err() { |
286 | super::CharReadError::Utf8(_) => {}, |
287 | e => panic!("Unexpected result: {e:?}" ) |
288 | }; |
289 | |
290 | // error during read |
291 | struct ErrorReader; |
292 | impl io::Read for ErrorReader { |
293 | fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { |
294 | Err(io::Error::new(io::ErrorKind::Other, "test error" )) |
295 | } |
296 | } |
297 | |
298 | let mut r = ErrorReader; |
299 | match CharReader::new().next_char_from(&mut r).unwrap_err() { |
300 | super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other && |
301 | e.to_string().contains("test error" ) => {}, |
302 | e => panic!("Unexpected result: {e:?}" ) |
303 | } |
304 | } |
305 | } |
306 | |