1use rustpython_common::{
2 borrow::BorrowedValue,
3 encodings::{
4 CodecContext, DecodeContext, DecodeErrorHandler, EncodeContext, EncodeErrorHandler,
5 EncodeReplace, StrBuffer, StrSize, errors,
6 },
7 str::StrKind,
8 wtf8::{CodePoint, Wtf8, Wtf8Buf},
9};
10
11use crate::common::lock::OnceCell;
12use crate::{
13 AsObject, Context, Py, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject,
14 VirtualMachine,
15 builtins::{
16 PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyUtf8Str,
17 PyUtf8StrRef,
18 },
19 common::{ascii, lock::PyRwLock},
20 convert::ToPyObject,
21 function::{ArgBytesLike, PyMethodDef},
22};
23use alloc::borrow::Cow;
24use core::ops::{self, Range};
25use std::collections::HashMap;
26
27pub struct CodecsRegistry {
28 inner: PyRwLock<RegistryInner>,
29}
30
31struct RegistryInner {
32 search_path: Vec<PyObjectRef>,
33 search_cache: HashMap<String, PyCodec>,
34 errors: HashMap<String, PyObjectRef>,
35}
36
37pub const DEFAULT_ENCODING: &str = "utf-8";
38
39#[derive(Clone)]
40#[repr(transparent)]
41pub struct PyCodec(PyTupleRef);
42impl PyCodec {
43 #[inline]
44 pub fn from_tuple(tuple: PyTupleRef) -> Result<Self, PyTupleRef> {
45 if tuple.len() == 4 {
46 Ok(Self(tuple))
47 } else {
48 Err(tuple)
49 }
50 }
51 #[inline]
52 pub fn into_tuple(self) -> PyTupleRef {
53 self.0
54 }
55 #[inline]
56 pub fn as_tuple(&self) -> &Py<PyTuple> {
57 &self.0
58 }
59
60 #[inline]
61 pub fn get_encode_func(&self) -> &PyObject {
62 &self.0[0]
63 }
64 #[inline]
65 pub fn get_decode_func(&self) -> &PyObject {
66 &self.0[1]
67 }
68
69 pub fn is_text_codec(&self, vm: &VirtualMachine) -> PyResult<bool> {
70 let is_text = vm.get_attribute_opt(self.0.clone().into(), "_is_text_encoding")?;
71 is_text.map_or(Ok(true), |is_text| is_text.try_to_bool(vm))
72 }
73
74 pub fn encode(
75 &self,
76 obj: PyObjectRef,
77 errors: Option<PyUtf8StrRef>,
78 vm: &VirtualMachine,
79 ) -> PyResult {
80 let args = match errors {
81 Some(errors) => vec![obj, errors.into_wtf8().into()],
82 None => vec![obj],
83 };
84 let res = self.get_encode_func().call(args, vm)?;
85 let res = res
86 .downcast::<PyTuple>()
87 .ok()
88 .filter(|tuple| tuple.len() == 2)
89 .ok_or_else(|| vm.new_type_error("encoder must return a tuple (object, integer)"))?;
90 Ok(res[0].clone())
92 }
93
94 pub fn decode(
95 &self,
96 obj: PyObjectRef,
97 errors: Option<PyUtf8StrRef>,
98 vm: &VirtualMachine,
99 ) -> PyResult {
100 let args = match errors {
101 Some(errors) => vec![obj, errors.into_wtf8().into()],
102 None => vec![obj],
103 };
104 let res = self.get_decode_func().call(args, vm)?;
105 let res = res
106 .downcast::<PyTuple>()
107 .ok()
108 .filter(|tuple| tuple.len() == 2)
109 .ok_or_else(|| vm.new_type_error("decoder must return a tuple (object,integer)"))?;
110 Ok(res[0].clone())
112 }
113
114 pub fn get_incremental_encoder(
115 &self,
116 errors: Option<PyStrRef>,
117 vm: &VirtualMachine,
118 ) -> PyResult {
119 let args = match errors {
120 Some(e) => vec![e.into()],
121 None => vec![],
122 };
123 vm.call_method(self.0.as_object(), "incrementalencoder", args)
124 }
125
126 pub fn get_incremental_decoder(
127 &self,
128 errors: Option<PyStrRef>,
129 vm: &VirtualMachine,
130 ) -> PyResult {
131 let args = match errors {
132 Some(e) => vec![e.into()],
133 None => vec![],
134 };
135 vm.call_method(self.0.as_object(), "incrementaldecoder", args)
136 }
137}
138
139impl TryFromObject for PyCodec {
140 fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
141 obj.downcast::<PyTuple>()
142 .ok()
143 .and_then(|tuple| Self::from_tuple(tuple).ok())
144 .ok_or_else(|| vm.new_type_error("codec search functions must return 4-tuples"))
145 }
146}
147
148impl ToPyObject for PyCodec {
149 #[inline]
150 fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
151 self.0.into()
152 }
153}
154
155impl CodecsRegistry {
156 #[cfg(all(unix, feature = "threading"))]
162 pub(crate) unsafe fn reinit_after_fork(&self) {
163 unsafe { crate::common::lock::reinit_rwlock_after_fork(&self.inner) };
164 }
165
166 pub(crate) fn new(ctx: &Context) -> Self {
167 ::rustpython_vm::common::static_cell! {
168 static METHODS: Box<[PyMethodDef]>;
169 }
170
171 let methods = METHODS.get_or_init(|| {
172 crate::define_methods![
173 "strict_errors" => strict_errors as EMPTY,
174 "ignore_errors" => ignore_errors as EMPTY,
175 "replace_errors" => replace_errors as EMPTY,
176 "xmlcharrefreplace_errors" => xmlcharrefreplace_errors as EMPTY,
177 "backslashreplace_errors" => backslashreplace_errors as EMPTY,
178 "namereplace_errors" => namereplace_errors as EMPTY,
179 "surrogatepass_errors" => surrogatepass_errors as EMPTY,
180 "surrogateescape_errors" => surrogateescape_errors as EMPTY
181 ]
182 .into_boxed_slice()
183 });
184
185 let errors = [
186 ("strict", methods[0].build_function(ctx)),
187 ("ignore", methods[1].build_function(ctx)),
188 ("replace", methods[2].build_function(ctx)),
189 ("xmlcharrefreplace", methods[3].build_function(ctx)),
190 ("backslashreplace", methods[4].build_function(ctx)),
191 ("namereplace", methods[5].build_function(ctx)),
192 ("surrogatepass", methods[6].build_function(ctx)),
193 ("surrogateescape", methods[7].build_function(ctx)),
194 ];
195 let errors = errors
196 .into_iter()
197 .map(|(name, f)| (name.to_owned(), f.into()))
198 .collect();
199 let inner = RegistryInner {
200 search_path: Vec::new(),
201 search_cache: HashMap::new(),
202 errors,
203 };
204 Self {
205 inner: PyRwLock::new(inner),
206 }
207 }
208
209 pub fn register(&self, search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
210 if !search_function.is_callable() {
211 return Err(vm.new_type_error("argument must be callable"));
212 }
213 self.inner.write().search_path.push(search_function);
214 Ok(())
215 }
216
217 pub fn unregister(&self, search_function: PyObjectRef) -> PyResult<()> {
218 let mut inner = self.inner.write();
219 if inner.search_path.is_empty() {
221 return Ok(());
222 }
223 for (i, item) in inner.search_path.iter().enumerate() {
224 if item.get_id() == search_function.get_id() {
225 if !inner.search_cache.is_empty() {
226 inner.search_cache.clear();
227 }
228 inner.search_path.remove(i);
229 return Ok(());
230 }
231 }
232 Ok(())
233 }
234
235 pub(crate) fn register_manual(&self, name: &str, codec: PyCodec) -> PyResult<()> {
236 let name = normalize_encoding_name(name);
237 self.inner
238 .write()
239 .search_cache
240 .insert(name.into_owned(), codec);
241 Ok(())
242 }
243
244 pub fn lookup(&self, encoding: &str, vm: &VirtualMachine) -> PyResult<PyCodec> {
245 let encoding = normalize_encoding_name(encoding);
246 let search_path = {
247 let inner = self.inner.read();
248 if let Some(codec) = inner.search_cache.get(encoding.as_ref()) {
249 return Ok(codec.clone());
251 }
252 inner.search_path.clone()
253 };
254 let encoding: PyUtf8StrRef = vm.ctx.new_utf8_str(encoding.as_ref());
255 for func in search_path {
256 let res = func.call((encoding.clone(),), vm)?;
257 let res: Option<PyCodec> = res.try_into_value(vm)?;
258 if let Some(codec) = res {
259 let mut inner = self.inner.write();
260 let codec = inner
262 .search_cache
263 .entry(encoding.as_str().to_owned())
264 .or_insert(codec);
265 return Ok(codec.clone());
266 }
267 }
268 Err(vm.new_lookup_error(format!("unknown encoding: {encoding}")))
269 }
270
271 fn _lookup_text_encoding(
272 &self,
273 encoding: &str,
274 generic_func: &str,
275 vm: &VirtualMachine,
276 ) -> PyResult<PyCodec> {
277 let codec = self.lookup(encoding, vm)?;
278 if codec.is_text_codec(vm)? {
279 Ok(codec)
280 } else {
281 Err(vm.new_lookup_error(format!(
282 "'{encoding}' is not a text encoding; use {generic_func} to handle arbitrary codecs"
283 )))
284 }
285 }
286
287 pub fn forget(&self, encoding: &str) -> Option<PyCodec> {
288 let encoding = normalize_encoding_name(encoding);
289 self.inner.write().search_cache.remove(encoding.as_ref())
290 }
291
292 pub fn encode(
293 &self,
294 obj: PyObjectRef,
295 encoding: &str,
296 errors: Option<PyUtf8StrRef>,
297 vm: &VirtualMachine,
298 ) -> PyResult {
299 let codec = self.lookup(encoding, vm)?;
300 codec.encode(obj, errors, vm).inspect_err(|exc| {
301 Self::add_codec_note(exc, "encoding", encoding, vm);
302 })
303 }
304
305 pub fn decode(
306 &self,
307 obj: PyObjectRef,
308 encoding: &str,
309 errors: Option<PyUtf8StrRef>,
310 vm: &VirtualMachine,
311 ) -> PyResult {
312 let codec = self.lookup(encoding, vm)?;
313 codec.decode(obj, errors, vm).inspect_err(|exc| {
314 Self::add_codec_note(exc, "decoding", encoding, vm);
315 })
316 }
317
318 pub fn encode_text(
319 &self,
320 obj: PyStrRef,
321 encoding: &str,
322 errors: Option<PyUtf8StrRef>,
323 vm: &VirtualMachine,
324 ) -> PyResult<PyBytesRef> {
325 let codec = self._lookup_text_encoding(encoding, "codecs.encode()", vm)?;
326 codec
327 .encode(obj.into(), errors, vm)
328 .inspect_err(|exc| {
329 Self::add_codec_note(exc, "encoding", encoding, vm);
330 })?
331 .downcast()
332 .map_err(|obj| {
333 vm.new_type_error(format!(
334 "'{}' encoder returned '{}' instead of 'bytes'; use codecs.encode() to \
335 encode to arbitrary types",
336 encoding,
337 obj.class().name(),
338 ))
339 })
340 }
341
342 pub fn decode_text(
343 &self,
344 obj: PyObjectRef,
345 encoding: &str,
346 errors: Option<PyUtf8StrRef>,
347 vm: &VirtualMachine,
348 ) -> PyResult<PyStrRef> {
349 let codec = self._lookup_text_encoding(encoding, "codecs.decode()", vm)?;
350 codec
351 .decode(obj, errors, vm)
352 .inspect_err(|exc| {
353 Self::add_codec_note(exc, "decoding", encoding, vm);
354 })?
355 .downcast()
356 .map_err(|obj| {
357 vm.new_type_error(format!(
358 "'{}' decoder returned '{}' instead of 'str'; use codecs.decode() to \
359 decode to arbitrary types",
360 encoding,
361 obj.class().name(),
362 ))
363 })
364 }
365
366 fn add_codec_note(
367 exc: &crate::builtins::PyBaseExceptionRef,
368 operation: &str,
369 encoding: &str,
370 vm: &VirtualMachine,
371 ) {
372 let note = format!("{operation} with '{encoding}' codec failed");
373 let _ = vm.call_method(exc.as_object(), "add_note", (vm.ctx.new_str(note),));
374 }
375
376 pub fn register_error(&self, name: String, handler: PyObjectRef) -> Option<PyObjectRef> {
377 self.inner.write().errors.insert(name, handler)
378 }
379
380 pub fn unregister_error(&self, name: &str, vm: &VirtualMachine) -> PyResult<bool> {
381 const BUILTIN_ERROR_HANDLERS: &[&str] = &[
382 "strict",
383 "ignore",
384 "replace",
385 "xmlcharrefreplace",
386 "backslashreplace",
387 "namereplace",
388 "surrogatepass",
389 "surrogateescape",
390 ];
391 if BUILTIN_ERROR_HANDLERS.contains(&name) {
392 return Err(vm.new_value_error(format!(
393 "cannot un-register built-in error handler '{name}'"
394 )));
395 }
396 Ok(self.inner.write().errors.remove(name).is_some())
397 }
398
399 pub fn lookup_error_opt(&self, name: &str) -> Option<PyObjectRef> {
400 self.inner.read().errors.get(name).cloned()
401 }
402
403 pub fn lookup_error(&self, name: &str, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
404 self.lookup_error_opt(name)
405 .ok_or_else(|| vm.new_lookup_error(format!("unknown error handler name '{name}'")))
406 }
407}
408
409fn normalize_encoding_name(encoding: &str) -> Cow<'_, str> {
410 let needs_transform = encoding
413 .bytes()
414 .any(|b| b.is_ascii_uppercase() || !b.is_ascii_alphanumeric() && b != b'.');
415 if !needs_transform {
416 return encoding.into();
417 }
418 let mut out = String::with_capacity(encoding.len());
419 let mut punct = false;
420 for c in encoding.chars() {
421 if c.is_ascii_alphanumeric() || c == '.' {
422 if punct && !out.is_empty() {
423 out.push('_');
424 }
425 out.push(c.to_ascii_lowercase());
426 punct = false;
427 } else {
428 punct = true;
429 }
430 }
431 out.into()
432}
433
434#[derive(Eq, PartialEq)]
435enum StandardEncoding {
436 Utf8,
437 Utf16Be,
438 Utf16Le,
439 Utf32Be,
440 Utf32Le,
441}
442
443impl StandardEncoding {
444 #[cfg(target_endian = "little")]
445 const UTF_16_NE: Self = Self::Utf16Le;
446 #[cfg(target_endian = "big")]
447 const UTF_16_NE: Self = Self::Utf16Be;
448
449 #[cfg(target_endian = "little")]
450 const UTF_32_NE: Self = Self::Utf32Le;
451 #[cfg(target_endian = "big")]
452 const UTF_32_NE: Self = Self::Utf32Be;
453
454 fn parse(encoding: &str) -> Option<Self> {
455 if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") {
456 let encoding = encoding
457 .strip_prefix(|c| ['-', '_'].contains(&c))
458 .unwrap_or(encoding);
459 if encoding == "8" {
460 Some(Self::Utf8)
461 } else if let Some(encoding) = encoding.strip_prefix("16") {
462 if encoding.is_empty() {
463 return Some(Self::UTF_16_NE);
464 }
465 let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding);
466 match encoding {
467 "be" => Some(Self::Utf16Be),
468 "le" => Some(Self::Utf16Le),
469 _ => None,
470 }
471 } else if let Some(encoding) = encoding.strip_prefix("32") {
472 if encoding.is_empty() {
473 return Some(Self::UTF_32_NE);
474 }
475 let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding);
476 match encoding {
477 "be" => Some(Self::Utf32Be),
478 "le" => Some(Self::Utf32Le),
479 _ => None,
480 }
481 } else {
482 None
483 }
484 } else if encoding == "cp65001" {
485 Some(Self::Utf8)
486 } else {
487 None
488 }
489 }
490}
491
492struct SurrogatePass;
493
494impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for SurrogatePass {
495 fn handle_encode_error(
496 &self,
497 ctx: &mut PyEncodeContext<'a>,
498 range: Range<StrSize>,
499 reason: Option<&str>,
500 ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
501 let standard_encoding = StandardEncoding::parse(ctx.encoding)
502 .ok_or_else(|| ctx.error_encoding(range.clone(), reason))?;
503 let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
504 let num_chars = range.end.chars - range.start.chars;
505 let mut out: Vec<u8> = Vec::with_capacity(num_chars * 4);
506 for ch in err_str.code_points() {
507 let c = ch.to_u32();
508 let 0xd800..=0xdfff = c else {
509 return Err(ctx.error_encoding(range, reason));
511 };
512 match standard_encoding {
513 StandardEncoding::Utf8 => out.extend(ch.encode_wtf8(&mut [0; 4]).as_bytes()),
514 StandardEncoding::Utf16Le => out.extend((c as u16).to_le_bytes()),
515 StandardEncoding::Utf16Be => out.extend((c as u16).to_be_bytes()),
516 StandardEncoding::Utf32Le => out.extend(c.to_le_bytes()),
517 StandardEncoding::Utf32Be => out.extend(c.to_be_bytes()),
518 }
519 }
520 Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end))
521 }
522}
523
524impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for SurrogatePass {
525 fn handle_decode_error(
526 &self,
527 ctx: &mut PyDecodeContext<'a>,
528 byte_range: Range<usize>,
529 reason: Option<&str>,
530 ) -> PyResult<(PyStrRef, usize)> {
531 let standard_encoding = StandardEncoding::parse(ctx.encoding)
532 .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?;
533
534 let s = ctx.full_data();
535 debug_assert!(byte_range.start <= 0.max(s.len() - 1));
536 debug_assert!(byte_range.end >= 1.min(s.len()));
537 debug_assert!(byte_range.end <= s.len());
538
539 let p = &s[byte_range.start..];
542
543 fn slice<const N: usize>(p: &[u8]) -> Option<[u8; N]> {
544 p.first_chunk().copied()
545 }
546
547 let c = match standard_encoding {
548 StandardEncoding::Utf8 => {
549 slice::<3>(p)
551 .filter(|&[a, b, c]| {
552 (u32::from(a) & 0xf0) == 0xe0
553 && (u32::from(b) & 0xc0) == 0x80
554 && (u32::from(c) & 0xc0) == 0x80
555 })
556 .map(|[a, b, c]| {
557 ((u32::from(a) & 0x0f) << 12)
558 + ((u32::from(b) & 0x3f) << 6)
559 + (u32::from(c) & 0x3f)
560 })
561 }
562 StandardEncoding::Utf16Le => slice(p).map(u16::from_le_bytes).map(u32::from),
563 StandardEncoding::Utf16Be => slice(p).map(u16::from_be_bytes).map(u32::from),
564 StandardEncoding::Utf32Le => slice(p).map(u32::from_le_bytes),
565 StandardEncoding::Utf32Be => slice(p).map(u32::from_be_bytes),
566 };
567 let byte_length = match standard_encoding {
568 StandardEncoding::Utf8 => 3,
569 StandardEncoding::Utf16Be | StandardEncoding::Utf16Le => 2,
570 StandardEncoding::Utf32Be | StandardEncoding::Utf32Le => 4,
571 };
572
573 let c = c
575 .and_then(CodePoint::from_u32)
576 .filter(|c| matches!(c.to_u32(), 0xd800..=0xdfff))
577 .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?;
578
579 Ok((ctx.string(c.into()), byte_range.start + byte_length))
580 }
581}
582
583pub struct PyEncodeContext<'a> {
584 vm: &'a VirtualMachine,
585 encoding: &'a str,
586 data: &'a Py<PyStr>,
587 pos: StrSize,
588 exception: OnceCell<PyBaseExceptionRef>,
589}
590
591impl<'a> PyEncodeContext<'a> {
592 pub fn new(encoding: &'a str, data: &'a Py<PyStr>, vm: &'a VirtualMachine) -> Self {
593 Self {
594 vm,
595 encoding,
596 data,
597 pos: StrSize::default(),
598 exception: OnceCell::new(),
599 }
600 }
601}
602
603impl CodecContext for PyEncodeContext<'_> {
604 type Error = PyBaseExceptionRef;
605 type StrBuf = PyStrRef;
606 type BytesBuf = PyBytesRef;
607
608 fn string(&self, s: Wtf8Buf) -> Self::StrBuf {
609 self.vm.ctx.new_str(s)
610 }
611
612 fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf {
613 self.vm.ctx.new_bytes(b)
614 }
615}
616impl EncodeContext for PyEncodeContext<'_> {
617 fn full_data(&self) -> &Wtf8 {
618 self.data.as_wtf8()
619 }
620
621 fn data_len(&self) -> StrSize {
622 StrSize {
623 bytes: self.data.byte_len(),
624 chars: self.data.char_len(),
625 }
626 }
627
628 fn remaining_data(&self) -> &Wtf8 {
629 &self.full_data()[self.pos.bytes..]
630 }
631
632 fn position(&self) -> StrSize {
633 self.pos
634 }
635
636 fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error> {
637 if pos.chars > self.data.char_len() {
638 return Err(self.vm.new_index_error(format!(
639 "position {} from error handler out of bounds",
640 pos.chars
641 )));
642 }
643 assert!(
644 self.data.as_wtf8().is_code_point_boundary(pos.bytes),
645 "invalid pos {pos:?} for {:?}",
646 self.data.as_wtf8()
647 );
648 self.pos = pos;
649 Ok(())
650 }
651
652 fn error_encoding(&self, range: Range<StrSize>, reason: Option<&str>) -> Self::Error {
653 let vm = self.vm;
654 match self.exception.get() {
655 Some(exc) => {
656 match update_unicode_error_attrs(
657 exc.as_object(),
658 range.start.chars,
659 range.end.chars,
660 reason,
661 vm,
662 ) {
663 Ok(()) => exc.clone(),
664 Err(e) => e,
665 }
666 }
667 None => self
668 .exception
669 .get_or_init(|| {
670 let reason = reason.expect(
671 "should only ever pass reason: None if an exception is already set",
672 );
673 vm.new_unicode_encode_error_real(
674 vm.ctx.new_str(self.encoding),
675 self.data.to_owned(),
676 range.start.chars,
677 range.end.chars,
678 vm.ctx.new_str(reason),
679 )
680 })
681 .clone(),
682 }
683 }
684}
685
686pub struct PyDecodeContext<'a> {
687 vm: &'a VirtualMachine,
688 encoding: &'a str,
689 data: PyDecodeData<'a>,
690 orig_bytes: Option<&'a Py<PyBytes>>,
691 pos: usize,
692 exception: OnceCell<PyBaseExceptionRef>,
693}
694enum PyDecodeData<'a> {
695 Original(BorrowedValue<'a, [u8]>),
696 Modified(PyBytesRef),
697}
698impl ops::Deref for PyDecodeData<'_> {
699 type Target = [u8];
700 fn deref(&self) -> &Self::Target {
701 match self {
702 PyDecodeData::Original(data) => data,
703 PyDecodeData::Modified(data) => data,
704 }
705 }
706}
707
708impl<'a> PyDecodeContext<'a> {
709 pub fn new(encoding: &'a str, data: &'a ArgBytesLike, vm: &'a VirtualMachine) -> Self {
710 Self {
711 vm,
712 encoding,
713 data: PyDecodeData::Original(data.borrow_buf()),
714 orig_bytes: data.as_object().downcast_ref(),
715 pos: 0,
716 exception: OnceCell::new(),
717 }
718 }
719}
720
721impl CodecContext for PyDecodeContext<'_> {
722 type Error = PyBaseExceptionRef;
723 type StrBuf = PyStrRef;
724 type BytesBuf = PyBytesRef;
725
726 fn string(&self, s: Wtf8Buf) -> Self::StrBuf {
727 self.vm.ctx.new_str(s)
728 }
729
730 fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf {
731 self.vm.ctx.new_bytes(b)
732 }
733}
734impl DecodeContext for PyDecodeContext<'_> {
735 fn full_data(&self) -> &[u8] {
736 &self.data
737 }
738
739 fn remaining_data(&self) -> &[u8] {
740 &self.data[self.pos..]
741 }
742
743 fn position(&self) -> usize {
744 self.pos
745 }
746
747 fn advance(&mut self, by: usize) {
748 self.pos += by;
749 }
750
751 fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error> {
752 if pos > self.data.len() {
753 return Err(self
754 .vm
755 .new_index_error(format!("position {pos} from error handler out of bounds",)));
756 }
757 self.pos = pos;
758 Ok(())
759 }
760
761 fn error_decoding(&self, byte_range: Range<usize>, reason: Option<&str>) -> Self::Error {
762 let vm = self.vm;
763
764 match self.exception.get() {
765 Some(exc) => {
766 match update_unicode_error_attrs(
767 exc.as_object(),
768 byte_range.start,
769 byte_range.end,
770 reason,
771 vm,
772 ) {
773 Ok(()) => exc.clone(),
774 Err(e) => e,
775 }
776 }
777 None => self
778 .exception
779 .get_or_init(|| {
780 let reason = reason.expect(
781 "should only ever pass reason: None if an exception is already set",
782 );
783 let data = if let Some(bytes) = self.orig_bytes {
784 bytes.to_owned()
785 } else {
786 vm.ctx.new_bytes(self.data.to_vec())
787 };
788 vm.new_unicode_decode_error_real(
789 vm.ctx.new_str(self.encoding),
790 data,
791 byte_range.start,
792 byte_range.end,
793 vm.ctx.new_str(reason),
794 )
795 })
796 .clone(),
797 }
798 }
799}
800
801#[derive(strum_macros::EnumString)]
802#[strum(serialize_all = "lowercase")]
803enum StandardError {
804 Strict,
805 Ignore,
806 Replace,
807 XmlCharRefReplace,
808 BackslashReplace,
809 SurrogatePass,
810 SurrogateEscape,
811}
812
813impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for StandardError {
814 fn handle_encode_error(
815 &self,
816 ctx: &mut PyEncodeContext<'a>,
817 range: Range<StrSize>,
818 reason: Option<&str>,
819 ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
820 use StandardError::*;
821 match self {
823 Strict => errors::Strict.handle_encode_error(ctx, range, reason),
824 Ignore => errors::Ignore.handle_encode_error(ctx, range, reason),
825 Replace => errors::Replace.handle_encode_error(ctx, range, reason),
826 XmlCharRefReplace => errors::XmlCharRefReplace.handle_encode_error(ctx, range, reason),
827 BackslashReplace => errors::BackslashReplace.handle_encode_error(ctx, range, reason),
828 SurrogatePass => SurrogatePass.handle_encode_error(ctx, range, reason),
829 SurrogateEscape => errors::SurrogateEscape.handle_encode_error(ctx, range, reason),
830 }
831 }
832}
833
834impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for StandardError {
835 fn handle_decode_error(
836 &self,
837 ctx: &mut PyDecodeContext<'a>,
838 byte_range: Range<usize>,
839 reason: Option<&str>,
840 ) -> PyResult<(PyStrRef, usize)> {
841 use StandardError::*;
842 match self {
843 Strict => errors::Strict.handle_decode_error(ctx, byte_range, reason),
844 Ignore => errors::Ignore.handle_decode_error(ctx, byte_range, reason),
845 Replace => errors::Replace.handle_decode_error(ctx, byte_range, reason),
846 XmlCharRefReplace => Err(ctx
847 .vm
848 .new_type_error("don't know how to handle UnicodeDecodeError in error callback")),
849 BackslashReplace => {
850 errors::BackslashReplace.handle_decode_error(ctx, byte_range, reason)
851 }
852 SurrogatePass => self::SurrogatePass.handle_decode_error(ctx, byte_range, reason),
853 SurrogateEscape => errors::SurrogateEscape.handle_decode_error(ctx, byte_range, reason),
854 }
855 }
856}
857
858pub struct ErrorsHandler<'a> {
859 errors: &'a Py<PyUtf8Str>,
860 resolved: OnceCell<ResolvedError>,
861}
862enum ResolvedError {
863 Standard(StandardError),
864 Handler(PyObjectRef),
865}
866
867impl<'a> ErrorsHandler<'a> {
868 #[inline]
869 pub fn new(errors: Option<&'a Py<PyUtf8Str>>, vm: &VirtualMachine) -> Self {
870 match errors {
871 Some(errors) => Self {
872 errors,
873 resolved: OnceCell::new(),
874 },
875 None => Self {
876 errors: identifier_utf8!(vm, strict),
877 resolved: OnceCell::from(ResolvedError::Standard(StandardError::Strict)),
878 },
879 }
880 }
881 #[inline]
882 fn resolve(&self, vm: &VirtualMachine) -> PyResult<&ResolvedError> {
883 if let Some(val) = self.resolved.get() {
884 return Ok(val);
885 }
886 let errors_str = self.errors.as_str();
887 let val = if let Ok(standard) = errors_str.parse() {
888 ResolvedError::Standard(standard)
889 } else {
890 vm.state
891 .codec_registry
892 .lookup_error(errors_str, vm)
893 .map(ResolvedError::Handler)?
894 };
895 let _ = self.resolved.set(val);
896 Ok(self.resolved.get().unwrap())
897 }
898}
899impl StrBuffer for PyStrRef {
900 fn is_compatible_with(&self, kind: StrKind) -> bool {
901 self.kind() <= kind
902 }
903}
904impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for ErrorsHandler<'_> {
905 fn handle_encode_error(
906 &self,
907 ctx: &mut PyEncodeContext<'a>,
908 range: Range<StrSize>,
909 reason: Option<&str>,
910 ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
911 let vm = ctx.vm;
912 let handler = match self.resolve(vm)? {
913 ResolvedError::Standard(standard) => {
914 return standard.handle_encode_error(ctx, range, reason);
915 }
916 ResolvedError::Handler(handler) => handler,
917 };
918 let encode_exc = ctx.error_encoding(range.clone(), reason);
919 let res = handler.call((encode_exc.clone(),), vm)?;
920 let tuple_err =
921 || vm.new_type_error("encoding error handler must return (str/bytes, int) tuple");
922 let (replace, restart) = match res.downcast_ref::<PyTuple>().map(|tup| tup.as_slice()) {
923 Some([replace, restart]) => (replace.clone(), restart),
924 _ => return Err(tuple_err()),
925 };
926 let replace = match_class!(match replace {
927 s @ PyStr => EncodeReplace::Str(s),
928 b @ PyBytes => EncodeReplace::Bytes(b),
929 _ => return Err(tuple_err()),
930 });
931 let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
932 let restart = if restart < 0 {
933 ctx.data.char_len().wrapping_sub(restart.unsigned_abs())
935 } else {
936 restart as usize
937 };
938 let restart = if restart == range.end.chars {
939 range.end
940 } else {
941 StrSize {
942 chars: restart,
943 bytes: ctx
944 .data
945 .as_wtf8()
946 .code_point_indices()
947 .nth(restart)
948 .map_or(ctx.data.byte_len(), |(i, _)| i),
949 }
950 };
951 Ok((replace, restart))
952 }
953}
954impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for ErrorsHandler<'_> {
955 fn handle_decode_error(
956 &self,
957 ctx: &mut PyDecodeContext<'a>,
958 byte_range: Range<usize>,
959 reason: Option<&str>,
960 ) -> PyResult<(PyStrRef, usize)> {
961 let vm = ctx.vm;
962 let handler = match self.resolve(vm)? {
963 ResolvedError::Standard(standard) => {
964 return standard.handle_decode_error(ctx, byte_range, reason);
965 }
966 ResolvedError::Handler(handler) => handler,
967 };
968 let decode_exc = ctx.error_decoding(byte_range.clone(), reason);
969 let data_bytes: PyObjectRef = decode_exc.as_object().get_attr("object", vm)?;
970 let res = handler.call((decode_exc.clone(),), vm)?;
971 let new_data = decode_exc.as_object().get_attr("object", vm)?;
972 if !new_data.is(&data_bytes) {
973 let new_data: PyBytesRef = new_data
974 .downcast()
975 .map_err(|_| vm.new_type_error("object attribute must be bytes"))?;
976 ctx.data = PyDecodeData::Modified(new_data);
977 }
978 let data = &*ctx.data;
979 let tuple_err = || vm.new_type_error("decoding error handler must return (str, int) tuple");
980 match res.downcast_ref::<PyTuple>().map(|tup| tup.as_slice()) {
981 Some([replace, restart]) => {
982 let replace = replace
983 .downcast_ref::<PyStr>()
984 .ok_or_else(tuple_err)?
985 .to_owned();
986 let restart =
987 isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
988 let restart = if restart < 0 {
989 data.len().wrapping_sub(restart.unsigned_abs())
991 } else {
992 restart as usize
993 };
994 Ok((replace, restart))
995 }
996 _ => Err(tuple_err()),
997 }
998 }
999}
1000
1001fn call_native_encode_error<E>(
1002 handler: E,
1003 err: PyObjectRef,
1004 vm: &VirtualMachine,
1005) -> PyResult<(PyObjectRef, usize)>
1006where
1007 for<'a> E: EncodeErrorHandler<PyEncodeContext<'a>>,
1008{
1009 let range = extract_unicode_error_range(&err, vm)?;
1011 let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
1012 let s_encoding = PyUtf8StrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
1013 let mut ctx = PyEncodeContext {
1014 vm,
1015 encoding: s_encoding.as_str(),
1016 data: &s,
1017 pos: StrSize::default(),
1018 exception: OnceCell::from(err.downcast().unwrap()),
1019 };
1020 let mut iter = s.as_wtf8().code_point_indices();
1021 let start = StrSize {
1022 chars: range.start,
1023 bytes: iter.nth(range.start).unwrap().0,
1024 };
1025 let end = StrSize {
1026 chars: range.end,
1027 bytes: if let Some(n) = range.len().checked_sub(1) {
1028 iter.nth(n).map_or(s.byte_len(), |(i, _)| i)
1029 } else {
1030 start.bytes
1031 },
1032 };
1033 let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?;
1034 let replace = match replace {
1035 EncodeReplace::Str(s) => s.into(),
1036 EncodeReplace::Bytes(b) => b.into(),
1037 };
1038 Ok((replace, restart.chars))
1039}
1040
1041fn call_native_decode_error<E>(
1042 handler: E,
1043 err: PyObjectRef,
1044 vm: &VirtualMachine,
1045) -> PyResult<(PyObjectRef, usize)>
1046where
1047 for<'a> E: DecodeErrorHandler<PyDecodeContext<'a>>,
1048{
1049 let range = extract_unicode_error_range(&err, vm)?;
1050 let s = ArgBytesLike::try_from_object(vm, err.get_attr("object", vm)?)?;
1051 let s_encoding = PyUtf8StrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
1052 let mut ctx = PyDecodeContext {
1053 vm,
1054 encoding: s_encoding.as_str(),
1055 data: PyDecodeData::Original(s.borrow_buf()),
1056 orig_bytes: s.as_object().downcast_ref(),
1057 pos: 0,
1058 exception: OnceCell::from(err.downcast().unwrap()),
1059 };
1060 let (replace, restart) = handler.handle_decode_error(&mut ctx, range, None)?;
1061 Ok((replace.into(), restart))
1062}
1063
1064fn call_native_translate_error<E>(
1066 handler: E,
1067 err: PyObjectRef,
1068 vm: &VirtualMachine,
1069) -> PyResult<(PyObjectRef, usize)>
1070where
1071 for<'a> E: EncodeErrorHandler<PyEncodeContext<'a>>,
1072{
1073 let range = extract_unicode_error_range(&err, vm)?;
1075 let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
1076 let mut ctx = PyEncodeContext {
1077 vm,
1078 encoding: "",
1079 data: &s,
1080 pos: StrSize::default(),
1081 exception: OnceCell::from(err.downcast().unwrap()),
1082 };
1083 let mut iter = s.as_wtf8().code_point_indices();
1084 let start = StrSize {
1085 chars: range.start,
1086 bytes: iter.nth(range.start).unwrap().0,
1087 };
1088 let end = StrSize {
1089 chars: range.end,
1090 bytes: if let Some(n) = range.len().checked_sub(1) {
1091 iter.nth(n).map_or(s.byte_len(), |(i, _)| i)
1092 } else {
1093 start.bytes
1094 },
1095 };
1096 let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?;
1097 let replace = match replace {
1098 EncodeReplace::Str(s) => s.into(),
1099 EncodeReplace::Bytes(b) => b.into(),
1100 };
1101 Ok((replace, restart.chars))
1102}
1103
1104fn extract_unicode_error_range(err: &PyObject, vm: &VirtualMachine) -> PyResult<Range<usize>> {
1106 let start = err.get_attr("start", vm)?;
1107 let start = start.try_into_value(vm)?;
1108 let end = err.get_attr("end", vm)?;
1109 let end = end.try_into_value(vm)?;
1110 Ok(Range { start, end })
1111}
1112
1113fn update_unicode_error_attrs(
1114 err: &PyObject,
1115 start: usize,
1116 end: usize,
1117 reason: Option<&str>,
1118 vm: &VirtualMachine,
1119) -> PyResult<()> {
1120 err.set_attr("start", start.to_pyobject(vm), vm)?;
1121 err.set_attr("end", end.to_pyobject(vm), vm)?;
1122 if let Some(reason) = reason {
1123 err.set_attr("reason", reason.to_pyobject(vm), vm)?;
1124 }
1125 Ok(())
1126}
1127
1128#[inline]
1129fn is_encode_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1130 err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error)
1131}
1132#[inline]
1133fn is_decode_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1134 err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error)
1135}
1136#[inline]
1137fn is_translate_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1138 err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error)
1139}
1140
1141fn bad_err_type(err: PyObjectRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
1142 vm.new_type_error(format!(
1143 "don't know how to handle {} in error callback",
1144 err.class().name()
1145 ))
1146}
1147
1148fn strict_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult {
1149 let err = err
1150 .downcast()
1151 .unwrap_or_else(|_| vm.new_type_error("codec must pass exception instance"));
1152 Err(err)
1153}
1154
1155fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1156 if is_encode_err(&err, vm) || is_decode_err(&err, vm) || is_translate_err(&err, vm) {
1157 let range = extract_unicode_error_range(&err, vm)?;
1158 Ok((vm.ctx.new_str(ascii!("")).into(), range.end))
1159 } else {
1160 Err(bad_err_type(err, vm))
1161 }
1162}
1163
1164fn replace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1165 if is_encode_err(&err, vm) {
1166 call_native_encode_error(errors::Replace, err, vm)
1167 } else if is_decode_err(&err, vm) {
1168 call_native_decode_error(errors::Replace, err, vm)
1169 } else if is_translate_err(&err, vm) {
1170 let replacement_char = "\u{FFFD}";
1172 let range = extract_unicode_error_range(&err, vm)?;
1173 let replace = replacement_char.repeat(range.end - range.start);
1174 Ok((replace.to_pyobject(vm), range.end))
1175 } else {
1176 Err(bad_err_type(err, vm))
1177 }
1178}
1179
1180fn xmlcharrefreplace_errors(
1181 err: PyObjectRef,
1182 vm: &VirtualMachine,
1183) -> PyResult<(PyObjectRef, usize)> {
1184 if is_encode_err(&err, vm) {
1185 call_native_encode_error(errors::XmlCharRefReplace, err, vm)
1186 } else {
1187 Err(bad_err_type(err, vm))
1188 }
1189}
1190
1191fn backslashreplace_errors(
1192 err: PyObjectRef,
1193 vm: &VirtualMachine,
1194) -> PyResult<(PyObjectRef, usize)> {
1195 if is_decode_err(&err, vm) {
1196 call_native_decode_error(errors::BackslashReplace, err, vm)
1197 } else if is_encode_err(&err, vm) {
1198 call_native_encode_error(errors::BackslashReplace, err, vm)
1199 } else if is_translate_err(&err, vm) {
1200 call_native_translate_error(errors::BackslashReplace, err, vm)
1201 } else {
1202 Err(bad_err_type(err, vm))
1203 }
1204}
1205
1206fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1207 if is_encode_err(&err, vm) {
1208 call_native_encode_error(errors::NameReplace, err, vm)
1209 } else {
1210 Err(bad_err_type(err, vm))
1211 }
1212}
1213
1214fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1215 if is_encode_err(&err, vm) {
1216 call_native_encode_error(SurrogatePass, err, vm)
1217 } else if is_decode_err(&err, vm) {
1218 call_native_decode_error(SurrogatePass, err, vm)
1219 } else {
1220 Err(bad_err_type(err, vm))
1221 }
1222}
1223
1224fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1225 if is_encode_err(&err, vm) {
1226 call_native_encode_error(errors::SurrogateEscape, err, vm)
1227 } else if is_decode_err(&err, vm) {
1228 call_native_decode_error(errors::SurrogateEscape, err, vm)
1229 } else {
1230 Err(bad_err_type(err, vm))
1231 }
1232}