1use crate::common::lock::LazyLock;
2use crate::{
3 AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
4 builtins::{PyBaseExceptionRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef},
5 class::{PyClassImpl, StaticType},
6 function::{Either, FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags},
7 iter::PyExactSizeIterator,
8 protocol::{PyMappingMethods, PySequenceMethods},
9 sliceable::{SequenceIndex, SliceableSequenceOp},
10 types::PyComparisonOp,
11 vm::Context,
12};
13
14const DEFAULT_STRUCTSEQ_REDUCE: PyMethodDef = PyMethodDef::new_const(
15 "__reduce__",
16 |zelf: PyRef<PyTuple>, vm: &VirtualMachine| -> PyTupleRef {
17 vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
18 },
19 PyMethodFlags::METHOD,
20 None,
21);
22
23pub fn struct_sequence_new(cls: PyTypeRef, seq: PyObjectRef, vm: &VirtualMachine) -> PyResult {
28 #[cold]
31 fn length_error(
32 tp_name: &str,
33 min_len: usize,
34 max_len: usize,
35 len: usize,
36 vm: &VirtualMachine,
37 ) -> PyBaseExceptionRef {
38 if min_len == max_len {
39 vm.new_type_error(format!(
40 "{tp_name}() takes a {min_len}-sequence ({len}-sequence given)"
41 ))
42 } else if len < min_len {
43 vm.new_type_error(format!(
44 "{tp_name}() takes an at least {min_len}-sequence ({len}-sequence given)"
45 ))
46 } else {
47 vm.new_type_error(format!(
48 "{tp_name}() takes an at most {max_len}-sequence ({len}-sequence given)"
49 ))
50 }
51 }
52
53 let min_len: usize = cls
54 .get_attr(identifier!(vm.ctx, n_sequence_fields))
55 .ok_or_else(|| vm.new_type_error("missing n_sequence_fields attribute"))?
56 .try_into_value(vm)?;
57 let max_len: usize = cls
58 .get_attr(identifier!(vm.ctx, n_fields))
59 .ok_or_else(|| vm.new_type_error("missing n_fields attribute"))?
60 .try_into_value(vm)?;
61
62 let seq: Vec<PyObjectRef> = seq.try_into_value(vm)?;
63 let len = seq.len();
64
65 if len < min_len || len > max_len {
66 return Err(length_error(&cls.slot_name(), min_len, max_len, len, vm));
67 }
68
69 let mut items = seq;
71 items.resize_with(max_len, || vm.ctx.none());
72
73 PyTuple::new_unchecked(items.into_boxed_slice())
74 .into_ref_with_type(vm, cls)
75 .map(Into::into)
76}
77
78fn get_visible_len(obj: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
79 obj.class()
80 .get_attr(identifier!(vm.ctx, n_sequence_fields))
81 .ok_or_else(|| vm.new_type_error("missing n_sequence_fields"))?
82 .try_into_value(vm)
83}
84
85static STRUCT_SEQUENCE_AS_SEQUENCE: LazyLock<PySequenceMethods> =
88 LazyLock::new(|| PySequenceMethods {
89 length: atomic_func!(|seq, vm| get_visible_len(seq.obj, vm)),
90 concat: atomic_func!(|seq, other, vm| {
91 let n_seq = get_visible_len(seq.obj, vm)?;
93 let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
94 let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
95 let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
96 visible_tuple
98 .as_object()
99 .sequence_unchecked()
100 .concat(other, vm)
101 }),
102 repeat: atomic_func!(|seq, n, vm| {
103 let n_seq = get_visible_len(seq.obj, vm)?;
105 let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
106 let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
107 let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
108 visible_tuple.as_object().sequence_unchecked().repeat(n, vm)
110 }),
111 item: atomic_func!(|seq, i, vm| {
112 let n_seq = get_visible_len(seq.obj, vm)?;
113 let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
114 let idx = if i < 0 {
115 let pos_i = n_seq as isize + i;
116 if pos_i < 0 {
117 return Err(vm.new_index_error("tuple index out of range"));
118 }
119 pos_i as usize
120 } else {
121 i as usize
122 };
123 if idx >= n_seq {
124 return Err(vm.new_index_error("tuple index out of range"));
125 }
126 Ok(tuple[idx].clone())
127 }),
128 contains: atomic_func!(|seq, needle, vm| {
129 let n_seq = get_visible_len(seq.obj, vm)?;
130 let tuple = seq.obj.downcast_ref::<PyTuple>().unwrap();
131 for item in tuple.iter().take(n_seq) {
132 if item.rich_compare_bool(needle, PyComparisonOp::Eq, vm)? {
133 return Ok(true);
134 }
135 }
136 Ok(false)
137 }),
138 ..PySequenceMethods::NOT_IMPLEMENTED
139 });
140
141static STRUCT_SEQUENCE_AS_MAPPING: LazyLock<PyMappingMethods> =
144 LazyLock::new(|| PyMappingMethods {
145 length: atomic_func!(|mapping, vm| get_visible_len(mapping.obj, vm)),
146 subscript: atomic_func!(|mapping, needle, vm| {
147 let n_seq = get_visible_len(mapping.obj, vm)?;
148 let tuple = mapping.obj.downcast_ref::<PyTuple>().unwrap();
149 let visible_elements = &tuple.as_slice()[..n_seq];
150
151 match SequenceIndex::try_from_borrowed_object(vm, needle, "tuple")? {
152 SequenceIndex::Int(i) => visible_elements.getitem_by_index(vm, i),
153 SequenceIndex::Slice(slice) => visible_elements
154 .getitem_by_slice(vm, slice)
155 .map(|x| vm.ctx.new_tuple(x).into()),
156 }
157 }),
158 ..PyMappingMethods::NOT_IMPLEMENTED
159 });
160
161pub trait PyStructSequenceData: Sized {
166 const REQUIRED_FIELD_NAMES: &'static [&'static str];
168
169 const OPTIONAL_FIELD_NAMES: &'static [&'static str];
171
172 const UNNAMED_FIELDS_LEN: usize = 0;
174
175 fn into_tuple(self, vm: &VirtualMachine) -> PyTuple;
177
178 fn try_from_elements(_elements: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Self> {
182 Err(vm.new_type_error("This struct sequence does not support construction from elements"))
183 }
184}
185
186#[pyclass]
191pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
192 type Data: PyStructSequenceData;
194
195 fn from_data(data: Self::Data, vm: &VirtualMachine) -> PyTupleRef {
197 let tuple =
198 <Self::Data as ::rustpython_vm::types::PyStructSequenceData>::into_tuple(data, vm);
199 let typ = Self::static_type();
200 tuple
201 .into_ref_with_type(vm, typ.to_owned())
202 .expect("Every PyStructSequence must be a valid tuple. This is a RustPython bug.")
203 }
204
205 #[pyslot]
206 fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
207 let zelf = zelf
208 .downcast_ref::<PyTuple>()
209 .ok_or_else(|| vm.new_type_error("unexpected payload for __repr__"))?;
210
211 let field_names = Self::Data::REQUIRED_FIELD_NAMES;
212 let format_field = |(value, name): (&PyObject, _)| {
213 let s = value.repr(vm)?;
214 Ok(format!("{name}={s}"))
215 };
216 let (body, suffix) =
217 if let Some(_guard) = rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_ref()) {
218 let fields: PyResult<Vec<_>> = zelf
219 .iter()
220 .map(|value| value.as_ref())
221 .zip(field_names.iter().copied())
222 .map(format_field)
223 .collect();
224 (fields?.join(", "), "")
225 } else {
226 (String::new(), "...")
227 };
228 let type_name = if Self::MODULE_NAME.is_some() {
231 alloc::borrow::Cow::Borrowed(Self::TP_NAME)
232 } else {
233 let typ = zelf.class();
234 match typ.get_attr(identifier!(vm.ctx, __module__)) {
235 Some(module) if module.downcastable::<PyStr>() => {
236 let module_str = module.downcast_ref::<PyStr>().unwrap();
237 alloc::borrow::Cow::Owned(format!("{}.{}", module_str.as_wtf8(), Self::NAME))
238 }
239 _ => alloc::borrow::Cow::Borrowed(Self::TP_NAME),
240 }
241 };
242 let repr_str = format!("{}({}{})", type_name, body, suffix);
243 Ok(vm.ctx.new_str(repr_str))
244 }
245
246 #[pymethod]
247 fn __replace__(zelf: PyRef<PyTuple>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
248 if !args.args.is_empty() {
249 return Err(vm.new_type_error("__replace__() takes no positional arguments"));
250 }
251
252 if Self::Data::UNNAMED_FIELDS_LEN > 0 {
253 return Err(vm.new_type_error(format!(
254 "__replace__() is not supported for {} because it has unnamed field(s)",
255 zelf.class().slot_name()
256 )));
257 }
258
259 let n_fields =
260 Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::OPTIONAL_FIELD_NAMES.len();
261 let mut items: Vec<PyObjectRef> = zelf.as_slice()[..n_fields].to_vec();
262
263 let mut kwargs = args.kwargs.clone();
264
265 let all_field_names: Vec<&str> = Self::Data::REQUIRED_FIELD_NAMES
267 .iter()
268 .chain(Self::Data::OPTIONAL_FIELD_NAMES.iter())
269 .copied()
270 .collect();
271 for (i, &name) in all_field_names.iter().enumerate() {
272 if let Some(val) = kwargs.shift_remove(name) {
273 items[i] = val;
274 }
275 }
276
277 if !kwargs.is_empty() {
279 let names: Vec<&str> = kwargs.keys().map(|k| k.as_str()).collect();
280 return Err(vm.new_type_error(format!("Got unexpected field name(s): {:?}", names)));
281 }
282
283 PyTuple::new_unchecked(items.into_boxed_slice())
284 .into_ref_with_type(vm, zelf.class().to_owned())
285 .map(Into::into)
286 }
287
288 #[pymethod]
289 fn __getitem__(zelf: PyRef<PyTuple>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
290 let n_seq = get_visible_len(zelf.as_ref(), vm)?;
291 let visible_elements = &zelf.as_slice()[..n_seq];
292
293 match SequenceIndex::try_from_borrowed_object(vm, &needle, "tuple")? {
294 SequenceIndex::Int(i) => visible_elements.getitem_by_index(vm, i),
295 SequenceIndex::Slice(slice) => visible_elements
296 .getitem_by_slice(vm, slice)
297 .map(|x| vm.ctx.new_tuple(x).into()),
298 }
299 }
300
301 #[extend_class]
302 fn extend_pyclass(ctx: &Context, class: &'static Py<PyType>) {
303 for (i, &name) in Self::Data::REQUIRED_FIELD_NAMES.iter().enumerate() {
305 let i = i as u8;
308 class.set_attr(
309 ctx.intern_str(name),
310 ctx.new_readonly_getset(name, class, move |zelf: &PyTuple| {
311 zelf[i as usize].to_owned()
312 })
313 .into(),
314 );
315 }
316
317 let visible_count = Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::UNNAMED_FIELDS_LEN;
319 for (i, &name) in Self::Data::OPTIONAL_FIELD_NAMES.iter().enumerate() {
320 let idx = (visible_count + i) as u8;
321 class.set_attr(
322 ctx.intern_str(name),
323 ctx.new_readonly_getset(name, class, move |zelf: &PyTuple| {
324 zelf[idx as usize].to_owned()
325 })
326 .into(),
327 );
328 }
329
330 class.set_attr(
331 identifier!(ctx, __match_args__),
332 ctx.new_tuple(
333 Self::Data::REQUIRED_FIELD_NAMES
334 .iter()
335 .map(|&name| ctx.new_str(name).into())
336 .collect::<Vec<_>>(),
337 )
338 .into(),
339 );
340
341 let n_unnamed_fields = Self::Data::UNNAMED_FIELDS_LEN;
346 let n_sequence_fields = Self::Data::REQUIRED_FIELD_NAMES.len() + n_unnamed_fields;
347 let n_fields = n_sequence_fields + Self::Data::OPTIONAL_FIELD_NAMES.len();
348 class.set_attr(
349 identifier!(ctx, n_sequence_fields),
350 ctx.new_int(n_sequence_fields).into(),
351 );
352 class.set_attr(identifier!(ctx, n_fields), ctx.new_int(n_fields).into());
353 class.set_attr(
354 identifier!(ctx, n_unnamed_fields),
355 ctx.new_int(n_unnamed_fields).into(),
356 );
357
358 class
360 .slots
361 .as_sequence
362 .copy_from(&STRUCT_SEQUENCE_AS_SEQUENCE);
363 class
364 .slots
365 .as_mapping
366 .copy_from(&STRUCT_SEQUENCE_AS_MAPPING);
367
368 class.slots.iter.store(Some(struct_sequence_iter));
370
371 class.slots.hash.store(Some(struct_sequence_hash));
373
374 class
376 .slots
377 .richcompare
378 .store(Some(struct_sequence_richcompare));
379
380 if !class
384 .attributes
385 .read()
386 .contains_key(ctx.intern_str("__reduce__"))
387 {
388 class.set_attr(
389 ctx.intern_str("__reduce__"),
390 DEFAULT_STRUCTSEQ_REDUCE.to_proper_method(class, ctx),
391 );
392 }
393 }
394}
395
396fn struct_sequence_iter(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult {
398 let tuple = zelf
399 .downcast_ref::<PyTuple>()
400 .ok_or_else(|| vm.new_type_error("expected tuple"))?;
401 let n_seq = get_visible_len(&zelf, vm)?;
402 let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
403 let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
404 visible_tuple
405 .as_object()
406 .to_owned()
407 .get_iter(vm)
408 .map(Into::into)
409}
410
411fn struct_sequence_hash(
413 zelf: &PyObject,
414 vm: &VirtualMachine,
415) -> PyResult<crate::common::hash::PyHash> {
416 let tuple = zelf
417 .downcast_ref::<PyTuple>()
418 .ok_or_else(|| vm.new_type_error("expected tuple"))?;
419 let n_seq = get_visible_len(zelf, vm)?;
420 let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect();
422 let visible_tuple = PyTuple::new_ref(visible, &vm.ctx);
423 visible_tuple.as_object().hash(vm)
424}
425
426fn struct_sequence_richcompare(
428 zelf: &PyObject,
429 other: &PyObject,
430 op: PyComparisonOp,
431 vm: &VirtualMachine,
432) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
433 let zelf_tuple = zelf
434 .downcast_ref::<PyTuple>()
435 .ok_or_else(|| vm.new_type_error("expected tuple"))?;
436
437 let Some(other_tuple) = other.downcast_ref::<PyTuple>() else {
439 return Ok(Either::B(PyComparisonValue::NotImplemented));
440 };
441
442 let zelf_len = get_visible_len(zelf, vm)?;
443 let other_len = get_visible_len(other, vm).unwrap_or(other_tuple.len());
445
446 let zelf_visible = &zelf_tuple.as_slice()[..zelf_len];
447 let other_visible = &other_tuple.as_slice()[..other_len];
448
449 zelf_visible
451 .iter()
452 .richcompare(other_visible.iter(), op, vm)
453 .map(|v| Either::B(PyComparisonValue::Implemented(v)))
454}