rustpython_vm/stdlib/
codecs.rs

1pub(crate) use _codecs::make_module;
2
3#[pymodule]
4mod _codecs {
5    use crate::common::encodings;
6    use crate::{
7        builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple},
8        codecs,
9        function::{ArgBytesLike, FuncArgs},
10        AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine,
11    };
12    use std::ops::Range;
13
14    #[pyfunction]
15    fn register(search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
16        vm.state.codec_registry.register(search_function, vm)
17    }
18
19    #[pyfunction]
20    fn unregister(search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
21        vm.state.codec_registry.unregister(search_function)
22    }
23
24    #[pyfunction]
25    fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult {
26        vm.state
27            .codec_registry
28            .lookup(encoding.as_str(), vm)
29            .map(|codec| codec.into_tuple().into())
30    }
31
32    #[derive(FromArgs)]
33    struct CodeArgs {
34        obj: PyObjectRef,
35        #[pyarg(any, optional)]
36        encoding: Option<PyStrRef>,
37        #[pyarg(any, optional)]
38        errors: Option<PyStrRef>,
39    }
40
41    #[pyfunction]
42    fn encode(args: CodeArgs, vm: &VirtualMachine) -> PyResult {
43        let encoding = args
44            .encoding
45            .as_ref()
46            .map_or(codecs::DEFAULT_ENCODING, |s| s.as_str());
47        vm.state
48            .codec_registry
49            .encode(args.obj, encoding, args.errors, vm)
50    }
51
52    #[pyfunction]
53    fn decode(args: CodeArgs, vm: &VirtualMachine) -> PyResult {
54        let encoding = args
55            .encoding
56            .as_ref()
57            .map_or(codecs::DEFAULT_ENCODING, |s| s.as_str());
58        vm.state
59            .codec_registry
60            .decode(args.obj, encoding, args.errors, vm)
61    }
62
63    #[pyfunction]
64    fn _forget_codec(encoding: PyStrRef, vm: &VirtualMachine) {
65        vm.state.codec_registry.forget(encoding.as_str());
66    }
67
68    #[pyfunction]
69    fn register_error(name: PyStrRef, handler: PyObjectRef, vm: &VirtualMachine) {
70        vm.state
71            .codec_registry
72            .register_error(name.as_str().to_owned(), handler);
73    }
74
75    #[pyfunction]
76    fn lookup_error(name: PyStrRef, vm: &VirtualMachine) -> PyResult {
77        vm.state.codec_registry.lookup_error(name.as_str(), vm)
78    }
79
80    struct ErrorsHandler<'a> {
81        vm: &'a VirtualMachine,
82        encoding: &'a str,
83        errors: Option<PyStrRef>,
84        handler: once_cell::unsync::OnceCell<PyObjectRef>,
85    }
86    impl<'a> ErrorsHandler<'a> {
87        #[inline]
88        fn new(encoding: &'a str, errors: Option<PyStrRef>, vm: &'a VirtualMachine) -> Self {
89            ErrorsHandler {
90                vm,
91                encoding,
92                errors,
93                handler: Default::default(),
94            }
95        }
96        #[inline]
97        fn handler_func(&self) -> PyResult<&PyObject> {
98            let vm = self.vm;
99            Ok(self.handler.get_or_try_init(|| {
100                let errors = self.errors.as_ref().map_or("strict", |s| s.as_str());
101                vm.state.codec_registry.lookup_error(errors, vm)
102            })?)
103        }
104    }
105    impl encodings::StrBuffer for PyStrRef {
106        fn is_ascii(&self) -> bool {
107            PyStr::is_ascii(self)
108        }
109    }
110    impl<'vm> encodings::ErrorHandler for ErrorsHandler<'vm> {
111        type Error = PyBaseExceptionRef;
112        type StrBuf = PyStrRef;
113        type BytesBuf = PyBytesRef;
114
115        fn handle_encode_error(
116            &self,
117            data: &str,
118            char_range: Range<usize>,
119            reason: &str,
120        ) -> PyResult<(encodings::EncodeReplace<PyStrRef, PyBytesRef>, usize)> {
121            let vm = self.vm;
122            let data_str = vm.ctx.new_str(data).into();
123            let encode_exc = vm.new_exception(
124                vm.ctx.exceptions.unicode_encode_error.to_owned(),
125                vec![
126                    vm.ctx.new_str(self.encoding).into(),
127                    data_str,
128                    vm.ctx.new_int(char_range.start).into(),
129                    vm.ctx.new_int(char_range.end).into(),
130                    vm.ctx.new_str(reason).into(),
131                ],
132            );
133            let res = self.handler_func()?.call((encode_exc,), vm)?;
134            let tuple_err = || {
135                vm.new_type_error(
136                    "encoding error handler must return (str/bytes, int) tuple".to_owned(),
137                )
138            };
139            let (replace, restart) = match res.payload::<PyTuple>().map(|tup| tup.as_slice()) {
140                Some([replace, restart]) => (replace.clone(), restart),
141                _ => return Err(tuple_err()),
142            };
143            let replace = match_class!(match replace {
144                s @ PyStr => encodings::EncodeReplace::Str(s),
145                b @ PyBytes => encodings::EncodeReplace::Bytes(b),
146                _ => return Err(tuple_err()),
147            });
148            let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
149            let restart = if restart < 0 {
150                // will still be out of bounds if it underflows ¯\_(ツ)_/¯
151                data.len().wrapping_sub(restart.unsigned_abs())
152            } else {
153                restart as usize
154            };
155            Ok((replace, restart))
156        }
157
158        fn handle_decode_error(
159            &self,
160            data: &[u8],
161            byte_range: Range<usize>,
162            reason: &str,
163        ) -> PyResult<(PyStrRef, Option<PyBytesRef>, usize)> {
164            let vm = self.vm;
165            let data_bytes: PyObjectRef = vm.ctx.new_bytes(data.to_vec()).into();
166            let decode_exc = vm.new_exception(
167                vm.ctx.exceptions.unicode_decode_error.to_owned(),
168                vec![
169                    vm.ctx.new_str(self.encoding).into(),
170                    data_bytes.clone(),
171                    vm.ctx.new_int(byte_range.start).into(),
172                    vm.ctx.new_int(byte_range.end).into(),
173                    vm.ctx.new_str(reason).into(),
174                ],
175            );
176            let handler = self.handler_func()?;
177            let res = handler.call((decode_exc.clone(),), vm)?;
178            let new_data = decode_exc
179                .get_arg(1)
180                .ok_or_else(|| vm.new_type_error("object attribute not set".to_owned()))?;
181            let new_data = if new_data.is(&data_bytes) {
182                None
183            } else {
184                let new_data: PyBytesRef = new_data
185                    .downcast()
186                    .map_err(|_| vm.new_type_error("object attribute must be bytes".to_owned()))?;
187                Some(new_data)
188            };
189            let data = new_data.as_ref().map_or(data, |s| s.as_ref());
190            let tuple_err = || {
191                vm.new_type_error("decoding error handler must return (str, int) tuple".to_owned())
192            };
193            match res.payload::<PyTuple>().map(|tup| tup.as_slice()) {
194                Some([replace, restart]) => {
195                    let replace = replace
196                        .downcast_ref::<PyStr>()
197                        .ok_or_else(tuple_err)?
198                        .to_owned();
199                    let restart =
200                        isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
201                    let restart = if restart < 0 {
202                        // will still be out of bounds if it underflows ¯\_(ツ)_/¯
203                        data.len().wrapping_sub(restart.unsigned_abs())
204                    } else {
205                        restart as usize
206                    };
207                    Ok((replace, new_data, restart))
208                }
209                _ => Err(tuple_err()),
210            }
211        }
212
213        fn error_oob_restart(&self, i: usize) -> PyBaseExceptionRef {
214            self.vm
215                .new_index_error(format!("position {i} from error handler out of bounds"))
216        }
217
218        fn error_encoding(
219            &self,
220            data: &str,
221            char_range: Range<usize>,
222            reason: &str,
223        ) -> Self::Error {
224            let vm = self.vm;
225            vm.new_exception(
226                vm.ctx.exceptions.unicode_encode_error.to_owned(),
227                vec![
228                    vm.ctx.new_str(self.encoding).into(),
229                    vm.ctx.new_str(data).into(),
230                    vm.ctx.new_int(char_range.start).into(),
231                    vm.ctx.new_int(char_range.end).into(),
232                    vm.ctx.new_str(reason).into(),
233                ],
234            )
235        }
236    }
237
238    type EncodeResult = PyResult<(Vec<u8>, usize)>;
239
240    #[derive(FromArgs)]
241    struct EncodeArgs {
242        #[pyarg(positional)]
243        s: PyStrRef,
244        #[pyarg(positional, optional)]
245        errors: Option<PyStrRef>,
246    }
247
248    impl EncodeArgs {
249        #[inline]
250        fn encode<'a, F>(self, name: &'a str, encode: F, vm: &'a VirtualMachine) -> EncodeResult
251        where
252            F: FnOnce(&str, &ErrorsHandler<'a>) -> PyResult<Vec<u8>>,
253        {
254            let errors = ErrorsHandler::new(name, self.errors, vm);
255            let encoded = encode(self.s.as_str(), &errors)?;
256            Ok((encoded, self.s.char_len()))
257        }
258    }
259
260    type DecodeResult = PyResult<(String, usize)>;
261
262    #[derive(FromArgs)]
263    struct DecodeArgs {
264        #[pyarg(positional)]
265        data: ArgBytesLike,
266        #[pyarg(positional, optional)]
267        errors: Option<PyStrRef>,
268        #[pyarg(positional, default = "false")]
269        final_decode: bool,
270    }
271
272    impl DecodeArgs {
273        #[inline]
274        fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult
275        where
276            F: FnOnce(&[u8], &ErrorsHandler<'a>, bool) -> DecodeResult,
277        {
278            let data = self.data.borrow_buf();
279            let errors = ErrorsHandler::new(name, self.errors, vm);
280            decode(&data, &errors, self.final_decode)
281        }
282    }
283
284    #[derive(FromArgs)]
285    struct DecodeArgsNoFinal {
286        #[pyarg(positional)]
287        data: ArgBytesLike,
288        #[pyarg(positional, optional)]
289        errors: Option<PyStrRef>,
290    }
291
292    impl DecodeArgsNoFinal {
293        #[inline]
294        fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult
295        where
296            F: FnOnce(&[u8], &ErrorsHandler<'a>) -> DecodeResult,
297        {
298            let data = self.data.borrow_buf();
299            let errors = ErrorsHandler::new(name, self.errors, vm);
300            decode(&data, &errors)
301        }
302    }
303
304    macro_rules! do_codec {
305        ($module:ident :: $func:ident, $args: expr, $vm:expr) => {{
306            use encodings::$module as codec;
307            $args.$func(codec::ENCODING_NAME, codec::$func, $vm)
308        }};
309    }
310
311    #[pyfunction]
312    fn utf_8_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult {
313        do_codec!(utf8::encode, args, vm)
314    }
315
316    #[pyfunction]
317    fn utf_8_decode(args: DecodeArgs, vm: &VirtualMachine) -> DecodeResult {
318        do_codec!(utf8::decode, args, vm)
319    }
320
321    #[pyfunction]
322    fn latin_1_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult {
323        if args.s.is_ascii() {
324            return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len()));
325        }
326        do_codec!(latin_1::encode, args, vm)
327    }
328
329    #[pyfunction]
330    fn latin_1_decode(args: DecodeArgsNoFinal, vm: &VirtualMachine) -> DecodeResult {
331        do_codec!(latin_1::decode, args, vm)
332    }
333
334    #[pyfunction]
335    fn ascii_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult {
336        if args.s.is_ascii() {
337            return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len()));
338        }
339        do_codec!(ascii::encode, args, vm)
340    }
341
342    #[pyfunction]
343    fn ascii_decode(args: DecodeArgsNoFinal, vm: &VirtualMachine) -> DecodeResult {
344        do_codec!(ascii::decode, args, vm)
345    }
346
347    // TODO: implement these codecs in Rust!
348
349    use crate::common::static_cell::StaticCell;
350    #[inline]
351    fn delegate_pycodecs(
352        cell: &'static StaticCell<PyObjectRef>,
353        name: &'static str,
354        args: FuncArgs,
355        vm: &VirtualMachine,
356    ) -> PyResult {
357        let f = cell.get_or_try_init(|| {
358            let module = vm.import("_pycodecs", 0)?;
359            module.get_attr(name, vm)
360        })?;
361        f.call(args, vm)
362    }
363    macro_rules! delegate_pycodecs {
364        ($name:ident, $args:ident, $vm:ident) => {{
365            rustpython_common::static_cell!(
366                static FUNC: PyObjectRef;
367            );
368            delegate_pycodecs(&FUNC, stringify!($name), $args, $vm)
369        }};
370    }
371
372    #[pyfunction]
373    fn mbcs_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
374        delegate_pycodecs!(mbcs_encode, args, vm)
375    }
376    #[pyfunction]
377    fn mbcs_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
378        delegate_pycodecs!(mbcs_decode, args, vm)
379    }
380    #[pyfunction]
381    fn readbuffer_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
382        delegate_pycodecs!(readbuffer_encode, args, vm)
383    }
384    #[pyfunction]
385    fn escape_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
386        delegate_pycodecs!(escape_encode, args, vm)
387    }
388    #[pyfunction]
389    fn escape_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
390        delegate_pycodecs!(escape_decode, args, vm)
391    }
392    #[pyfunction]
393    fn unicode_escape_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
394        delegate_pycodecs!(unicode_escape_encode, args, vm)
395    }
396    #[pyfunction]
397    fn unicode_escape_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
398        delegate_pycodecs!(unicode_escape_decode, args, vm)
399    }
400    #[pyfunction]
401    fn raw_unicode_escape_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
402        delegate_pycodecs!(raw_unicode_escape_encode, args, vm)
403    }
404    #[pyfunction]
405    fn raw_unicode_escape_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
406        delegate_pycodecs!(raw_unicode_escape_decode, args, vm)
407    }
408    #[pyfunction]
409    fn utf_7_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
410        delegate_pycodecs!(utf_7_encode, args, vm)
411    }
412    #[pyfunction]
413    fn utf_7_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
414        delegate_pycodecs!(utf_7_decode, args, vm)
415    }
416    #[pyfunction]
417    fn utf_16_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
418        delegate_pycodecs!(utf_16_encode, args, vm)
419    }
420    #[pyfunction]
421    fn utf_16_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
422        delegate_pycodecs!(utf_16_decode, args, vm)
423    }
424    #[pyfunction]
425    fn charmap_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
426        delegate_pycodecs!(charmap_encode, args, vm)
427    }
428    #[pyfunction]
429    fn charmap_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
430        delegate_pycodecs!(charmap_decode, args, vm)
431    }
432    #[pyfunction]
433    fn charmap_build(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
434        delegate_pycodecs!(charmap_build, args, vm)
435    }
436    #[pyfunction]
437    fn utf_16_le_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
438        delegate_pycodecs!(utf_16_le_encode, args, vm)
439    }
440    #[pyfunction]
441    fn utf_16_le_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
442        delegate_pycodecs!(utf_16_le_decode, args, vm)
443    }
444    #[pyfunction]
445    fn utf_16_be_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
446        delegate_pycodecs!(utf_16_be_encode, args, vm)
447    }
448    #[pyfunction]
449    fn utf_16_be_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
450        delegate_pycodecs!(utf_16_be_decode, args, vm)
451    }
452    #[pyfunction]
453    fn utf_16_ex_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {
454        delegate_pycodecs!(utf_16_ex_decode, args, vm)
455    }
456    // TODO: utf-32 functions
457}