rustpython_common/
encodings.rs

1use std::ops::Range;
2
3pub type EncodeErrorResult<S, B, E> = Result<(EncodeReplace<S, B>, usize), E>;
4
5pub type DecodeErrorResult<S, B, E> = Result<(S, Option<B>, usize), E>;
6
7pub trait StrBuffer: AsRef<str> {
8    fn is_ascii(&self) -> bool {
9        self.as_ref().is_ascii()
10    }
11}
12
13pub trait ErrorHandler {
14    type Error;
15    type StrBuf: StrBuffer;
16    type BytesBuf: AsRef<[u8]>;
17    fn handle_encode_error(
18        &self,
19        data: &str,
20        char_range: Range<usize>,
21        reason: &str,
22    ) -> EncodeErrorResult<Self::StrBuf, Self::BytesBuf, Self::Error>;
23    fn handle_decode_error(
24        &self,
25        data: &[u8],
26        byte_range: Range<usize>,
27        reason: &str,
28    ) -> DecodeErrorResult<Self::StrBuf, Self::BytesBuf, Self::Error>;
29    fn error_oob_restart(&self, i: usize) -> Self::Error;
30    fn error_encoding(&self, data: &str, char_range: Range<usize>, reason: &str) -> Self::Error;
31}
32pub enum EncodeReplace<S, B> {
33    Str(S),
34    Bytes(B),
35}
36
37struct DecodeError<'a> {
38    valid_prefix: &'a str,
39    rest: &'a [u8],
40    err_len: Option<usize>,
41}
42/// # Safety
43/// `v[..valid_up_to]` must be valid utf8
44unsafe fn make_decode_err(v: &[u8], valid_up_to: usize, err_len: Option<usize>) -> DecodeError<'_> {
45    let valid_prefix = core::str::from_utf8_unchecked(v.get_unchecked(..valid_up_to));
46    let rest = v.get_unchecked(valid_up_to..);
47    DecodeError {
48        valid_prefix,
49        rest,
50        err_len,
51    }
52}
53
54enum HandleResult<'a> {
55    Done,
56    Error {
57        err_len: Option<usize>,
58        reason: &'a str,
59    },
60}
61fn decode_utf8_compatible<E: ErrorHandler, DecodeF, ErrF>(
62    data: &[u8],
63    errors: &E,
64    decode: DecodeF,
65    handle_error: ErrF,
66) -> Result<(String, usize), E::Error>
67where
68    DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>,
69    ErrF: Fn(&[u8], Option<usize>) -> HandleResult<'_>,
70{
71    if data.is_empty() {
72        return Ok((String::new(), 0));
73    }
74    // we need to coerce the lifetime to that of the function body rather than the
75    // anonymous input lifetime, so that we can assign it data borrowed from data_from_err
76    let mut data = data;
77    let mut data_from_err: E::BytesBuf;
78    let mut out = String::with_capacity(data.len());
79    let mut remaining_index = 0;
80    let mut remaining_data = data;
81    loop {
82        match decode(remaining_data) {
83            Ok(decoded) => {
84                out.push_str(decoded);
85                remaining_index += decoded.len();
86                break;
87            }
88            Err(e) => {
89                out.push_str(e.valid_prefix);
90                match handle_error(e.rest, e.err_len) {
91                    HandleResult::Done => {
92                        remaining_index += e.valid_prefix.len();
93                        break;
94                    }
95                    HandleResult::Error { err_len, reason } => {
96                        let err_idx = remaining_index + e.valid_prefix.len();
97                        let err_range =
98                            err_idx..err_len.map_or_else(|| data.len(), |len| err_idx + len);
99                        let (replace, new_data, restart) =
100                            errors.handle_decode_error(data, err_range, reason)?;
101                        out.push_str(replace.as_ref());
102                        if let Some(new_data) = new_data {
103                            data_from_err = new_data;
104                            data = data_from_err.as_ref();
105                        }
106                        remaining_data = data
107                            .get(restart..)
108                            .ok_or_else(|| errors.error_oob_restart(restart))?;
109                        remaining_index = restart;
110                        continue;
111                    }
112                }
113            }
114        }
115    }
116    Ok((out, remaining_index))
117}
118
119pub mod utf8 {
120    use super::*;
121
122    pub const ENCODING_NAME: &str = "utf-8";
123
124    #[inline]
125    pub fn encode<E: ErrorHandler>(s: &str, _errors: &E) -> Result<Vec<u8>, E::Error> {
126        Ok(s.as_bytes().to_vec())
127    }
128
129    pub fn decode<E: ErrorHandler>(
130        data: &[u8],
131        errors: &E,
132        final_decode: bool,
133    ) -> Result<(String, usize), E::Error> {
134        decode_utf8_compatible(
135            data,
136            errors,
137            |v| {
138                core::str::from_utf8(v).map_err(|e| {
139                    // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()]
140                    //         is valid utf8
141                    unsafe { make_decode_err(v, e.valid_up_to(), e.error_len()) }
142                })
143            },
144            |rest, err_len| {
145                let first_err = rest[0];
146                if matches!(first_err, 0x80..=0xc1 | 0xf5..=0xff) {
147                    HandleResult::Error {
148                        err_len: Some(1),
149                        reason: "invalid start byte",
150                    }
151                } else if err_len.is_none() {
152                    // error_len() == None means unexpected eof
153                    if final_decode {
154                        HandleResult::Error {
155                            err_len,
156                            reason: "unexpected end of data",
157                        }
158                    } else {
159                        HandleResult::Done
160                    }
161                } else if !final_decode && matches!(rest, [0xed, 0xa0..=0xbf]) {
162                    // truncated surrogate
163                    HandleResult::Done
164                } else {
165                    HandleResult::Error {
166                        err_len,
167                        reason: "invalid continuation byte",
168                    }
169                }
170            },
171        )
172    }
173}
174
175pub mod latin_1 {
176    use super::*;
177
178    pub const ENCODING_NAME: &str = "latin-1";
179
180    const ERR_REASON: &str = "ordinal not in range(256)";
181
182    #[inline]
183    pub fn encode<E: ErrorHandler>(s: &str, errors: &E) -> Result<Vec<u8>, E::Error> {
184        let full_data = s;
185        let mut data = s;
186        let mut char_data_index = 0;
187        let mut out = Vec::<u8>::new();
188        loop {
189            match data
190                .char_indices()
191                .enumerate()
192                .find(|(_, (_, c))| !c.is_ascii())
193            {
194                None => {
195                    out.extend_from_slice(data.as_bytes());
196                    break;
197                }
198                Some((char_i, (byte_i, ch))) => {
199                    out.extend_from_slice(&data.as_bytes()[..byte_i]);
200                    let char_start = char_data_index + char_i;
201                    if (ch as u32) <= 255 {
202                        out.push(ch as u8);
203                        let char_restart = char_start + 1;
204                        data = crate::str::try_get_chars(full_data, char_restart..)
205                            .ok_or_else(|| errors.error_oob_restart(char_restart))?;
206                        char_data_index = char_restart;
207                    } else {
208                        // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char
209                        let non_latin_1_run_length = data[byte_i..]
210                            .chars()
211                            .take_while(|c| (*c as u32) > 255)
212                            .count();
213                        let char_range = char_start..char_start + non_latin_1_run_length;
214                        let (replace, char_restart) = errors.handle_encode_error(
215                            full_data,
216                            char_range.clone(),
217                            ERR_REASON,
218                        )?;
219                        match replace {
220                            EncodeReplace::Str(s) => {
221                                if s.as_ref().chars().any(|c| (c as u32) > 255) {
222                                    return Err(
223                                        errors.error_encoding(full_data, char_range, ERR_REASON)
224                                    );
225                                }
226                                out.extend_from_slice(s.as_ref().as_bytes());
227                            }
228                            EncodeReplace::Bytes(b) => {
229                                out.extend_from_slice(b.as_ref());
230                            }
231                        }
232                        data = crate::str::try_get_chars(full_data, char_restart..)
233                            .ok_or_else(|| errors.error_oob_restart(char_restart))?;
234                        char_data_index = char_restart;
235                    }
236                    continue;
237                }
238            }
239        }
240        Ok(out)
241    }
242
243    pub fn decode<E: ErrorHandler>(data: &[u8], _errors: &E) -> Result<(String, usize), E::Error> {
244        let out: String = data.iter().map(|c| *c as char).collect();
245        let out_len = out.len();
246        Ok((out, out_len))
247    }
248}
249
250pub mod ascii {
251    use super::*;
252    use ::ascii::AsciiStr;
253
254    pub const ENCODING_NAME: &str = "ascii";
255
256    const ERR_REASON: &str = "ordinal not in range(128)";
257
258    #[inline]
259    pub fn encode<E: ErrorHandler>(s: &str, errors: &E) -> Result<Vec<u8>, E::Error> {
260        let full_data = s;
261        let mut data = s;
262        let mut char_data_index = 0;
263        let mut out = Vec::<u8>::new();
264        loop {
265            match data
266                .char_indices()
267                .enumerate()
268                .find(|(_, (_, c))| !c.is_ascii())
269            {
270                None => {
271                    out.extend_from_slice(data.as_bytes());
272                    break;
273                }
274                Some((char_i, (byte_i, _))) => {
275                    out.extend_from_slice(&data.as_bytes()[..byte_i]);
276                    let char_start = char_data_index + char_i;
277                    // number of non-ascii chars between the first non-ascii char and the next ascii char
278                    let non_ascii_run_length =
279                        data[byte_i..].chars().take_while(|c| !c.is_ascii()).count();
280                    let char_range = char_start..char_start + non_ascii_run_length;
281                    let (replace, char_restart) =
282                        errors.handle_encode_error(full_data, char_range.clone(), ERR_REASON)?;
283                    match replace {
284                        EncodeReplace::Str(s) => {
285                            if !s.is_ascii() {
286                                return Err(
287                                    errors.error_encoding(full_data, char_range, ERR_REASON)
288                                );
289                            }
290                            out.extend_from_slice(s.as_ref().as_bytes());
291                        }
292                        EncodeReplace::Bytes(b) => {
293                            out.extend_from_slice(b.as_ref());
294                        }
295                    }
296                    data = crate::str::try_get_chars(full_data, char_restart..)
297                        .ok_or_else(|| errors.error_oob_restart(char_restart))?;
298                    char_data_index = char_restart;
299                    continue;
300                }
301            }
302        }
303        Ok(out)
304    }
305
306    pub fn decode<E: ErrorHandler>(data: &[u8], errors: &E) -> Result<(String, usize), E::Error> {
307        decode_utf8_compatible(
308            data,
309            errors,
310            |v| {
311                AsciiStr::from_ascii(v).map(|s| s.as_str()).map_err(|e| {
312                    // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()]
313                    //         is valid ascii & therefore valid utf8
314                    unsafe { make_decode_err(v, e.valid_up_to(), Some(1)) }
315                })
316            },
317            |_rest, err_len| HandleResult::Error {
318                err_len,
319                reason: ERR_REASON,
320            },
321        )
322    }
323}