1use std::fmt;
2use std::io::{self, Read};
3use std::str::{self, FromStr};
4
5#[derive(Debug)]
6pub enum CharReadError {
7 UnexpectedEof,
8 Utf8(str::Utf8Error),
9 Io(io::Error),
10}
11
12impl From<str::Utf8Error> for CharReadError {
13 #[cold]
14 fn from(e: str::Utf8Error) -> CharReadError {
15 CharReadError::Utf8(e)
16 }
17}
18
19impl From<io::Error> for CharReadError {
20 #[cold]
21 fn from(e: io::Error) -> CharReadError {
22 CharReadError::Io(e)
23 }
24}
25
26impl fmt::Display for CharReadError {
27 #[cold]
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 use self::CharReadError::{Io, UnexpectedEof, Utf8};
30 match *self {
31 UnexpectedEof => write!(f, "unexpected end of stream"),
32 Utf8(ref e: &Utf8Error) => write!(f, "UTF-8 decoding error: {e}"),
33 Io(ref e: &Error) => write!(f, "I/O error: {e}"),
34 }
35 }
36}
37
38/// Character encoding used for parsing
39#[derive(Debug, Copy, Clone, Eq, PartialEq)]
40#[non_exhaustive]
41pub enum Encoding {
42 /// Explicitly UTF-8 only
43 Utf8,
44 /// UTF-8 fallback, but can be any 8-bit encoding
45 Default,
46 /// ISO-8859-1
47 Latin1,
48 /// US-ASCII
49 Ascii,
50 /// Big-Endian
51 Utf16Be,
52 /// Little-Endian
53 Utf16Le,
54 /// Unknown endianness yet, will be sniffed
55 Utf16,
56 /// Not determined yet, may be sniffed to be anything
57 Unknown,
58}
59
60// Rustc inlines eq_ignore_ascii_case and creates kilobytes of code!
61#[inline(never)]
62fn icmp(lower: &str, varcase: &str) -> bool {
63 lower.bytes().zip(varcase.bytes()).all(|(l: u8, v: u8)| l == v.to_ascii_lowercase())
64}
65
66impl FromStr for Encoding {
67 type Err = &'static str;
68
69 fn from_str(val: &str) -> Result<Self, Self::Err> {
70 if ["utf-8", "utf8"].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) {
71 Ok(Encoding::Utf8)
72 } else if ["iso-8859-1", "latin1"].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) {
73 Ok(Encoding::Latin1)
74 } else if ["utf-16", "utf16"].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) {
75 Ok(Encoding::Utf16)
76 } else if ["ascii", "us-ascii"].into_iter().any(move |label: &str| icmp(lower:label, varcase:val)) {
77 Ok(Encoding::Ascii)
78 } else {
79 Err("unknown encoding name")
80 }
81 }
82}
83
84impl fmt::Display for Encoding {
85 #[cold]
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 f.write_str(data:match self {
88 Encoding::Utf8 => "UTF-8",
89 Encoding::Default => "UTF-8",
90 Encoding::Latin1 => "ISO-8859-1",
91 Encoding::Ascii => "US-ASCII",
92 Encoding::Utf16Be => "UTF-16",
93 Encoding::Utf16Le => "UTF-16",
94 Encoding::Utf16 => "UTF-16",
95 Encoding::Unknown => "(unknown)",
96 })
97 }
98}
99
100pub(crate) struct CharReader {
101 pub encoding: Encoding,
102}
103
104impl CharReader {
105 pub fn new() -> Self {
106 Self {
107 encoding: Encoding::Unknown,
108 }
109 }
110
111 pub fn next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError> {
112 let mut bytes = source.bytes();
113 const MAX_CODEPOINT_LEN: usize = 4;
114
115 let mut buf = [0u8; MAX_CODEPOINT_LEN];
116 let mut pos = 0;
117 loop {
118 let next = match bytes.next() {
119 Some(Ok(b)) => b,
120 Some(Err(e)) => return Err(e.into()),
121 None if pos == 0 => return Ok(None),
122 None => return Err(CharReadError::UnexpectedEof),
123 };
124
125 match self.encoding {
126 Encoding::Utf8 | Encoding::Default => {
127 // fast path for ASCII subset
128 if pos == 0 && next.is_ascii() {
129 return Ok(Some(next.into()));
130 }
131
132 buf[pos] = next;
133 pos += 1;
134
135 match str::from_utf8(&buf[..pos]) {
136 Ok(s) => return Ok(s.chars().next()), // always Some(..)
137 Err(_) if pos < MAX_CODEPOINT_LEN => continue,
138 Err(e) => return Err(e.into()),
139 }
140 },
141 Encoding::Latin1 => {
142 return Ok(Some(next.into()));
143 },
144 Encoding::Ascii => {
145 if next.is_ascii() {
146 return Ok(Some(next.into()));
147 } else {
148 return Err(CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, "char is not ASCII")));
149 }
150 },
151 Encoding::Unknown | Encoding::Utf16 => {
152 buf[pos] = next;
153 pos += 1;
154
155 // sniff BOM
156 if pos <= 3 && buf[..pos] == [0xEF, 0xBB, 0xBF][..pos] {
157 if pos == 3 && self.encoding != Encoding::Utf16 {
158 pos = 0;
159 self.encoding = Encoding::Utf8;
160 }
161 } else if pos <= 2 && buf[..pos] == [0xFE, 0xFF][..pos] {
162 if pos == 2 {
163 pos = 0;
164 self.encoding = Encoding::Utf16Be;
165 }
166 } else if pos <= 2 && buf[..pos] == [0xFF, 0xFE][..pos] {
167 if pos == 2 {
168 pos = 0;
169 self.encoding = Encoding::Utf16Le;
170 }
171 } else if pos == 1 && self.encoding == Encoding::Utf16 {
172 // sniff ASCII char in UTF-16
173 self.encoding = if next == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le };
174 } else {
175 // UTF-8 is the default, but XML decl can change it to other 8-bit encoding
176 self.encoding = Encoding::Default;
177 if pos == 1 && next.is_ascii() {
178 return Ok(Some(next.into()));
179 }
180 }
181 },
182 Encoding::Utf16Be => {
183 buf[pos] = next;
184 pos += 1;
185 if pos == 2 {
186 if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() {
187 return Ok(Some(c));
188 }
189 } else if pos == 4 { // surrogate
190 return char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())])
191 .next().transpose()
192 .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
193 }
194 },
195 Encoding::Utf16Le => {
196 buf[pos] = next;
197 pos += 1;
198 if pos == 2 {
199 if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() {
200 return Ok(Some(c));
201 }
202 } else if pos == 4 { // surrogate
203 return char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())])
204 .next().transpose()
205 .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
206 }
207 },
208 }
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::{CharReadError, CharReader, Encoding};
216
217 #[test]
218 fn test_next_char_from() {
219 use std::io;
220
221 let mut bytes: &[u8] = "correct".as_bytes(); // correct ASCII
222 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c'));
223
224 let mut bytes: &[u8] = b"\xEF\xBB\xBF\xE2\x80\xA2!"; // BOM
225 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•'));
226
227 let mut bytes: &[u8] = b"\xEF\xBB\xBFx123"; // BOM
228 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x'));
229
230 let mut bytes: &[u8] = b"\xEF\xBB\xBF"; // Nothing after BOM
231 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
232
233 let mut bytes: &[u8] = b"\xEF\xBB"; // Nothing after BO
234 assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
235
236 let mut bytes: &[u8] = b"\xEF\xBB\x42"; // Nothing after BO
237 assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(_)));
238
239 let mut bytes: &[u8] = b"\xFE\xFF\x00\x42"; // UTF-16
240 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
241
242 let mut bytes: &[u8] = b"\xFF\xFE\x42\x00"; // UTF-16
243 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
244
245 let mut bytes: &[u8] = b"\xFF\xFE"; // UTF-16
246 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
247
248 let mut bytes: &[u8] = b"\xFF\xFE\x00"; // UTF-16
249 assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
250
251 let mut bytes: &[u8] = "правильно".as_bytes(); // correct BMP
252 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п'));
253
254 let mut bytes: &[u8] = "правильно".as_bytes();
255 assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿'));
256
257 let mut bytes: &[u8] = "правильно".as_bytes();
258 assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐'));
259
260 let mut bytes: &[u8] = b"\xD8\xD8\x80";
261 assert!(matches!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes), Err(_)));
262
263 let mut bytes: &[u8] = b"\x00\x42";
264 assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
265
266 let mut bytes: &[u8] = b"\x42\x00";
267 assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
268
269 let mut bytes: &[u8] = b"\x00";
270 assert!(matches!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes), Err(_)));
271
272 let mut bytes: &[u8] = "😊".as_bytes(); // correct non-BMP
273 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('😊'));
274
275 let mut bytes: &[u8] = b""; // empty
276 assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
277
278 let mut bytes: &[u8] = b"\xf0\x9f\x98"; // incomplete code point
279 match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
280 super::CharReadError::UnexpectedEof => {},
281 e => panic!("Unexpected result: {e:?}")
282 };
283
284 let mut bytes: &[u8] = b"\xff\x9f\x98\x32"; // invalid code point
285 match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
286 super::CharReadError::Utf8(_) => {},
287 e => panic!("Unexpected result: {e:?}")
288 };
289
290 // error during read
291 struct ErrorReader;
292 impl io::Read for ErrorReader {
293 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
294 Err(io::Error::new(io::ErrorKind::Other, "test error"))
295 }
296 }
297
298 let mut r = ErrorReader;
299 match CharReader::new().next_char_from(&mut r).unwrap_err() {
300 super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other &&
301 e.to_string().contains("test error") => {},
302 e => panic!("Unexpected result: {e:?}")
303 }
304 }
305}
306