Skip to main content

rustpython_vm/builtins/
iter.rs

1/*
2 * iterator types
3 */
4
5use super::{PyInt, PyTupleRef, PyType};
6use crate::{
7    Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
8    class::PyClassImpl,
9    function::ArgCallable,
10    object::{Traverse, TraverseFn},
11    protocol::PyIterReturn,
12    types::{IterNext, Iterable, SelfIter},
13};
14use rustpython_common::lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard};
15
16/// Marks status of iterator.
17#[derive(Debug, Clone)]
18pub enum IterStatus<T> {
19    /// Iterator hasn't raised StopIteration.
20    Active(T),
21    /// Iterator has raised StopIteration.
22    Exhausted,
23}
24
25unsafe impl<T: Traverse> Traverse for IterStatus<T> {
26    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
27        match self {
28            Self::Active(r) => r.traverse(tracer_fn),
29            Self::Exhausted => (),
30        }
31    }
32}
33
34#[derive(Debug)]
35pub struct PositionIterInternal<T> {
36    pub status: IterStatus<T>,
37    pub position: usize,
38}
39
40unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
41    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
42        self.status.traverse(tracer_fn)
43    }
44}
45
46impl<T> PositionIterInternal<T> {
47    pub const fn new(obj: T, position: usize) -> Self {
48        Self {
49            status: IterStatus::Active(obj),
50            position,
51        }
52    }
53
54    pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
55    where
56        F: FnOnce(&T, usize) -> usize,
57    {
58        if let IterStatus::Active(obj) = &self.status {
59            if let Some(i) = state.downcast_ref::<PyInt>() {
60                let i = i.try_to_primitive(vm).unwrap_or(0);
61                self.position = f(obj, i);
62                Ok(())
63            } else {
64                Err(vm.new_type_error("an integer is required."))
65            }
66        } else {
67            Ok(())
68        }
69    }
70
71    /// Build a pickle-compatible reduce tuple.
72    ///
73    /// `func` must be resolved **before** acquiring any lock that guards this
74    /// `PositionIterInternal`, so that the builtins lookup cannot trigger
75    /// reentrant iterator access and deadlock.
76    pub fn reduce<F, E>(
77        &self,
78        func: PyObjectRef,
79        active: F,
80        empty: E,
81        vm: &VirtualMachine,
82    ) -> PyTupleRef
83    where
84        F: FnOnce(&T) -> PyObjectRef,
85        E: FnOnce(&VirtualMachine) -> PyObjectRef,
86    {
87        if let IterStatus::Active(obj) = &self.status {
88            vm.new_tuple((func, (active(obj),), self.position))
89        } else {
90            vm.new_tuple((func, (empty(vm),)))
91        }
92    }
93
94    fn _next<F, OP>(&mut self, f: F, op: OP) -> PyResult<PyIterReturn>
95    where
96        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
97        OP: FnOnce(&mut Self),
98    {
99        if let IterStatus::Active(obj) = &self.status {
100            let ret = f(obj, self.position);
101            if let Ok(PyIterReturn::Return(_)) = ret {
102                op(self);
103            } else {
104                self.status = IterStatus::Exhausted;
105            }
106            ret
107        } else {
108            Ok(PyIterReturn::StopIteration(None))
109        }
110    }
111
112    pub fn next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
113    where
114        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
115    {
116        self._next(f, |zelf| zelf.position += 1)
117    }
118
119    pub fn rev_next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
120    where
121        F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
122    {
123        self._next(f, |zelf| {
124            if zelf.position == 0 {
125                zelf.status = IterStatus::Exhausted;
126            } else {
127                zelf.position -= 1;
128            }
129        })
130    }
131
132    pub fn length_hint<F>(&self, f: F) -> usize
133    where
134        F: FnOnce(&T) -> usize,
135    {
136        if let IterStatus::Active(obj) = &self.status {
137            f(obj).saturating_sub(self.position)
138        } else {
139            0
140        }
141    }
142
143    pub fn rev_length_hint<F>(&self, f: F) -> usize
144    where
145        F: FnOnce(&T) -> usize,
146    {
147        if let IterStatus::Active(obj) = &self.status
148            && self.position <= f(obj)
149        {
150            return self.position + 1;
151        }
152        0
153    }
154}
155
156pub fn builtins_iter(vm: &VirtualMachine) -> PyObjectRef {
157    vm.builtins.get_attr("iter", vm).unwrap()
158}
159
160pub fn builtins_reversed(vm: &VirtualMachine) -> PyObjectRef {
161    vm.builtins.get_attr("reversed", vm).unwrap()
162}
163
164#[pyclass(module = false, name = "iterator", traverse)]
165#[derive(Debug)]
166pub struct PySequenceIterator {
167    internal: PyMutex<PositionIterInternal<PyObjectRef>>,
168}
169
170impl PyPayload for PySequenceIterator {
171    #[inline]
172    fn class(ctx: &Context) -> &'static Py<PyType> {
173        ctx.types.iter_type
174    }
175}
176
177#[pyclass(with(IterNext, Iterable))]
178impl PySequenceIterator {
179    pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
180        let _seq = obj.try_sequence(vm)?;
181        Ok(Self {
182            internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
183        })
184    }
185
186    #[pymethod]
187    fn __length_hint__(&self, vm: &VirtualMachine) -> PyObjectRef {
188        let internal = self.internal.lock();
189        if let IterStatus::Active(obj) = &internal.status {
190            let seq = obj.sequence_unchecked();
191            seq.length(vm)
192                .map(|x| PyInt::from(x).into_pyobject(vm))
193                .unwrap_or_else(|_| vm.ctx.not_implemented())
194        } else {
195            PyInt::from(0).into_pyobject(vm)
196        }
197    }
198
199    #[pymethod]
200    fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
201        let func = builtins_iter(vm);
202        self.internal.lock().reduce(
203            func,
204            |x| x.clone(),
205            |vm| vm.ctx.empty_tuple.clone().into(),
206            vm,
207        )
208    }
209
210    #[pymethod]
211    fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
212        self.internal.lock().set_state(state, |_, pos| pos, vm)
213    }
214}
215
216impl SelfIter for PySequenceIterator {}
217impl IterNext for PySequenceIterator {
218    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
219        zelf.internal.lock().next(|obj, pos| {
220            let seq = obj.sequence_unchecked();
221            PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm)
222        })
223    }
224}
225
226#[pyclass(module = false, name = "callable_iterator", traverse)]
227#[derive(Debug)]
228pub struct PyCallableIterator {
229    sentinel: PyObjectRef,
230    status: PyRwLock<IterStatus<ArgCallable>>,
231}
232
233impl PyPayload for PyCallableIterator {
234    #[inline]
235    fn class(ctx: &Context) -> &'static Py<PyType> {
236        ctx.types.callable_iterator
237    }
238}
239
240#[pyclass(with(IterNext, Iterable))]
241impl PyCallableIterator {
242    pub const fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self {
243        Self {
244            sentinel,
245            status: PyRwLock::new(IterStatus::Active(callable)),
246        }
247    }
248
249    #[pymethod]
250    fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
251        let func = builtins_iter(vm);
252        let status = self.status.read();
253        if let IterStatus::Active(callable) = &*status {
254            let callable_obj: PyObjectRef = callable.clone().into();
255            vm.new_tuple((func, (callable_obj, self.sentinel.clone())))
256        } else {
257            vm.new_tuple((func, (vm.ctx.empty_tuple.clone(),)))
258        }
259    }
260}
261
262impl SelfIter for PyCallableIterator {}
263impl IterNext for PyCallableIterator {
264    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
265        // Clone the callable and release the lock before invoking,
266        // so that reentrant next() calls don't deadlock.
267        let callable = {
268            let status = zelf.status.read();
269            match &*status {
270                IterStatus::Active(callable) => callable.clone(),
271                IterStatus::Exhausted => return Ok(PyIterReturn::StopIteration(None)),
272            }
273        };
274
275        let ret = callable.invoke((), vm)?;
276
277        // Re-check: a reentrant call may have exhausted the iterator.
278        let status = zelf.status.upgradable_read();
279        if !matches!(&*status, IterStatus::Active(_)) {
280            return Ok(PyIterReturn::StopIteration(None));
281        }
282
283        if vm.bool_eq(&ret, &zelf.sentinel)? {
284            *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted;
285            Ok(PyIterReturn::StopIteration(None))
286        } else {
287            Ok(PyIterReturn::Return(ret))
288        }
289    }
290}
291
292pub fn init(context: &'static Context) {
293    PySequenceIterator::extend_class(context, context.types.iter_type);
294    PyCallableIterator::extend_class(context, context.types.callable_iterator);
295}