Skip to main content

rustpython_vm/builtins/
filter.rs

1use super::{PyType, PyTypeRef};
2use crate::{
3    Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
4    class::PyClassImpl,
5    protocol::{PyIter, PyIterReturn},
6    raise_if_stop,
7    types::{Constructor, IterNext, Iterable, SelfIter},
8};
9
10#[pyclass(module = false, name = "filter", traverse)]
11#[derive(Debug)]
12pub struct PyFilter {
13    predicate: PyObjectRef,
14    iterator: PyIter,
15}
16
17impl PyPayload for PyFilter {
18    #[inline]
19    fn class(ctx: &Context) -> &'static Py<PyType> {
20        ctx.types.filter_type
21    }
22}
23
24impl Constructor for PyFilter {
25    type Args = (PyObjectRef, PyIter);
26
27    fn py_new(
28        _cls: &Py<PyType>,
29        (function, iterator): Self::Args,
30        _vm: &VirtualMachine,
31    ) -> PyResult<Self> {
32        Ok(Self {
33            predicate: function,
34            iterator,
35        })
36    }
37}
38
39#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
40impl PyFilter {
41    #[pymethod]
42    fn __reduce__(&self, vm: &VirtualMachine) -> (PyTypeRef, (PyObjectRef, PyIter)) {
43        (
44            vm.ctx.types.filter_type.to_owned(),
45            (self.predicate.clone(), self.iterator.clone()),
46        )
47    }
48}
49
50impl SelfIter for PyFilter {}
51
52impl IterNext for PyFilter {
53    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
54        let predicate = &zelf.predicate;
55        loop {
56            let next_obj = raise_if_stop!(zelf.iterator.next(vm)?);
57            let predicate_value = if vm.is_none(predicate) {
58                next_obj.clone()
59            } else {
60                // the predicate itself can raise StopIteration which does stop the filter iteration
61                raise_if_stop!(PyIterReturn::from_pyresult(
62                    predicate.call((next_obj.clone(),), vm),
63                    vm
64                )?)
65            };
66
67            if predicate_value.try_to_bool(vm)? {
68                return Ok(PyIterReturn::Return(next_obj));
69            }
70        }
71    }
72}
73
74pub fn init(context: &'static Context) {
75    PyFilter::extend_class(context, context.types.filter_type);
76}