Skip to main content

rustpython_vm/types/
structseq.rs

1use crate::common::lock::LazyLock;
2use crate::{
3    AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
4    builtins::{PyBaseExceptionRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef},
5    class::{PyClassImpl, StaticType},
6    function::{Either, FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags},
7    iter::PyExactSizeIterator,
8    protocol::{PyMappingMethods, PySequenceMethods},
9    sliceable::{SequenceIndex, SliceableSequenceOp},
10    types::PyComparisonOp,
11    vm::Context,
12};
13
14const DEFAULT_STRUCTSEQ_REDUCE: PyMethodDef = PyMethodDef::new_const(
15    "__reduce__",
16    |zelf: PyRef<PyTuple>, vm: &VirtualMachine| -> PyTupleRef {
17        vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
18    },
19    PyMethodFlags::METHOD,
20    None,
21);
22
23/// Create a new struct sequence instance from a sequence.
24///
25/// The class must have `n_sequence_fields` and `n_fields` attributes set
26/// (done automatically by `PyStructSequence::extend_pyclass`).
27pub fn struct_sequence_new(cls: PyTypeRef, seq: PyObjectRef, vm: &VirtualMachine) -> PyResult {
28    // = structseq_new
29
30    #[cold]
31    fn length_error(
32        tp_name: &str,
33        min_len: usize,
34        max_len: usize,
35        len: usize,
36        vm: &VirtualMachine,
37    ) -> PyBaseExceptionRef {
38        if min_len == max_len {
39            vm.new_type_error(format!(
40                "{tp_name}() takes a {min_len}-sequence ({len}-sequence given)"
41            ))
42        } else if len < min_len {
43            vm.new_type_error(format!(
44                "{tp_name}() takes an at least {min_len}-sequence ({len}-sequence given)"
45            ))
46        } else {
47            vm.new_type_error(format!(
48                "{tp_name}() takes an at most {max_len}-sequence ({len}-sequence given)"
49            ))
50        }
51    }
52
53    let min_len: usize = cls
54        .get_attr(identifier!(vm.ctx, n_sequence_fields))
55        .ok_or_else(|| vm.new_type_error("missing n_sequence_fields attribute"))?
56        .try_into_value(vm)?;
57    let max_len: usize = cls
58        .get_attr(identifier!(vm.ctx, n_fields))
59        .ok_or_else(|| vm.new_type_error("missing n_fields attribute"))?
60        .try_into_value(vm)?;
61
62    let seq: Vec<PyObjectRef> = seq.try_into_value(vm)?;
63    let len = seq.len();
64
65    if len < min_len || len > max_len {
66        return Err(length_error(&cls.slot_name(), min_len, max_len, len, vm));
67    }
68
69    // Copy items and pad with None
70    let mut items = seq;
71    items.resize_with(max_len, || vm.ctx.none());
72
73    PyTuple::new_unchecked(items.into_boxed_slice())
74        .into_ref_with_type(vm, cls)
75        .map(Into::into)
76}
77
78fn get_visible_len(obj: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
79    obj.class()
80        .get_attr(identifier!(vm.ctx, n_sequence_fields))
81        .ok_or_else(|| vm.new_type_error("missing n_sequence_fields"))?
82        .try_into_value(vm)
83}
84
85/// Sequence methods for struct sequences.
86/// Uses n_sequence_fields to determine visible length.
87static STRUCT_SEQUENCE_AS_SEQUENCE: LazyLock<PySequenceMethods> =
88    LazyLock::new(|| PySequenceMethods {
89        length: atomic_func!(|seq, vm| get_visible_len(seq.obj, vm)),
90        concat: atomic_func!(|seq, other, vm| {
91            // Convert to visible-only tuple, then use regular tuple concat
92            let n_seq = get_visible_len(seq.obj, vm)?;
93            let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
94            let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
95            let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
96            // Use tuple's concat implementation
97            visible_tuple
98                .as_object()
99                .sequence_unchecked()
100                .concat(other, vm)
101        }),
102        repeat: atomic_func!(|seq, n, vm| {
103            // Convert to visible-only tuple, then use regular tuple repeat
104            let n_seq = get_visible_len(seq.obj, vm)?;
105            let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
106            let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
107            let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
108            // Use tuple's repeat implementation
109            visible_tuple.as_object().sequence_unchecked().repeat(n, vm)
110        }),
111        item: atomic_func!(|seq, i, vm| {
112            let n_seq = get_visible_len(seq.obj, vm)?;
113            let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
114            let idx = if i < 0 {
115                let pos_i = n_seq as isize + i;
116                if pos_i < 0 {
117                    return Err(vm.new_index_error("tuple index out of range"));
118                }
119                pos_i as usize
120            } else {
121                i as usize
122            };
123            if idx >= n_seq {
124                return Err(vm.new_index_error("tuple index out of range"));
125            }
126            Ok(tuple[idx].clone())
127        }),
128        contains: atomic_func!(|seq, needle, vm| {
129            let n_seq = get_visible_len(seq.obj, vm)?;
130            let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
131            for item in tuple.iter().take(n_seq) {
132                if item.rich_compare_bool(needle, PyComparisonOp::Eq, vm)? {
133                    return Ok(true);
134                }
135            }
136            Ok(false)
137        }),
138        ..PySequenceMethods::NOT_IMPLEMENTED
139    });
140
141/// Mapping methods for struct sequences.
142/// Handles subscript (indexing) with visible length bounds.
143static STRUCT_SEQUENCE_AS_MAPPING: LazyLock<PyMappingMethods> =
144    LazyLock::new(|| PyMappingMethods {
145        length: atomic_func!(|mapping, vm| get_visible_len(mapping.obj, vm)),
146        subscript: atomic_func!(|mapping, needle, vm| {
147            let n_seq = get_visible_len(mapping.obj, vm)?;
148            let tuple = mapping.obj.downcast_ref::<PyTuple>().unwrap();
149            let visible_elements = &tuple.as_slice()[..n_seq];
150
151            match SequenceIndex::try_from_borrowed_object(vm, needle, "tuple")? {
152                SequenceIndex::Int(i) => visible_elements.getitem_by_index(vm, i),
153                SequenceIndex::Slice(slice) => visible_elements
154                    .getitem_by_slice(vm, slice)
155                    .map(|x| vm.ctx.new_tuple(x).into()),
156            }
157        }),
158        ..PyMappingMethods::NOT_IMPLEMENTED
159    });
160
161/// Trait for Data structs that back a PyStructSequence.
162///
163/// This trait is implemented by `#[pystruct_sequence_data]` on the Data struct.
164/// It provides field information, tuple conversion, and element parsing.
165pub trait PyStructSequenceData: Sized {
166    /// Names of required fields (in order). Shown in repr.
167    const REQUIRED_FIELD_NAMES: &'static [&'static str];
168
169    /// Names of optional/skipped fields (in order, after required fields).
170    const OPTIONAL_FIELD_NAMES: &'static [&'static str];
171
172    /// Number of unnamed fields (visible but index-only access).
173    const UNNAMED_FIELDS_LEN: usize = 0;
174
175    /// Convert this Data struct into a PyTuple.
176    fn into_tuple(self, vm: &VirtualMachine) -> PyTuple;
177
178    /// Construct this Data struct from tuple elements.
179    /// Default implementation returns an error.
180    /// Override with `#[pystruct_sequence_data(try_from_object)]` to enable.
181    fn try_from_elements(_elements: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Self> {
182        Err(vm.new_type_error("This struct sequence does not support construction from elements"))
183    }
184}
185
186/// Trait for Python struct sequence types.
187///
188/// This trait is implemented by the `#[pystruct_sequence]` macro on the Python type struct.
189/// It connects to the Data struct and provides Python-level functionality.
190#[pyclass]
191pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
192    /// The Data struct that provides field definitions.
193    type Data: PyStructSequenceData;
194
195    /// Convert a Data struct into a PyStructSequence instance.
196    fn from_data(data: Self::Data, vm: &VirtualMachine) -> PyTupleRef {
197        let tuple =
198            <Self::Data as ::rustpython_vm::types::PyStructSequenceData>::into_tuple(data, vm);
199        let typ = Self::static_type();
200        tuple
201            .into_ref_with_type(vm, typ.to_owned())
202            .expect("Every PyStructSequence must be a valid tuple. This is a RustPython bug.")
203    }
204
205    #[pyslot]
206    fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
207        let zelf = zelf
208            .downcast_ref::<PyTuple>()
209            .ok_or_else(|| vm.new_type_error("unexpected payload for __repr__"))?;
210
211        let field_names = Self::Data::REQUIRED_FIELD_NAMES;
212        let format_field = |(value, name): (&PyObject, _)| {
213            let s = value.repr(vm)?;
214            Ok(format!("{name}={s}"))
215        };
216        let (body, suffix) =
217            if let Some(_guard) = rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_ref()) {
218                let fields: PyResult<Vec<_>> = zelf
219                    .iter()
220                    .map(|value| value.as_ref())
221                    .zip(field_names.iter().copied())
222                    .map(format_field)
223                    .collect();
224                (fields?.join(", "), "")
225            } else {
226                (String::new(), "...")
227            };
228        // Build qualified name: if MODULE_NAME is already in TP_NAME, use it directly.
229        // Otherwise, check __module__ attribute (set by #[pymodule] at runtime).
230        let type_name = if Self::MODULE_NAME.is_some() {
231            alloc::borrow::Cow::Borrowed(Self::TP_NAME)
232        } else {
233            let typ = zelf.class();
234            match typ.get_attr(identifier!(vm.ctx, __module__)) {
235                Some(module) if module.downcastable::<PyStr>() => {
236                    let module_str = module.downcast_ref::<PyStr>().unwrap();
237                    alloc::borrow::Cow::Owned(format!("{}.{}", module_str.as_wtf8(), Self::NAME))
238                }
239                _ => alloc::borrow::Cow::Borrowed(Self::TP_NAME),
240            }
241        };
242        let repr_str = format!("{}({}{})", type_name, body, suffix);
243        Ok(vm.ctx.new_str(repr_str))
244    }
245
246    #[pymethod]
247    fn __replace__(zelf: PyRef<PyTuple>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
248        if !args.args.is_empty() {
249            return Err(vm.new_type_error("__replace__() takes no positional arguments"));
250        }
251
252        if Self::Data::UNNAMED_FIELDS_LEN > 0 {
253            return Err(vm.new_type_error(format!(
254                "__replace__() is not supported for {} because it has unnamed field(s)",
255                zelf.class().slot_name()
256            )));
257        }
258
259        let n_fields =
260            Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::OPTIONAL_FIELD_NAMES.len();
261        let mut items: Vec<PyObjectRef> = zelf.as_slice()[..n_fields].to_vec();
262
263        let mut kwargs = args.kwargs.clone();
264
265        // Replace fields from kwargs
266        let all_field_names: Vec<&str> = Self::Data::REQUIRED_FIELD_NAMES
267            .iter()
268            .chain(Self::Data::OPTIONAL_FIELD_NAMES.iter())
269            .copied()
270            .collect();
271        for (i, &name) in all_field_names.iter().enumerate() {
272            if let Some(val) = kwargs.shift_remove(name) {
273                items[i] = val;
274            }
275        }
276
277        // Check for unexpected keyword arguments
278        if !kwargs.is_empty() {
279            let names: Vec<&str> = kwargs.keys().map(|k| k.as_str()).collect();
280            return Err(vm.new_type_error(format!("Got unexpected field name(s): {:?}", names)));
281        }
282
283        PyTuple::new_unchecked(items.into_boxed_slice())
284            .into_ref_with_type(vm, zelf.class().to_owned())
285            .map(Into::into)
286    }
287
288    #[pymethod]
289    fn __getitem__(zelf: PyRef<PyTuple>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
290        let n_seq = get_visible_len(zelf.as_ref(), vm)?;
291        let visible_elements = &zelf.as_slice()[..n_seq];
292
293        match SequenceIndex::try_from_borrowed_object(vm, &needle, "tuple")? {
294            SequenceIndex::Int(i) => visible_elements.getitem_by_index(vm, i),
295            SequenceIndex::Slice(slice) => visible_elements
296                .getitem_by_slice(vm, slice)
297                .map(|x| vm.ctx.new_tuple(x).into()),
298        }
299    }
300
301    #[extend_class]
302    fn extend_pyclass(ctx: &Context, class: &'static Py<PyType>) {
303        // Getters for named visible fields (indices 0 to REQUIRED_FIELD_NAMES.len() - 1)
304        for (i, &name) in Self::Data::REQUIRED_FIELD_NAMES.iter().enumerate() {
305            // cast i to a u8 so there's less to store in the getter closure.
306            // Hopefully there's not struct sequences with >=256 elements :P
307            let i = i as u8;
308            class.set_attr(
309                ctx.intern_str(name),
310                ctx.new_readonly_getset(name, class, move |zelf: &PyTuple| {
311                    zelf[i as usize].to_owned()
312                })
313                .into(),
314            );
315        }
316
317        // Getters for hidden/skipped fields (indices after visible fields)
318        let visible_count = Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::UNNAMED_FIELDS_LEN;
319        for (i, &name) in Self::Data::OPTIONAL_FIELD_NAMES.iter().enumerate() {
320            let idx = (visible_count + i) as u8;
321            class.set_attr(
322                ctx.intern_str(name),
323                ctx.new_readonly_getset(name, class, move |zelf: &PyTuple| {
324                    zelf[idx as usize].to_owned()
325                })
326                .into(),
327            );
328        }
329
330        class.set_attr(
331            identifier!(ctx, __match_args__),
332            ctx.new_tuple(
333                Self::Data::REQUIRED_FIELD_NAMES
334                    .iter()
335                    .map(|&name| ctx.new_str(name).into())
336                    .collect::<Vec<_>>(),
337            )
338            .into(),
339        );
340
341        // special fields:
342        // n_sequence_fields = visible fields (named + unnamed)
343        // n_fields = all fields (visible + hidden/skipped)
344        // n_unnamed_fields
345        let n_unnamed_fields = Self::Data::UNNAMED_FIELDS_LEN;
346        let n_sequence_fields = Self::Data::REQUIRED_FIELD_NAMES.len() + n_unnamed_fields;
347        let n_fields = n_sequence_fields + Self::Data::OPTIONAL_FIELD_NAMES.len();
348        class.set_attr(
349            identifier!(ctx, n_sequence_fields),
350            ctx.new_int(n_sequence_fields).into(),
351        );
352        class.set_attr(identifier!(ctx, n_fields), ctx.new_int(n_fields).into());
353        class.set_attr(
354            identifier!(ctx, n_unnamed_fields),
355            ctx.new_int(n_unnamed_fields).into(),
356        );
357
358        // Override as_sequence and as_mapping slots to use visible length
359        class
360            .slots
361            .as_sequence
362            .copy_from(&STRUCT_SEQUENCE_AS_SEQUENCE);
363        class
364            .slots
365            .as_mapping
366            .copy_from(&STRUCT_SEQUENCE_AS_MAPPING);
367
368        // Override iter slot to return only visible elements
369        class.slots.iter.store(Some(struct_sequence_iter));
370
371        // Override hash slot to hash only visible elements
372        class.slots.hash.store(Some(struct_sequence_hash));
373
374        // Override richcompare slot to compare only visible elements
375        class
376            .slots
377            .richcompare
378            .store(Some(struct_sequence_richcompare));
379
380        // Default __reduce__: only set if not already overridden by the impl's extend_class.
381        // This allows struct sequences like sched_param to provide a custom __reduce__
382        // (equivalent to METH_COEXIST in structseq.c).
383        if !class
384            .attributes
385            .read()
386            .contains_key(ctx.intern_str("__reduce__"))
387        {
388            class.set_attr(
389                ctx.intern_str("__reduce__"),
390                DEFAULT_STRUCTSEQ_REDUCE.to_proper_method(class, ctx),
391            );
392        }
393    }
394}
395
396/// Iterator function for struct sequences - returns only visible elements
397fn struct_sequence_iter(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult {
398    let tuple = zelf
399        .downcast_ref::<PyTuple>()
400        .ok_or_else(|| vm.new_type_error("expected tuple"))?;
401    let n_seq = get_visible_len(&zelf, vm)?;
402    let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
403    let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
404    visible_tuple
405        .as_object()
406        .to_owned()
407        .get_iter(vm)
408        .map(Into::into)
409}
410
411/// Hash function for struct sequences - hashes only visible elements
412fn struct_sequence_hash(
413    zelf: &PyObject,
414    vm: &VirtualMachine,
415) -> PyResult<crate::common::hash::PyHash> {
416    let tuple = zelf
417        .downcast_ref::<PyTuple>()
418        .ok_or_else(|| vm.new_type_error("expected tuple"))?;
419    let n_seq = get_visible_len(zelf, vm)?;
420    // Create a visible-only tuple and hash it
421    let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
422    let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
423    visible_tuple.as_object().hash(vm)
424}
425
426/// Rich comparison for struct sequences - compares only visible elements
427fn struct_sequence_richcompare(
428    zelf: &PyObject,
429    other: &PyObject,
430    op: PyComparisonOp,
431    vm: &VirtualMachine,
432) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
433    let zelf_tuple = zelf
434        .downcast_ref::<PyTuple>()
435        .ok_or_else(|| vm.new_type_error("expected tuple"))?;
436
437    // If other is not a tuple, return NotImplemented
438    let Some(other_tuple) = other.downcast_ref::<PyTuple>() else {
439        return Ok(Either::B(PyComparisonValue::NotImplemented));
440    };
441
442    let zelf_len = get_visible_len(zelf, vm)?;
443    // For other, try to get visible len; if it fails (not a struct sequence), use full length
444    let other_len = get_visible_len(other, vm).unwrap_or(other_tuple.len());
445
446    let zelf_visible = &zelf_tuple.as_slice()[..zelf_len];
447    let other_visible = &other_tuple.as_slice()[..other_len];
448
449    // Use the same comparison logic as regular tuples
450    zelf_visible
451        .iter()
452        .richcompare(other_visible.iter(), op, vm)
453        .map(|v| Either::B(PyComparisonValue::Implemented(v)))
454}