Skip to main content

rustpython_vm/stdlib/
marshal.rs

1// spell-checker:ignore pyfrozen pycomplex
2pub(crate) use decl::module_def;
3
4#[pymodule(name = "marshal")]
5mod decl {
6    use crate::builtins::code::{CodeObject, Literal, PyObjBag};
7    use crate::class::StaticType;
8    use crate::common::wtf8::Wtf8;
9    use crate::{
10        PyObjectRef, PyResult, TryFromObject, VirtualMachine,
11        builtins::{
12            PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat,
13            PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple,
14        },
15        convert::ToPyObject,
16        function::{ArgBytesLike, OptionalArg},
17        object::{AsObject, PyPayload},
18        protocol::PyBuffer,
19    };
20    use malachite_bigint::BigInt;
21    use num_traits::Zero;
22    use rustpython_compiler_core::marshal;
23
24    #[pyattr(name = "version")]
25    use marshal::FORMAT_VERSION;
26
27    pub struct DumpError;
28
29    impl marshal::Dumpable for PyObjectRef {
30        type Error = DumpError;
31        type Constant = Literal;
32
33        fn with_dump<R>(
34            &self,
35            f: impl FnOnce(marshal::DumpableValue<'_, Self>) -> R,
36        ) -> Result<R, Self::Error> {
37            use marshal::DumpableValue::*;
38            if self.is(PyStopIteration::static_type()) {
39                return Ok(f(StopIter));
40            }
41            let ret = match_class!(match self {
42                PyNone => f(None),
43                PyEllipsis => f(Ellipsis),
44                ref pyint @ PyInt => {
45                    if self.class().is(PyBool::static_type()) {
46                        f(Boolean(!pyint.as_bigint().is_zero()))
47                    } else {
48                        f(Integer(pyint.as_bigint()))
49                    }
50                }
51                ref pyfloat @ PyFloat => {
52                    f(Float(pyfloat.to_f64()))
53                }
54                ref pycomplex @ PyComplex => {
55                    f(Complex(pycomplex.to_complex64()))
56                }
57                ref pystr @ PyStr => {
58                    f(Str(pystr.as_wtf8()))
59                }
60                ref pylist @ PyList => {
61                    f(List(&pylist.borrow_vec()))
62                }
63                ref pyset @ PySet => {
64                    let elements = pyset.elements();
65                    f(Set(&elements))
66                }
67                ref pyfrozen @ PyFrozenSet => {
68                    let elements = pyfrozen.elements();
69                    f(Frozenset(&elements))
70                }
71                ref pytuple @ PyTuple => {
72                    f(Tuple(pytuple.as_slice()))
73                }
74                ref pydict @ PyDict => {
75                    let entries = pydict.into_iter().collect::<Vec<_>>();
76                    f(Dict(&entries))
77                }
78                ref bytes @ PyBytes => {
79                    f(Bytes(bytes.as_bytes()))
80                }
81                ref bytes @ PyByteArray => {
82                    f(Bytes(&bytes.borrow_buf()))
83                }
84                ref co @ PyCode => {
85                    f(Code(co))
86                }
87                _ => return Err(DumpError),
88            });
89            Ok(ret)
90        }
91    }
92
93    #[derive(FromArgs)]
94    struct DumpsArgs {
95        value: PyObjectRef,
96        #[pyarg(any, optional)]
97        _version: OptionalArg<i32>,
98        #[pyarg(named, default = true)]
99        allow_code: bool,
100    }
101
102    #[pyfunction]
103    fn dumps(args: DumpsArgs, vm: &VirtualMachine) -> PyResult<PyBytes> {
104        let DumpsArgs {
105            value,
106            allow_code,
107            _version,
108        } = args;
109        let version = _version.unwrap_or(marshal::FORMAT_VERSION as i32);
110        if !allow_code {
111            check_no_code(&value, vm)?;
112        }
113        check_exact_type(&value, vm)?;
114        let mut buf = Vec::new();
115        let mut refs = if version >= 3 {
116            Some(WriterRefTable::new())
117        } else {
118            None
119        };
120        write_object(&mut buf, &value, &mut refs, version, vm)?;
121        Ok(PyBytes::from(buf))
122    }
123
124    struct WriterRefTable {
125        map: std::collections::HashMap<usize, u32>,
126        next_idx: u32,
127    }
128
129    impl WriterRefTable {
130        fn new() -> Self {
131            Self {
132                map: std::collections::HashMap::new(),
133                next_idx: 0,
134            }
135        }
136        fn try_ref(&mut self, buf: &mut Vec<u8>, obj: &PyObjectRef) -> bool {
137            use marshal::Write;
138            let id = obj.get_id();
139            if let Some(&idx) = self.map.get(&id) {
140                buf.write_u8(b'r');
141                buf.write_u32(idx);
142                true
143            } else {
144                false
145            }
146        }
147        fn reserve(&mut self, obj: &PyObjectRef) -> u32 {
148            let idx = self.next_idx;
149            self.map.insert(obj.get_id(), idx);
150            self.next_idx += 1;
151            idx
152        }
153    }
154
155    fn write_object(
156        buf: &mut Vec<u8>,
157        obj: &PyObjectRef,
158        refs: &mut Option<WriterRefTable>,
159        version: i32,
160        vm: &VirtualMachine,
161    ) -> PyResult<()> {
162        write_object_depth(
163            buf,
164            obj,
165            refs,
166            version,
167            vm,
168            marshal::MAX_MARSHAL_STACK_DEPTH,
169        )
170    }
171
172    fn write_object_depth(
173        buf: &mut Vec<u8>,
174        obj: &PyObjectRef,
175        refs: &mut Option<WriterRefTable>,
176        version: i32,
177        vm: &VirtualMachine,
178        depth: usize,
179    ) -> PyResult<()> {
180        use marshal::Write;
181        if depth == 0 {
182            return Err(vm.new_value_error("object too deeply nested to marshal".to_string()));
183        }
184
185        // Singletons: no FLAG_REF needed
186        let is_singleton = vm.is_none(obj)
187            || obj.class().is(PyBool::static_type())
188            || obj.is(PyStopIteration::static_type())
189            || obj.downcast_ref::<crate::builtins::PyEllipsis>().is_some();
190
191        // FLAG_REF: check if already written, otherwise reserve slot
192        if !is_singleton
193            && let Some(rt) = refs.as_mut()
194            && rt.try_ref(buf, obj)
195        {
196            return Ok(());
197        }
198        let type_pos = buf.len();
199        let use_ref = refs.is_some() && !is_singleton;
200        if use_ref {
201            refs.as_mut().unwrap().reserve(obj);
202        }
203
204        if vm.is_none(obj) {
205            buf.write_u8(b'N');
206        } else if obj.is(PyStopIteration::static_type()) {
207            buf.write_u8(b'S');
208        } else if obj.class().is(PyBool::static_type()) {
209            let val = obj
210                .downcast_ref::<PyInt>()
211                .is_some_and(|i| !i.as_bigint().is_zero());
212            buf.write_u8(if val { b'T' } else { b'F' });
213        } else if obj.downcast_ref::<crate::builtins::PyEllipsis>().is_some() {
214            buf.write_u8(b'.');
215        } else if let Some(i) = obj.downcast_ref::<PyInt>() {
216            // TYPE_INT for i32 range, TYPE_LONG for larger
217            if let Ok(val) = i32::try_from(i.as_bigint()) {
218                buf.write_u8(b'i');
219                buf.write_u32(val as u32);
220            } else {
221                buf.write_u8(b'l');
222                let (sign, raw) = i.as_bigint().to_bytes_le();
223                let mut digits = Vec::new();
224                let mut accum: u32 = 0;
225                let mut bits = 0u32;
226                for &byte in &raw {
227                    accum |= (byte as u32) << bits;
228                    bits += 8;
229                    while bits >= 15 {
230                        digits.push((accum & 0x7fff) as u16);
231                        accum >>= 15;
232                        bits -= 15;
233                    }
234                }
235                if accum > 0 || digits.is_empty() {
236                    digits.push(accum as u16);
237                }
238                while digits.len() > 1 && *digits.last().unwrap() == 0 {
239                    digits.pop();
240                }
241                let n = digits.len() as i32;
242                let n = if sign == malachite_bigint::Sign::Minus {
243                    -n
244                } else {
245                    n
246                };
247                buf.write_u32(n as u32);
248                for d in &digits {
249                    buf.write_u16(*d);
250                }
251            }
252        } else if let Some(f) = obj.downcast_ref::<PyFloat>() {
253            buf.write_u8(b'g');
254            buf.write_u64(f.to_f64().to_bits());
255        } else if let Some(c) = obj.downcast_ref::<PyComplex>() {
256            buf.write_u8(b'y');
257            let cv = c.to_complex64();
258            buf.write_u64(cv.re.to_bits());
259            buf.write_u64(cv.im.to_bits());
260        } else if let Some(s) = obj.downcast_ref::<PyStr>() {
261            let bytes = s.as_wtf8().as_bytes();
262            let interned = version >= 3;
263            if bytes.len() < 256 && bytes.is_ascii() {
264                buf.write_u8(if interned { b'Z' } else { b'z' });
265                buf.write_u8(bytes.len() as u8);
266            } else {
267                buf.write_u8(if interned { b't' } else { b'u' });
268                buf.write_u32(bytes.len() as u32);
269            }
270            buf.write_slice(bytes);
271        } else if let Some(b) = obj.downcast_ref::<PyBytes>() {
272            buf.write_u8(b's');
273            let data = b.as_bytes();
274            buf.write_u32(data.len() as u32);
275            buf.write_slice(data);
276        } else if let Some(b) = obj.downcast_ref::<PyByteArray>() {
277            buf.write_u8(b's');
278            let data = b.borrow_buf();
279            buf.write_u32(data.len() as u32);
280            buf.write_slice(&data);
281        } else if let Some(t) = obj.downcast_ref::<PyTuple>() {
282            buf.write_u8(b'(');
283            buf.write_u32(t.len() as u32);
284            for elem in t.as_slice() {
285                write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
286            }
287        } else if let Some(l) = obj.downcast_ref::<PyList>() {
288            buf.write_u8(b'[');
289            let items = l.borrow_vec();
290            buf.write_u32(items.len() as u32);
291            for elem in items.iter() {
292                write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
293            }
294        } else if let Some(d) = obj.downcast_ref::<PyDict>() {
295            buf.write_u8(b'{');
296            for (k, v) in d.into_iter() {
297                write_object_depth(buf, &k, refs, version, vm, depth - 1)?;
298                write_object_depth(buf, &v, refs, version, vm, depth - 1)?;
299            }
300            buf.write_u8(b'0'); // TYPE_NULL terminator
301        } else if let Some(s) = obj.downcast_ref::<PySet>() {
302            buf.write_u8(b'<');
303            let elems = s.elements();
304            buf.write_u32(elems.len() as u32);
305            for elem in &elems {
306                write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
307            }
308        } else if let Some(s) = obj.downcast_ref::<PyFrozenSet>() {
309            buf.write_u8(b'>');
310            let elems = s.elements();
311            buf.write_u32(elems.len() as u32);
312            for elem in &elems {
313                write_object_depth(buf, elem, refs, version, vm, depth - 1)?;
314            }
315        } else if let Some(co) = obj.downcast_ref::<PyCode>() {
316            buf.write_u8(b'c');
317            marshal::serialize_code(buf, &co.code);
318        } else if let Some(sl) = obj.downcast_ref::<crate::builtins::PySlice>() {
319            if version < 5 {
320                return Err(vm.new_value_error("unmarshallable object".to_string()));
321            }
322            buf.write_u8(b':');
323            let none: PyObjectRef = vm.ctx.none();
324            write_object_depth(
325                buf,
326                sl.start.as_ref().unwrap_or(&none),
327                refs,
328                version,
329                vm,
330                depth - 1,
331            )?;
332            write_object_depth(buf, &sl.stop, refs, version, vm, depth - 1)?;
333            write_object_depth(
334                buf,
335                sl.step.as_ref().unwrap_or(&none),
336                refs,
337                version,
338                vm,
339                depth - 1,
340            )?;
341        } else if let Ok(bytes_like) = ArgBytesLike::try_from_object(vm, obj.clone()) {
342            buf.write_u8(b's');
343            let data = bytes_like.borrow_buf();
344            buf.write_u32(data.len() as u32);
345            buf.write_slice(&data);
346        } else {
347            return Err(vm.new_value_error("unmarshallable object".to_string()));
348        }
349
350        if use_ref {
351            buf[type_pos] |= marshal::FLAG_REF;
352        }
353        Ok(())
354    }
355
356    #[derive(FromArgs)]
357    struct DumpArgs {
358        value: PyObjectRef,
359        f: PyObjectRef,
360        #[pyarg(any, optional)]
361        _version: OptionalArg<i32>,
362        #[pyarg(named, default = true)]
363        allow_code: bool,
364    }
365
366    #[pyfunction]
367    fn dump(args: DumpArgs, vm: &VirtualMachine) -> PyResult<()> {
368        let dumped = dumps(
369            DumpsArgs {
370                value: args.value,
371                _version: args._version,
372                allow_code: args.allow_code,
373            },
374            vm,
375        )?;
376        vm.call_method(&args.f, "write", (dumped,))?;
377        Ok(())
378    }
379
380    #[derive(Copy, Clone)]
381    struct PyMarshalBag<'a>(&'a VirtualMachine);
382
383    impl<'a> marshal::MarshalBag for PyMarshalBag<'a> {
384        type Value = PyObjectRef;
385        type ConstantBag = PyObjBag<'a>;
386
387        fn make_bool(&self, value: bool) -> Self::Value {
388            self.0.ctx.new_bool(value).into()
389        }
390        fn make_none(&self) -> Self::Value {
391            self.0.ctx.none()
392        }
393        fn make_ellipsis(&self) -> Self::Value {
394            self.0.ctx.ellipsis.clone().into()
395        }
396        fn make_float(&self, value: f64) -> Self::Value {
397            self.0.ctx.new_float(value).into()
398        }
399        fn make_complex(&self, value: num_complex::Complex64) -> Self::Value {
400            self.0.ctx.new_complex(value).into()
401        }
402        fn make_str(&self, value: &Wtf8) -> Self::Value {
403            self.0.ctx.new_str(value).into()
404        }
405        fn make_bytes(&self, value: &[u8]) -> Self::Value {
406            self.0.ctx.new_bytes(value.to_vec()).into()
407        }
408        fn make_int(&self, value: BigInt) -> Self::Value {
409            self.0.ctx.new_int(value).into()
410        }
411        fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
412            self.0.ctx.new_tuple(elements.collect()).into()
413        }
414        fn make_code(&self, code: CodeObject) -> Self::Value {
415            self.0.ctx.new_code(code).into()
416        }
417        fn make_stop_iter(&self) -> Result<Self::Value, marshal::MarshalError> {
418            Ok(self.0.ctx.exceptions.stop_iteration.to_owned().into())
419        }
420        fn make_list(
421            &self,
422            it: impl Iterator<Item = Self::Value>,
423        ) -> Result<Self::Value, marshal::MarshalError> {
424            Ok(self.0.ctx.new_list(it.collect()).into())
425        }
426        fn make_set(
427            &self,
428            it: impl Iterator<Item = Self::Value>,
429        ) -> Result<Self::Value, marshal::MarshalError> {
430            let set = PySet::default().into_ref(&self.0.ctx);
431            for elem in it {
432                set.add(elem, self.0).unwrap()
433            }
434            Ok(set.into())
435        }
436        fn make_frozenset(
437            &self,
438            it: impl Iterator<Item = Self::Value>,
439        ) -> Result<Self::Value, marshal::MarshalError> {
440            Ok(PyFrozenSet::from_iter(self.0, it)
441                .unwrap()
442                .to_pyobject(self.0))
443        }
444        fn make_dict(
445            &self,
446            it: impl Iterator<Item = (Self::Value, Self::Value)>,
447        ) -> Result<Self::Value, marshal::MarshalError> {
448            let dict = self.0.ctx.new_dict();
449            for (k, v) in it {
450                dict.set_item(&*k, v, self.0).unwrap()
451            }
452            Ok(dict.into())
453        }
454        fn make_slice(
455            &self,
456            start: Self::Value,
457            stop: Self::Value,
458            step: Self::Value,
459        ) -> Result<Self::Value, marshal::MarshalError> {
460            use crate::builtins::PySlice;
461            let vm = self.0;
462            Ok(PySlice {
463                start: if vm.is_none(&start) {
464                    None
465                } else {
466                    Some(start)
467                },
468                stop,
469                step: if vm.is_none(&step) { None } else { Some(step) },
470            }
471            .into_ref(&vm.ctx)
472            .into())
473        }
474        fn constant_bag(self) -> Self::ConstantBag {
475            PyObjBag(&self.0.ctx)
476        }
477    }
478
479    #[derive(FromArgs)]
480    struct LoadsArgs {
481        #[pyarg(any)]
482        data: PyBuffer,
483        #[pyarg(named, default = true)]
484        allow_code: bool,
485    }
486
487    #[pyfunction]
488    fn loads(args: LoadsArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
489        let LoadsArgs {
490            data: pybuffer,
491            allow_code,
492        } = args;
493        let buf = pybuffer.as_contiguous().ok_or_else(|| {
494            vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous")
495        })?;
496
497        let result =
498            marshal::deserialize_value(&mut &buf[..], PyMarshalBag(vm)).map_err(|e| match e {
499                marshal::MarshalError::Eof => vm.new_exception_msg(
500                    vm.ctx.exceptions.eof_error.to_owned(),
501                    "marshal data too short".into(),
502                ),
503                _ => vm.new_value_error("bad marshal data"),
504            })?;
505        if !allow_code {
506            check_no_code(&result, vm)?;
507        }
508        Ok(result)
509    }
510
511    #[derive(FromArgs)]
512    struct LoadArgs {
513        f: PyObjectRef,
514        #[pyarg(named, default = true)]
515        allow_code: bool,
516    }
517
518    #[pyfunction]
519    fn load(args: LoadArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
520        // Read from file object into a buffer, one object at a time.
521        // We read all available data, deserialize one object, then seek
522        // back to just after the consumed bytes.
523        let tell_before = vm
524            .call_method(&args.f, "tell", ())?
525            .try_into_value::<i64>(vm)?;
526        let read_res = vm.call_method(&args.f, "read", ())?;
527        let bytes = ArgBytesLike::try_from_object(vm, read_res)?;
528        let buf = bytes.borrow_buf();
529
530        let mut rdr: &[u8] = &buf;
531        let len_before = rdr.len();
532        let result =
533            marshal::deserialize_value(&mut rdr, PyMarshalBag(vm)).map_err(|e| match e {
534                marshal::MarshalError::Eof => vm.new_exception_msg(
535                    vm.ctx.exceptions.eof_error.to_owned(),
536                    "marshal data too short".into(),
537                ),
538                _ => vm.new_value_error("bad marshal data"),
539            })?;
540        let consumed = len_before - rdr.len();
541
542        // Seek file to just after the consumed bytes
543        let new_pos = tell_before + consumed as i64;
544        vm.call_method(&args.f, "seek", (new_pos,))?;
545
546        if !args.allow_code {
547            check_no_code(&result, vm)?;
548        }
549        Ok(result)
550    }
551
552    /// Reject subclasses of marshallable types (int, float, complex, tuple, etc.).
553    /// Recursively check that no code objects are present.
554    fn check_no_code(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
555        if obj.downcast_ref::<PyCode>().is_some() {
556            return Err(vm.new_value_error("unmarshalling code objects is disallowed".to_string()));
557        }
558        if let Some(tup) = obj.downcast_ref::<PyTuple>() {
559            for elem in tup.as_slice() {
560                check_no_code(elem, vm)?;
561            }
562        } else if let Some(list) = obj.downcast_ref::<PyList>() {
563            for elem in list.borrow_vec().iter() {
564                check_no_code(elem, vm)?;
565            }
566        } else if let Some(set) = obj.downcast_ref::<PySet>() {
567            for elem in set.elements() {
568                check_no_code(&elem, vm)?;
569            }
570        } else if let Some(fset) = obj.downcast_ref::<PyFrozenSet>() {
571            for elem in fset.elements() {
572                check_no_code(&elem, vm)?;
573            }
574        } else if let Some(dict) = obj.downcast_ref::<PyDict>() {
575            for (k, v) in dict.into_iter() {
576                check_no_code(&k, vm)?;
577                check_no_code(&v, vm)?;
578            }
579        }
580        Ok(())
581    }
582
583    fn check_exact_type(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
584        let cls = obj.class();
585        // bool is a subclass of int but is marshallable
586        if cls.is(PyBool::static_type()) {
587            return Ok(());
588        }
589        for base in [
590            PyInt::static_type(),
591            PyFloat::static_type(),
592            PyComplex::static_type(),
593            PyTuple::static_type(),
594            PyList::static_type(),
595            PyDict::static_type(),
596            PySet::static_type(),
597            PyFrozenSet::static_type(),
598        ] {
599            if cls.fast_issubclass(base) && !cls.is(base) {
600                return Err(vm.new_value_error("unmarshallable object".to_string()));
601            }
602        }
603        Ok(())
604    }
605}