Skip to main content

rustpython_vm/builtins/
map.rs

1use super::PyType;
2use crate::{
3    AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
4    builtins::PyTupleRef,
5    class::PyClassImpl,
6    function::{ArgIntoBool, OptionalArg, PosArgs},
7    protocol::{PyIter, PyIterReturn},
8    types::{Constructor, IterNext, Iterable, SelfIter},
9};
10use rustpython_common::atomic::{self, PyAtomic, Radium};
11
12#[pyclass(module = false, name = "map", traverse)]
13#[derive(Debug)]
14pub struct PyMap {
15    mapper: PyObjectRef,
16    iterators: Vec<PyIter>,
17    #[pytraverse(skip)]
18    strict: PyAtomic<bool>,
19}
20
21impl PyPayload for PyMap {
22    #[inline]
23    fn class(ctx: &Context) -> &'static Py<PyType> {
24        ctx.types.map_type
25    }
26}
27
28#[derive(FromArgs)]
29pub struct PyMapNewArgs {
30    #[pyarg(named, optional)]
31    strict: OptionalArg<bool>,
32}
33
34impl Constructor for PyMap {
35    type Args = (PyObjectRef, PosArgs<PyIter>, PyMapNewArgs);
36
37    fn py_new(
38        _cls: &Py<PyType>,
39        (mapper, iterators, args): Self::Args,
40        _vm: &VirtualMachine,
41    ) -> PyResult<Self> {
42        let iterators = iterators.into_vec();
43        let strict = Radium::new(args.strict.unwrap_or(false));
44        Ok(Self {
45            mapper,
46            iterators,
47            strict,
48        })
49    }
50}
51
52#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
53impl PyMap {
54    #[pymethod]
55    fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult<usize> {
56        self.iterators.iter().try_fold(0, |prev, cur| {
57            let cur = cur.as_ref().to_owned().length_hint(0, vm)?;
58            let max = core::cmp::max(prev, cur);
59            Ok(max)
60        })
61    }
62
63    #[pymethod]
64    fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
65        let cls = zelf.class().to_owned();
66        let mut vec = vec![zelf.mapper.clone()];
67        vec.extend(zelf.iterators.iter().map(|o| o.clone().into()));
68        let tuple_args = vm.ctx.new_tuple(vec);
69        Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
70            vm.new_tuple((cls, tuple_args, true))
71        } else {
72            vm.new_tuple((cls, tuple_args))
73        })
74    }
75
76    #[pymethod]
77    fn __setstate__(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
78        if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) {
79            zelf.strict.store(obj.into(), atomic::Ordering::Release);
80        }
81        Ok(())
82    }
83}
84
85impl SelfIter for PyMap {}
86
87impl IterNext for PyMap {
88    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
89        let mut next_objs = Vec::new();
90        for (idx, iterator) in zelf.iterators.iter().enumerate() {
91            let item = match iterator.next(vm)? {
92                PyIterReturn::Return(obj) => obj,
93                PyIterReturn::StopIteration(v) => {
94                    if zelf.strict.load(atomic::Ordering::Acquire) {
95                        if idx > 0 {
96                            let plural = if idx == 1 { " " } else { "s 1-" };
97                            return Err(vm.new_value_error(format!(
98                                "map() argument {} is shorter than argument{}{}",
99                                idx + 1,
100                                plural,
101                                idx,
102                            )));
103                        }
104                        for (idx, iterator) in zelf.iterators[1..].iter().enumerate() {
105                            if let PyIterReturn::Return(_) = iterator.next(vm)? {
106                                let plural = if idx == 0 { " " } else { "s 1-" };
107                                return Err(vm.new_value_error(format!(
108                                    "map() argument {} is longer than argument{}{}",
109                                    idx + 2,
110                                    plural,
111                                    idx + 1,
112                                )));
113                            }
114                        }
115                    }
116                    return Ok(PyIterReturn::StopIteration(v));
117                }
118            };
119            next_objs.push(item);
120        }
121
122        // the mapper itself can raise StopIteration which does stop the map iteration
123        PyIterReturn::from_pyresult(zelf.mapper.call(next_objs, vm), vm)
124    }
125}
126
127pub fn init(context: &'static Context) {
128    PyMap::extend_class(context, context.types.map_type);
129}