1use crate::{
2 PyObjectRef, PyResult, TryFromObject, VirtualMachine,
3 builtins::{PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef},
4 common::{static_cell, str::wchar_t},
5 convert::ToPyObject,
6 function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat},
7};
8use alloc::fmt;
9use core::{iter::Peekable, mem};
10use half::f16;
11use itertools::Itertools;
12use malachite_bigint::BigInt;
13use num_traits::{PrimInt, ToPrimitive};
14use std::os::raw;
15
16type PackFunc = fn(&VirtualMachine, PyObjectRef, &mut [u8]) -> PyResult<()>;
17type UnpackFunc = fn(&VirtualMachine, &[u8]) -> PyObjectRef;
18
19static OVERFLOW_MSG: &str = "total struct size too long"; #[derive(Debug, Copy, Clone, PartialEq)]
22pub(crate) enum Endianness {
23 Native,
24 Little,
25 Big,
26 Host,
27}
28
29impl Endianness {
30 fn parse<I>(chars: &mut Peekable<I>) -> Self
33 where
34 I: Sized + Iterator<Item = u8>,
35 {
36 let e = match chars.peek() {
37 Some(b'@') => Self::Native,
38 Some(b'=') => Self::Host,
39 Some(b'<') => Self::Little,
40 Some(b'>') | Some(b'!') => Self::Big,
41 _ => return Self::Native,
42 };
43 chars.next().unwrap();
44 e
45 }
46}
47
48trait ByteOrder {
49 fn convert<I: PrimInt>(i: I) -> I;
50}
51enum BigEndian {}
52impl ByteOrder for BigEndian {
53 fn convert<I: PrimInt>(i: I) -> I {
54 i.to_be()
55 }
56}
57enum LittleEndian {}
58impl ByteOrder for LittleEndian {
59 fn convert<I: PrimInt>(i: I) -> I {
60 i.to_le()
61 }
62}
63
64#[cfg(target_endian = "big")]
65type NativeEndian = BigEndian;
66#[cfg(target_endian = "little")]
67type NativeEndian = LittleEndian;
68
69#[derive(Copy, Clone, num_enum::TryFromPrimitive)]
70#[repr(u8)]
71pub(crate) enum FormatType {
72 Pad = b'x',
73 SByte = b'b',
74 UByte = b'B',
75 Char = b'c',
76 WideChar = b'u',
77 Str = b's',
78 Pascal = b'p',
79 Short = b'h',
80 UShort = b'H',
81 Int = b'i',
82 UInt = b'I',
83 Long = b'l',
84 ULong = b'L',
85 SSizeT = b'n',
86 SizeT = b'N',
87 LongLong = b'q',
88 ULongLong = b'Q',
89 Bool = b'?',
90 Half = b'e',
91 Float = b'f',
92 Double = b'd',
93 LongDouble = b'g',
94 VoidP = b'P',
95 PyObject = b'O',
96}
97
98impl fmt::Debug for FormatType {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 fmt::Debug::fmt(&(*self as u8 as char), f)
101 }
102}
103
104impl FormatType {
105 fn info(self, e: Endianness) -> &'static FormatInfo {
106 use FormatType::*;
107 use mem::{align_of, size_of};
108 macro_rules! native_info {
109 ($t:ty) => {{
110 &FormatInfo {
111 size: size_of::<$t>(),
112 align: align_of::<$t>(),
113 pack: Some(<$t as Packable>::pack::<NativeEndian>),
114 unpack: Some(<$t as Packable>::unpack::<NativeEndian>),
115 }
116 }};
117 }
118 macro_rules! nonnative_info {
119 ($t:ty, $end:ty) => {{
120 &FormatInfo {
121 size: size_of::<$t>(),
122 align: 0,
123 pack: Some(<$t as Packable>::pack::<$end>),
124 unpack: Some(<$t as Packable>::unpack::<$end>),
125 }
126 }};
127 }
128 macro_rules! match_nonnative {
129 ($zelf:expr, $end:ty) => {{
130 match $zelf {
131 Pad | Str | Pascal => &FormatInfo {
132 size: size_of::<u8>(),
133 align: 0,
134 pack: None,
135 unpack: None,
136 },
137 SByte => nonnative_info!(i8, $end),
138 UByte => nonnative_info!(u8, $end),
139 Char => &FormatInfo {
140 size: size_of::<u8>(),
141 align: 0,
142 pack: Some(pack_char),
143 unpack: Some(unpack_char),
144 },
145 Short => nonnative_info!(i16, $end),
146 UShort => nonnative_info!(u16, $end),
147 Int | Long => nonnative_info!(i32, $end),
148 UInt | ULong => nonnative_info!(u32, $end),
149 LongLong => nonnative_info!(i64, $end),
150 ULongLong => nonnative_info!(u64, $end),
151 Bool => nonnative_info!(bool, $end),
152 Half => nonnative_info!(f16, $end),
153 Float => nonnative_info!(f32, $end),
154 Double => nonnative_info!(f64, $end),
155 LongDouble => nonnative_info!(f64, $end), PyObject => nonnative_info!(usize, $end), _ => unreachable!(), }
159 }};
160 }
161 match e {
162 Endianness::Native => match self {
163 Pad | Str | Pascal => &FormatInfo {
164 size: size_of::<raw::c_char>(),
165 align: 0,
166 pack: None,
167 unpack: None,
168 },
169 SByte => native_info!(raw::c_schar),
170 UByte => native_info!(raw::c_uchar),
171 Char => &FormatInfo {
172 size: size_of::<raw::c_char>(),
173 align: 0,
174 pack: Some(pack_char),
175 unpack: Some(unpack_char),
176 },
177 WideChar => native_info!(wchar_t),
178 Short => native_info!(raw::c_short),
179 UShort => native_info!(raw::c_ushort),
180 Int => native_info!(raw::c_int),
181 UInt => native_info!(raw::c_uint),
182 Long => native_info!(raw::c_long),
183 ULong => native_info!(raw::c_ulong),
184 SSizeT => native_info!(isize), SizeT => native_info!(usize), LongLong => native_info!(raw::c_longlong),
187 ULongLong => native_info!(raw::c_ulonglong),
188 Bool => native_info!(bool),
189 Half => native_info!(f16),
190 Float => native_info!(raw::c_float),
191 Double => native_info!(raw::c_double),
192 LongDouble => native_info!(raw::c_double), VoidP => native_info!(*mut raw::c_void),
194 PyObject => native_info!(*mut raw::c_void), },
196 Endianness::Big => match_nonnative!(self, BigEndian),
197 Endianness::Little => match_nonnative!(self, LittleEndian),
198 Endianness::Host => match_nonnative!(self, NativeEndian),
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
204pub(crate) struct FormatCode {
205 pub repeat: usize,
206 pub code: FormatType,
207 pub info: &'static FormatInfo,
208 pub pre_padding: usize,
209}
210
211impl FormatCode {
212 pub const fn arg_count(&self) -> usize {
213 match self.code {
214 FormatType::Pad => 0,
215 FormatType::Str | FormatType::Pascal => 1,
216 _ => self.repeat,
217 }
218 }
219
220 pub fn parse<I>(
221 chars: &mut Peekable<I>,
222 endianness: Endianness,
223 ) -> Result<(Vec<Self>, usize, usize), String>
224 where
225 I: Sized + Iterator<Item = u8>,
226 {
227 let mut offset = 0isize;
228 let mut arg_count = 0usize;
229 let mut codes = vec![];
230 while chars.peek().is_some() {
231 while let Some(b' ' | b'\t' | b'\n' | b'\r') = chars.peek() {
233 chars.next();
234 }
235
236 let repeat = match chars.peek() {
238 Some(b'0'..=b'9') => {
239 let mut repeat = 0isize;
240 while let Some(b'0'..=b'9') = chars.peek() {
241 if let Some(c) = chars.next() {
242 let current_digit = c - b'0';
243 repeat = repeat
244 .checked_mul(10)
245 .and_then(|r| r.checked_add(current_digit as _))
246 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
247 }
248 }
249 repeat
250 }
251 _ => 1,
252 };
253
254 let c = match chars.next() {
256 Some(c) => c,
257 None => {
258 if repeat != 1 {
260 return Err("repeat count given without format specifier".to_owned());
261 }
262 break;
264 }
265 };
266
267 if c == 0 {
269 return Err("embedded null character".to_owned());
270 }
271
272 if c == b'T' || c == b'X' {
275 if chars.peek() == Some(&b'{') {
277 chars.next(); let mut depth = 1;
279 while depth > 0 {
280 match chars.next() {
281 Some(b'{') => depth += 1,
282 Some(b'}') => depth -= 1,
283 None => return Err("unmatched '{' in format".to_owned()),
284 _ => {}
285 }
286 }
287 continue;
288 }
289 }
290
291 if c == b'(' {
292 let mut depth = 1;
294 while depth > 0 {
295 match chars.next() {
296 Some(b'(') => depth += 1,
297 Some(b')') => depth -= 1,
298 None => return Err("unmatched '(' in format".to_owned()),
299 _ => {}
300 }
301 }
302 continue;
303 }
304
305 if c == b':' {
306 loop {
308 match chars.next() {
309 Some(b':') => break,
310 None => return Err("unmatched ':' in format".to_owned()),
311 _ => {}
312 }
313 }
314 continue;
315 }
316
317 if c == b'{'
318 || c == b'}'
319 || c == b'&'
320 || c == b'<'
321 || c == b'>'
322 || c == b'@'
323 || c == b'='
324 || c == b'!'
325 {
326 continue;
328 }
329
330 let code = FormatType::try_from(c)
331 .ok()
332 .filter(|c| match c {
333 FormatType::SSizeT | FormatType::SizeT | FormatType::VoidP => {
334 endianness == Endianness::Native
335 }
336 _ => true,
337 })
338 .ok_or_else(|| "bad char in struct format".to_owned())?;
339
340 let info = code.info(endianness);
341
342 let padding = compensate_alignment(offset as usize, info.align)
343 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
344 offset = padding
345 .to_isize()
346 .and_then(|extra| offset.checked_add(extra))
347 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
348
349 let code = Self {
350 repeat: repeat as usize,
351 code,
352 info,
353 pre_padding: padding,
354 };
355 arg_count += code.arg_count();
356 codes.push(code);
357
358 offset = (info.size as isize)
359 .checked_mul(repeat)
360 .and_then(|item_size| offset.checked_add(item_size))
361 .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
362 }
363
364 Ok((codes, offset as usize, arg_count))
365 }
366}
367
368const fn compensate_alignment(offset: usize, align: usize) -> Option<usize> {
369 if align != 0 && offset != 0 {
370 (align - 1).checked_sub((offset - 1) & (align - 1))
372 } else {
373 Some(0)
375 }
376}
377
378pub(crate) struct FormatInfo {
379 pub size: usize,
380 pub align: usize,
381 pub pack: Option<PackFunc>,
382 pub unpack: Option<UnpackFunc>,
383}
384impl fmt::Debug for FormatInfo {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 f.debug_struct("FormatInfo")
387 .field("size", &self.size)
388 .field("align", &self.align)
389 .finish()
390 }
391}
392
393#[derive(Debug, Clone)]
394pub struct FormatSpec {
395 #[allow(dead_code)]
396 pub(crate) endianness: Endianness,
397 pub(crate) codes: Vec<FormatCode>,
398 pub size: usize,
399 pub arg_count: usize,
400}
401
402impl FormatSpec {
403 pub fn parse(fmt: &[u8], vm: &VirtualMachine) -> PyResult<Self> {
404 let mut chars = fmt.iter().copied().peekable();
405
406 let endianness = Endianness::parse(&mut chars);
408
409 let (codes, size, arg_count) =
411 FormatCode::parse(&mut chars, endianness).map_err(|err| new_struct_error(vm, err))?;
412
413 Ok(Self {
414 endianness,
415 codes,
416 size,
417 arg_count,
418 })
419 }
420
421 pub fn pack(&self, args: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
422 let mut data = vec![0; self.size];
424
425 self.pack_into(&mut data, args, vm)?;
426
427 Ok(data)
428 }
429
430 pub fn pack_into(
431 &self,
432 mut buffer: &mut [u8],
433 args: Vec<PyObjectRef>,
434 vm: &VirtualMachine,
435 ) -> PyResult<()> {
436 if self.arg_count != args.len() {
437 return Err(new_struct_error(
438 vm,
439 format!(
440 "pack expected {} items for packing (got {})",
441 self.codes.len(),
442 args.len()
443 ),
444 ));
445 }
446
447 let mut args = args.into_iter();
448 for code in &self.codes {
450 buffer = &mut buffer[code.pre_padding..];
451 debug!("code: {code:?}");
452 match code.code {
453 FormatType::Str => {
454 let (buf, rest) = buffer.split_at_mut(code.repeat);
455 pack_string(vm, args.next().unwrap(), buf)?;
456 buffer = rest;
457 }
458 FormatType::Pascal => {
459 let (buf, rest) = buffer.split_at_mut(code.repeat);
460 pack_pascal(vm, args.next().unwrap(), buf)?;
461 buffer = rest;
462 }
463 FormatType::Pad => {
464 let (pad_buf, rest) = buffer.split_at_mut(code.repeat);
465 for el in pad_buf {
466 *el = 0
467 }
468 buffer = rest;
469 }
470 _ => {
471 let pack = code.info.pack.unwrap();
472 for arg in args.by_ref().take(code.repeat) {
473 let (item_buf, rest) = buffer.split_at_mut(code.info.size);
474 pack(vm, arg, item_buf)?;
475 buffer = rest;
476 }
477 }
478 }
479 }
480
481 Ok(())
482 }
483
484 pub fn unpack(&self, mut data: &[u8], vm: &VirtualMachine) -> PyResult<PyTupleRef> {
485 if self.size != data.len() {
486 return Err(new_struct_error(
487 vm,
488 format!("unpack requires a buffer of {} bytes", self.size),
489 ));
490 }
491
492 let mut items = Vec::with_capacity(self.arg_count);
493 for code in &self.codes {
494 data = &data[code.pre_padding..];
495 debug!("unpack code: {code:?}");
496 match code.code {
497 FormatType::Pad => {
498 data = &data[code.repeat..];
499 }
500 FormatType::Str => {
501 let (str_data, rest) = data.split_at(code.repeat);
502 items.push(vm.ctx.new_bytes(str_data.to_vec()).into());
504 data = rest;
505 }
506 FormatType::Pascal => {
507 let (str_data, rest) = data.split_at(code.repeat);
508 items.push(unpack_pascal(vm, str_data));
509 data = rest;
510 }
511 _ => {
512 let unpack = code.info.unpack.unwrap();
513 for _ in 0..code.repeat {
514 let (item_data, rest) = data.split_at(code.info.size);
515 items.push(unpack(vm, item_data));
516 data = rest;
517 }
518 }
519 };
520 }
521
522 Ok(PyTuple::new_ref(items, &vm.ctx))
523 }
524
525 #[inline]
526 pub const fn size(&self) -> usize {
527 self.size
528 }
529}
530
531trait Packable {
532 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()>;
533 fn unpack<E: ByteOrder>(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef;
534}
535
536trait PackInt: PrimInt {
537 fn pack_int<E: ByteOrder>(self, data: &mut [u8]);
538 fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self;
539}
540
541macro_rules! make_pack_prim_int {
542 ($T:ty) => {
543 impl PackInt for $T {
544 fn pack_int<E: ByteOrder>(self, data: &mut [u8]) {
545 let i = E::convert(self);
546 data.copy_from_slice(&i.to_ne_bytes());
547 }
548 #[inline]
549 fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self {
550 let mut x = [0; core::mem::size_of::<$T>()];
551 x.copy_from_slice(data);
552 E::convert(<$T>::from_ne_bytes(x))
553 }
554 }
555
556 impl Packable for $T {
557 fn pack<E: ByteOrder>(
558 vm: &VirtualMachine,
559 arg: PyObjectRef,
560 data: &mut [u8],
561 ) -> PyResult<()> {
562 let i: $T = get_int_or_index(vm, arg)?;
563 i.pack_int::<E>(data);
564 Ok(())
565 }
566
567 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
568 let i = <$T>::unpack_int::<E>(rdr);
569 vm.ctx.new_int(i).into()
570 }
571 }
572 };
573}
574
575fn get_int_or_index<T>(vm: &VirtualMachine, arg: PyObjectRef) -> PyResult<T>
576where
577 T: PrimInt + for<'a> TryFrom<&'a BigInt>,
578{
579 let index = arg
580 .try_index_opt(vm)
581 .unwrap_or_else(|| Err(new_struct_error(vm, "required argument is not an integer")))?;
582 index
583 .try_to_primitive(vm)
584 .map_err(|_| new_struct_error(vm, "argument out of range"))
585}
586
587make_pack_prim_int!(i8);
588make_pack_prim_int!(u8);
589make_pack_prim_int!(i16);
590make_pack_prim_int!(u16);
591make_pack_prim_int!(i32);
592make_pack_prim_int!(u32);
593make_pack_prim_int!(i64);
594make_pack_prim_int!(u64);
595make_pack_prim_int!(usize);
596make_pack_prim_int!(isize);
597
598macro_rules! make_pack_float {
599 ($T:ty) => {
600 impl Packable for $T {
601 fn pack<E: ByteOrder>(
602 vm: &VirtualMachine,
603 arg: PyObjectRef,
604 data: &mut [u8],
605 ) -> PyResult<()> {
606 let f = ArgIntoFloat::try_from_object(vm, arg)?.into_float() as $T;
607 f.to_bits().pack_int::<E>(data);
608 Ok(())
609 }
610
611 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
612 let i = PackInt::unpack_int::<E>(rdr);
613 <$T>::from_bits(i).to_pyobject(vm)
614 }
615 }
616 };
617}
618
619make_pack_float!(f32);
620make_pack_float!(f64);
621
622impl Packable for f16 {
623 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
624 let f_64 = ArgIntoFloat::try_from_object(vm, arg)?.into_float();
625 let f_16 = Self::from_f64_const(f_64);
627 if f_16.is_infinite() != f_64.is_infinite() {
628 return Err(vm.new_overflow_error("float too large to pack with e format"));
629 }
630 f_16.to_bits().pack_int::<E>(data);
631 Ok(())
632 }
633
634 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
635 let i = PackInt::unpack_int::<E>(rdr);
636 Self::from_bits(i).to_f64().to_pyobject(vm)
637 }
638}
639
640impl Packable for *mut raw::c_void {
641 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
642 usize::pack::<E>(vm, arg, data)
643 }
644
645 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
646 usize::unpack::<E>(vm, rdr)
647 }
648}
649
650impl Packable for bool {
651 fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
652 let v = ArgIntoBool::try_from_object(vm, arg)?.into_bool() as u8;
653 v.pack_int::<E>(data);
654 Ok(())
655 }
656
657 fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
658 let i = u8::unpack_int::<E>(rdr);
659 vm.ctx.new_bool(i != 0).into()
660 }
661}
662
663fn pack_char(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
664 let v = PyBytesRef::try_from_object(vm, arg)?;
665 let ch = *v
666 .as_bytes()
667 .iter()
668 .exactly_one()
669 .map_err(|_| new_struct_error(vm, "char format requires a bytes object of length 1"))?;
670 data[0] = ch;
671 Ok(())
672}
673
674fn pack_string(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
675 let b = ArgBytesLike::try_from_object(vm, arg)?;
676 b.with_ref(|data| write_string(buf, data));
677 Ok(())
678}
679
680fn pack_pascal(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
681 if buf.is_empty() {
682 return Ok(());
683 }
684 let b = ArgBytesLike::try_from_object(vm, arg)?;
685 b.with_ref(|data| {
686 let string_length = core::cmp::min(core::cmp::min(data.len(), 255), buf.len() - 1);
687 buf[0] = string_length as u8;
688 write_string(&mut buf[1..], data);
689 });
690 Ok(())
691}
692
693fn write_string(buf: &mut [u8], data: &[u8]) {
694 let len_from_data = core::cmp::min(data.len(), buf.len());
695 buf[..len_from_data].copy_from_slice(&data[..len_from_data]);
696 for byte in &mut buf[len_from_data..] {
697 *byte = 0
698 }
699}
700
701fn unpack_char(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
702 vm.ctx.new_bytes(vec![data[0]]).into()
703}
704
705fn unpack_pascal(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
706 let (&len, data) = match data.split_first() {
707 Some(x) => x,
708 None => {
709 return vm.ctx.new_bytes(vec![]).into();
711 }
712 };
713 let len = core::cmp::min(len as usize, data.len());
714 vm.ctx.new_bytes(data[..len].to_vec()).into()
715}
716
717pub fn struct_error_type(vm: &VirtualMachine) -> &'static PyTypeRef {
719 static_cell! {
720 static INSTANCE: PyTypeRef;
721 }
722 INSTANCE.get_or_init(|| vm.ctx.new_exception_type("struct", "error", None))
723}
724
725pub fn new_struct_error(vm: &VirtualMachine, msg: impl Into<String>) -> PyBaseExceptionRef {
726 let msg: String = msg.into();
729 vm.new_exception_msg(struct_error_type(vm).clone(), msg.into())
730}