Skip to main content

rustpython_vm/protocol/
iter.rs

1use crate::{
2    AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine,
3    builtins::iter::PySequenceIterator,
4    convert::{ToPyObject, ToPyResult},
5    object::{Traverse, TraverseFn},
6};
7use core::borrow::Borrow;
8use core::ops::Deref;
9
10/// Iterator Protocol
11// https://docs.python.org/3/c-api/iter.html
12#[derive(Debug, Clone)]
13#[repr(transparent)]
14pub struct PyIter<O = PyObjectRef>(O)
15where
16    O: Borrow<PyObject>;
17
18unsafe impl<O: Borrow<PyObject>> Traverse for PyIter<O> {
19    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
20        self.0.borrow().traverse(tracer_fn);
21    }
22}
23
24impl PyIter<PyObjectRef> {
25    pub fn check(obj: &PyObject) -> bool {
26        obj.class().slots.iternext.load().is_some()
27    }
28}
29
30impl<O> PyIter<O>
31where
32    O: Borrow<PyObject>,
33{
34    pub const fn new(obj: O) -> Self {
35        Self(obj)
36    }
37    pub fn next(&self, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
38        let iternext = self
39            .0
40            .borrow()
41            .class()
42            .slots
43            .iternext
44            .load()
45            .ok_or_else(|| {
46                vm.new_type_error(format!(
47                    "'{}' object is not an iterator",
48                    self.0.borrow().class().name()
49                ))
50            })?;
51        iternext(self.0.borrow(), vm)
52    }
53
54    pub fn iter<'a, 'b, U>(
55        &'b self,
56        vm: &'a VirtualMachine,
57    ) -> PyResult<PyIterIter<'a, U, &'b PyObject>> {
58        let length_hint = vm.length_hint_opt(self.as_ref().to_owned())?;
59        Ok(PyIterIter::new(vm, self.0.borrow(), length_hint))
60    }
61
62    pub fn iter_without_hint<'a, 'b, U>(
63        &'b self,
64        vm: &'a VirtualMachine,
65    ) -> PyResult<PyIterIter<'a, U, &'b PyObject>> {
66        Ok(PyIterIter::new(vm, self.0.borrow(), None))
67    }
68}
69
70impl PyIter<PyObjectRef> {
71    /// Returns an iterator over this sequence of objects.
72    pub fn into_iter<U>(self, vm: &VirtualMachine) -> PyResult<PyIterIter<'_, U, PyObjectRef>> {
73        let length_hint = vm.length_hint_opt(self.as_object().to_owned())?;
74        Ok(PyIterIter::new(vm, self.0, length_hint))
75    }
76}
77
78impl From<PyIter<Self>> for PyObjectRef {
79    fn from(value: PyIter<Self>) -> Self {
80        value.0
81    }
82}
83
84impl<O> Borrow<PyObject> for PyIter<O>
85where
86    O: Borrow<PyObject>,
87{
88    #[inline(always)]
89    fn borrow(&self) -> &PyObject {
90        self.0.borrow()
91    }
92}
93
94impl<O> AsRef<PyObject> for PyIter<O>
95where
96    O: Borrow<PyObject>,
97{
98    #[inline(always)]
99    fn as_ref(&self) -> &PyObject {
100        self.0.borrow()
101    }
102}
103
104impl<O> Deref for PyIter<O>
105where
106    O: Borrow<PyObject>,
107{
108    type Target = PyObject;
109
110    #[inline(always)]
111    fn deref(&self) -> &Self::Target {
112        self.0.borrow()
113    }
114}
115
116impl ToPyObject for PyIter<PyObjectRef> {
117    #[inline(always)]
118    fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
119        self.into()
120    }
121}
122
123impl TryFromObject for PyIter<PyObjectRef> {
124    // This helper function is called at multiple places. First, it is called
125    // in the vm when a for loop is entered. Next, it is used when the builtin
126    // function 'iter' is called.
127    fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult<Self> {
128        let get_iter = iter_target.class().slots.iter.load();
129        if let Some(get_iter) = get_iter {
130            let iter = get_iter(iter_target, vm)?;
131            if Self::check(&iter) {
132                Ok(Self(iter))
133            } else {
134                Err(vm.new_type_error(format!(
135                    "iter() returned non-iterator of type '{}'",
136                    iter.class().name()
137                )))
138            }
139        } else if let Ok(seq_iter) = PySequenceIterator::new(iter_target.clone(), vm) {
140            Ok(Self(seq_iter.into_pyobject(vm)))
141        } else {
142            Err(vm.new_type_error(format!(
143                "'{}' object is not iterable",
144                iter_target.class().name()
145            )))
146        }
147    }
148}
149
150#[derive(result_like::ResultLike)]
151pub enum PyIterReturn<T = PyObjectRef> {
152    Return(T),
153    StopIteration(Option<PyObjectRef>),
154}
155
156unsafe impl<T: Traverse> Traverse for PyIterReturn<T> {
157    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
158        match self {
159            Self::Return(r) => r.traverse(tracer_fn),
160            Self::StopIteration(Some(obj)) => obj.traverse(tracer_fn),
161            _ => (),
162        }
163    }
164}
165
166impl PyIterReturn {
167    pub fn from_pyresult(result: PyResult, vm: &VirtualMachine) -> PyResult<Self> {
168        match result {
169            Ok(obj) => Ok(Self::Return(obj)),
170            Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
171                let args = err.get_arg(0);
172                Ok(Self::StopIteration(args))
173            }
174            Err(err) => Err(err),
175        }
176    }
177
178    pub fn from_getitem_result(result: PyResult, vm: &VirtualMachine) -> PyResult<Self> {
179        match result {
180            Ok(obj) => Ok(Self::Return(obj)),
181            Err(err) if err.fast_isinstance(vm.ctx.exceptions.index_error) => {
182                Ok(Self::StopIteration(None))
183            }
184            Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
185                let args = err.get_arg(0);
186                Ok(Self::StopIteration(args))
187            }
188            Err(err) => Err(err),
189        }
190    }
191
192    pub fn into_async_pyresult(self, vm: &VirtualMachine) -> PyResult {
193        match self {
194            Self::Return(obj) => Ok(obj),
195            Self::StopIteration(v) => Err({
196                let args = if let Some(v) = v { vec![v] } else { Vec::new() };
197                vm.new_exception(vm.ctx.exceptions.stop_async_iteration.to_owned(), args)
198            }),
199        }
200    }
201}
202
203impl ToPyResult for PyIterReturn {
204    fn to_pyresult(self, vm: &VirtualMachine) -> PyResult {
205        match self {
206            Self::Return(obj) => Ok(obj),
207            Self::StopIteration(v) => Err(vm.new_stop_iteration(v)),
208        }
209    }
210}
211
212impl ToPyResult for PyResult<PyIterReturn> {
213    fn to_pyresult(self, vm: &VirtualMachine) -> PyResult {
214        self?.to_pyresult(vm)
215    }
216}
217
218// Typical rust `Iter` object for `PyIter`
219pub struct PyIterIter<'a, T, O = PyObjectRef>
220where
221    O: Borrow<PyObject>,
222{
223    vm: &'a VirtualMachine,
224    obj: O, // creating PyIter<O> is zero-cost
225    length_hint: Option<usize>,
226    _phantom: core::marker::PhantomData<T>,
227}
228
229unsafe impl<T, O> Traverse for PyIterIter<'_, T, O>
230where
231    O: Traverse + Borrow<PyObject>,
232{
233    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
234        self.obj.traverse(tracer_fn)
235    }
236}
237
238impl<'a, T, O> PyIterIter<'a, T, O>
239where
240    O: Borrow<PyObject>,
241{
242    pub const fn new(vm: &'a VirtualMachine, obj: O, length_hint: Option<usize>) -> Self {
243        Self {
244            vm,
245            obj,
246            length_hint,
247            _phantom: core::marker::PhantomData,
248        }
249    }
250}
251
252impl<T, O> Iterator for PyIterIter<'_, T, O>
253where
254    T: TryFromObject,
255    O: Borrow<PyObject>,
256{
257    type Item = PyResult<T>;
258
259    fn next(&mut self) -> Option<Self::Item> {
260        let imp = |next: PyResult<PyIterReturn>| -> PyResult<Option<T>> {
261            let Some(obj) = next?.into_result().ok() else {
262                return Ok(None);
263            };
264            Ok(Some(T::try_from_object(self.vm, obj)?))
265        };
266        let next = PyIter::new(self.obj.borrow()).next(self.vm);
267        imp(next).transpose()
268    }
269
270    #[inline]
271    fn size_hint(&self) -> (usize, Option<usize>) {
272        (self.length_hint.unwrap_or(0), self.length_hint)
273    }
274}
275
276/// Macro to handle `PyIterReturn` values in iterator implementations.
277///
278/// Extracts the object from `PyIterReturn::Return(obj)` or performs early return
279/// for `PyIterReturn::StopIteration(v)`. This macro should only be used within
280/// functions that return `PyResult<PyIterReturn>`.
281#[macro_export]
282macro_rules! raise_if_stop {
283    ($input:expr) => {
284        match $input {
285            $crate::protocol::PyIterReturn::Return(obj) => obj,
286            $crate::protocol::PyIterReturn::StopIteration(v) => {
287                return Ok($crate::protocol::PyIterReturn::StopIteration(v))
288            }
289        }
290    };
291}