rustpython_vm/protocol/
iter.rs

1use crate::{
2    builtins::iter::PySequenceIterator,
3    convert::{ToPyObject, ToPyResult},
4    object::{Traverse, TraverseFn},
5    AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine,
6};
7use std::borrow::Borrow;
8use std::ops::Deref;
9
10/// Iterator Protocol
11// https://docs.python.org/3/c-api/iter.html
12#[derive(Debug, Clone)]
13#[repr(transparent)]
14pub struct PyIter<O = PyObjectRef>(O)
15where
16    O: Borrow<PyObject>;
17
18unsafe impl<O: Borrow<PyObject>> Traverse for PyIter<O> {
19    fn traverse(&self, tracer_fn: &mut TraverseFn) {
20        self.0.borrow().traverse(tracer_fn);
21    }
22}
23
24impl PyIter<PyObjectRef> {
25    pub fn check(obj: &PyObject) -> bool {
26        obj.class()
27            .mro_find_map(|x| x.slots.iternext.load())
28            .is_some()
29    }
30}
31
32impl<O> PyIter<O>
33where
34    O: Borrow<PyObject>,
35{
36    pub fn new(obj: O) -> Self {
37        Self(obj)
38    }
39    pub fn next(&self, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
40        let iternext = {
41            self.0
42                .borrow()
43                .class()
44                .mro_find_map(|x| x.slots.iternext.load())
45                .ok_or_else(|| {
46                    vm.new_type_error(format!(
47                        "'{}' object is not an iterator",
48                        self.0.borrow().class().name()
49                    ))
50                })?
51        };
52        iternext(self.0.borrow(), vm)
53    }
54
55    pub fn iter<'a, 'b, U>(
56        &'b self,
57        vm: &'a VirtualMachine,
58    ) -> PyResult<PyIterIter<'a, U, &'b PyObject>> {
59        let length_hint = vm.length_hint_opt(self.as_ref().to_owned())?;
60        Ok(PyIterIter::new(vm, self.0.borrow(), length_hint))
61    }
62
63    pub fn iter_without_hint<'a, 'b, U>(
64        &'b self,
65        vm: &'a VirtualMachine,
66    ) -> PyResult<PyIterIter<'a, U, &'b PyObject>> {
67        Ok(PyIterIter::new(vm, self.0.borrow(), None))
68    }
69}
70
71impl PyIter<PyObjectRef> {
72    /// Returns an iterator over this sequence of objects.
73    pub fn into_iter<U>(self, vm: &VirtualMachine) -> PyResult<PyIterIter<U, PyObjectRef>> {
74        let length_hint = vm.length_hint_opt(self.as_object().to_owned())?;
75        Ok(PyIterIter::new(vm, self.0, length_hint))
76    }
77}
78
79impl From<PyIter<PyObjectRef>> for PyObjectRef {
80    fn from(value: PyIter<PyObjectRef>) -> PyObjectRef {
81        value.0
82    }
83}
84
85impl<O> Borrow<PyObject> for PyIter<O>
86where
87    O: Borrow<PyObject>,
88{
89    #[inline(always)]
90    fn borrow(&self) -> &PyObject {
91        self.0.borrow()
92    }
93}
94
95impl<O> AsRef<PyObject> for PyIter<O>
96where
97    O: Borrow<PyObject>,
98{
99    #[inline(always)]
100    fn as_ref(&self) -> &PyObject {
101        self.0.borrow()
102    }
103}
104
105impl<O> Deref for PyIter<O>
106where
107    O: Borrow<PyObject>,
108{
109    type Target = PyObject;
110    #[inline(always)]
111    fn deref(&self) -> &Self::Target {
112        self.0.borrow()
113    }
114}
115
116impl ToPyObject for PyIter<PyObjectRef> {
117    #[inline(always)]
118    fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
119        self.into()
120    }
121}
122
123impl TryFromObject for PyIter<PyObjectRef> {
124    // This helper function is called at multiple places. First, it is called
125    // in the vm when a for loop is entered. Next, it is used when the builtin
126    // function 'iter' is called.
127    fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult<Self> {
128        let getiter = {
129            let cls = iter_target.class();
130            cls.mro_find_map(|x| x.slots.iter.load())
131        };
132        if let Some(getiter) = getiter {
133            let iter = getiter(iter_target, vm)?;
134            if PyIter::check(&iter) {
135                Ok(Self(iter))
136            } else {
137                Err(vm.new_type_error(format!(
138                    "iter() returned non-iterator of type '{}'",
139                    iter.class().name()
140                )))
141            }
142        } else if let Ok(seq_iter) = PySequenceIterator::new(iter_target.clone(), vm) {
143            Ok(Self(seq_iter.into_pyobject(vm)))
144        } else {
145            Err(vm.new_type_error(format!(
146                "'{}' object is not iterable",
147                iter_target.class().name()
148            )))
149        }
150    }
151}
152
153#[derive(result_like::ResultLike)]
154pub enum PyIterReturn<T = PyObjectRef> {
155    Return(T),
156    StopIteration(Option<PyObjectRef>),
157}
158
159unsafe impl<T: Traverse> Traverse for PyIterReturn<T> {
160    fn traverse(&self, tracer_fn: &mut TraverseFn) {
161        match self {
162            PyIterReturn::Return(r) => r.traverse(tracer_fn),
163            PyIterReturn::StopIteration(Some(obj)) => obj.traverse(tracer_fn),
164            _ => (),
165        }
166    }
167}
168
169impl PyIterReturn {
170    pub fn from_pyresult(result: PyResult, vm: &VirtualMachine) -> PyResult<Self> {
171        match result {
172            Ok(obj) => Ok(Self::Return(obj)),
173            Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
174                let args = err.get_arg(0);
175                Ok(Self::StopIteration(args))
176            }
177            Err(err) => Err(err),
178        }
179    }
180
181    pub fn from_getitem_result(result: PyResult, vm: &VirtualMachine) -> PyResult<Self> {
182        match result {
183            Ok(obj) => Ok(Self::Return(obj)),
184            Err(err) if err.fast_isinstance(vm.ctx.exceptions.index_error) => {
185                Ok(Self::StopIteration(None))
186            }
187            Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
188                let args = err.get_arg(0);
189                Ok(Self::StopIteration(args))
190            }
191            Err(err) => Err(err),
192        }
193    }
194
195    pub fn into_async_pyresult(self, vm: &VirtualMachine) -> PyResult {
196        match self {
197            Self::Return(obj) => Ok(obj),
198            Self::StopIteration(v) => Err({
199                let args = if let Some(v) = v { vec![v] } else { Vec::new() };
200                vm.new_exception(vm.ctx.exceptions.stop_async_iteration.to_owned(), args)
201            }),
202        }
203    }
204}
205
206impl ToPyResult for PyIterReturn {
207    fn to_pyresult(self, vm: &VirtualMachine) -> PyResult {
208        match self {
209            Self::Return(obj) => Ok(obj),
210            Self::StopIteration(v) => Err(vm.new_stop_iteration(v)),
211        }
212    }
213}
214
215impl ToPyResult for PyResult<PyIterReturn> {
216    fn to_pyresult(self, vm: &VirtualMachine) -> PyResult {
217        self?.to_pyresult(vm)
218    }
219}
220
221// Typical rust `Iter` object for `PyIter`
222pub struct PyIterIter<'a, T, O = PyObjectRef>
223where
224    O: Borrow<PyObject>,
225{
226    vm: &'a VirtualMachine,
227    obj: O, // creating PyIter<O> is zero-cost
228    length_hint: Option<usize>,
229    _phantom: std::marker::PhantomData<T>,
230}
231
232unsafe impl<'a, T, O> Traverse for PyIterIter<'a, T, O>
233where
234    O: Traverse + Borrow<PyObject>,
235{
236    fn traverse(&self, tracer_fn: &mut TraverseFn) {
237        self.obj.traverse(tracer_fn)
238    }
239}
240
241impl<'a, T, O> PyIterIter<'a, T, O>
242where
243    O: Borrow<PyObject>,
244{
245    pub fn new(vm: &'a VirtualMachine, obj: O, length_hint: Option<usize>) -> Self {
246        Self {
247            vm,
248            obj,
249            length_hint,
250            _phantom: std::marker::PhantomData,
251        }
252    }
253}
254
255impl<'a, T, O> Iterator for PyIterIter<'a, T, O>
256where
257    T: TryFromObject,
258    O: Borrow<PyObject>,
259{
260    type Item = PyResult<T>;
261
262    fn next(&mut self) -> Option<Self::Item> {
263        let imp = |next: PyResult<PyIterReturn>| -> PyResult<Option<T>> {
264            let Some(obj) = next?.into_result().ok() else {
265                return Ok(None);
266            };
267            Ok(Some(T::try_from_object(self.vm, obj)?))
268        };
269        let next = PyIter::new(self.obj.borrow()).next(self.vm);
270        imp(next).transpose()
271    }
272
273    #[inline]
274    fn size_hint(&self) -> (usize, Option<usize>) {
275        (self.length_hint.unwrap_or(0), self.length_hint)
276    }
277}