rustpython_vm/
buffer.rs

1use crate::{
2    builtins::{PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef},
3    common::{static_cell, str::wchar_t},
4    convert::ToPyObject,
5    function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat},
6    PyObjectRef, PyResult, TryFromObject, VirtualMachine,
7};
8use half::f16;
9use itertools::Itertools;
10use malachite_bigint::BigInt;
11use num_traits::{PrimInt, ToPrimitive};
12use std::{fmt, iter::Peekable, mem, os::raw};
13
14type PackFunc = fn(&VirtualMachine, PyObjectRef, &mut [u8]) -> PyResult<()>;
15type UnpackFunc = fn(&VirtualMachine, &[u8]) -> PyObjectRef;
16
17static OVERFLOW_MSG: &str = "total struct size too long"; // not a const to reduce code size
18
19#[derive(Debug, Copy, Clone, PartialEq)]
20pub(crate) enum Endianness {
21    Native,
22    Little,
23    Big,
24    Host,
25}
26
27impl Endianness {
28    /// Parse endianness
29    /// See also: https://docs.python.org/3/library/struct.html?highlight=struct#byte-order-size-and-alignment
30    fn parse<I>(chars: &mut Peekable<I>) -> Endianness
31    where
32        I: Sized + Iterator<Item = u8>,
33    {
34        let e = match chars.peek() {
35            Some(b'@') => Endianness::Native,
36            Some(b'=') => Endianness::Host,
37            Some(b'<') => Endianness::Little,
38            Some(b'>') | Some(b'!') => Endianness::Big,
39            _ => return Endianness::Native,
40        };
41        chars.next().unwrap();
42        e
43    }
44}
45
46trait ByteOrder {
47    fn convert<I: PrimInt>(i: I) -> I;
48}
49enum BigEndian {}
50impl ByteOrder for BigEndian {
51    fn convert<I: PrimInt>(i: I) -> I {
52        i.to_be()
53    }
54}
55enum LittleEndian {}
56impl ByteOrder for LittleEndian {
57    fn convert<I: PrimInt>(i: I) -> I {
58        i.to_le()
59    }
60}
61
62#[cfg(target_endian = "big")]
63type NativeEndian = BigEndian;
64#[cfg(target_endian = "little")]
65type NativeEndian = LittleEndian;
66
67#[derive(Copy, Clone, num_enum::TryFromPrimitive)]
68#[repr(u8)]
69pub(crate) enum FormatType {
70    Pad = b'x',
71    SByte = b'b',
72    UByte = b'B',
73    Char = b'c',
74    WideChar = b'u',
75    Str = b's',
76    Pascal = b'p',
77    Short = b'h',
78    UShort = b'H',
79    Int = b'i',
80    UInt = b'I',
81    Long = b'l',
82    ULong = b'L',
83    SSizeT = b'n',
84    SizeT = b'N',
85    LongLong = b'q',
86    ULongLong = b'Q',
87    Bool = b'?',
88    Half = b'e',
89    Float = b'f',
90    Double = b'd',
91    VoidP = b'P',
92}
93
94impl fmt::Debug for FormatType {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        fmt::Debug::fmt(&(*self as u8 as char), f)
97    }
98}
99
100impl FormatType {
101    fn info(self, e: Endianness) -> &'static FormatInfo {
102        use mem::{align_of, size_of};
103        use FormatType::*;
104        macro_rules! native_info {
105            ($t:ty) => {{
106                &FormatInfo {
107                    size: size_of::<$t>(),
108                    align: align_of::<$t>(),
109                    pack: Some(<$t as Packable>::pack::<NativeEndian>),
110                    unpack: Some(<$t as Packable>::unpack::<NativeEndian>),
111                }
112            }};
113        }
114        macro_rules! nonnative_info {
115            ($t:ty, $end:ty) => {{
116                &FormatInfo {
117                    size: size_of::<$t>(),
118                    align: 0,
119                    pack: Some(<$t as Packable>::pack::<$end>),
120                    unpack: Some(<$t as Packable>::unpack::<$end>),
121                }
122            }};
123        }
124        macro_rules! match_nonnative {
125            ($zelf:expr, $end:ty) => {{
126                match $zelf {
127                    Pad | Str | Pascal => &FormatInfo {
128                        size: size_of::<u8>(),
129                        align: 0,
130                        pack: None,
131                        unpack: None,
132                    },
133                    SByte => nonnative_info!(i8, $end),
134                    UByte => nonnative_info!(u8, $end),
135                    Char => &FormatInfo {
136                        size: size_of::<u8>(),
137                        align: 0,
138                        pack: Some(pack_char),
139                        unpack: Some(unpack_char),
140                    },
141                    Short => nonnative_info!(i16, $end),
142                    UShort => nonnative_info!(u16, $end),
143                    Int | Long => nonnative_info!(i32, $end),
144                    UInt | ULong => nonnative_info!(u32, $end),
145                    LongLong => nonnative_info!(i64, $end),
146                    ULongLong => nonnative_info!(u64, $end),
147                    Bool => nonnative_info!(bool, $end),
148                    Half => nonnative_info!(f16, $end),
149                    Float => nonnative_info!(f32, $end),
150                    Double => nonnative_info!(f64, $end),
151                    _ => unreachable!(), // size_t or void*
152                }
153            }};
154        }
155        match e {
156            Endianness::Native => match self {
157                Pad | Str | Pascal => &FormatInfo {
158                    size: size_of::<raw::c_char>(),
159                    align: 0,
160                    pack: None,
161                    unpack: None,
162                },
163                SByte => native_info!(raw::c_schar),
164                UByte => native_info!(raw::c_uchar),
165                Char => &FormatInfo {
166                    size: size_of::<raw::c_char>(),
167                    align: 0,
168                    pack: Some(pack_char),
169                    unpack: Some(unpack_char),
170                },
171                WideChar => native_info!(wchar_t),
172                Short => native_info!(raw::c_short),
173                UShort => native_info!(raw::c_ushort),
174                Int => native_info!(raw::c_int),
175                UInt => native_info!(raw::c_uint),
176                Long => native_info!(raw::c_long),
177                ULong => native_info!(raw::c_ulong),
178                SSizeT => native_info!(isize), // ssize_t == isize
179                SizeT => native_info!(usize),  //  size_t == usize
180                LongLong => native_info!(raw::c_longlong),
181                ULongLong => native_info!(raw::c_ulonglong),
182                Bool => native_info!(bool),
183                Half => native_info!(f16),
184                Float => native_info!(raw::c_float),
185                Double => native_info!(raw::c_double),
186                VoidP => native_info!(*mut raw::c_void),
187            },
188            Endianness::Big => match_nonnative!(self, BigEndian),
189            Endianness::Little => match_nonnative!(self, LittleEndian),
190            Endianness::Host => match_nonnative!(self, NativeEndian),
191        }
192    }
193}
194
195#[derive(Debug, Clone)]
196pub(crate) struct FormatCode {
197    pub repeat: usize,
198    pub code: FormatType,
199    pub info: &'static FormatInfo,
200    pub pre_padding: usize,
201}
202
203impl FormatCode {
204    pub fn arg_count(&self) -> usize {
205        match self.code {
206            FormatType::Pad => 0,
207            FormatType::Str | FormatType::Pascal => 1,
208            _ => self.repeat,
209        }
210    }
211
212    pub fn parse<I>(
213        chars: &mut Peekable<I>,
214        endianness: Endianness,
215    ) -> Result<(Vec<Self>, usize, usize), String>
216    where
217        I: Sized + Iterator<Item = u8>,
218    {
219        let mut offset = 0isize;
220        let mut arg_count = 0usize;
221        let mut codes = vec![];
222        while chars.peek().is_some() {
223            // determine repeat operator:
224            let repeat = match chars.peek() {
225                Some(b'0'..=b'9') => {
226                    let mut repeat = 0isize;
227                    while let Some(b'0'..=b'9') = chars.peek() {
228                        if let Some(c) = chars.next() {
229                            let current_digit = c - b'0';
230                            repeat = repeat
231                                .checked_mul(10)
232                                .and_then(|r| r.checked_add(current_digit as _))
233                                .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
234                        }
235                    }
236                    repeat
237                }
238                _ => 1,
239            };
240
241            // determine format char:
242            let c = chars
243                .next()
244                .ok_or_else(|| "repeat count given without format specifier".to_owned())?;
245            let code = FormatType::try_from(c)
246                .ok()
247                .filter(|c| match c {
248                    FormatType::SSizeT | FormatType::SizeT | FormatType::VoidP => {
249                        endianness == Endianness::Native
250                    }
251                    _ => true,
252                })
253                .ok_or_else(|| "bad char in struct format".to_owned())?;
254
255            let info = code.info(endianness);
256
257            let padding = compensate_alignment(offset as usize, info.align)
258                .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
259            offset = padding
260                .to_isize()
261                .and_then(|extra| offset.checked_add(extra))
262                .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
263
264            let code = FormatCode {
265                repeat: repeat as usize,
266                code,
267                info,
268                pre_padding: padding,
269            };
270            arg_count += code.arg_count();
271            codes.push(code);
272
273            offset = (info.size as isize)
274                .checked_mul(repeat)
275                .and_then(|item_size| offset.checked_add(item_size))
276                .ok_or_else(|| OVERFLOW_MSG.to_owned())?;
277        }
278
279        Ok((codes, offset as usize, arg_count))
280    }
281}
282
283fn compensate_alignment(offset: usize, align: usize) -> Option<usize> {
284    if align != 0 && offset != 0 {
285        // a % b == a & (b-1) if b is a power of 2
286        (align - 1).checked_sub((offset - 1) & (align - 1))
287    } else {
288        // alignment is already all good
289        Some(0)
290    }
291}
292
293pub(crate) struct FormatInfo {
294    pub size: usize,
295    pub align: usize,
296    pub pack: Option<PackFunc>,
297    pub unpack: Option<UnpackFunc>,
298}
299impl fmt::Debug for FormatInfo {
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        f.debug_struct("FormatInfo")
302            .field("size", &self.size)
303            .field("align", &self.align)
304            .finish()
305    }
306}
307
308#[derive(Debug, Clone)]
309pub struct FormatSpec {
310    #[allow(dead_code)]
311    pub(crate) endianness: Endianness,
312    pub(crate) codes: Vec<FormatCode>,
313    pub size: usize,
314    pub arg_count: usize,
315}
316
317impl FormatSpec {
318    pub fn parse(fmt: &[u8], vm: &VirtualMachine) -> PyResult<FormatSpec> {
319        let mut chars = fmt.iter().copied().peekable();
320
321        // First determine "@", "<", ">","!" or "="
322        let endianness = Endianness::parse(&mut chars);
323
324        // Now, analyze struct string further:
325        let (codes, size, arg_count) =
326            FormatCode::parse(&mut chars, endianness).map_err(|err| new_struct_error(vm, err))?;
327
328        Ok(FormatSpec {
329            endianness,
330            codes,
331            size,
332            arg_count,
333        })
334    }
335
336    pub fn pack(&self, args: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
337        // Create data vector:
338        let mut data = vec![0; self.size];
339
340        self.pack_into(&mut data, args, vm)?;
341
342        Ok(data)
343    }
344
345    pub fn pack_into(
346        &self,
347        mut buffer: &mut [u8],
348        args: Vec<PyObjectRef>,
349        vm: &VirtualMachine,
350    ) -> PyResult<()> {
351        if self.arg_count != args.len() {
352            return Err(new_struct_error(
353                vm,
354                format!(
355                    "pack expected {} items for packing (got {})",
356                    self.codes.len(),
357                    args.len()
358                ),
359            ));
360        }
361
362        let mut args = args.into_iter();
363        // Loop over all opcodes:
364        for code in &self.codes {
365            buffer = &mut buffer[code.pre_padding..];
366            debug!("code: {:?}", code);
367            match code.code {
368                FormatType::Str => {
369                    let (buf, rest) = buffer.split_at_mut(code.repeat);
370                    pack_string(vm, args.next().unwrap(), buf)?;
371                    buffer = rest;
372                }
373                FormatType::Pascal => {
374                    let (buf, rest) = buffer.split_at_mut(code.repeat);
375                    pack_pascal(vm, args.next().unwrap(), buf)?;
376                    buffer = rest;
377                }
378                FormatType::Pad => {
379                    let (pad_buf, rest) = buffer.split_at_mut(code.repeat);
380                    for el in pad_buf {
381                        *el = 0
382                    }
383                    buffer = rest;
384                }
385                _ => {
386                    let pack = code.info.pack.unwrap();
387                    for arg in args.by_ref().take(code.repeat) {
388                        let (item_buf, rest) = buffer.split_at_mut(code.info.size);
389                        pack(vm, arg, item_buf)?;
390                        buffer = rest;
391                    }
392                }
393            }
394        }
395
396        Ok(())
397    }
398
399    pub fn unpack(&self, mut data: &[u8], vm: &VirtualMachine) -> PyResult<PyTupleRef> {
400        if self.size != data.len() {
401            return Err(new_struct_error(
402                vm,
403                format!("unpack requires a buffer of {} bytes", self.size),
404            ));
405        }
406
407        let mut items = Vec::with_capacity(self.arg_count);
408        for code in &self.codes {
409            data = &data[code.pre_padding..];
410            debug!("unpack code: {:?}", code);
411            match code.code {
412                FormatType::Pad => {
413                    data = &data[code.repeat..];
414                }
415                FormatType::Str => {
416                    let (str_data, rest) = data.split_at(code.repeat);
417                    // string is just stored inline
418                    items.push(vm.ctx.new_bytes(str_data.to_vec()).into());
419                    data = rest;
420                }
421                FormatType::Pascal => {
422                    let (str_data, rest) = data.split_at(code.repeat);
423                    items.push(unpack_pascal(vm, str_data));
424                    data = rest;
425                }
426                _ => {
427                    let unpack = code.info.unpack.unwrap();
428                    for _ in 0..code.repeat {
429                        let (item_data, rest) = data.split_at(code.info.size);
430                        items.push(unpack(vm, item_data));
431                        data = rest;
432                    }
433                }
434            };
435        }
436
437        Ok(PyTuple::new_ref(items, &vm.ctx))
438    }
439
440    #[inline]
441    pub fn size(&self) -> usize {
442        self.size
443    }
444}
445
446trait Packable {
447    fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()>;
448    fn unpack<E: ByteOrder>(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef;
449}
450
451trait PackInt: PrimInt {
452    fn pack_int<E: ByteOrder>(self, data: &mut [u8]);
453    fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self;
454}
455
456macro_rules! make_pack_primint {
457    ($T:ty) => {
458        impl PackInt for $T {
459            fn pack_int<E: ByteOrder>(self, data: &mut [u8]) {
460                let i = E::convert(self);
461                data.copy_from_slice(&i.to_ne_bytes());
462            }
463            #[inline]
464            fn unpack_int<E: ByteOrder>(data: &[u8]) -> Self {
465                let mut x = [0; std::mem::size_of::<$T>()];
466                x.copy_from_slice(data);
467                E::convert(<$T>::from_ne_bytes(x))
468            }
469        }
470
471        impl Packable for $T {
472            fn pack<E: ByteOrder>(
473                vm: &VirtualMachine,
474                arg: PyObjectRef,
475                data: &mut [u8],
476            ) -> PyResult<()> {
477                let i: $T = get_int_or_index(vm, arg)?;
478                i.pack_int::<E>(data);
479                Ok(())
480            }
481
482            fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
483                let i = <$T>::unpack_int::<E>(rdr);
484                vm.ctx.new_int(i).into()
485            }
486        }
487    };
488}
489
490fn get_int_or_index<T>(vm: &VirtualMachine, arg: PyObjectRef) -> PyResult<T>
491where
492    T: PrimInt + for<'a> TryFrom<&'a BigInt>,
493{
494    let index = arg.try_index_opt(vm).unwrap_or_else(|| {
495        Err(new_struct_error(
496            vm,
497            "required argument is not an integer".to_owned(),
498        ))
499    })?;
500    index
501        .try_to_primitive(vm)
502        .map_err(|_| new_struct_error(vm, "argument out of range".to_owned()))
503}
504
505make_pack_primint!(i8);
506make_pack_primint!(u8);
507make_pack_primint!(i16);
508make_pack_primint!(u16);
509make_pack_primint!(i32);
510make_pack_primint!(u32);
511make_pack_primint!(i64);
512make_pack_primint!(u64);
513make_pack_primint!(usize);
514make_pack_primint!(isize);
515
516macro_rules! make_pack_float {
517    ($T:ty) => {
518        impl Packable for $T {
519            fn pack<E: ByteOrder>(
520                vm: &VirtualMachine,
521                arg: PyObjectRef,
522                data: &mut [u8],
523            ) -> PyResult<()> {
524                let f = *ArgIntoFloat::try_from_object(vm, arg)? as $T;
525                f.to_bits().pack_int::<E>(data);
526                Ok(())
527            }
528
529            fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
530                let i = PackInt::unpack_int::<E>(rdr);
531                <$T>::from_bits(i).to_pyobject(vm)
532            }
533        }
534    };
535}
536
537make_pack_float!(f32);
538make_pack_float!(f64);
539
540impl Packable for f16 {
541    fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
542        let f_64 = *ArgIntoFloat::try_from_object(vm, arg)?;
543        let f_16 = f16::from_f64(f_64);
544        if f_16.is_infinite() != f_64.is_infinite() {
545            return Err(vm.new_overflow_error("float too large to pack with e format".to_owned()));
546        }
547        f_16.to_bits().pack_int::<E>(data);
548        Ok(())
549    }
550
551    fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
552        let i = PackInt::unpack_int::<E>(rdr);
553        f16::from_bits(i).to_f64().to_pyobject(vm)
554    }
555}
556
557impl Packable for *mut raw::c_void {
558    fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
559        usize::pack::<E>(vm, arg, data)
560    }
561
562    fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
563        usize::unpack::<E>(vm, rdr)
564    }
565}
566
567impl Packable for bool {
568    fn pack<E: ByteOrder>(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
569        let v = *ArgIntoBool::try_from_object(vm, arg)? as u8;
570        v.pack_int::<E>(data);
571        Ok(())
572    }
573
574    fn unpack<E: ByteOrder>(vm: &VirtualMachine, rdr: &[u8]) -> PyObjectRef {
575        let i = u8::unpack_int::<E>(rdr);
576        vm.ctx.new_bool(i != 0).into()
577    }
578}
579
580fn pack_char(vm: &VirtualMachine, arg: PyObjectRef, data: &mut [u8]) -> PyResult<()> {
581    let v = PyBytesRef::try_from_object(vm, arg)?;
582    let ch = *v.as_bytes().iter().exactly_one().map_err(|_| {
583        new_struct_error(
584            vm,
585            "char format requires a bytes object of length 1".to_owned(),
586        )
587    })?;
588    data[0] = ch;
589    Ok(())
590}
591
592fn pack_string(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
593    let b = ArgBytesLike::try_from_object(vm, arg)?;
594    b.with_ref(|data| write_string(buf, data));
595    Ok(())
596}
597
598fn pack_pascal(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResult<()> {
599    if buf.is_empty() {
600        return Ok(());
601    }
602    let b = ArgBytesLike::try_from_object(vm, arg)?;
603    b.with_ref(|data| {
604        let string_length = std::cmp::min(std::cmp::min(data.len(), 255), buf.len() - 1);
605        buf[0] = string_length as u8;
606        write_string(&mut buf[1..], data);
607    });
608    Ok(())
609}
610
611fn write_string(buf: &mut [u8], data: &[u8]) {
612    let len_from_data = std::cmp::min(data.len(), buf.len());
613    buf[..len_from_data].copy_from_slice(&data[..len_from_data]);
614    for byte in &mut buf[len_from_data..] {
615        *byte = 0
616    }
617}
618
619fn unpack_char(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
620    vm.ctx.new_bytes(vec![data[0]]).into()
621}
622
623fn unpack_pascal(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef {
624    let (&len, data) = match data.split_first() {
625        Some(x) => x,
626        None => {
627            // cpython throws an internal SystemError here
628            return vm.ctx.new_bytes(vec![]).into();
629        }
630    };
631    let len = std::cmp::min(len as usize, data.len());
632    vm.ctx.new_bytes(data[..len].to_vec()).into()
633}
634
635// XXX: are those functions expected to be placed here?
636pub fn struct_error_type(vm: &VirtualMachine) -> &'static PyTypeRef {
637    static_cell! {
638        static INSTANCE: PyTypeRef;
639    }
640    INSTANCE.get_or_init(|| vm.ctx.new_exception_type("struct", "error", None))
641}
642
643pub fn new_struct_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
644    // can't just STRUCT_ERROR.get().unwrap() cause this could be called before from buffer
645    // machinery, independent of whether _struct was ever imported
646    vm.new_exception_msg(struct_error_type(vm).clone(), msg)
647}