1use crate::bytecode::*;
2use malachite_bigint::{BigInt, Sign};
3use num_complex::Complex64;
4use rustpython_parser_core::source_code::{OneIndexed, SourceLocation};
5use std::convert::Infallible;
6
7pub const FORMAT_VERSION: u32 = 4;
8
9#[derive(Debug)]
10pub enum MarshalError {
11 Eof,
13 InvalidBytecode,
15 InvalidUtf8,
17 InvalidLocation,
19 BadType,
21}
22
23impl std::fmt::Display for MarshalError {
24 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25 match self {
26 Self::Eof => f.write_str("unexpected end of data"),
27 Self::InvalidBytecode => f.write_str("invalid bytecode"),
28 Self::InvalidUtf8 => f.write_str("invalid utf8"),
29 Self::InvalidLocation => f.write_str("invalid source location"),
30 Self::BadType => f.write_str("bad type marker"),
31 }
32 }
33}
34
35impl From<std::str::Utf8Error> for MarshalError {
36 fn from(_: std::str::Utf8Error) -> Self {
37 Self::InvalidUtf8
38 }
39}
40
41impl std::error::Error for MarshalError {}
42
43type Result<T, E = MarshalError> = std::result::Result<T, E>;
44
45#[repr(u8)]
46enum Type {
47 None = b'N',
49 False = b'F',
50 True = b'T',
51 StopIter = b'S',
52 Ellipsis = b'.',
53 Int = b'i',
54 Float = b'g',
55 Complex = b'y',
56 Bytes = b's', Tuple = b'(',
61 List = b'[',
62 Dict = b'{',
63 Code = b'c',
64 Unicode = b'u',
65 Set = b'<',
67 FrozenSet = b'>',
68 Ascii = b'a',
69 }
74impl TryFrom<u8> for Type {
77 type Error = MarshalError;
78 fn try_from(value: u8) -> Result<Self> {
79 use Type::*;
80 Ok(match value {
81 b'N' => None,
83 b'F' => False,
84 b'T' => True,
85 b'S' => StopIter,
86 b'.' => Ellipsis,
87 b'i' => Int,
88 b'g' => Float,
89 b'y' => Complex,
90 b's' => Bytes,
92 b'(' => Tuple,
95 b'[' => List,
96 b'{' => Dict,
97 b'c' => Code,
98 b'u' => Unicode,
99 b'<' => Set,
101 b'>' => FrozenSet,
102 b'a' => Ascii,
103 _ => return Err(MarshalError::BadType),
108 })
109 }
110}
111
112pub trait Read {
113 fn read_slice(&mut self, n: u32) -> Result<&[u8]>;
114 fn read_array<const N: usize>(&mut self) -> Result<&[u8; N]> {
115 self.read_slice(N as u32).map(|s| s.try_into().unwrap())
116 }
117 fn read_str(&mut self, len: u32) -> Result<&str> {
118 Ok(std::str::from_utf8(self.read_slice(len)?)?)
119 }
120 fn read_u8(&mut self) -> Result<u8> {
121 Ok(u8::from_le_bytes(*self.read_array()?))
122 }
123 fn read_u16(&mut self) -> Result<u16> {
124 Ok(u16::from_le_bytes(*self.read_array()?))
125 }
126 fn read_u32(&mut self) -> Result<u32> {
127 Ok(u32::from_le_bytes(*self.read_array()?))
128 }
129 fn read_u64(&mut self) -> Result<u64> {
130 Ok(u64::from_le_bytes(*self.read_array()?))
131 }
132}
133
134pub(crate) trait ReadBorrowed<'a>: Read {
135 fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]>;
136 fn read_str_borrow(&mut self, len: u32) -> Result<&'a str> {
137 Ok(std::str::from_utf8(self.read_slice_borrow(len)?)?)
138 }
139}
140
141impl Read for &[u8] {
142 fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
143 self.read_slice_borrow(n)
144 }
145}
146
147impl<'a> ReadBorrowed<'a> for &'a [u8] {
148 fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]> {
149 let data = self.get(..n as usize).ok_or(MarshalError::Eof)?;
150 *self = &self[n as usize..];
151 Ok(data)
152 }
153}
154
155pub struct Cursor<B> {
156 pub data: B,
157 pub position: usize,
158}
159
160impl<B: AsRef<[u8]>> Read for Cursor<B> {
161 fn read_slice(&mut self, n: u32) -> Result<&[u8]> {
162 let data = &self.data.as_ref()[self.position..];
163 let slice = data.get(..n as usize).ok_or(MarshalError::Eof)?;
164 self.position += n as usize;
165 Ok(slice)
166 }
167}
168
169pub fn deserialize_code<R: Read, Bag: ConstantBag>(
170 rdr: &mut R,
171 bag: Bag,
172) -> Result<CodeObject<Bag::Constant>> {
173 let len = rdr.read_u32()?;
174 let instructions = rdr.read_slice(len * 2)?;
175 let instructions = instructions
176 .chunks_exact(2)
177 .map(|cu| {
178 let op = Instruction::try_from(cu[0])?;
179 let arg = OpArgByte(cu[1]);
180 Ok(CodeUnit { op, arg })
181 })
182 .collect::<Result<Box<[CodeUnit]>>>()?;
183
184 let len = rdr.read_u32()?;
185 let locations = (0..len)
186 .map(|_| {
187 Ok(SourceLocation {
188 row: OneIndexed::new(rdr.read_u32()?).ok_or(MarshalError::InvalidLocation)?,
189 column: OneIndexed::from_zero_indexed(rdr.read_u32()?),
190 })
191 })
192 .collect::<Result<Box<[SourceLocation]>>>()?;
193
194 let flags = CodeFlags::from_bits_truncate(rdr.read_u16()?);
195
196 let posonlyarg_count = rdr.read_u32()?;
197 let arg_count = rdr.read_u32()?;
198 let kwonlyarg_count = rdr.read_u32()?;
199
200 let len = rdr.read_u32()?;
201 let source_path = bag.make_name(rdr.read_str(len)?);
202
203 let first_line_number = OneIndexed::new(rdr.read_u32()?);
204 let max_stackdepth = rdr.read_u32()?;
205
206 let len = rdr.read_u32()?;
207 let obj_name = bag.make_name(rdr.read_str(len)?);
208
209 let len = rdr.read_u32()?;
210 let cell2arg = (len != 0)
211 .then(|| {
212 (0..len)
213 .map(|_| Ok(rdr.read_u32()? as i32))
214 .collect::<Result<Box<[i32]>>>()
215 })
216 .transpose()?;
217
218 let len = rdr.read_u32()?;
219 let constants = (0..len)
220 .map(|_| deserialize_value(rdr, bag))
221 .collect::<Result<Box<[_]>>>()?;
222
223 let mut read_names = || {
224 let len = rdr.read_u32()?;
225 (0..len)
226 .map(|_| {
227 let len = rdr.read_u32()?;
228 Ok(bag.make_name(rdr.read_str(len)?))
229 })
230 .collect::<Result<Box<[_]>>>()
231 };
232
233 let names = read_names()?;
234 let varnames = read_names()?;
235 let cellvars = read_names()?;
236 let freevars = read_names()?;
237
238 Ok(CodeObject {
239 instructions,
240 locations,
241 flags,
242 posonlyarg_count,
243 arg_count,
244 kwonlyarg_count,
245 source_path,
246 first_line_number,
247 max_stackdepth,
248 obj_name,
249 cell2arg,
250 constants,
251 names,
252 varnames,
253 cellvars,
254 freevars,
255 })
256}
257
258pub trait MarshalBag: Copy {
259 type Value;
260 fn make_bool(&self, value: bool) -> Self::Value;
261 fn make_none(&self) -> Self::Value;
262 fn make_ellipsis(&self) -> Self::Value;
263 fn make_float(&self, value: f64) -> Self::Value;
264 fn make_complex(&self, value: Complex64) -> Self::Value;
265 fn make_str(&self, value: &str) -> Self::Value;
266 fn make_bytes(&self, value: &[u8]) -> Self::Value;
267 fn make_int(&self, value: BigInt) -> Self::Value;
268 fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value;
269 fn make_code(
270 &self,
271 code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
272 ) -> Self::Value;
273 fn make_stop_iter(&self) -> Result<Self::Value>;
274 fn make_list(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
275 fn make_set(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
276 fn make_frozenset(&self, it: impl Iterator<Item = Self::Value>) -> Result<Self::Value>;
277 fn make_dict(
278 &self,
279 it: impl Iterator<Item = (Self::Value, Self::Value)>,
280 ) -> Result<Self::Value>;
281 type ConstantBag: ConstantBag;
282 fn constant_bag(self) -> Self::ConstantBag;
283}
284
285impl<Bag: ConstantBag> MarshalBag for Bag {
286 type Value = Bag::Constant;
287 fn make_bool(&self, value: bool) -> Self::Value {
288 self.make_constant::<Bag::Constant>(BorrowedConstant::Boolean { value })
289 }
290 fn make_none(&self) -> Self::Value {
291 self.make_constant::<Bag::Constant>(BorrowedConstant::None)
292 }
293 fn make_ellipsis(&self) -> Self::Value {
294 self.make_constant::<Bag::Constant>(BorrowedConstant::Ellipsis)
295 }
296 fn make_float(&self, value: f64) -> Self::Value {
297 self.make_constant::<Bag::Constant>(BorrowedConstant::Float { value })
298 }
299 fn make_complex(&self, value: Complex64) -> Self::Value {
300 self.make_constant::<Bag::Constant>(BorrowedConstant::Complex { value })
301 }
302 fn make_str(&self, value: &str) -> Self::Value {
303 self.make_constant::<Bag::Constant>(BorrowedConstant::Str { value })
304 }
305 fn make_bytes(&self, value: &[u8]) -> Self::Value {
306 self.make_constant::<Bag::Constant>(BorrowedConstant::Bytes { value })
307 }
308 fn make_int(&self, value: BigInt) -> Self::Value {
309 self.make_int(value)
310 }
311 fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
312 self.make_tuple(elements)
313 }
314 fn make_code(
315 &self,
316 code: CodeObject<<Self::ConstantBag as ConstantBag>::Constant>,
317 ) -> Self::Value {
318 self.make_code(code)
319 }
320 fn make_stop_iter(&self) -> Result<Self::Value> {
321 Err(MarshalError::BadType)
322 }
323 fn make_list(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
324 Err(MarshalError::BadType)
325 }
326 fn make_set(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
327 Err(MarshalError::BadType)
328 }
329 fn make_frozenset(&self, _: impl Iterator<Item = Self::Value>) -> Result<Self::Value> {
330 Err(MarshalError::BadType)
331 }
332 fn make_dict(
333 &self,
334 _: impl Iterator<Item = (Self::Value, Self::Value)>,
335 ) -> Result<Self::Value> {
336 Err(MarshalError::BadType)
337 }
338 type ConstantBag = Self;
339 fn constant_bag(self) -> Self::ConstantBag {
340 self
341 }
342}
343
344pub fn deserialize_value<R: Read, Bag: MarshalBag>(rdr: &mut R, bag: Bag) -> Result<Bag::Value> {
345 let typ = Type::try_from(rdr.read_u8()?)?;
346 let value = match typ {
347 Type::True => bag.make_bool(true),
348 Type::False => bag.make_bool(false),
349 Type::None => bag.make_none(),
350 Type::StopIter => bag.make_stop_iter()?,
351 Type::Ellipsis => bag.make_ellipsis(),
352 Type::Int => {
353 let len = rdr.read_u32()? as i32;
354 let sign = if len < 0 { Sign::Minus } else { Sign::Plus };
355 let bytes = rdr.read_slice(len.unsigned_abs())?;
356 let int = BigInt::from_bytes_le(sign, bytes);
357 bag.make_int(int)
358 }
359 Type::Float => {
360 let value = f64::from_bits(rdr.read_u64()?);
361 bag.make_float(value)
362 }
363 Type::Complex => {
364 let re = f64::from_bits(rdr.read_u64()?);
365 let im = f64::from_bits(rdr.read_u64()?);
366 let value = Complex64 { re, im };
367 bag.make_complex(value)
368 }
369 Type::Ascii | Type::Unicode => {
370 let len = rdr.read_u32()?;
371 let value = rdr.read_str(len)?;
372 bag.make_str(value)
373 }
374 Type::Tuple => {
375 let len = rdr.read_u32()?;
376 let it = (0..len).map(|_| deserialize_value(rdr, bag));
377 itertools::process_results(it, |it| bag.make_tuple(it))?
378 }
379 Type::List => {
380 let len = rdr.read_u32()?;
381 let it = (0..len).map(|_| deserialize_value(rdr, bag));
382 itertools::process_results(it, |it| bag.make_list(it))??
383 }
384 Type::Set => {
385 let len = rdr.read_u32()?;
386 let it = (0..len).map(|_| deserialize_value(rdr, bag));
387 itertools::process_results(it, |it| bag.make_set(it))??
388 }
389 Type::FrozenSet => {
390 let len = rdr.read_u32()?;
391 let it = (0..len).map(|_| deserialize_value(rdr, bag));
392 itertools::process_results(it, |it| bag.make_frozenset(it))??
393 }
394 Type::Dict => {
395 let len = rdr.read_u32()?;
396 let it = (0..len).map(|_| {
397 let k = deserialize_value(rdr, bag)?;
398 let v = deserialize_value(rdr, bag)?;
399 Ok::<_, MarshalError>((k, v))
400 });
401 itertools::process_results(it, |it| bag.make_dict(it))??
402 }
403 Type::Bytes => {
404 let len = rdr.read_u32()?;
406 let value = rdr.read_slice(len)?;
407 bag.make_bytes(value)
408 }
409 Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?),
410 };
411 Ok(value)
412}
413
414pub trait Dumpable: Sized {
415 type Error;
416 type Constant: Constant;
417 fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error>;
418}
419
420pub enum DumpableValue<'a, D: Dumpable> {
421 Integer(&'a BigInt),
422 Float(f64),
423 Complex(Complex64),
424 Boolean(bool),
425 Str(&'a str),
426 Bytes(&'a [u8]),
427 Code(&'a CodeObject<D::Constant>),
428 Tuple(&'a [D]),
429 None,
430 Ellipsis,
431 StopIter,
432 List(&'a [D]),
433 Set(&'a [D]),
434 Frozenset(&'a [D]),
435 Dict(&'a [(D, D)]),
436}
437
438impl<'a, C: Constant> From<BorrowedConstant<'a, C>> for DumpableValue<'a, C> {
439 fn from(c: BorrowedConstant<'a, C>) -> Self {
440 match c {
441 BorrowedConstant::Integer { value } => Self::Integer(value),
442 BorrowedConstant::Float { value } => Self::Float(value),
443 BorrowedConstant::Complex { value } => Self::Complex(value),
444 BorrowedConstant::Boolean { value } => Self::Boolean(value),
445 BorrowedConstant::Str { value } => Self::Str(value),
446 BorrowedConstant::Bytes { value } => Self::Bytes(value),
447 BorrowedConstant::Code { code } => Self::Code(code),
448 BorrowedConstant::Tuple { elements } => Self::Tuple(elements),
449 BorrowedConstant::None => Self::None,
450 BorrowedConstant::Ellipsis => Self::Ellipsis,
451 }
452 }
453}
454
455impl<C: Constant> Dumpable for C {
456 type Error = Infallible;
457 type Constant = Self;
458 #[inline(always)]
459 fn with_dump<R>(&self, f: impl FnOnce(DumpableValue<'_, Self>) -> R) -> Result<R, Self::Error> {
460 Ok(f(self.borrow_constant().into()))
461 }
462}
463
464pub trait Write {
465 fn write_slice(&mut self, slice: &[u8]);
466 fn write_u8(&mut self, v: u8) {
467 self.write_slice(&v.to_le_bytes())
468 }
469 fn write_u16(&mut self, v: u16) {
470 self.write_slice(&v.to_le_bytes())
471 }
472 fn write_u32(&mut self, v: u32) {
473 self.write_slice(&v.to_le_bytes())
474 }
475 fn write_u64(&mut self, v: u64) {
476 self.write_slice(&v.to_le_bytes())
477 }
478}
479
480impl Write for Vec<u8> {
481 fn write_slice(&mut self, slice: &[u8]) {
482 self.extend_from_slice(slice)
483 }
484}
485
486pub(crate) fn write_len<W: Write>(buf: &mut W, len: usize) {
487 let Ok(len) = len.try_into() else {
488 panic!("too long to serialize")
489 };
490 buf.write_u32(len);
491}
492
493pub(crate) fn write_vec<W: Write>(buf: &mut W, slice: &[u8]) {
494 write_len(buf, slice.len());
495 buf.write_slice(slice);
496}
497
498pub fn serialize_value<W: Write, D: Dumpable>(
499 buf: &mut W,
500 constant: DumpableValue<'_, D>,
501) -> Result<(), D::Error> {
502 match constant {
503 DumpableValue::Integer(int) => {
504 buf.write_u8(Type::Int as u8);
505 let (sign, bytes) = int.to_bytes_le();
506 let len: i32 = bytes.len().try_into().expect("too long to serialize");
507 let len = if sign == Sign::Minus { -len } else { len };
508 buf.write_u32(len as u32);
509 buf.write_slice(&bytes);
510 }
511 DumpableValue::Float(f) => {
512 buf.write_u8(Type::Float as u8);
513 buf.write_u64(f.to_bits());
514 }
515 DumpableValue::Complex(c) => {
516 buf.write_u8(Type::Complex as u8);
517 buf.write_u64(c.re.to_bits());
518 buf.write_u64(c.im.to_bits());
519 }
520 DumpableValue::Boolean(b) => {
521 buf.write_u8(if b { Type::True } else { Type::False } as u8);
522 }
523 DumpableValue::Str(s) => {
524 buf.write_u8(Type::Unicode as u8);
525 write_vec(buf, s.as_bytes());
526 }
527 DumpableValue::Bytes(b) => {
528 buf.write_u8(Type::Bytes as u8);
529 write_vec(buf, b);
530 }
531 DumpableValue::Code(c) => {
532 buf.write_u8(Type::Code as u8);
533 serialize_code(buf, c);
534 }
535 DumpableValue::Tuple(tup) => {
536 buf.write_u8(Type::Tuple as u8);
537 write_len(buf, tup.len());
538 for val in tup {
539 val.with_dump(|val| serialize_value(buf, val))??
540 }
541 }
542 DumpableValue::None => {
543 buf.write_u8(Type::None as u8);
544 }
545 DumpableValue::Ellipsis => {
546 buf.write_u8(Type::Ellipsis as u8);
547 }
548 DumpableValue::StopIter => {
549 buf.write_u8(Type::StopIter as u8);
550 }
551 DumpableValue::List(l) => {
552 buf.write_u8(Type::List as u8);
553 write_len(buf, l.len());
554 for val in l {
555 val.with_dump(|val| serialize_value(buf, val))??
556 }
557 }
558 DumpableValue::Set(set) => {
559 buf.write_u8(Type::Set as u8);
560 write_len(buf, set.len());
561 for val in set {
562 val.with_dump(|val| serialize_value(buf, val))??
563 }
564 }
565 DumpableValue::Frozenset(set) => {
566 buf.write_u8(Type::FrozenSet as u8);
567 write_len(buf, set.len());
568 for val in set {
569 val.with_dump(|val| serialize_value(buf, val))??
570 }
571 }
572 DumpableValue::Dict(d) => {
573 buf.write_u8(Type::Dict as u8);
574 write_len(buf, d.len());
575 for (k, v) in d {
576 k.with_dump(|val| serialize_value(buf, val))??;
577 v.with_dump(|val| serialize_value(buf, val))??;
578 }
579 }
580 }
581 Ok(())
582}
583
584pub fn serialize_code<W: Write, C: Constant>(buf: &mut W, code: &CodeObject<C>) {
585 write_len(buf, code.instructions.len());
586 let (_, instructions_bytes, _) = unsafe { code.instructions.align_to() };
588 buf.write_slice(instructions_bytes);
589
590 write_len(buf, code.locations.len());
591 for loc in &*code.locations {
592 buf.write_u32(loc.row.get());
593 buf.write_u32(loc.column.to_zero_indexed());
594 }
595
596 buf.write_u16(code.flags.bits());
597
598 buf.write_u32(code.posonlyarg_count);
599 buf.write_u32(code.arg_count);
600 buf.write_u32(code.kwonlyarg_count);
601
602 write_vec(buf, code.source_path.as_ref().as_bytes());
603
604 buf.write_u32(code.first_line_number.map_or(0, |x| x.get()));
605 buf.write_u32(code.max_stackdepth);
606
607 write_vec(buf, code.obj_name.as_ref().as_bytes());
608
609 let cell2arg = code.cell2arg.as_deref().unwrap_or(&[]);
610 write_len(buf, cell2arg.len());
611 for &i in cell2arg {
612 buf.write_u32(i as u32)
613 }
614
615 write_len(buf, code.constants.len());
616 for constant in &*code.constants {
617 serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {})
618 }
619
620 let mut write_names = |names: &[C::Name]| {
621 write_len(buf, names.len());
622 for name in names {
623 write_vec(buf, name.as_ref().as_bytes());
624 }
625 };
626
627 write_names(&code.names);
628 write_names(&code.varnames);
629 write_names(&code.cellvars);
630 write_names(&code.freevars);
631}