xml_no_std/
util.rs

1extern crate alloc;
2
3use alloc::string::ToString;
4
5use core::fmt;
6use core::str::{self, FromStr};
7
8#[derive(Debug)]
9pub enum CharReadError {
10    UnexpectedEof,
11    Utf8(str::Utf8Error),
12    Io(alloc::string::String),
13}
14
15impl From<str::Utf8Error> for CharReadError {
16    #[cold]
17    fn from(e: str::Utf8Error) -> CharReadError {
18        CharReadError::Utf8(e)
19    }
20}
21
22
23impl fmt::Display for CharReadError {
24    #[cold]
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        use self::CharReadError::{Io, UnexpectedEof, Utf8};
27        match *self {
28            UnexpectedEof => write!(f, "unexpected end of stream"),
29            Utf8(ref e) => write!(f, "UTF-8 decoding error: {e}"),
30            Io(ref e) => write!(f, "I/O error: {e}"),
31        }
32    }
33}
34
35/// Character encoding used for parsing
36#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37#[non_exhaustive]
38pub enum Encoding {
39    /// Explicitly UTF-8 only
40    Utf8,
41    /// UTF-8 fallback, but can be any 8-bit encoding
42    Default,
43    /// ISO-8859-1
44    Latin1,
45    /// US-ASCII
46    Ascii,
47    /// Big-Endian
48    Utf16Be,
49    /// Little-Endian
50    Utf16Le,
51    /// Unknown endianness yet, will be sniffed
52    Utf16,
53    /// Not determined yet, may be sniffed to be anything
54    Unknown,
55}
56
57// Rustc inlines eq_ignore_ascii_case and creates kilobytes of code!
58#[inline(never)]
59fn icmp(lower: &str, varcase: &str) -> bool {
60    lower.bytes().zip(varcase.bytes()).all(|(l, v)| l == v.to_ascii_lowercase())
61}
62
63impl FromStr for Encoding {
64    type Err = &'static str;
65
66    fn from_str(val: &str) -> Result<Self, Self::Err> {
67        if ["utf-8", "utf8"].into_iter().any(move |label| icmp(label, val)) {
68            Ok(Encoding::Utf8)
69        } else if ["iso-8859-1", "latin1"].into_iter().any(move |label| icmp(label, val)) {
70            Ok(Encoding::Latin1)
71        } else if ["utf-16", "utf16"].into_iter().any(move |label| icmp(label, val)) {
72            Ok(Encoding::Utf16)
73        } else if ["ascii", "us-ascii"].into_iter().any(move |label| icmp(label, val)) {
74            Ok(Encoding::Ascii)
75        } else {
76            Err("unknown encoding name")
77        }
78    }
79}
80
81impl fmt::Display for Encoding {
82    #[cold]
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.write_str(match self {
85            Encoding::Utf8 |
86            Encoding::Default => "UTF-8",
87            Encoding::Latin1 => "ISO-8859-1",
88            Encoding::Ascii => "US-ASCII",
89            Encoding::Utf16Be |
90            Encoding::Utf16Le |
91            Encoding::Utf16 => "UTF-16",
92            Encoding::Unknown => "(unknown)",
93        })
94    }
95}
96
97pub(crate) struct CharReader {
98    pub encoding: Encoding,
99}
100
101impl CharReader {
102    pub const fn new() -> Self {
103        Self {
104            encoding: Encoding::Unknown,
105        }
106    }
107
108    pub fn next_char_from<'a, S: Iterator<Item = &'a u8>>(&mut self, source: &mut S) -> Result<Option<char>, CharReadError> {
109        const MAX_CODEPOINT_LEN: usize = 4;
110
111        let mut buf = [0u8; MAX_CODEPOINT_LEN];
112        let mut pos = 0;
113        while pos < MAX_CODEPOINT_LEN {
114            let next = match source.next() {
115                Some(b) => *b,
116                None if pos == 0 => return Ok(None),
117                None => return Err(CharReadError::UnexpectedEof),
118            };
119
120            match self.encoding {
121                Encoding::Utf8 | Encoding::Default => {
122                    // fast path for ASCII subset
123                    if pos == 0 && next.is_ascii() {
124                        return Ok(Some(next.into()));
125                    }
126
127                    buf[pos] = next;
128                    pos += 1;
129
130                    match str::from_utf8(&buf[..pos]) {
131                        Ok(s) => return Ok(s.chars().next()), // always Some(..)
132                        Err(_) if pos < MAX_CODEPOINT_LEN => continue,
133                        Err(e) => return Err(e.into()),
134                    }
135                },
136                Encoding::Latin1 => {
137                    return Ok(Some(next.into()));
138                },
139                Encoding::Ascii => {
140                    return if next.is_ascii() {
141                        Ok(Some(next.into()))
142                    } else {
143                        Err(CharReadError::Io("char is not ASCII".to_string()))
144                    }
145                },
146                Encoding::Unknown | Encoding::Utf16 => {
147                    buf[pos] = next;
148                    pos += 1;
149                    if let Some(value) = self.sniff_bom(&buf[..pos], &mut pos) {
150                        return value;
151                    }
152                },
153                Encoding::Utf16Be => {
154                    buf[pos] = next;
155                    pos += 1;
156                    if pos == 2 {
157                        if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() {
158                            return Ok(Some(c));
159                        }
160                    } else if pos == 4 { // surrogate
161                        return char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())])
162                            .next().transpose()
163                            .map_err(|e| CharReadError::Io(alloc::format!("Invalid data: {e:?}")));
164                    }
165                },
166                Encoding::Utf16Le => {
167                    buf[pos] = next;
168                    pos += 1;
169                    if pos == 2 {
170                        if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() {
171                            return Ok(Some(c));
172                        }
173                    } else if pos == 4 { // surrogate
174                        return char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())])
175                            .next().transpose()
176                            .map_err(|e| CharReadError::Io(alloc::format!("Invalid data: {e:?}")));
177                    }
178                },
179            }
180        }
181        Err(CharReadError::Io(alloc::string::String::from("InvalidData")))
182    }
183
184    #[cold]
185    fn sniff_bom(&mut self, buf: &[u8], pos: &mut usize) -> Option<Result<Option<char>, CharReadError>> {
186        // sniff BOM
187        if buf.len() <= 3 && [0xEF, 0xBB, 0xBF].starts_with(buf) {
188            if buf.len() == 3 && self.encoding != Encoding::Utf16 {
189                *pos = 0;
190                self.encoding = Encoding::Utf8;
191            }
192        } else if buf.len() <= 2 && [0xFE, 0xFF].starts_with(buf) {
193            if buf.len() == 2 {
194                *pos = 0;
195                self.encoding = Encoding::Utf16Be;
196            }
197        } else if buf.len() <= 2 && [0xFF, 0xFE].starts_with(buf) {
198            if buf.len() == 2 {
199                *pos = 0;
200                self.encoding = Encoding::Utf16Le;
201            }
202        } else if buf.len() == 1 && self.encoding == Encoding::Utf16 {
203            // sniff ASCII char in UTF-16
204            self.encoding = if buf[0] == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le };
205        } else {
206            // UTF-8 is the default, but XML decl can change it to other 8-bit encoding
207            self.encoding = Encoding::Default;
208            if buf.len() == 1 && buf[0].is_ascii() {
209                return Some(Ok(Some(buf[0].into())));
210            }
211        }
212        None
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::{CharReadError, CharReader, Encoding};
219
220    #[test]
221    fn test_next_char_from() {
222        // use std::io;
223
224        // let mut bytes: &[u8] = "correct".as_bytes();    // correct ASCII
225        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c'));
226
227        // let mut bytes: &[u8] = b"\xEF\xBB\xBF\xE2\x80\xA2!";  // BOM
228        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•'));
229
230        // let mut bytes: &[u8] = b"\xEF\xBB\xBFx123";  // BOM
231        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x'));
232
233        // let mut bytes: &[u8] = b"\xEF\xBB\xBF";  // Nothing after BOM
234        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
235
236        // let mut bytes: &[u8] = b"\xEF\xBB";  // Nothing after BO
237        // assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
238
239        // let mut bytes: &[u8] = b"\xEF\xBB\x42";  // Nothing after BO
240        // assert!(CharReader::new().next_char_from(&mut bytes).is_err());
241
242        // let mut bytes: &[u8] = b"\xFE\xFF\x00\x42";  // UTF-16
243        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
244
245        // let mut bytes: &[u8] = b"\xFF\xFE\x42\x00";  // UTF-16
246        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
247
248        // let mut bytes: &[u8] = b"\xFF\xFE";  // UTF-16
249        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
250
251        // let mut bytes: &[u8] = b"\xFF\xFE\x00";  // UTF-16
252        // assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
253
254        // let mut bytes: &[u8] = "правильно".as_bytes();  // correct BMP
255        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п'));
256
257        // let mut bytes: &[u8] = "правильно".as_bytes();
258        // assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿'));
259
260        // let mut bytes: &[u8] = "правильно".as_bytes();
261        // assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐'));
262
263        // let mut bytes: &[u8] = b"\xD8\xD8\x80";
264        // assert!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).is_err());
265
266        // let mut bytes: &[u8] = b"\x00\x42";
267        // assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
268
269        // let mut bytes: &[u8] = b"\x42\x00";
270        // assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
271
272        // let mut bytes: &[u8] = b"\x00";
273        // assert!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).is_err());
274
275        // let mut bytes: &[u8] = "😊".as_bytes();          // correct non-BMP
276        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('😊'));
277
278        // let mut bytes: &[u8] = b"";                     // empty
279        // assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
280
281        // let mut bytes: &[u8] = b"\xf0\x9f\x98";         // incomplete code point
282        // match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
283        //     super::CharReadError::UnexpectedEof => {},
284        //     e => panic!("Unexpected result: {e:?}")
285        // };
286
287        // let mut bytes: &[u8] = b"\xff\x9f\x98\x32";     // invalid code point
288        // match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
289        //     super::CharReadError::Utf8(_) => {},
290        //     e => panic!("Unexpected result: {e:?}")
291        // };
292
293        // // error during read
294        // struct ErrorReader;
295        // impl io::Read for ErrorReader {
296        //     fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
297        //         Err(io::Error::new(io::ErrorKind::Other, "test error"))
298        //     }
299        // }
300
301        // let mut r = ErrorReader;
302        // match CharReader::new().next_char_from(&mut r).unwrap_err() {
303        //     super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other &&
304        //                                        e.to_string().contains("test error") => {},
305        //     e => panic!("Unexpected result: {e:?}")
306        // }
307    }
308}