Skip to main content

rustpython_compiler_core/
marshal.rs

1use crate::{OneIndexed, SourceLocation, bytecode::*};
2use alloc::{boxed::Box, vec::Vec};
3use core::convert::Infallible;
4use malachite_bigint::{BigInt, Sign};
5use num_complex::Complex64;
6use rustpython_wtf8::Wtf8;
7
8pub const FORMAT_VERSION: u32 = 5;
9
10#[derive(Clone, Copy, Debug)]
11pub enum MarshalError {
12    /// Unexpected End Of File
13    Eof,
14    /// Invalid Bytecode
15    InvalidBytecode,
16    /// Invalid utf8 in string
17    InvalidUtf8,
18    /// Invalid source location
19    InvalidLocation,
20    /// Bad type marker
21    BadType,
22}
23
24impl core::fmt::Display for MarshalError {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        match self {
27            Self::Eof => f.write_str("unexpected end of data"),
28            Self::InvalidBytecode => f.write_str("invalid bytecode"),
29            Self::InvalidUtf8 => f.write_str("invalid utf8"),
30            Self::InvalidLocation => f.write_str("invalid source location"),
31            Self::BadType => f.write_str("bad type marker"),
32        }
33    }
34}
35
36impl From<core::str::Utf8Error> for MarshalError {
37    fn from(_: core::str::Utf8Error) -> Self {
38        Self::InvalidUtf8
39    }
40}
41
42impl core::error::Error for MarshalError {}
43
44type Result<T, E = MarshalError> = core::result::Result<T, E>;
45
46#[derive(Clone, Copy)]
47#[repr(u8)]
48enum Type {
49    Null = b'0',
50    None = b'N',
51    False = b'F',
52    True = b'T',
53    StopIter = b'S',
54    Ellipsis = b'.',
55    Int = b'i',
56    Int64 = b'I',
57    Long = b'l',
58    Float = b'g',
59    FloatStr = b'f',
60    ComplexStr = b'x',
61    Complex = b'y',
62    Bytes = b's',
63    Interned = b't',
64    Ref = b'r',
65    Tuple = b'(',
66    SmallTuple = b')',
67    List = b'[',
68    Dict = b'{',
69    Code = b'c',
70    Unicode = b'u',
71    Set = b'<',
72    FrozenSet = b'>',
73    Slice = b':',
74    Ascii = b'a',
75    AsciiInterned = b'A',
76    ShortAscii = b'z',
77    ShortAsciiInterned = b'Z',
78}
79
80impl TryFrom<u8> for Type {
81    type Error = MarshalError;
82
83    fn try_from(value: u8) -> Result<Self> {
84        use Type::*;
85        Ok(match value {
86            b'0' => Null,
87            b'N' => None,
88            b'F' => False,
89            b'T' => True,
90            b'S' => StopIter,
91            b'.' => Ellipsis,
92            b'i' => Int,
93            b'I' => Int64,
94            b'l' => Long,
95            b'f' => FloatStr,
96            b'g' => Float,
97            b'x' => ComplexStr,
98            b'y' => Complex,
99            b's' => Bytes,
100            b't' => Interned,
101            b'r' => Ref,
102            b'(' => Tuple,
103            b')' => SmallTuple,
104            b'[' => List,
105            b'{' => Dict,
106            b'c' => Code,
107            b'u' => Unicode,
108            b'<' => Set,
109            b'>' => FrozenSet,
110            b':' => Slice,
111            b'a' => Ascii,
112            b'A' => AsciiInterned,
113            b'z' => ShortAscii,
114            b'Z' => ShortAsciiInterned,
115            _ => return Err(MarshalError::BadType),
116        })
117    }
118}
119
120pub trait Read {
121    fn read_slice(&mut self, n: u32) -> Result<&[u8]>;
122
123    fn read_array<const N: usize>(&mut self) -> Result<&[u8; N]> {
124        self.read_slice(N as u32).map(|s| s.try_into().unwrap())
125    }
126
127    fn read_str(&mut self, len: u32) -> Result<&str> {
128        Ok(core::str::from_utf8(self.read_slice(len)?)?)
129    }
130
131    fn read_wtf8(&mut self, len: u32) -> Result<&Wtf8> {
132        Wtf8::from_bytes(self.read_slice(len)?).ok_or(MarshalError::InvalidUtf8)
133    }
134
135    fn read_u8(&mut self) -> Result<u8> {
136        Ok(u8::from_le_bytes(*self.read_array()?))
137    }
138
139    fn read_u16(&mut self) -> Result<u16> {
140        Ok(u16::from_le_bytes(*self.read_array()?))
141    }
142
143    fn read_u32(&mut self) -> Result<u32> {
144        Ok(u32::from_le_bytes(*self.read_array()?))
145    }
146
147    fn read_u64(&mut self) -> Result<u64> {
148        Ok(u64::from_le_bytes(*self.read_array()?))
149    }
150}
151
152pub(crate) trait ReadBorrowed<'a>: Read {
153    fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]>;
154
155    fn read_str_borrow(&mut self, len: u32) -> Result<&'a str> {
156        Ok(core::str::from_utf8(self.read_slice_borrow(len)?)?)
157    }
158}
159
160impl Read for &[u8] {
161    fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
162        self.read_slice_borrow(n)
163    }
164
165    fn read_array<const N: usize>(&mut self) -> Result<&[u8; N]> {
166        let (chunk, rest) = self.split_first_chunk::<N>().ok_or(MarshalError::Eof)?;
167        *self = rest;
168        Ok(chunk)
169    }
170}
171
172impl<'a> ReadBorrowed<'a> for &'a [u8] {
173    fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]> {
174        self.split_off(..n as usize).ok_or(MarshalError::Eof)
175    }
176}
177
178pub struct Cursor<B> {
179    pub data: B,
180    pub position: usize,
181}
182
183impl<B: AsRef<[u8]>> Read for Cursor<B> {
184    fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
185        let data = &self.data.as_ref()[self.position..];
186        let slice = data.get(..n as usize).ok_or(MarshalError::Eof)?;
187        self.position += n as usize;
188        Ok(slice)
189    }
190}
191
192/// Deserialize a code object (CPython field order).
193pub fn deserialize_code<R: Read, Bag: ConstantBag>(
194    rdr: &mut R,
195    bag: Bag,
196) -> Result<CodeObject<Bag::Constant>> {
197    // 1–5: scalar fields
198    let arg_count = rdr.read_u32()?;
199    let posonlyarg_count = rdr.read_u32()?;
200    let kwonlyarg_count = rdr.read_u32()?;
201    let max_stackdepth = rdr.read_u32()?;
202    let flags = CodeFlags::from_bits_truncate(rdr.read_u32()?);
203
204    // 6: co_code
205    let code_bytes = read_marshal_bytes(rdr)?;
206
207    // 7: co_consts
208    let constants = read_marshal_const_tuple(rdr, bag)?;
209
210    // 8: co_names
211    let names = read_marshal_name_tuple(rdr, &bag)?;
212
213    // 9: co_localsplusnames
214    let localsplusnames = read_marshal_str_vec(rdr)?;
215
216    // 10: co_localspluskinds
217    let localspluskinds = read_marshal_bytes(rdr)?;
218
219    // 11–13: filename, name, qualname
220    let source_path = bag.make_name(&read_marshal_str(rdr)?);
221    let obj_name = bag.make_name(&read_marshal_str(rdr)?);
222    let qualname = bag.make_name(&read_marshal_str(rdr)?);
223
224    // 14: co_firstlineno
225    let first_line_raw = rdr.read_u32()? as i32;
226    let first_line_number = if first_line_raw > 0 {
227        OneIndexed::new(first_line_raw as usize)
228    } else {
229        None
230    };
231
232    // 15–16: linetable, exceptiontable
233    let linetable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice();
234    let exceptiontable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice();
235
236    // Split localsplusnames/kinds → varnames/cellvars/freevars
237    let lp = split_localplus(
238        &localsplusnames
239            .iter()
240            .map(|s| s.as_str())
241            .collect::<Vec<&str>>(),
242        &localspluskinds,
243        arg_count,
244        kwonlyarg_count,
245        flags,
246    )?;
247
248    // Bytecode already uses flat localsplus indices (no translation needed)
249    let instructions = CodeUnits::try_from(code_bytes.as_slice())?;
250    let locations = linetable_to_locations(&linetable, first_line_raw, instructions.len());
251
252    // Use original localspluskinds from marshal data (preserves CO_FAST_HIDDEN etc.)
253    let localspluskinds = localspluskinds.into_boxed_slice();
254
255    Ok(CodeObject {
256        instructions,
257        locations,
258        flags,
259        posonlyarg_count,
260        arg_count,
261        kwonlyarg_count,
262        source_path,
263        first_line_number,
264        max_stackdepth,
265        obj_name,
266        qualname,
267        constants,
268        names,
269        varnames: lp.varnames.iter().map(|s| bag.make_name(s)).collect(),
270        cellvars: lp.cellvars.iter().map(|s| bag.make_name(s)).collect(),
271        freevars: lp.freevars.iter().map(|s| bag.make_name(s)).collect(),
272        localspluskinds,
273        linetable,
274        exceptiontable,
275    })
276}
277
278/// Read a marshal bytes object (TYPE_STRING = b's').
279fn read_marshal_bytes<R: Read>(rdr: &mut R) -> Result<Vec<u8>> {
280    let type_byte = rdr.read_u8()? & !FLAG_REF;
281    if type_byte != Type::Bytes as u8 {
282        return Err(MarshalError::BadType);
283    }
284    let len = rdr.read_u32()?;
285    Ok(rdr.read_slice(len)?.to_vec())
286}
287
288/// Read a marshal string object.
289fn read_marshal_str<R: Read>(rdr: &mut R) -> Result<alloc::string::String> {
290    let type_byte = rdr.read_u8()? & !FLAG_REF;
291    let s = match type_byte {
292        b'u' | b't' | b'a' | b'A' => {
293            let len = rdr.read_u32()?;
294            rdr.read_str(len)?
295        }
296        b'z' | b'Z' => {
297            let len = rdr.read_u8()? as u32;
298            rdr.read_str(len)?
299        }
300        _ => return Err(MarshalError::BadType),
301    };
302    Ok(alloc::string::String::from(s))
303}
304
305/// Read a marshal tuple of strings, returning owned Strings.
306fn read_marshal_str_vec<R: Read>(rdr: &mut R) -> Result<Vec<alloc::string::String>> {
307    let type_byte = rdr.read_u8()? & !FLAG_REF;
308    let n = match type_byte {
309        b'(' => rdr.read_u32()? as usize,
310        b')' => rdr.read_u8()? as usize,
311        _ => return Err(MarshalError::BadType),
312    };
313    (0..n).map(|_| read_marshal_str(rdr)).collect()
314}
315
316fn read_marshal_name_tuple<R: Read, Bag: ConstantBag>(
317    rdr: &mut R,
318    bag: &Bag,
319) -> Result<Box<[<Bag::Constant as Constant>::Name]>> {
320    let type_byte = rdr.read_u8()? & !FLAG_REF;
321    let n = match type_byte {
322        b'(' => rdr.read_u32()? as usize,
323        b')' => rdr.read_u8()? as usize,
324        _ => return Err(MarshalError::BadType),
325    };
326    (0..n)
327        .map(|_| Ok(bag.make_name(&read_marshal_str(rdr)?)))
328        .collect::<Result<Vec<_>>>()
329        .map(Vec::into_boxed_slice)
330}
331
332/// Read a marshal tuple of constants.
333fn read_marshal_const_tuple<R: Read, Bag: ConstantBag>(
334    rdr: &mut R,
335    bag: Bag,
336) -> Result<Constants<Bag::Constant>> {
337    let type_byte = rdr.read_u8()? & !FLAG_REF;
338    let n = match type_byte {
339        b'(' => rdr.read_u32()? as usize,
340        b')' => rdr.read_u8()? as usize,
341        _ => return Err(MarshalError::BadType),
342    };
343    (0..n).map(|_| deserialize_value(rdr, bag)).collect()
344}
345
346pub trait MarshalBag: Copy {
347    type Value: Clone;
348    type ConstantBag: ConstantBag;
349
350    fn make_bool(&self, value: bool) -> Self::Value;
351
352    fn make_none(&self) -> Self::Value;
353
354    fn make_ellipsis(&self) -> Self::Value;
355
356    fn make_float(&self, value: f64) -> Self::Value;
357
358    fn make_complex(&self, value: Complex64) -> Self::Value;
359
360    fn make_str(&self, value: &Wtf8) -> Self::Value;
361
362    fn make_bytes(&self, value: &[u8]) -> Self::Value;
363
364    fn make_int(&self, value: BigInt) -> Self::Value;
365
366    fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value;
367
368    fn make_code(
369        &self,
370        code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
371    ) -> Self::Value;
372
373    fn make_stop_iter(&self) -> Result<Self::Value>;
374
375    fn make_list(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
376
377    fn make_set(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
378
379    fn make_frozenset(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
380
381    fn make_dict(
382        &self,
383        it: impl Iterator<Item = (Self::Value, Self::Value)>,
384    ) -> Result<Self::Value>;
385
386    fn make_slice(
387        &self,
388        _start: Self::Value,
389        _stop: Self::Value,
390        _step: Self::Value,
391    ) -> Result<Self::Value> {
392        Err(MarshalError::BadType)
393    }
394
395    fn constant_bag(self) -> Self::ConstantBag;
396}
397
398impl<Bag: ConstantBag> MarshalBag for Bag {
399    type Value = Bag::Constant;
400    type ConstantBag = Self;
401
402    fn make_bool(&self, value: bool) -> Self::Value {
403        self.make_constant::<Bag::Constant>(BorrowedConstant::Boolean { value })
404    }
405
406    fn make_none(&self) -> Self::Value {
407        self.make_constant::<Bag::Constant>(BorrowedConstant::None)
408    }
409
410    fn make_ellipsis(&self) -> Self::Value {
411        self.make_constant::<Bag::Constant>(BorrowedConstant::Ellipsis)
412    }
413
414    fn make_float(&self, value: f64) -> Self::Value {
415        self.make_constant::<Bag::Constant>(BorrowedConstant::Float { value })
416    }
417
418    fn make_complex(&self, value: Complex64) -> Self::Value {
419        self.make_constant::<Bag::Constant>(BorrowedConstant::Complex { value })
420    }
421
422    fn make_str(&self, value: &Wtf8) -> Self::Value {
423        self.make_constant::<Bag::Constant>(BorrowedConstant::Str { value })
424    }
425
426    fn make_bytes(&self, value: &[u8]) -> Self::Value {
427        self.make_constant::<Bag::Constant>(BorrowedConstant::Bytes { value })
428    }
429
430    fn make_int(&self, value: BigInt) -> Self::Value {
431        self.make_int(value)
432    }
433
434    fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
435        self.make_tuple(elements)
436    }
437
438    fn make_slice(
439        &self,
440        start: Self::Value,
441        stop: Self::Value,
442        step: Self::Value,
443    ) -> Result<Self::Value> {
444        let elements = [start, stop, step];
445        Ok(
446            self.make_constant::<Bag::Constant>(BorrowedConstant::Slice {
447                elements: &elements,
448            }),
449        )
450    }
451
452    fn make_code(
453        &self,
454        code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
455    ) -> Self::Value {
456        self.make_code(code)
457    }
458
459    fn make_stop_iter(&self) -> Result<Self::Value> {
460        Err(MarshalError::BadType)
461    }
462
463    fn make_list(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
464        Err(MarshalError::BadType)
465    }
466
467    fn make_set(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
468        Err(MarshalError::BadType)
469    }
470
471    fn make_frozenset(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
472        let elements: Vec<Self::Value> = it.collect();
473        Ok(
474            self.make_constant::<Bag::Constant>(BorrowedConstant::Frozenset {
475                elements: &elements,
476            }),
477        )
478    }
479
480    fn make_dict(
481        &self,
482        _: impl Iterator<Item = (Self::Value, Self::Value)>,
483    ) -> Result<Self::Value> {
484        Err(MarshalError::BadType)
485    }
486
487    fn constant_bag(self) -> Self::ConstantBag {
488        self
489    }
490}
491
492pub const MAX_MARSHAL_STACK_DEPTH: usize = 2000;
493
494pub fn deserialize_value<R: Read, Bag: MarshalBag>(rdr: &mut R, bag: Bag) -> Result<Bag::Value> {
495    let mut refs: Vec<Option<Bag::Value>> = Vec::new();
496    deserialize_value_depth(rdr, bag, MAX_MARSHAL_STACK_DEPTH, &mut refs)
497}
498
499fn deserialize_value_depth<R: Read, Bag: MarshalBag>(
500    rdr: &mut R,
501    bag: Bag,
502    depth: usize,
503    refs: &mut Vec<Option<Bag::Value>>,
504) -> Result<Bag::Value> {
505    if depth == 0 {
506        return Err(MarshalError::InvalidBytecode);
507    }
508    let raw = rdr.read_u8()?;
509    let flag = raw & FLAG_REF != 0;
510    let type_code = raw & !FLAG_REF;
511
512    // TYPE_REF: return previously stored object
513    if type_code == Type::Ref as u8 {
514        let idx = rdr.read_u32()? as usize;
515        return refs
516            .get(idx)
517            .and_then(|v| v.clone())
518            .ok_or(MarshalError::InvalidBytecode);
519    }
520
521    // Reserve ref slot before reading (matches write order)
522    let slot = if flag {
523        let idx = refs.len();
524        refs.push(None);
525        Some(idx)
526    } else {
527        None
528    };
529
530    let typ = Type::try_from(type_code)?;
531    let value = deserialize_value_typed(rdr, bag, depth, refs, typ)?;
532
533    if let Some(idx) = slot {
534        refs[idx] = Some(value.clone());
535    }
536    Ok(value)
537}
538
539fn deserialize_value_typed<R: Read, Bag: MarshalBag>(
540    rdr: &mut R,
541    bag: Bag,
542    depth: usize,
543    refs: &mut Vec<Option<Bag::Value>>,
544    typ: Type,
545) -> Result<Bag::Value> {
546    if depth == 0 {
547        return Err(MarshalError::InvalidBytecode);
548    }
549    let value = match typ {
550        Type::True => bag.make_bool(true),
551        Type::False => bag.make_bool(false),
552        Type::None => bag.make_none(),
553        Type::StopIter => bag.make_stop_iter()?,
554        Type::Ellipsis => bag.make_ellipsis(),
555        Type::Int => {
556            let val = rdr.read_u32()? as i32;
557            bag.make_int(BigInt::from(val))
558        }
559        Type::Int64 => {
560            let lo = rdr.read_u32()? as u64;
561            let hi = rdr.read_u32()? as u64;
562            bag.make_int(BigInt::from(((hi << 32) | lo) as i64))
563        }
564        Type::Long => bag.make_int(read_pylong(rdr)?),
565        Type::FloatStr => bag.make_float(read_float_str(rdr)?),
566        Type::Float => {
567            let value = f64::from_bits(rdr.read_u64()?);
568            bag.make_float(value)
569        }
570        Type::ComplexStr => {
571            let re = read_float_str(rdr)?;
572            let im = read_float_str(rdr)?;
573            bag.make_complex(Complex64 { re, im })
574        }
575        Type::Complex => {
576            let re = f64::from_bits(rdr.read_u64()?);
577            let im = f64::from_bits(rdr.read_u64()?);
578            let value = Complex64 { re, im };
579            bag.make_complex(value)
580        }
581        Type::Ascii | Type::AsciiInterned | Type::Unicode | Type::Interned => {
582            let len = rdr.read_u32()?;
583            let value = rdr.read_wtf8(len)?;
584            bag.make_str(value)
585        }
586        Type::ShortAscii | Type::ShortAsciiInterned => {
587            let len = rdr.read_u8()? as u32;
588            let value = rdr.read_wtf8(len)?;
589            bag.make_str(value)
590        }
591        Type::SmallTuple => {
592            let len = rdr.read_u8()? as usize;
593            let d = depth - 1;
594            let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs));
595            itertools::process_results(it, |it| bag.make_tuple(it))?
596        }
597        Type::Null => {
598            return Err(MarshalError::BadType);
599        }
600        Type::Ref => {
601            // Handled in deserialize_value_depth before calling this function
602            return Err(MarshalError::BadType);
603        }
604        Type::Tuple => {
605            let len = rdr.read_u32()?;
606            let d = depth - 1;
607            let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs));
608            itertools::process_results(it, |it| bag.make_tuple(it))?
609        }
610        Type::List => {
611            let len = rdr.read_u32()?;
612            let d = depth - 1;
613            let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs));
614            itertools::process_results(it, |it| bag.make_list(it))??
615        }
616        Type::Set => {
617            let len = rdr.read_u32()?;
618            let d = depth - 1;
619            let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs));
620            itertools::process_results(it, |it| bag.make_set(it))??
621        }
622        Type::FrozenSet => {
623            let len = rdr.read_u32()?;
624            let d = depth - 1;
625            let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs));
626            itertools::process_results(it, |it| bag.make_frozenset(it))??
627        }
628        Type::Dict => {
629            let d = depth - 1;
630            let mut pairs = Vec::new();
631            loop {
632                let raw = rdr.read_u8()?;
633                let type_code = raw & !FLAG_REF;
634                if type_code == b'0' {
635                    break;
636                }
637                // TYPE_REF for key
638                let k = if type_code == Type::Ref as u8 {
639                    let idx = rdr.read_u32()? as usize;
640                    refs.get(idx)
641                        .and_then(|v| v.clone())
642                        .ok_or(MarshalError::InvalidBytecode)?
643                } else {
644                    let flag = raw & FLAG_REF != 0;
645                    let key_slot = if flag {
646                        let idx = refs.len();
647                        refs.push(None);
648                        Some(idx)
649                    } else {
650                        None
651                    };
652                    let key_type = Type::try_from(type_code)?;
653                    let k = deserialize_value_typed(rdr, bag, d, refs, key_type)?;
654                    if let Some(idx) = key_slot {
655                        refs[idx] = Some(k.clone());
656                    }
657                    k
658                };
659                let v = deserialize_value_depth(rdr, bag, d, refs)?;
660                pairs.push((k, v));
661            }
662            bag.make_dict(pairs.into_iter())?
663        }
664        Type::Bytes => {
665            // After marshaling, byte arrays are converted into bytes.
666            let len = rdr.read_u32()?;
667            let value = rdr.read_slice(len)?;
668            bag.make_bytes(value)
669        }
670        Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?),
671        Type::Slice => {
672            let d = depth - 1;
673            let start = deserialize_value_depth(rdr, bag, d, refs)?;
674            let stop = deserialize_value_depth(rdr, bag, d, refs)?;
675            let step = deserialize_value_depth(rdr, bag, d, refs)?;
676            bag.make_slice(start, stop, step)?
677        }
678    };
679    Ok(value)
680}
681
682pub trait Dumpable: Sized {
683    type Error;
684    type Constant: Constant;
685
686    fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error>;
687}
688
689pub enum DumpableValue<'a, D: Dumpable> {
690    Integer(&'a BigInt),
691    Float(f64),
692    Complex(Complex64),
693    Boolean(bool),
694    Str(&'a Wtf8),
695    Bytes(&'a [u8]),
696    Code(&'a CodeObject<D::Constant>),
697    Tuple(&'a [D]),
698    None,
699    Ellipsis,
700    StopIter,
701    List(&'a [D]),
702    Set(&'a [D]),
703    Frozenset(&'a [D]),
704    Dict(&'a [(D, D)]),
705    Slice(&'a D, &'a D, &'a D),
706}
707
708impl<'a, C: Constant> From<BorrowedConstant<'a, C>> for DumpableValue<'a, C> {
709    fn from(c: BorrowedConstant<'a, C>) -> Self {
710        match c {
711            BorrowedConstant::Integer { value } => Self::Integer(value),
712            BorrowedConstant::Float { value } => Self::Float(value),
713            BorrowedConstant::Complex { value } => Self::Complex(value),
714            BorrowedConstant::Boolean { value } => Self::Boolean(value),
715            BorrowedConstant::Str { value } => Self::Str(value),
716            BorrowedConstant::Bytes { value } => Self::Bytes(value),
717            BorrowedConstant::Code { code } => Self::Code(code),
718            BorrowedConstant::Tuple { elements } => Self::Tuple(elements),
719            BorrowedConstant::Slice { elements } => {
720                Self::Slice(&elements[0], &elements[1], &elements[2])
721            }
722            BorrowedConstant::Frozenset { elements } => Self::Frozenset(elements),
723            BorrowedConstant::None => Self::None,
724            BorrowedConstant::Ellipsis => Self::Ellipsis,
725        }
726    }
727}
728
729impl<C: Constant> Dumpable for C {
730    type Error = Infallible;
731    type Constant = Self;
732
733    #[inline(always)]
734    fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error> {
735        Ok(f(self.borrow_constant().into()))
736    }
737}
738
739pub trait Write {
740    fn write_slice(&mut self, slice: &[u8]);
741
742    fn write_u8(&mut self, v: u8) {
743        self.write_slice(&v.to_le_bytes())
744    }
745
746    fn write_u16(&mut self, v: u16) {
747        self.write_slice(&v.to_le_bytes())
748    }
749
750    fn write_u32(&mut self, v: u32) {
751        self.write_slice(&v.to_le_bytes())
752    }
753
754    fn write_u64(&mut self, v: u64) {
755        self.write_slice(&v.to_le_bytes())
756    }
757}
758
759impl Write for Vec<u8> {
760    fn write_slice(&mut self, slice: &[u8]) {
761        self.extend_from_slice(slice)
762    }
763}
764
765pub(crate) fn write_len<W: Write>(buf: &mut W, len: usize) {
766    let Ok(len) = len.try_into() else {
767        panic!("too long to serialize")
768    };
769    buf.write_u32(len);
770}
771
772pub(crate) fn write_vec<W: Write>(buf: &mut W, slice: &[u8]) {
773    write_len(buf, slice.len());
774    buf.write_slice(slice);
775}
776
777pub fn serialize_value<W: Write, D: Dumpable>(
778    buf: &mut W,
779    constant: DumpableValue<'_, D>,
780) -> Result<(), D::Error> {
781    match constant {
782        DumpableValue::Integer(int) => {
783            if let Ok(val) = i32::try_from(int) {
784                buf.write_u8(Type::Int as u8); // TYPE_INT: 4-byte LE i32
785                buf.write_u32(val as u32);
786            } else {
787                buf.write_u8(Type::Long as u8);
788                let (sign, raw) = int.to_bytes_le();
789                let mut digits = alloc::vec::Vec::new();
790                let mut accum: u32 = 0;
791                let mut bits = 0u32;
792                for &byte in &raw {
793                    accum |= (byte as u32) << bits;
794                    bits += 8;
795                    while bits >= 15 {
796                        digits.push((accum & 0x7fff) as u16);
797                        accum >>= 15;
798                        bits -= 15;
799                    }
800                }
801                if accum > 0 || digits.is_empty() {
802                    digits.push(accum as u16);
803                }
804                while digits.len() > 1 && *digits.last().unwrap() == 0 {
805                    digits.pop();
806                }
807                let n = digits.len() as i32;
808                let n = if sign == Sign::Minus { -n } else { n };
809                buf.write_u32(n as u32);
810                for d in &digits {
811                    buf.write_u16(*d);
812                }
813            }
814        }
815        DumpableValue::Float(f) => {
816            buf.write_u8(Type::Float as u8);
817            buf.write_u64(f.to_bits());
818        }
819        DumpableValue::Complex(c) => {
820            buf.write_u8(Type::Complex as u8);
821            buf.write_u64(c.re.to_bits());
822            buf.write_u64(c.im.to_bits());
823        }
824        DumpableValue::Boolean(b) => {
825            buf.write_u8(if b { Type::True } else { Type::False } as u8);
826        }
827        DumpableValue::Str(s) => {
828            buf.write_u8(Type::Unicode as u8);
829            write_vec(buf, s.as_bytes());
830        }
831        DumpableValue::Bytes(b) => {
832            buf.write_u8(Type::Bytes as u8);
833            write_vec(buf, b);
834        }
835        DumpableValue::Code(c) => {
836            buf.write_u8(Type::Code as u8);
837            serialize_code(buf, c);
838        }
839        DumpableValue::Tuple(tup) => {
840            buf.write_u8(Type::Tuple as u8);
841            write_len(buf, tup.len());
842            for val in tup {
843                val.with_dump(|val| serialize_value(buf, val))??
844            }
845        }
846        DumpableValue::None => {
847            buf.write_u8(Type::None as u8);
848        }
849        DumpableValue::Ellipsis => {
850            buf.write_u8(Type::Ellipsis as u8);
851        }
852        DumpableValue::StopIter => {
853            buf.write_u8(Type::StopIter as u8);
854        }
855        DumpableValue::List(l) => {
856            buf.write_u8(Type::List as u8);
857            write_len(buf, l.len());
858            for val in l {
859                val.with_dump(|val| serialize_value(buf, val))??
860            }
861        }
862        DumpableValue::Set(set) => {
863            buf.write_u8(Type::Set as u8);
864            write_len(buf, set.len());
865            for val in set {
866                val.with_dump(|val| serialize_value(buf, val))??
867            }
868        }
869        DumpableValue::Frozenset(set) => {
870            buf.write_u8(Type::FrozenSet as u8);
871            write_len(buf, set.len());
872            for val in set {
873                val.with_dump(|val| serialize_value(buf, val))??
874            }
875        }
876        DumpableValue::Dict(d) => {
877            buf.write_u8(Type::Dict as u8);
878            for (k, v) in d {
879                k.with_dump(|val| serialize_value(buf, val))??;
880                v.with_dump(|val| serialize_value(buf, val))??;
881            }
882            buf.write_u8(b'0'); // TYPE_NULL
883        }
884        DumpableValue::Slice(start, stop, step) => {
885            buf.write_u8(Type::Slice as u8);
886            start.with_dump(|val| serialize_value(buf, val))??;
887            stop.with_dump(|val| serialize_value(buf, val))??;
888            step.with_dump(|val| serialize_value(buf, val))??;
889        }
890    }
891    Ok(())
892}
893
894/// Serialize a code object in CPython field order.
895///
896/// Split varnames/cellvars/freevars are reassembled into
897/// co_localsplusnames/co_localspluskinds.
898pub fn serialize_code<W: Write, C: Constant>(buf: &mut W, code: &CodeObject<C>) {
899    // 1–5: scalar fields
900    buf.write_u32(code.arg_count);
901    buf.write_u32(code.posonlyarg_count);
902    buf.write_u32(code.kwonlyarg_count);
903    buf.write_u32(code.max_stackdepth);
904    buf.write_u32(code.flags.bits());
905
906    // 6: co_code (TYPE_STRING) — bytecode already uses flat localsplus indices
907    let bytecode = code.instructions.original_bytes();
908    buf.write_u8(Type::Bytes as u8);
909    write_vec(buf, &bytecode);
910
911    // 7: co_consts (TYPE_TUPLE)
912    buf.write_u8(Type::Tuple as u8);
913    write_len(buf, code.constants.len());
914    for constant in &*code.constants {
915        serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {})
916    }
917
918    // 8: co_names (tuple of strings)
919    write_marshal_name_tuple(buf, &code.names);
920
921    // 9: co_localsplusnames — varnames ++ cell_only ++ freevars
922    let cell_only_names: Vec<&str> = code
923        .cellvars
924        .iter()
925        .filter(|cv| !code.varnames.iter().any(|v| v.as_ref() == cv.as_ref()))
926        .map(|cv| cv.as_ref())
927        .collect();
928    let total_lp_count = code.varnames.len() + cell_only_names.len() + code.freevars.len();
929    buf.write_u8(Type::Tuple as u8);
930    write_len(buf, total_lp_count);
931    for n in code.varnames.iter() {
932        write_marshal_str(buf, n.as_ref());
933    }
934    for &n in &cell_only_names {
935        write_marshal_str(buf, n);
936    }
937    for n in code.freevars.iter() {
938        write_marshal_str(buf, n.as_ref());
939    }
940    // 10: co_localspluskinds — use the stored kinds directly
941    buf.write_u8(Type::Bytes as u8);
942    write_vec(buf, &code.localspluskinds);
943
944    // 11: co_filename
945    write_marshal_str(buf, code.source_path.as_ref());
946    // 12: co_name
947    write_marshal_str(buf, code.obj_name.as_ref());
948    // 13: co_qualname
949    write_marshal_str(buf, code.qualname.as_ref());
950    // 14: co_firstlineno
951    buf.write_u32(code.first_line_number.map_or(0, |x| x.get() as _));
952    // 15: co_linetable
953    buf.write_u8(Type::Bytes as u8);
954    write_vec(buf, &code.linetable);
955    // 16: co_exceptiontable
956    buf.write_u8(Type::Bytes as u8);
957    write_vec(buf, &code.exceptiontable);
958}
959
960fn write_marshal_str<W: Write>(buf: &mut W, s: &str) {
961    let bytes = s.as_bytes();
962    if bytes.len() < 256 && bytes.is_ascii() {
963        buf.write_u8(b'z'); // TYPE_SHORT_ASCII
964        buf.write_u8(bytes.len() as u8);
965    } else {
966        buf.write_u8(Type::Unicode as u8);
967        write_len(buf, bytes.len());
968    }
969    buf.write_slice(bytes);
970}
971
972fn write_marshal_name_tuple<W: Write, N: AsRef<str>>(buf: &mut W, names: &[N]) {
973    buf.write_u8(Type::Tuple as u8);
974    write_len(buf, names.len());
975    for name in names {
976        write_marshal_str(buf, name.as_ref());
977    }
978}
979
980pub const FLAG_REF: u8 = 0x80;
981
982/// Read a signed 32-bit LE integer.
983pub fn read_i32<R: Read>(rdr: &mut R) -> Result<i32> {
984    let bytes = rdr.read_array::<4>()?;
985    Ok(i32::from_le_bytes(*bytes))
986}
987
988/// Read a TYPE_LONG arbitrary-precision integer (base-2^15 digits).
989pub fn read_pylong<R: Read>(rdr: &mut R) -> Result<BigInt> {
990    const MARSHAL_SHIFT: u32 = 15;
991    const MARSHAL_BASE: u32 = 1 << MARSHAL_SHIFT;
992    let n = read_i32(rdr)?;
993    if n == 0 {
994        return Ok(BigInt::from(0));
995    }
996    let negative = n < 0;
997    let num_digits = n.unsigned_abs() as usize;
998    let mut accum = BigInt::from(0);
999    let mut last_digit = 0u32;
1000    for i in 0..num_digits {
1001        let d = rdr.read_u16()? as u32;
1002        if d >= MARSHAL_BASE {
1003            return Err(MarshalError::InvalidBytecode);
1004        }
1005        last_digit = d;
1006        accum += BigInt::from(d) << (i as u32 * MARSHAL_SHIFT);
1007    }
1008    if num_digits > 0 && last_digit == 0 {
1009        return Err(MarshalError::InvalidBytecode);
1010    }
1011    if negative {
1012        accum = -accum;
1013    }
1014    Ok(accum)
1015}
1016
1017/// Read a text-encoded float (1-byte length + ASCII).
1018pub fn read_float_str<R: Read>(rdr: &mut R) -> Result<f64> {
1019    let n = rdr.read_u8()? as u32;
1020    let s = rdr.read_str(n)?;
1021    s.parse::<f64>().map_err(|_| MarshalError::InvalidBytecode)
1022}
1023
1024/// Read a 4-byte-length-prefixed byte string.
1025pub fn read_pstring<R: Read>(rdr: &mut R) -> Result<&[u8]> {
1026    let n = read_i32(rdr)?;
1027    if n < 0 {
1028        return Err(MarshalError::InvalidBytecode);
1029    }
1030    rdr.read_slice(n as u32)
1031}
1032
1033const CO_FAST_LOCAL: u8 = 0x20;
1034const CO_FAST_CELL: u8 = 0x40;
1035const CO_FAST_FREE: u8 = 0x80;
1036
1037pub struct LocalsPlusResult<S> {
1038    pub varnames: Vec<S>,
1039    pub cellvars: Vec<S>,
1040    pub freevars: Vec<S>,
1041    pub cell2arg: Option<Box<[i32]>>,
1042    pub deref_map: Vec<u32>,
1043}
1044
1045pub fn split_localplus<S: Clone>(
1046    names: &[S],
1047    kinds: &[u8],
1048    arg_count: u32,
1049    kwonlyarg_count: u32,
1050    flags: CodeFlags,
1051) -> Result<LocalsPlusResult<S>> {
1052    if names.len() != kinds.len() {
1053        return Err(MarshalError::InvalidBytecode);
1054    }
1055
1056    let mut varnames = Vec::new();
1057    let mut cellvars = Vec::new();
1058    let mut freevars = Vec::new();
1059
1060    // First pass: collect varnames (LOCAL entries) and freevars
1061    for (name, &kind) in names.iter().zip(kinds.iter()) {
1062        if kind & CO_FAST_LOCAL != 0 {
1063            varnames.push(name.clone());
1064        }
1065        if kind & CO_FAST_FREE != 0 {
1066            freevars.push(name.clone());
1067        }
1068    }
1069
1070    // Second pass: collect cellvars in localsplusnames order.
1071    // CELL-only vars come from non-LOCAL CELL entries.
1072    // LOCAL|CELL vars are also added to cellvars.
1073    // This preserves the original ordering from localsplusnames.
1074    let mut arg_cell_positions = Vec::new(); // (cell_idx, localplus_idx)
1075    for (i, (name, &kind)) in names.iter().zip(kinds.iter()).enumerate() {
1076        let is_local = kind & CO_FAST_LOCAL != 0;
1077        let is_cell = kind & CO_FAST_CELL != 0;
1078        if is_cell {
1079            let cell_idx = cellvars.len();
1080            cellvars.push(name.clone());
1081            if is_local {
1082                arg_cell_positions.push((cell_idx, i));
1083            }
1084        }
1085    }
1086
1087    let total_args = {
1088        let mut t = arg_count + kwonlyarg_count;
1089        if flags.contains(CodeFlags::VARARGS) {
1090            t += 1;
1091        }
1092        if flags.contains(CodeFlags::VARKEYWORDS) {
1093            t += 1;
1094        }
1095        t
1096    };
1097
1098    let cell2arg = if !cellvars.is_empty() {
1099        let mut mapping = alloc::vec![-1i32; cellvars.len()];
1100        for &(cell_idx, localplus_idx) in &arg_cell_positions {
1101            if (localplus_idx as u32) < total_args {
1102                mapping[cell_idx] = localplus_idx as i32;
1103            }
1104        }
1105        if mapping.iter().any(|&x| x >= 0) {
1106            Some(mapping.into_boxed_slice())
1107        } else {
1108            None
1109        }
1110    } else {
1111        None
1112    };
1113
1114    // Build deref_map: localsplusnames index → cellvar/freevar index
1115    let mut deref_map = alloc::vec![u32::MAX; names.len()];
1116    let mut cell_idx = 0u32;
1117    for (i, &kind) in kinds.iter().enumerate() {
1118        if kind & CO_FAST_CELL != 0 {
1119            deref_map[i] = cell_idx;
1120            cell_idx += 1;
1121        }
1122    }
1123    let ncells = cellvars.len();
1124    let mut free_idx = 0u32;
1125    for (i, &kind) in kinds.iter().enumerate() {
1126        if kind & CO_FAST_FREE != 0 {
1127            deref_map[i] = ncells as u32 + free_idx;
1128            free_idx += 1;
1129        }
1130    }
1131
1132    Ok(LocalsPlusResult {
1133        varnames,
1134        cellvars,
1135        freevars,
1136        cell2arg,
1137        deref_map,
1138    })
1139}
1140
1141pub fn linetable_to_locations(
1142    linetable: &[u8],
1143    first_line: i32,
1144    num_instructions: usize,
1145) -> Box<[(SourceLocation, SourceLocation)]> {
1146    let default_loc = || {
1147        let line = if first_line > 0 {
1148            OneIndexed::new(first_line as usize).unwrap_or(OneIndexed::MIN)
1149        } else {
1150            OneIndexed::MIN
1151        };
1152        let loc = SourceLocation {
1153            line,
1154            character_offset: OneIndexed::from_zero_indexed(0),
1155        };
1156        (loc, loc)
1157    };
1158    if linetable.is_empty() {
1159        return alloc::vec![default_loc(); num_instructions].into_boxed_slice();
1160    }
1161
1162    let mut locations = Vec::with_capacity(num_instructions);
1163    let mut pos = 0;
1164    let mut line = first_line;
1165
1166    while pos < linetable.len() && locations.len() < num_instructions {
1167        let first_byte = linetable[pos];
1168        pos += 1;
1169        if first_byte & 0x80 == 0 {
1170            break;
1171        }
1172        let code = (first_byte >> 3) & 0x0f;
1173        let length = ((first_byte & 0x07) + 1) as usize;
1174        let kind = match PyCodeLocationInfoKind::from_code(code) {
1175            Some(k) => k,
1176            None => break,
1177        };
1178
1179        let (line_delta, end_line_delta, col, end_col): (i32, i32, Option<u32>, Option<u32>) =
1180            match kind {
1181                PyCodeLocationInfoKind::None => (0, 0, None, None),
1182                PyCodeLocationInfoKind::Long => {
1183                    let d = lt_read_signed_varint(linetable, &mut pos);
1184                    let ed = lt_read_varint(linetable, &mut pos) as i32;
1185                    let c = lt_read_varint(linetable, &mut pos);
1186                    let ec = lt_read_varint(linetable, &mut pos);
1187                    (
1188                        d,
1189                        ed,
1190                        if c == 0 { None } else { Some(c - 1) },
1191                        if ec == 0 { None } else { Some(ec - 1) },
1192                    )
1193                }
1194                PyCodeLocationInfoKind::NoColumns => {
1195                    (lt_read_signed_varint(linetable, &mut pos), 0, None, None)
1196                }
1197                PyCodeLocationInfoKind::OneLine0
1198                | PyCodeLocationInfoKind::OneLine1
1199                | PyCodeLocationInfoKind::OneLine2 => {
1200                    let c = lt_byte(linetable, &mut pos) as u32;
1201                    let ec = lt_byte(linetable, &mut pos) as u32;
1202                    (kind.one_line_delta().unwrap_or(0), 0, Some(c), Some(ec))
1203                }
1204                _ if kind.is_short() => {
1205                    let d = lt_byte(linetable, &mut pos);
1206                    let g = kind.short_column_group().unwrap_or(0);
1207                    let c = ((g as u32) << 3) | ((d >> 4) as u32);
1208                    (0, 0, Some(c), Some(c + (d & 0x0f) as u32))
1209                }
1210                _ => (0, 0, None, None),
1211            };
1212
1213        line += line_delta;
1214        for _ in 0..length {
1215            if locations.len() >= num_instructions {
1216                break;
1217            }
1218            if kind == PyCodeLocationInfoKind::None {
1219                locations.push(default_loc());
1220            } else {
1221                let mk = |l: i32| {
1222                    if l > 0 {
1223                        OneIndexed::new(l as usize).unwrap_or(OneIndexed::MIN)
1224                    } else {
1225                        OneIndexed::MIN
1226                    }
1227                };
1228                locations.push((
1229                    SourceLocation {
1230                        line: mk(line),
1231                        character_offset: OneIndexed::from_zero_indexed(col.unwrap_or(0) as usize),
1232                    },
1233                    SourceLocation {
1234                        line: mk(line + end_line_delta),
1235                        character_offset: OneIndexed::from_zero_indexed(
1236                            end_col.unwrap_or(0) as usize
1237                        ),
1238                    },
1239                ));
1240            }
1241        }
1242    }
1243    while locations.len() < num_instructions {
1244        locations.push(default_loc());
1245    }
1246    locations.into_boxed_slice()
1247}
1248
1249fn lt_byte(data: &[u8], pos: &mut usize) -> u8 {
1250    if *pos < data.len() {
1251        let b = data[*pos];
1252        *pos += 1;
1253        b
1254    } else {
1255        0
1256    }
1257}
1258
1259/// Linetable uses little-endian varint.
1260fn lt_read_varint(data: &[u8], pos: &mut usize) -> u32 {
1261    let mut result: u32 = 0;
1262    let mut shift = 0;
1263    loop {
1264        if *pos >= data.len() {
1265            break;
1266        }
1267        let b = data[*pos];
1268        *pos += 1;
1269        result |= ((b & 0x3f) as u32) << shift;
1270        shift += 6;
1271        if b & 0x40 == 0 {
1272            break;
1273        }
1274    }
1275    result
1276}
1277
1278fn lt_read_signed_varint(data: &[u8], pos: &mut usize) -> i32 {
1279    let val = lt_read_varint(data, pos);
1280    if val & 1 != 0 {
1281        -((val >> 1) as i32)
1282    } else {
1283        (val >> 1) as i32
1284    }
1285}