rustpython_vm/
codecs.rs

1use crate::{
2    builtins::{PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef},
3    common::{ascii, lock::PyRwLock},
4    convert::ToPyObject,
5    function::PyMethodDef,
6    AsObject, Context, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine,
7};
8use std::{borrow::Cow, collections::HashMap, fmt::Write, ops::Range};
9
10pub struct CodecsRegistry {
11    inner: PyRwLock<RegistryInner>,
12}
13
14struct RegistryInner {
15    search_path: Vec<PyObjectRef>,
16    search_cache: HashMap<String, PyCodec>,
17    errors: HashMap<String, PyObjectRef>,
18}
19
20pub const DEFAULT_ENCODING: &str = "utf-8";
21
22#[derive(Clone)]
23#[repr(transparent)]
24pub struct PyCodec(PyTupleRef);
25impl PyCodec {
26    #[inline]
27    pub fn from_tuple(tuple: PyTupleRef) -> Result<Self, PyTupleRef> {
28        if tuple.len() == 4 {
29            Ok(PyCodec(tuple))
30        } else {
31            Err(tuple)
32        }
33    }
34    #[inline]
35    pub fn into_tuple(self) -> PyTupleRef {
36        self.0
37    }
38    #[inline]
39    pub fn as_tuple(&self) -> &PyTupleRef {
40        &self.0
41    }
42
43    #[inline]
44    pub fn get_encode_func(&self) -> &PyObject {
45        &self.0[0]
46    }
47    #[inline]
48    pub fn get_decode_func(&self) -> &PyObject {
49        &self.0[1]
50    }
51
52    pub fn is_text_codec(&self, vm: &VirtualMachine) -> PyResult<bool> {
53        let is_text = vm.get_attribute_opt(self.0.clone().into(), "_is_text_encoding")?;
54        is_text.map_or(Ok(true), |is_text| is_text.try_to_bool(vm))
55    }
56
57    pub fn encode(
58        &self,
59        obj: PyObjectRef,
60        errors: Option<PyStrRef>,
61        vm: &VirtualMachine,
62    ) -> PyResult {
63        let args = match errors {
64            Some(errors) => vec![obj, errors.into()],
65            None => vec![obj],
66        };
67        let res = self.get_encode_func().call(args, vm)?;
68        let res = res
69            .downcast::<PyTuple>()
70            .ok()
71            .filter(|tuple| tuple.len() == 2)
72            .ok_or_else(|| {
73                vm.new_type_error("encoder must return a tuple (object, integer)".to_owned())
74            })?;
75        // we don't actually care about the integer
76        Ok(res[0].clone())
77    }
78
79    pub fn decode(
80        &self,
81        obj: PyObjectRef,
82        errors: Option<PyStrRef>,
83        vm: &VirtualMachine,
84    ) -> PyResult {
85        let args = match errors {
86            Some(errors) => vec![obj, errors.into()],
87            None => vec![obj],
88        };
89        let res = self.get_decode_func().call(args, vm)?;
90        let res = res
91            .downcast::<PyTuple>()
92            .ok()
93            .filter(|tuple| tuple.len() == 2)
94            .ok_or_else(|| {
95                vm.new_type_error("decoder must return a tuple (object,integer)".to_owned())
96            })?;
97        // we don't actually care about the integer
98        Ok(res[0].clone())
99    }
100
101    pub fn get_incremental_encoder(
102        &self,
103        errors: Option<PyStrRef>,
104        vm: &VirtualMachine,
105    ) -> PyResult {
106        let args = match errors {
107            Some(e) => vec![e.into()],
108            None => vec![],
109        };
110        vm.call_method(self.0.as_object(), "incrementalencoder", args)
111    }
112
113    pub fn get_incremental_decoder(
114        &self,
115        errors: Option<PyStrRef>,
116        vm: &VirtualMachine,
117    ) -> PyResult {
118        let args = match errors {
119            Some(e) => vec![e.into()],
120            None => vec![],
121        };
122        vm.call_method(self.0.as_object(), "incrementaldecoder", args)
123    }
124}
125
126impl TryFromObject for PyCodec {
127    fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
128        obj.downcast::<PyTuple>()
129            .ok()
130            .and_then(|tuple| PyCodec::from_tuple(tuple).ok())
131            .ok_or_else(|| {
132                vm.new_type_error("codec search functions must return 4-tuples".to_owned())
133            })
134    }
135}
136
137impl ToPyObject for PyCodec {
138    #[inline]
139    fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
140        self.0.into()
141    }
142}
143
144impl CodecsRegistry {
145    pub(crate) fn new(ctx: &Context) -> Self {
146        ::rustpython_vm::common::static_cell! {
147            static METHODS: Box<[PyMethodDef]>;
148        }
149
150        let methods = METHODS.get_or_init(|| {
151            crate::define_methods![
152                "strict_errors" => strict_errors as EMPTY,
153                "ignore_errors" => ignore_errors as EMPTY,
154                "replace_errors" => replace_errors as EMPTY,
155                "xmlcharrefreplace_errors" => xmlcharrefreplace_errors as EMPTY,
156                "backslashreplace_errors" => backslashreplace_errors as EMPTY,
157                "namereplace_errors" => namereplace_errors as EMPTY,
158                "surrogatepass_errors" => surrogatepass_errors as EMPTY,
159                "surrogateescape_errors" => surrogateescape_errors as EMPTY
160            ]
161            .into_boxed_slice()
162        });
163
164        let errors = [
165            ("strict", methods[0].build_function(ctx)),
166            ("ignore", methods[1].build_function(ctx)),
167            ("replace", methods[2].build_function(ctx)),
168            ("xmlcharrefreplace", methods[3].build_function(ctx)),
169            ("backslashreplace", methods[4].build_function(ctx)),
170            ("namereplace", methods[5].build_function(ctx)),
171            ("surrogatepass", methods[6].build_function(ctx)),
172            ("surrogateescape", methods[7].build_function(ctx)),
173        ];
174        let errors = errors
175            .into_iter()
176            .map(|(name, f)| (name.to_owned(), f.into()))
177            .collect();
178        let inner = RegistryInner {
179            search_path: Vec::new(),
180            search_cache: HashMap::new(),
181            errors,
182        };
183        CodecsRegistry {
184            inner: PyRwLock::new(inner),
185        }
186    }
187
188    pub fn register(&self, search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
189        if !search_function.is_callable() {
190            return Err(vm.new_type_error("argument must be callable".to_owned()));
191        }
192        self.inner.write().search_path.push(search_function);
193        Ok(())
194    }
195
196    pub fn unregister(&self, search_function: PyObjectRef) -> PyResult<()> {
197        let mut inner = self.inner.write();
198        // Do nothing if search_path is not created yet or was cleared.
199        if inner.search_path.is_empty() {
200            return Ok(());
201        }
202        for (i, item) in inner.search_path.iter().enumerate() {
203            if item.get_id() == search_function.get_id() {
204                if !inner.search_cache.is_empty() {
205                    inner.search_cache.clear();
206                }
207                inner.search_path.remove(i);
208                return Ok(());
209            }
210        }
211        Ok(())
212    }
213
214    pub(crate) fn register_manual(&self, name: &str, codec: PyCodec) -> PyResult<()> {
215        self.inner
216            .write()
217            .search_cache
218            .insert(name.to_owned(), codec);
219        Ok(())
220    }
221
222    pub fn lookup(&self, encoding: &str, vm: &VirtualMachine) -> PyResult<PyCodec> {
223        let encoding = normalize_encoding_name(encoding);
224        let search_path = {
225            let inner = self.inner.read();
226            if let Some(codec) = inner.search_cache.get(encoding.as_ref()) {
227                // hit cache
228                return Ok(codec.clone());
229            }
230            inner.search_path.clone()
231        };
232        let encoding = PyStr::from(encoding.into_owned()).into_ref(&vm.ctx);
233        for func in search_path {
234            let res = func.call((encoding.clone(),), vm)?;
235            let res: Option<PyCodec> = res.try_into_value(vm)?;
236            if let Some(codec) = res {
237                let mut inner = self.inner.write();
238                // someone might have raced us to this, so use theirs
239                let codec = inner
240                    .search_cache
241                    .entry(encoding.as_str().to_owned())
242                    .or_insert(codec);
243                return Ok(codec.clone());
244            }
245        }
246        Err(vm.new_lookup_error(format!("unknown encoding: {encoding}")))
247    }
248
249    fn _lookup_text_encoding(
250        &self,
251        encoding: &str,
252        generic_func: &str,
253        vm: &VirtualMachine,
254    ) -> PyResult<PyCodec> {
255        let codec = self.lookup(encoding, vm)?;
256        if codec.is_text_codec(vm)? {
257            Ok(codec)
258        } else {
259            Err(vm.new_lookup_error(format!(
260                "'{encoding}' is not a text encoding; use {generic_func} to handle arbitrary codecs"
261            )))
262        }
263    }
264
265    pub fn forget(&self, encoding: &str) -> Option<PyCodec> {
266        let encoding = normalize_encoding_name(encoding);
267        self.inner.write().search_cache.remove(encoding.as_ref())
268    }
269
270    pub fn encode(
271        &self,
272        obj: PyObjectRef,
273        encoding: &str,
274        errors: Option<PyStrRef>,
275        vm: &VirtualMachine,
276    ) -> PyResult {
277        let codec = self.lookup(encoding, vm)?;
278        codec.encode(obj, errors, vm)
279    }
280
281    pub fn decode(
282        &self,
283        obj: PyObjectRef,
284        encoding: &str,
285        errors: Option<PyStrRef>,
286        vm: &VirtualMachine,
287    ) -> PyResult {
288        let codec = self.lookup(encoding, vm)?;
289        codec.decode(obj, errors, vm)
290    }
291
292    pub fn encode_text(
293        &self,
294        obj: PyStrRef,
295        encoding: &str,
296        errors: Option<PyStrRef>,
297        vm: &VirtualMachine,
298    ) -> PyResult<PyBytesRef> {
299        let codec = self._lookup_text_encoding(encoding, "codecs.encode()", vm)?;
300        codec
301            .encode(obj.into(), errors, vm)?
302            .downcast()
303            .map_err(|obj| {
304                vm.new_type_error(format!(
305                    "'{}' encoder returned '{}' instead of 'bytes'; use codecs.encode() to \
306                     encode arbitrary types",
307                    encoding,
308                    obj.class().name(),
309                ))
310            })
311    }
312
313    pub fn decode_text(
314        &self,
315        obj: PyObjectRef,
316        encoding: &str,
317        errors: Option<PyStrRef>,
318        vm: &VirtualMachine,
319    ) -> PyResult<PyStrRef> {
320        let codec = self._lookup_text_encoding(encoding, "codecs.decode()", vm)?;
321        codec.decode(obj, errors, vm)?.downcast().map_err(|obj| {
322            vm.new_type_error(format!(
323                "'{}' decoder returned '{}' instead of 'str'; use codecs.decode() \
324                 to encode arbitrary types",
325                encoding,
326                obj.class().name(),
327            ))
328        })
329    }
330
331    pub fn register_error(&self, name: String, handler: PyObjectRef) -> Option<PyObjectRef> {
332        self.inner.write().errors.insert(name, handler)
333    }
334
335    pub fn lookup_error_opt(&self, name: &str) -> Option<PyObjectRef> {
336        self.inner.read().errors.get(name).cloned()
337    }
338
339    pub fn lookup_error(&self, name: &str, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
340        self.lookup_error_opt(name)
341            .ok_or_else(|| vm.new_lookup_error(format!("unknown error handler name '{name}'")))
342    }
343}
344
345fn normalize_encoding_name(encoding: &str) -> Cow<'_, str> {
346    if let Some(i) = encoding.find(|c: char| c == ' ' || c.is_ascii_uppercase()) {
347        let mut out = encoding.as_bytes().to_owned();
348        for byte in &mut out[i..] {
349            if *byte == b' ' {
350                *byte = b'-';
351            } else {
352                byte.make_ascii_lowercase();
353            }
354        }
355        String::from_utf8(out).unwrap().into()
356    } else {
357        encoding.into()
358    }
359}
360
361// TODO: exceptions with custom payloads
362fn extract_unicode_error_range(err: &PyObject, vm: &VirtualMachine) -> PyResult<Range<usize>> {
363    let start = err.get_attr("start", vm)?;
364    let start = start.try_into_value(vm)?;
365    let end = err.get_attr("end", vm)?;
366    let end = end.try_into_value(vm)?;
367    Ok(Range { start, end })
368}
369
370#[inline]
371fn is_decode_err(err: &PyObject, vm: &VirtualMachine) -> bool {
372    err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error)
373}
374#[inline]
375fn is_encode_ish_err(err: &PyObject, vm: &VirtualMachine) -> bool {
376    err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error)
377        || err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error)
378}
379
380fn bad_err_type(err: PyObjectRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
381    vm.new_type_error(format!(
382        "don't know how to handle {} in error callback",
383        err.class().name()
384    ))
385}
386
387fn strict_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult {
388    let err = err
389        .downcast()
390        .unwrap_or_else(|_| vm.new_type_error("codec must pass exception instance".to_owned()));
391    Err(err)
392}
393
394fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
395    if is_encode_ish_err(&err, vm) || is_decode_err(&err, vm) {
396        let range = extract_unicode_error_range(&err, vm)?;
397        Ok((vm.ctx.new_str(ascii!("")).into(), range.end))
398    } else {
399        Err(bad_err_type(err, vm))
400    }
401}
402
403fn replace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> {
404    // char::REPLACEMENT_CHARACTER as a str
405    let replacement_char = "\u{FFFD}";
406    let replace = if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) {
407        "?"
408    } else if err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error) {
409        let range = extract_unicode_error_range(&err, vm)?;
410        return Ok((replacement_char.to_owned(), range.end));
411    } else if err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error) {
412        replacement_char
413    } else {
414        return Err(bad_err_type(err, vm));
415    };
416    let range = extract_unicode_error_range(&err, vm)?;
417    let replace = replace.repeat(range.end - range.start);
418    Ok((replace, range.end))
419}
420
421fn xmlcharrefreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> {
422    if !is_encode_ish_err(&err, vm) {
423        return Err(bad_err_type(err, vm));
424    }
425    let range = extract_unicode_error_range(&err, vm)?;
426    let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
427    let s_after_start = crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or("");
428    let num_chars = range.len();
429    // capacity rough guess; assuming that the codepoints are 3 digits in decimal + the &#;
430    let mut out = String::with_capacity(num_chars * 6);
431    for c in s_after_start.chars().take(num_chars) {
432        write!(out, "&#{};", c as u32).unwrap()
433    }
434    Ok((out, range.end))
435}
436
437fn backslashreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> {
438    if is_decode_err(&err, vm) {
439        let range = extract_unicode_error_range(&err, vm)?;
440        let b = PyBytesRef::try_from_object(vm, err.get_attr("object", vm)?)?;
441        let mut replace = String::with_capacity(4 * range.len());
442        for &c in &b[range.clone()] {
443            write!(replace, "\\x{c:02x}").unwrap();
444        }
445        return Ok((replace, range.end));
446    } else if !is_encode_ish_err(&err, vm) {
447        return Err(bad_err_type(err, vm));
448    }
449    let range = extract_unicode_error_range(&err, vm)?;
450    let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
451    let s_after_start = crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or("");
452    let num_chars = range.len();
453    // minimum 4 output bytes per char: \xNN
454    let mut out = String::with_capacity(num_chars * 4);
455    for c in s_after_start.chars().take(num_chars) {
456        let c = c as u32;
457        if c >= 0x10000 {
458            write!(out, "\\U{c:08x}").unwrap();
459        } else if c >= 0x100 {
460            write!(out, "\\u{c:04x}").unwrap();
461        } else {
462            write!(out, "\\x{c:02x}").unwrap();
463        }
464    }
465    Ok((out, range.end))
466}
467
468fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> {
469    if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) {
470        let range = extract_unicode_error_range(&err, vm)?;
471        let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
472        let s_after_start =
473            crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or("");
474        let num_chars = range.len();
475        let mut out = String::with_capacity(num_chars * 4);
476        for c in s_after_start.chars().take(num_chars) {
477            let c_u32 = c as u32;
478            if let Some(c_name) = unicode_names2::name(c) {
479                write!(out, "\\N{{{c_name}}}").unwrap();
480            } else if c_u32 >= 0x10000 {
481                write!(out, "\\U{c_u32:08x}").unwrap();
482            } else if c_u32 >= 0x100 {
483                write!(out, "\\u{c_u32:04x}").unwrap();
484            } else {
485                write!(out, "\\x{c_u32:02x}").unwrap();
486            }
487        }
488        Ok((out, range.end))
489    } else {
490        Err(bad_err_type(err, vm))
491    }
492}
493
494#[derive(Eq, PartialEq)]
495enum StandardEncoding {
496    Utf8,
497    Utf16Be,
498    Utf16Le,
499    Utf32Be,
500    Utf32Le,
501    Unknown,
502}
503
504fn get_standard_encoding(encoding: &str) -> (usize, StandardEncoding) {
505    if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") {
506        let mut byte_length: usize = 0;
507        let mut standard_encoding = StandardEncoding::Unknown;
508        let encoding = encoding
509            .strip_prefix(|c| ['-', '_'].contains(&c))
510            .unwrap_or(encoding);
511        if encoding == "8" {
512            byte_length = 3;
513            standard_encoding = StandardEncoding::Utf8;
514        } else if let Some(encoding) = encoding.strip_prefix("16") {
515            byte_length = 2;
516            if encoding.is_empty() {
517                if cfg!(target_endian = "little") {
518                    standard_encoding = StandardEncoding::Utf16Le;
519                } else if cfg!(target_endian = "big") {
520                    standard_encoding = StandardEncoding::Utf16Be;
521                }
522                if standard_encoding != StandardEncoding::Unknown {
523                    return (byte_length, standard_encoding);
524                }
525            }
526            let encoding = encoding
527                .strip_prefix(|c| ['-', '_'].contains(&c))
528                .unwrap_or(encoding);
529            standard_encoding = match encoding {
530                "be" => StandardEncoding::Utf16Be,
531                "le" => StandardEncoding::Utf16Le,
532                _ => StandardEncoding::Unknown,
533            }
534        } else if let Some(encoding) = encoding.strip_prefix("32") {
535            byte_length = 4;
536            if encoding.is_empty() {
537                if cfg!(target_endian = "little") {
538                    standard_encoding = StandardEncoding::Utf32Le;
539                } else if cfg!(target_endian = "big") {
540                    standard_encoding = StandardEncoding::Utf32Be;
541                }
542                if standard_encoding != StandardEncoding::Unknown {
543                    return (byte_length, standard_encoding);
544                }
545            }
546            let encoding = encoding
547                .strip_prefix(|c| ['-', '_'].contains(&c))
548                .unwrap_or(encoding);
549            standard_encoding = match encoding {
550                "be" => StandardEncoding::Utf32Be,
551                "le" => StandardEncoding::Utf32Le,
552                _ => StandardEncoding::Unknown,
553            }
554        }
555        return (byte_length, standard_encoding);
556    } else if encoding == "CP_UTF8" {
557        return (3, StandardEncoding::Utf8);
558    }
559    (0, StandardEncoding::Unknown)
560}
561
562fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
563    if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) {
564        let range = extract_unicode_error_range(&err, vm)?;
565        let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
566        let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
567        let (_, standard_encoding) = get_standard_encoding(s_encoding.as_str());
568        if let StandardEncoding::Unknown = standard_encoding {
569            // Not supported, fail with original exception
570            return Err(err.downcast().unwrap());
571        }
572        let s_after_start =
573            crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or("");
574        let num_chars = range.len();
575        let mut out: Vec<u8> = Vec::with_capacity(num_chars * 4);
576        for c in s_after_start.chars().take(num_chars).map(|x| x as u32) {
577            if !(0xd800..=0xdfff).contains(&c) {
578                // Not a surrogate, fail with original exception
579                return Err(err.downcast().unwrap());
580            }
581            match standard_encoding {
582                StandardEncoding::Utf8 => {
583                    out.push((0xe0 | (c >> 12)) as u8);
584                    out.push((0x80 | ((c >> 6) & 0x3f)) as u8);
585                    out.push((0x80 | (c & 0x3f)) as u8);
586                }
587                StandardEncoding::Utf16Le => {
588                    out.push(c as u8);
589                    out.push((c >> 8) as u8);
590                }
591                StandardEncoding::Utf16Be => {
592                    out.push((c >> 8) as u8);
593                    out.push(c as u8);
594                }
595                StandardEncoding::Utf32Le => {
596                    out.push(c as u8);
597                    out.push((c >> 8) as u8);
598                    out.push((c >> 16) as u8);
599                    out.push((c >> 24) as u8);
600                }
601                StandardEncoding::Utf32Be => {
602                    out.push((c >> 24) as u8);
603                    out.push((c >> 16) as u8);
604                    out.push((c >> 8) as u8);
605                    out.push(c as u8);
606                }
607                StandardEncoding::Unknown => {
608                    unreachable!("NOTE: RUSTPYTHON, should've bailed out earlier")
609                }
610            }
611        }
612        Ok((vm.ctx.new_bytes(out).into(), range.end))
613    } else if is_decode_err(&err, vm) {
614        let range = extract_unicode_error_range(&err, vm)?;
615        let s = PyBytesRef::try_from_object(vm, err.get_attr("object", vm)?)?;
616        let s_encoding = PyStrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
617        let (byte_length, standard_encoding) = get_standard_encoding(s_encoding.as_str());
618        if let StandardEncoding::Unknown = standard_encoding {
619            // Not supported, fail with original exception
620            return Err(err.downcast().unwrap());
621        }
622        let mut c: u32 = 0;
623        // Try decoding a single surrogate character. If there are more,
624        // let the codec call us again.
625        let p = &s.as_bytes()[range.start..];
626        if p.len() - range.start >= byte_length {
627            match standard_encoding {
628                StandardEncoding::Utf8 => {
629                    if (p[0] as u32 & 0xf0) == 0xe0
630                        && (p[1] as u32 & 0xc0) == 0x80
631                        && (p[2] as u32 & 0xc0) == 0x80
632                    {
633                        // it's a three-byte code
634                        c = ((p[0] as u32 & 0x0f) << 12)
635                            + ((p[1] as u32 & 0x3f) << 6)
636                            + (p[2] as u32 & 0x3f);
637                    }
638                }
639                StandardEncoding::Utf16Le => {
640                    c = (p[1] as u32) << 8 | p[0] as u32;
641                }
642                StandardEncoding::Utf16Be => {
643                    c = (p[0] as u32) << 8 | p[1] as u32;
644                }
645                StandardEncoding::Utf32Le => {
646                    c = ((p[3] as u32) << 24)
647                        | ((p[2] as u32) << 16)
648                        | ((p[1] as u32) << 8)
649                        | p[0] as u32;
650                }
651                StandardEncoding::Utf32Be => {
652                    c = ((p[0] as u32) << 24)
653                        | ((p[1] as u32) << 16)
654                        | ((p[2] as u32) << 8)
655                        | p[3] as u32;
656                }
657                StandardEncoding::Unknown => {
658                    unreachable!("NOTE: RUSTPYTHON, should've bailed out earlier")
659                }
660            }
661        }
662        // !Py_UNICODE_IS_SURROGATE
663        if !(0xd800..=0xdfff).contains(&c) {
664            // Not a surrogate, fail with original exception
665            return Err(err.downcast().unwrap());
666        }
667
668        Ok((
669            vm.new_pyobj(format!("\\x{c:x?}")),
670            range.start + byte_length,
671        ))
672    } else {
673        Err(bad_err_type(err, vm))
674    }
675}
676
677fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
678    if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) {
679        let range = extract_unicode_error_range(&err, vm)?;
680        let object = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
681        let s_after_start =
682            crate::common::str::try_get_chars(object.as_str(), range.start..).unwrap_or("");
683        let mut out: Vec<u8> = Vec::with_capacity(range.len());
684        for ch in s_after_start.chars().take(range.len()) {
685            let ch = ch as u32;
686            if !(0xdc80..=0xdcff).contains(&ch) {
687                // Not a UTF-8b surrogate, fail with original exception
688                return Err(err.downcast().unwrap());
689            }
690            out.push((ch - 0xdc00) as u8);
691        }
692        let out = vm.ctx.new_bytes(out);
693        Ok((out.into(), range.end))
694    } else if is_decode_err(&err, vm) {
695        let range = extract_unicode_error_range(&err, vm)?;
696        let object = err.get_attr("object", vm)?;
697        let object = PyBytesRef::try_from_object(vm, object)?;
698        let p = &object.as_bytes()[range.clone()];
699        let mut consumed = 0;
700        let mut replace = String::with_capacity(4 * range.len());
701        while consumed < 4 && consumed < range.len() {
702            let c = p[consumed] as u32;
703            // Refuse to escape ASCII bytes
704            if c < 128 {
705                break;
706            }
707            write!(replace, "#{}", 0xdc00 + c).unwrap();
708            consumed += 1;
709        }
710        if consumed == 0 {
711            return Err(err.downcast().unwrap());
712        }
713        Ok((vm.new_pyobj(replace), range.start + consumed))
714    } else {
715        Err(bad_err_type(err, vm))
716    }
717}