1use super::{PyInt, PyTupleRef, PyType};
6use crate::{
7 Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
8 class::PyClassImpl,
9 function::ArgCallable,
10 object::{Traverse, TraverseFn},
11 protocol::PyIterReturn,
12 types::{IterNext, Iterable, SelfIter},
13};
14use rustpython_common::lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard};
15
16#[derive(Debug, Clone)]
18pub enum IterStatus<T> {
19 Active(T),
21 Exhausted,
23}
24
25unsafe impl<T: Traverse> Traverse for IterStatus<T> {
26 fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
27 match self {
28 Self::Active(r) => r.traverse(tracer_fn),
29 Self::Exhausted => (),
30 }
31 }
32}
33
34#[derive(Debug)]
35pub struct PositionIterInternal<T> {
36 pub status: IterStatus<T>,
37 pub position: usize,
38}
39
40unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
41 fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
42 self.status.traverse(tracer_fn)
43 }
44}
45
46impl<T> PositionIterInternal<T> {
47 pub const fn new(obj: T, position: usize) -> Self {
48 Self {
49 status: IterStatus::Active(obj),
50 position,
51 }
52 }
53
54 pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
55 where
56 F: FnOnce(&T, usize) -> usize,
57 {
58 if let IterStatus::Active(obj) = &self.status {
59 if let Some(i) = state.downcast_ref::<PyInt>() {
60 let i = i.try_to_primitive(vm).unwrap_or(0);
61 self.position = f(obj, i);
62 Ok(())
63 } else {
64 Err(vm.new_type_error("an integer is required."))
65 }
66 } else {
67 Ok(())
68 }
69 }
70
71 pub fn reduce<F, E>(
77 &self,
78 func: PyObjectRef,
79 active: F,
80 empty: E,
81 vm: &VirtualMachine,
82 ) -> PyTupleRef
83 where
84 F: FnOnce(&T) -> PyObjectRef,
85 E: FnOnce(&VirtualMachine) -> PyObjectRef,
86 {
87 if let IterStatus::Active(obj) = &self.status {
88 vm.new_tuple((func, (active(obj),), self.position))
89 } else {
90 vm.new_tuple((func, (empty(vm),)))
91 }
92 }
93
94 fn _next<F, OP>(&mut self, f: F, op: OP) -> PyResult<PyIterReturn>
95 where
96 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
97 OP: FnOnce(&mut Self),
98 {
99 if let IterStatus::Active(obj) = &self.status {
100 let ret = f(obj, self.position);
101 if let Ok(PyIterReturn::Return(_)) = ret {
102 op(self);
103 } else {
104 self.status = IterStatus::Exhausted;
105 }
106 ret
107 } else {
108 Ok(PyIterReturn::StopIteration(None))
109 }
110 }
111
112 pub fn next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
113 where
114 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
115 {
116 self._next(f, |zelf| zelf.position += 1)
117 }
118
119 pub fn rev_next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
120 where
121 F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
122 {
123 self._next(f, |zelf| {
124 if zelf.position == 0 {
125 zelf.status = IterStatus::Exhausted;
126 } else {
127 zelf.position -= 1;
128 }
129 })
130 }
131
132 pub fn length_hint<F>(&self, f: F) -> usize
133 where
134 F: FnOnce(&T) -> usize,
135 {
136 if let IterStatus::Active(obj) = &self.status {
137 f(obj).saturating_sub(self.position)
138 } else {
139 0
140 }
141 }
142
143 pub fn rev_length_hint<F>(&self, f: F) -> usize
144 where
145 F: FnOnce(&T) -> usize,
146 {
147 if let IterStatus::Active(obj) = &self.status
148 && self.position <= f(obj)
149 {
150 return self.position + 1;
151 }
152 0
153 }
154}
155
156pub fn builtins_iter(vm: &VirtualMachine) -> PyObjectRef {
157 vm.builtins.get_attr("iter", vm).unwrap()
158}
159
160pub fn builtins_reversed(vm: &VirtualMachine) -> PyObjectRef {
161 vm.builtins.get_attr("reversed", vm).unwrap()
162}
163
164#[pyclass(module = false, name = "iterator", traverse)]
165#[derive(Debug)]
166pub struct PySequenceIterator {
167 internal: PyMutex<PositionIterInternal<PyObjectRef>>,
168}
169
170impl PyPayload for PySequenceIterator {
171 #[inline]
172 fn class(ctx: &Context) -> &'static Py<PyType> {
173 ctx.types.iter_type
174 }
175}
176
177#[pyclass(with(IterNext, Iterable))]
178impl PySequenceIterator {
179 pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
180 let _seq = obj.try_sequence(vm)?;
181 Ok(Self {
182 internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
183 })
184 }
185
186 #[pymethod]
187 fn __length_hint__(&self, vm: &VirtualMachine) -> PyObjectRef {
188 let internal = self.internal.lock();
189 if let IterStatus::Active(obj) = &internal.status {
190 let seq = obj.sequence_unchecked();
191 seq.length(vm)
192 .map(|x| PyInt::from(x).into_pyobject(vm))
193 .unwrap_or_else(|_| vm.ctx.not_implemented())
194 } else {
195 PyInt::from(0).into_pyobject(vm)
196 }
197 }
198
199 #[pymethod]
200 fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
201 let func = builtins_iter(vm);
202 self.internal.lock().reduce(
203 func,
204 |x| x.clone(),
205 |vm| vm.ctx.empty_tuple.clone().into(),
206 vm,
207 )
208 }
209
210 #[pymethod]
211 fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
212 self.internal.lock().set_state(state, |_, pos| pos, vm)
213 }
214}
215
216impl SelfIter for PySequenceIterator {}
217impl IterNext for PySequenceIterator {
218 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
219 zelf.internal.lock().next(|obj, pos| {
220 let seq = obj.sequence_unchecked();
221 PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm)
222 })
223 }
224}
225
226#[pyclass(module = false, name = "callable_iterator", traverse)]
227#[derive(Debug)]
228pub struct PyCallableIterator {
229 sentinel: PyObjectRef,
230 status: PyRwLock<IterStatus<ArgCallable>>,
231}
232
233impl PyPayload for PyCallableIterator {
234 #[inline]
235 fn class(ctx: &Context) -> &'static Py<PyType> {
236 ctx.types.callable_iterator
237 }
238}
239
240#[pyclass(with(IterNext, Iterable))]
241impl PyCallableIterator {
242 pub const fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self {
243 Self {
244 sentinel,
245 status: PyRwLock::new(IterStatus::Active(callable)),
246 }
247 }
248
249 #[pymethod]
250 fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
251 let func = builtins_iter(vm);
252 let status = self.status.read();
253 if let IterStatus::Active(callable) = &*status {
254 let callable_obj: PyObjectRef = callable.clone().into();
255 vm.new_tuple((func, (callable_obj, self.sentinel.clone())))
256 } else {
257 vm.new_tuple((func, (vm.ctx.empty_tuple.clone(),)))
258 }
259 }
260}
261
262impl SelfIter for PyCallableIterator {}
263impl IterNext for PyCallableIterator {
264 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
265 let callable = {
268 let status = zelf.status.read();
269 match &*status {
270 IterStatus::Active(callable) => callable.clone(),
271 IterStatus::Exhausted => return Ok(PyIterReturn::StopIteration(None)),
272 }
273 };
274
275 let ret = callable.invoke((), vm)?;
276
277 let status = zelf.status.upgradable_read();
279 if !matches!(&*status, IterStatus::Active(_)) {
280 return Ok(PyIterReturn::StopIteration(None));
281 }
282
283 if vm.bool_eq(&ret, &zelf.sentinel)? {
284 *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted;
285 Ok(PyIterReturn::StopIteration(None))
286 } else {
287 Ok(PyIterReturn::Return(ret))
288 }
289 }
290}
291
292pub fn init(context: &'static Context) {
293 PySequenceIterator::extend_class(context, context.types.iter_type);
294 PyCallableIterator::extend_class(context, context.types.callable_iterator);
295}