rustpython_compiler_core/
marshal.rs

1use crate::bytecode::*;
2use malachite_bigint::{BigInt, Sign};
3use num_complex::Complex64;
4use rustpython_parser_core::source_code::{OneIndexed, SourceLocation};
5use std::convert::Infallible;
6
7pub const FORMAT_VERSION: u32 = 4;
8
9#[derive(Debug)]
10pub enum MarshalError {
11    /// Unexpected End Of File
12    Eof,
13    /// Invalid Bytecode
14    InvalidBytecode,
15    /// Invalid utf8 in string
16    InvalidUtf8,
17    /// Invalid source location
18    InvalidLocation,
19    /// Bad type marker
20    BadType,
21}
22
23impl std::fmt::Display for MarshalError {
24    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25        match self {
26            Self::Eof => f.write_str("unexpected end of data"),
27            Self::InvalidBytecode => f.write_str("invalid bytecode"),
28            Self::InvalidUtf8 => f.write_str("invalid utf8"),
29            Self::InvalidLocation => f.write_str("invalid source location"),
30            Self::BadType => f.write_str("bad type marker"),
31        }
32    }
33}
34
35impl From<std::str::Utf8Error> for MarshalError {
36    fn from(_: std::str::Utf8Error) -> Self {
37        Self::InvalidUtf8
38    }
39}
40
41impl std::error::Error for MarshalError {}
42
43type Result<T, E = MarshalError> = std::result::Result<T, E>;
44
45#[repr(u8)]
46enum Type {
47    // Null = b'0',
48    None = b'N',
49    False = b'F',
50    True = b'T',
51    StopIter = b'S',
52    Ellipsis = b'.',
53    Int = b'i',
54    Float = b'g',
55    Complex = b'y',
56    // Long = b'l',  // i32
57    Bytes = b's', // = TYPE_STRING
58    // Interned = b't',
59    // Ref = b'r',
60    Tuple = b'(',
61    List = b'[',
62    Dict = b'{',
63    Code = b'c',
64    Unicode = b'u',
65    // Unknown = b'?',
66    Set = b'<',
67    FrozenSet = b'>',
68    Ascii = b'a',
69    // AsciiInterned = b'A',
70    // SmallTuple = b')',
71    // ShortAscii = b'z',
72    // ShortAsciiInterned = b'Z',
73}
74// const FLAG_REF: u8 = b'\x80';
75
76impl TryFrom<u8> for Type {
77    type Error = MarshalError;
78    fn try_from(value: u8) -> Result<Self> {
79        use Type::*;
80        Ok(match value {
81            // b'0' => Null,
82            b'N' => None,
83            b'F' => False,
84            b'T' => True,
85            b'S' => StopIter,
86            b'.' => Ellipsis,
87            b'i' => Int,
88            b'g' => Float,
89            b'y' => Complex,
90            // b'l' => Long,
91            b's' => Bytes,
92            // b't' => Interned,
93            // b'r' => Ref,
94            b'(' => Tuple,
95            b'[' => List,
96            b'{' => Dict,
97            b'c' => Code,
98            b'u' => Unicode,
99            // b'?' => Unknown,
100            b'<' => Set,
101            b'>' => FrozenSet,
102            b'a' => Ascii,
103            // b'A' => AsciiInterned,
104            // b')' => SmallTuple,
105            // b'z' => ShortAscii,
106            // b'Z' => ShortAsciiInterned,
107            _ => return Err(MarshalError::BadType),
108        })
109    }
110}
111
112pub trait Read {
113    fn read_slice(&mut self, n: u32) -> Result<&[u8]>;
114    fn read_array<const N: usize>(&mut self) -> Result<&[u8; N]> {
115        self.read_slice(N as u32).map(|s| s.try_into().unwrap())
116    }
117    fn read_str(&mut self, len: u32) -> Result<&str> {
118        Ok(std::str::from_utf8(self.read_slice(len)?)?)
119    }
120    fn read_u8(&mut self) -> Result<u8> {
121        Ok(u8::from_le_bytes(*self.read_array()?))
122    }
123    fn read_u16(&mut self) -> Result<u16> {
124        Ok(u16::from_le_bytes(*self.read_array()?))
125    }
126    fn read_u32(&mut self) -> Result<u32> {
127        Ok(u32::from_le_bytes(*self.read_array()?))
128    }
129    fn read_u64(&mut self) -> Result<u64> {
130        Ok(u64::from_le_bytes(*self.read_array()?))
131    }
132}
133
134pub(crate) trait ReadBorrowed<'a>: Read {
135    fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]>;
136    fn read_str_borrow(&mut self, len: u32) -> Result<&'a str> {
137        Ok(std::str::from_utf8(self.read_slice_borrow(len)?)?)
138    }
139}
140
141impl Read for &[u8] {
142    fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
143        self.read_slice_borrow(n)
144    }
145}
146
147impl<'a> ReadBorrowed<'a> for &'a [u8] {
148    fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]> {
149        let data = self.get(..n as usize).ok_or(MarshalError::Eof)?;
150        *self = &self[n as usize..];
151        Ok(data)
152    }
153}
154
155pub struct Cursor<B> {
156    pub data: B,
157    pub position: usize,
158}
159
160impl<B: AsRef<[u8]>> Read for Cursor<B> {
161    fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
162        let data = &self.data.as_ref()[self.position..];
163        let slice = data.get(..n as usize).ok_or(MarshalError::Eof)?;
164        self.position += n as usize;
165        Ok(slice)
166    }
167}
168
169pub fn deserialize_code<R: Read, Bag: ConstantBag>(
170    rdr: &mut R,
171    bag: Bag,
172) -> Result<CodeObject<Bag::Constant>> {
173    let len = rdr.read_u32()?;
174    let instructions = rdr.read_slice(len * 2)?;
175    let instructions = instructions
176        .chunks_exact(2)
177        .map(|cu| {
178            let op = Instruction::try_from(cu[0])?;
179            let arg = OpArgByte(cu[1]);
180            Ok(CodeUnit { op, arg })
181        })
182        .collect::<Result<Box<[CodeUnit]>>>()?;
183
184    let len = rdr.read_u32()?;
185    let locations = (0..len)
186        .map(|_| {
187            Ok(SourceLocation {
188                row: OneIndexed::new(rdr.read_u32()?).ok_or(MarshalError::InvalidLocation)?,
189                column: OneIndexed::from_zero_indexed(rdr.read_u32()?),
190            })
191        })
192        .collect::<Result<Box<[SourceLocation]>>>()?;
193
194    let flags = CodeFlags::from_bits_truncate(rdr.read_u16()?);
195
196    let posonlyarg_count = rdr.read_u32()?;
197    let arg_count = rdr.read_u32()?;
198    let kwonlyarg_count = rdr.read_u32()?;
199
200    let len = rdr.read_u32()?;
201    let source_path = bag.make_name(rdr.read_str(len)?);
202
203    let first_line_number = OneIndexed::new(rdr.read_u32()?);
204    let max_stackdepth = rdr.read_u32()?;
205
206    let len = rdr.read_u32()?;
207    let obj_name = bag.make_name(rdr.read_str(len)?);
208
209    let len = rdr.read_u32()?;
210    let cell2arg = (len != 0)
211        .then(|| {
212            (0..len)
213                .map(|_| Ok(rdr.read_u32()? as i32))
214                .collect::<Result<Box<[i32]>>>()
215        })
216        .transpose()?;
217
218    let len = rdr.read_u32()?;
219    let constants = (0..len)
220        .map(|_| deserialize_value(rdr, bag))
221        .collect::<Result<Box<[_]>>>()?;
222
223    let mut read_names = || {
224        let len = rdr.read_u32()?;
225        (0..len)
226            .map(|_| {
227                let len = rdr.read_u32()?;
228                Ok(bag.make_name(rdr.read_str(len)?))
229            })
230            .collect::<Result<Box<[_]>>>()
231    };
232
233    let names = read_names()?;
234    let varnames = read_names()?;
235    let cellvars = read_names()?;
236    let freevars = read_names()?;
237
238    Ok(CodeObject {
239        instructions,
240        locations,
241        flags,
242        posonlyarg_count,
243        arg_count,
244        kwonlyarg_count,
245        source_path,
246        first_line_number,
247        max_stackdepth,
248        obj_name,
249        cell2arg,
250        constants,
251        names,
252        varnames,
253        cellvars,
254        freevars,
255    })
256}
257
258pub trait MarshalBag: Copy {
259    type Value;
260    fn make_bool(&self, value: bool) -> Self::Value;
261    fn make_none(&self) -> Self::Value;
262    fn make_ellipsis(&self) -> Self::Value;
263    fn make_float(&self, value: f64) -> Self::Value;
264    fn make_complex(&self, value: Complex64) -> Self::Value;
265    fn make_str(&self, value: &str) -> Self::Value;
266    fn make_bytes(&self, value: &[u8]) -> Self::Value;
267    fn make_int(&self, value: BigInt) -> Self::Value;
268    fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value;
269    fn make_code(
270        &self,
271        code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
272    ) -> Self::Value;
273    fn make_stop_iter(&self) -> Result<Self::Value>;
274    fn make_list(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
275    fn make_set(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
276    fn make_frozenset(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
277    fn make_dict(
278        &self,
279        it: impl Iterator<Item = (Self::Value, Self::Value)>,
280    ) -> Result<Self::Value>;
281    type ConstantBag: ConstantBag;
282    fn constant_bag(self) -> Self::ConstantBag;
283}
284
285impl<Bag: ConstantBag> MarshalBag for Bag {
286    type Value = Bag::Constant;
287    fn make_bool(&self, value: bool) -> Self::Value {
288        self.make_constant::<Bag::Constant>(BorrowedConstant::Boolean { value })
289    }
290    fn make_none(&self) -> Self::Value {
291        self.make_constant::<Bag::Constant>(BorrowedConstant::None)
292    }
293    fn make_ellipsis(&self) -> Self::Value {
294        self.make_constant::<Bag::Constant>(BorrowedConstant::Ellipsis)
295    }
296    fn make_float(&self, value: f64) -> Self::Value {
297        self.make_constant::<Bag::Constant>(BorrowedConstant::Float { value })
298    }
299    fn make_complex(&self, value: Complex64) -> Self::Value {
300        self.make_constant::<Bag::Constant>(BorrowedConstant::Complex { value })
301    }
302    fn make_str(&self, value: &str) -> Self::Value {
303        self.make_constant::<Bag::Constant>(BorrowedConstant::Str { value })
304    }
305    fn make_bytes(&self, value: &[u8]) -> Self::Value {
306        self.make_constant::<Bag::Constant>(BorrowedConstant::Bytes { value })
307    }
308    fn make_int(&self, value: BigInt) -> Self::Value {
309        self.make_int(value)
310    }
311    fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
312        self.make_tuple(elements)
313    }
314    fn make_code(
315        &self,
316        code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
317    ) -> Self::Value {
318        self.make_code(code)
319    }
320    fn make_stop_iter(&self) -> Result<Self::Value> {
321        Err(MarshalError::BadType)
322    }
323    fn make_list(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
324        Err(MarshalError::BadType)
325    }
326    fn make_set(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
327        Err(MarshalError::BadType)
328    }
329    fn make_frozenset(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
330        Err(MarshalError::BadType)
331    }
332    fn make_dict(
333        &self,
334        _: impl Iterator<Item = (Self::Value, Self::Value)>,
335    ) -> Result<Self::Value> {
336        Err(MarshalError::BadType)
337    }
338    type ConstantBag = Self;
339    fn constant_bag(self) -> Self::ConstantBag {
340        self
341    }
342}
343
344pub fn deserialize_value<R: Read, Bag: MarshalBag>(rdr: &mut R, bag: Bag) -> Result<Bag::Value> {
345    let typ = Type::try_from(rdr.read_u8()?)?;
346    let value = match typ {
347        Type::True => bag.make_bool(true),
348        Type::False => bag.make_bool(false),
349        Type::None => bag.make_none(),
350        Type::StopIter => bag.make_stop_iter()?,
351        Type::Ellipsis => bag.make_ellipsis(),
352        Type::Int => {
353            let len = rdr.read_u32()? as i32;
354            let sign = if len < 0 { Sign::Minus } else { Sign::Plus };
355            let bytes = rdr.read_slice(len.unsigned_abs())?;
356            let int = BigInt::from_bytes_le(sign, bytes);
357            bag.make_int(int)
358        }
359        Type::Float => {
360            let value = f64::from_bits(rdr.read_u64()?);
361            bag.make_float(value)
362        }
363        Type::Complex => {
364            let re = f64::from_bits(rdr.read_u64()?);
365            let im = f64::from_bits(rdr.read_u64()?);
366            let value = Complex64 { re, im };
367            bag.make_complex(value)
368        }
369        Type::Ascii | Type::Unicode => {
370            let len = rdr.read_u32()?;
371            let value = rdr.read_str(len)?;
372            bag.make_str(value)
373        }
374        Type::Tuple => {
375            let len = rdr.read_u32()?;
376            let it = (0..len).map(|_| deserialize_value(rdr, bag));
377            itertools::process_results(it, |it| bag.make_tuple(it))?
378        }
379        Type::List => {
380            let len = rdr.read_u32()?;
381            let it = (0..len).map(|_| deserialize_value(rdr, bag));
382            itertools::process_results(it, |it| bag.make_list(it))??
383        }
384        Type::Set => {
385            let len = rdr.read_u32()?;
386            let it = (0..len).map(|_| deserialize_value(rdr, bag));
387            itertools::process_results(it, |it| bag.make_set(it))??
388        }
389        Type::FrozenSet => {
390            let len = rdr.read_u32()?;
391            let it = (0..len).map(|_| deserialize_value(rdr, bag));
392            itertools::process_results(it, |it| bag.make_frozenset(it))??
393        }
394        Type::Dict => {
395            let len = rdr.read_u32()?;
396            let it = (0..len).map(|_| {
397                let k = deserialize_value(rdr, bag)?;
398                let v = deserialize_value(rdr, bag)?;
399                Ok::<_, MarshalError>((k, v))
400            });
401            itertools::process_results(it, |it| bag.make_dict(it))??
402        }
403        Type::Bytes => {
404            // Following CPython, after marshaling, byte arrays are converted into bytes.
405            let len = rdr.read_u32()?;
406            let value = rdr.read_slice(len)?;
407            bag.make_bytes(value)
408        }
409        Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?),
410    };
411    Ok(value)
412}
413
414pub trait Dumpable: Sized {
415    type Error;
416    type Constant: Constant;
417    fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error>;
418}
419
420pub enum DumpableValue<'a, D: Dumpable> {
421    Integer(&'a BigInt),
422    Float(f64),
423    Complex(Complex64),
424    Boolean(bool),
425    Str(&'a str),
426    Bytes(&'a [u8]),
427    Code(&'a CodeObject<D::Constant>),
428    Tuple(&'a [D]),
429    None,
430    Ellipsis,
431    StopIter,
432    List(&'a [D]),
433    Set(&'a [D]),
434    Frozenset(&'a [D]),
435    Dict(&'a [(D, D)]),
436}
437
438impl<'a, C: Constant> From<BorrowedConstant<'a, C>> for DumpableValue<'a, C> {
439    fn from(c: BorrowedConstant<'a, C>) -> Self {
440        match c {
441            BorrowedConstant::Integer { value } => Self::Integer(value),
442            BorrowedConstant::Float { value } => Self::Float(value),
443            BorrowedConstant::Complex { value } => Self::Complex(value),
444            BorrowedConstant::Boolean { value } => Self::Boolean(value),
445            BorrowedConstant::Str { value } => Self::Str(value),
446            BorrowedConstant::Bytes { value } => Self::Bytes(value),
447            BorrowedConstant::Code { code } => Self::Code(code),
448            BorrowedConstant::Tuple { elements } => Self::Tuple(elements),
449            BorrowedConstant::None => Self::None,
450            BorrowedConstant::Ellipsis => Self::Ellipsis,
451        }
452    }
453}
454
455impl<C: Constant> Dumpable for C {
456    type Error = Infallible;
457    type Constant = Self;
458    #[inline(always)]
459    fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error> {
460        Ok(f(self.borrow_constant().into()))
461    }
462}
463
464pub trait Write {
465    fn write_slice(&mut self, slice: &[u8]);
466    fn write_u8(&mut self, v: u8) {
467        self.write_slice(&v.to_le_bytes())
468    }
469    fn write_u16(&mut self, v: u16) {
470        self.write_slice(&v.to_le_bytes())
471    }
472    fn write_u32(&mut self, v: u32) {
473        self.write_slice(&v.to_le_bytes())
474    }
475    fn write_u64(&mut self, v: u64) {
476        self.write_slice(&v.to_le_bytes())
477    }
478}
479
480impl Write for Vec<u8> {
481    fn write_slice(&mut self, slice: &[u8]) {
482        self.extend_from_slice(slice)
483    }
484}
485
486pub(crate) fn write_len<W: Write>(buf: &mut W, len: usize) {
487    let Ok(len) = len.try_into() else {
488        panic!("too long to serialize")
489    };
490    buf.write_u32(len);
491}
492
493pub(crate) fn write_vec<W: Write>(buf: &mut W, slice: &[u8]) {
494    write_len(buf, slice.len());
495    buf.write_slice(slice);
496}
497
498pub fn serialize_value<W: Write, D: Dumpable>(
499    buf: &mut W,
500    constant: DumpableValue<'_, D>,
501) -> Result<(), D::Error> {
502    match constant {
503        DumpableValue::Integer(int) => {
504            buf.write_u8(Type::Int as u8);
505            let (sign, bytes) = int.to_bytes_le();
506            let len: i32 = bytes.len().try_into().expect("too long to serialize");
507            let len = if sign == Sign::Minus { -len } else { len };
508            buf.write_u32(len as u32);
509            buf.write_slice(&bytes);
510        }
511        DumpableValue::Float(f) => {
512            buf.write_u8(Type::Float as u8);
513            buf.write_u64(f.to_bits());
514        }
515        DumpableValue::Complex(c) => {
516            buf.write_u8(Type::Complex as u8);
517            buf.write_u64(c.re.to_bits());
518            buf.write_u64(c.im.to_bits());
519        }
520        DumpableValue::Boolean(b) => {
521            buf.write_u8(if b { Type::True } else { Type::False } as u8);
522        }
523        DumpableValue::Str(s) => {
524            buf.write_u8(Type::Unicode as u8);
525            write_vec(buf, s.as_bytes());
526        }
527        DumpableValue::Bytes(b) => {
528            buf.write_u8(Type::Bytes as u8);
529            write_vec(buf, b);
530        }
531        DumpableValue::Code(c) => {
532            buf.write_u8(Type::Code as u8);
533            serialize_code(buf, c);
534        }
535        DumpableValue::Tuple(tup) => {
536            buf.write_u8(Type::Tuple as u8);
537            write_len(buf, tup.len());
538            for val in tup {
539                val.with_dump(|val| serialize_value(buf, val))??
540            }
541        }
542        DumpableValue::None => {
543            buf.write_u8(Type::None as u8);
544        }
545        DumpableValue::Ellipsis => {
546            buf.write_u8(Type::Ellipsis as u8);
547        }
548        DumpableValue::StopIter => {
549            buf.write_u8(Type::StopIter as u8);
550        }
551        DumpableValue::List(l) => {
552            buf.write_u8(Type::List as u8);
553            write_len(buf, l.len());
554            for val in l {
555                val.with_dump(|val| serialize_value(buf, val))??
556            }
557        }
558        DumpableValue::Set(set) => {
559            buf.write_u8(Type::Set as u8);
560            write_len(buf, set.len());
561            for val in set {
562                val.with_dump(|val| serialize_value(buf, val))??
563            }
564        }
565        DumpableValue::Frozenset(set) => {
566            buf.write_u8(Type::FrozenSet as u8);
567            write_len(buf, set.len());
568            for val in set {
569                val.with_dump(|val| serialize_value(buf, val))??
570            }
571        }
572        DumpableValue::Dict(d) => {
573            buf.write_u8(Type::Dict as u8);
574            write_len(buf, d.len());
575            for (k, v) in d {
576                k.with_dump(|val| serialize_value(buf, val))??;
577                v.with_dump(|val| serialize_value(buf, val))??;
578            }
579        }
580    }
581    Ok(())
582}
583
584pub fn serialize_code<W: Write, C: Constant>(buf: &mut W, code: &CodeObject<C>) {
585    write_len(buf, code.instructions.len());
586    // SAFETY: it's ok to transmute CodeUnit to [u8; 2]
587    let (_, instructions_bytes, _) = unsafe { code.instructions.align_to() };
588    buf.write_slice(instructions_bytes);
589
590    write_len(buf, code.locations.len());
591    for loc in &*code.locations {
592        buf.write_u32(loc.row.get());
593        buf.write_u32(loc.column.to_zero_indexed());
594    }
595
596    buf.write_u16(code.flags.bits());
597
598    buf.write_u32(code.posonlyarg_count);
599    buf.write_u32(code.arg_count);
600    buf.write_u32(code.kwonlyarg_count);
601
602    write_vec(buf, code.source_path.as_ref().as_bytes());
603
604    buf.write_u32(code.first_line_number.map_or(0, |x| x.get()));
605    buf.write_u32(code.max_stackdepth);
606
607    write_vec(buf, code.obj_name.as_ref().as_bytes());
608
609    let cell2arg = code.cell2arg.as_deref().unwrap_or(&[]);
610    write_len(buf, cell2arg.len());
611    for &i in cell2arg {
612        buf.write_u32(i as u32)
613    }
614
615    write_len(buf, code.constants.len());
616    for constant in &*code.constants {
617        serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {})
618    }
619
620    let mut write_names = |names: &[C::Name]| {
621        write_len(buf, names.len());
622        for name in names {
623            write_vec(buf, name.as_ref().as_bytes());
624        }
625    };
626
627    write_names(&code.names);
628    write_names(&code.varnames);
629    write_names(&code.cellvars);
630    write_names(&code.freevars);
631}