Skip to main content

rustpython_vm/protocol/
sequence.rs

1use crate::{
2    PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
3    builtins::{PyList, PyListRef, PySlice, PyTuple, PyTupleRef},
4    convert::ToPyObject,
5    function::PyArithmeticValue,
6    object::{Traverse, TraverseFn},
7    protocol::PyNumberBinaryOp,
8};
9use crossbeam_utils::atomic::AtomicCell;
10use itertools::Itertools;
11
12// Sequence Protocol
13// https://docs.python.org/3/c-api/sequence.html
14
15#[allow(clippy::type_complexity)]
16#[derive(Default)]
17pub struct PySequenceSlots {
18    pub length: AtomicCell<Option<fn(PySequence<'_>, &VirtualMachine) -> PyResult<usize>>>,
19    pub concat: AtomicCell<Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult>>,
20    pub repeat: AtomicCell<Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>>,
21    pub item: AtomicCell<Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>>,
22    pub ass_item: AtomicCell<
23        Option<fn(PySequence<'_>, isize, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
24    >,
25    pub contains:
26        AtomicCell<Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult<bool>>>,
27    pub inplace_concat:
28        AtomicCell<Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult>>,
29    pub inplace_repeat: AtomicCell<Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>>,
30}
31
32impl core::fmt::Debug for PySequenceSlots {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        f.write_str("PySequenceSlots")
35    }
36}
37
38impl PySequenceSlots {
39    pub fn has_item(&self) -> bool {
40        self.item.load().is_some()
41    }
42
43    /// Copy from static PySequenceMethods
44    pub fn copy_from(&self, methods: &PySequenceMethods) {
45        if let Some(f) = methods.length {
46            self.length.store(Some(f));
47        }
48        if let Some(f) = methods.concat {
49            self.concat.store(Some(f));
50        }
51        if let Some(f) = methods.repeat {
52            self.repeat.store(Some(f));
53        }
54        if let Some(f) = methods.item {
55            self.item.store(Some(f));
56        }
57        if let Some(f) = methods.ass_item {
58            self.ass_item.store(Some(f));
59        }
60        if let Some(f) = methods.contains {
61            self.contains.store(Some(f));
62        }
63        if let Some(f) = methods.inplace_concat {
64            self.inplace_concat.store(Some(f));
65        }
66        if let Some(f) = methods.inplace_repeat {
67            self.inplace_repeat.store(Some(f));
68        }
69    }
70}
71
72#[allow(clippy::type_complexity)]
73#[derive(Default)]
74pub struct PySequenceMethods {
75    pub length: Option<fn(PySequence<'_>, &VirtualMachine) -> PyResult<usize>>,
76    pub concat: Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult>,
77    pub repeat: Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>,
78    pub item: Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>,
79    pub ass_item:
80        Option<fn(PySequence<'_>, isize, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
81    pub contains: Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult<bool>>,
82    pub inplace_concat: Option<fn(PySequence<'_>, &PyObject, &VirtualMachine) -> PyResult>,
83    pub inplace_repeat: Option<fn(PySequence<'_>, isize, &VirtualMachine) -> PyResult>,
84}
85
86impl core::fmt::Debug for PySequenceMethods {
87    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88        f.write_str("PySequenceMethods")
89    }
90}
91
92impl PySequenceMethods {
93    pub const NOT_IMPLEMENTED: Self = Self {
94        length: None,
95        concat: None,
96        repeat: None,
97        item: None,
98        ass_item: None,
99        contains: None,
100        inplace_concat: None,
101        inplace_repeat: None,
102    };
103}
104
105impl PyObject {
106    #[inline]
107    pub fn sequence_unchecked(&self) -> PySequence<'_> {
108        PySequence { obj: self }
109    }
110
111    pub fn try_sequence(&self, vm: &VirtualMachine) -> PyResult<PySequence<'_>> {
112        let seq = self.sequence_unchecked();
113        if seq.check() {
114            Ok(seq)
115        } else {
116            Err(vm.new_type_error(format!("'{}' is not a sequence", self.class())))
117        }
118    }
119}
120
121#[derive(Copy, Clone)]
122pub struct PySequence<'a> {
123    pub obj: &'a PyObject,
124}
125
126unsafe impl Traverse for PySequence<'_> {
127    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
128        self.obj.traverse(tracer_fn)
129    }
130}
131
132impl PySequence<'_> {
133    #[inline]
134    pub fn slots(&self) -> &PySequenceSlots {
135        &self.obj.class().slots.as_sequence
136    }
137
138    pub fn check(&self) -> bool {
139        self.slots().has_item()
140    }
141
142    pub fn length_opt(self, vm: &VirtualMachine) -> Option<PyResult<usize>> {
143        self.slots().length.load().map(|f| f(self, vm))
144    }
145
146    pub fn length(self, vm: &VirtualMachine) -> PyResult<usize> {
147        self.length_opt(vm).ok_or_else(|| {
148            vm.new_type_error(format!(
149                "'{}' is not a sequence or has no len()",
150                self.obj.class()
151            ))
152        })?
153    }
154
155    pub fn concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult {
156        if let Some(f) = self.slots().concat.load() {
157            return f(self, other, vm);
158        }
159
160        // if both arguments appear to be sequences, try fallback to __add__
161        if self.check() && other.sequence_unchecked().check() {
162            let ret = vm.binary_op1(self.obj, other, PyNumberBinaryOp::Add)?;
163            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
164                return Ok(ret);
165            }
166        }
167
168        Err(vm.new_type_error(format!(
169            "'{}' object can't be concatenated",
170            self.obj.class()
171        )))
172    }
173
174    pub fn repeat(self, n: isize, vm: &VirtualMachine) -> PyResult {
175        if let Some(f) = self.slots().repeat.load() {
176            return f(self, n, vm);
177        }
178
179        // fallback to __mul__
180        if self.check() {
181            let ret = vm.binary_op1(self.obj, &n.to_pyobject(vm), PyNumberBinaryOp::Multiply)?;
182            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
183                return Ok(ret);
184            }
185        }
186
187        Err(vm.new_type_error(format!("'{}' object can't be repeated", self.obj.class())))
188    }
189
190    pub fn inplace_concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult {
191        if let Some(f) = self.slots().inplace_concat.load() {
192            return f(self, other, vm);
193        }
194        if let Some(f) = self.slots().concat.load() {
195            return f(self, other, vm);
196        }
197
198        // if both arguments appear to be sequences, try fallback to __iadd__
199        if self.check() && other.sequence_unchecked().check() {
200            let ret = vm._iadd(self.obj, other)?;
201            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
202                return Ok(ret);
203            }
204        }
205
206        Err(vm.new_type_error(format!(
207            "'{}' object can't be concatenated",
208            self.obj.class()
209        )))
210    }
211
212    pub fn inplace_repeat(self, n: isize, vm: &VirtualMachine) -> PyResult {
213        if let Some(f) = self.slots().inplace_repeat.load() {
214            return f(self, n, vm);
215        }
216        if let Some(f) = self.slots().repeat.load() {
217            return f(self, n, vm);
218        }
219
220        if self.check() {
221            let ret = vm._imul(self.obj, &n.to_pyobject(vm))?;
222            if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) {
223                return Ok(ret);
224            }
225        }
226
227        Err(vm.new_type_error(format!("'{}' object can't be repeated", self.obj.class())))
228    }
229
230    pub fn get_item(self, i: isize, vm: &VirtualMachine) -> PyResult {
231        if let Some(f) = self.slots().item.load() {
232            return f(self, i, vm);
233        }
234        Err(vm.new_type_error(format!(
235            "'{}' is not a sequence or does not support indexing",
236            self.obj.class()
237        )))
238    }
239
240    fn _ass_item(self, i: isize, value: Option<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
241        if let Some(f) = self.slots().ass_item.load() {
242            return f(self, i, value, vm);
243        }
244        Err(vm.new_type_error(format!(
245            "'{}' is not a sequence or doesn't support item {}",
246            self.obj.class(),
247            if value.is_some() {
248                "assignment"
249            } else {
250                "deletion"
251            }
252        )))
253    }
254
255    pub fn set_item(self, i: isize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
256        self._ass_item(i, Some(value), vm)
257    }
258
259    pub fn del_item(self, i: isize, vm: &VirtualMachine) -> PyResult<()> {
260        self._ass_item(i, None, vm)
261    }
262
263    pub fn get_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult {
264        if let Ok(mapping) = self.obj.try_mapping(vm) {
265            let slice = PySlice {
266                start: Some(start.to_pyobject(vm)),
267                stop: stop.to_pyobject(vm),
268                step: None,
269            };
270            mapping.subscript(&slice.into_pyobject(vm), vm)
271        } else {
272            Err(vm.new_type_error(format!("'{}' object is unsliceable", self.obj.class())))
273        }
274    }
275
276    fn _ass_slice(
277        &self,
278        start: isize,
279        stop: isize,
280        value: Option<PyObjectRef>,
281        vm: &VirtualMachine,
282    ) -> PyResult<()> {
283        let mapping = self.obj.mapping_unchecked();
284        if let Some(f) = mapping.slots().ass_subscript.load() {
285            let slice = PySlice {
286                start: Some(start.to_pyobject(vm)),
287                stop: stop.to_pyobject(vm),
288                step: None,
289            };
290            f(mapping, &slice.into_pyobject(vm), value, vm)
291        } else {
292            Err(vm.new_type_error(format!(
293                "'{}' object doesn't support slice {}",
294                self.obj.class(),
295                if value.is_some() {
296                    "assignment"
297                } else {
298                    "deletion"
299                }
300            )))
301        }
302    }
303
304    pub fn set_slice(
305        &self,
306        start: isize,
307        stop: isize,
308        value: PyObjectRef,
309        vm: &VirtualMachine,
310    ) -> PyResult<()> {
311        self._ass_slice(start, stop, Some(value), vm)
312    }
313
314    pub fn del_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult<()> {
315        self._ass_slice(start, stop, None, vm)
316    }
317
318    pub fn tuple(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
319        if let Some(tuple) = self.obj.downcast_ref_if_exact::<PyTuple>(vm) {
320            Ok(tuple.to_owned())
321        } else if let Some(list) = self.obj.downcast_ref_if_exact::<PyList>(vm) {
322            Ok(vm.ctx.new_tuple(list.borrow_vec().to_vec()))
323        } else {
324            let iter = self.obj.to_owned().get_iter(vm)?;
325            let iter = iter.iter(vm)?;
326            Ok(vm.ctx.new_tuple(iter.try_collect()?))
327        }
328    }
329
330    pub fn list(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
331        let list = vm.ctx.new_list(self.obj.try_to_value(vm)?);
332        Ok(list)
333    }
334
335    pub fn count(&self, target: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
336        let mut n = 0;
337
338        let iter = self.obj.to_owned().get_iter(vm)?;
339        let iter = iter.iter::<PyObjectRef>(vm)?;
340
341        for elem in iter {
342            let elem = elem?;
343            if vm.bool_eq(&elem, target)? {
344                if n == isize::MAX as usize {
345                    return Err(vm.new_overflow_error("index exceeds C integer size"));
346                }
347                n += 1;
348            }
349        }
350
351        Ok(n)
352    }
353
354    pub fn index(&self, target: &PyObject, vm: &VirtualMachine) -> PyResult<usize> {
355        let mut index: isize = -1;
356
357        let iter = self.obj.to_owned().get_iter(vm)?;
358        let iter = iter.iter::<PyObjectRef>(vm)?;
359
360        for elem in iter {
361            if index == isize::MAX {
362                return Err(vm.new_overflow_error("index exceeds C integer size"));
363            }
364            index += 1;
365
366            let elem = elem?;
367            if vm.bool_eq(&elem, target)? {
368                return Ok(index as usize);
369            }
370        }
371
372        Err(vm.new_value_error("sequence.index(x): x not in sequence"))
373    }
374
375    pub fn extract<F, R>(&self, mut f: F, vm: &VirtualMachine) -> PyResult<Vec<R>>
376    where
377        F: FnMut(&PyObject) -> PyResult<R>,
378    {
379        if let Some(tuple) = self.obj.downcast_ref_if_exact::<PyTuple>(vm) {
380            tuple.iter().map(|x| f(x.as_ref())).collect()
381        } else if let Some(list) = self.obj.downcast_ref_if_exact::<PyList>(vm) {
382            list.borrow_vec().iter().map(|x| f(x.as_ref())).collect()
383        } else {
384            let iter = self.obj.to_owned().get_iter(vm)?;
385            let iter = iter.iter::<PyObjectRef>(vm)?;
386            let len = self.length(vm).unwrap_or(0);
387            let mut v = Vec::with_capacity(len);
388            for x in iter {
389                v.push(f(x?.as_ref())?);
390            }
391            v.shrink_to_fit();
392            Ok(v)
393        }
394    }
395
396    pub fn contains(self, target: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
397        if let Some(f) = self.slots().contains.load() {
398            return f(self, target, vm);
399        }
400
401        let iter = self.obj.to_owned().get_iter(vm)?;
402        let iter = iter.iter::<PyObjectRef>(vm)?;
403
404        for elem in iter {
405            let elem = elem?;
406            if vm.bool_eq(&elem, target)? {
407                return Ok(true);
408            }
409        }
410        Ok(false)
411    }
412}