Skip to main content

rustpython_vm/protocol/
mapping.rs

1use crate::{
2    AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine,
3    builtins::{
4        PyDict, PyStrInterned,
5        dict::{PyDictItems, PyDictKeys, PyDictValues},
6    },
7    convert::ToPyResult,
8    object::{Traverse, TraverseFn},
9};
10use crossbeam_utils::atomic::AtomicCell;
11
12// Mapping protocol
13// https://docs.python.org/3/c-api/mapping.html
14
15#[allow(clippy::type_complexity)]
16#[derive(Default)]
17pub struct PyMappingSlots {
18    pub length: AtomicCell<Option<fn(PyMapping<'_>, &VirtualMachine) -> PyResult<usize>>>,
19    pub subscript: AtomicCell<Option<fn(PyMapping<'_>, &PyObject, &VirtualMachine) -> PyResult>>,
20    pub ass_subscript: AtomicCell<
21        Option<fn(PyMapping<'_>, &PyObject, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
22    >,
23}
24
25impl core::fmt::Debug for PyMappingSlots {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        f.write_str("PyMappingSlots")
28    }
29}
30
31impl PyMappingSlots {
32    pub fn has_subscript(&self) -> bool {
33        self.subscript.load().is_some()
34    }
35
36    /// Copy from static PyMappingMethods
37    pub fn copy_from(&self, methods: &PyMappingMethods) {
38        if let Some(f) = methods.length {
39            self.length.store(Some(f));
40        }
41        if let Some(f) = methods.subscript {
42            self.subscript.store(Some(f));
43        }
44        if let Some(f) = methods.ass_subscript {
45            self.ass_subscript.store(Some(f));
46        }
47    }
48}
49
50#[allow(clippy::type_complexity)]
51#[derive(Default)]
52pub struct PyMappingMethods {
53    pub length: Option<fn(PyMapping<'_>, &VirtualMachine) -> PyResult<usize>>,
54    pub subscript: Option<fn(PyMapping<'_>, &PyObject, &VirtualMachine) -> PyResult>,
55    pub ass_subscript:
56        Option<fn(PyMapping<'_>, &PyObject, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
57}
58
59impl core::fmt::Debug for PyMappingMethods {
60    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61        f.write_str("PyMappingMethods")
62    }
63}
64
65impl PyMappingMethods {
66    pub const NOT_IMPLEMENTED: Self = Self {
67        length: None,
68        subscript: None,
69        ass_subscript: None,
70    };
71}
72
73impl PyObject {
74    pub fn mapping_unchecked(&self) -> PyMapping<'_> {
75        PyMapping { obj: self }
76    }
77
78    pub fn try_mapping(&self, vm: &VirtualMachine) -> PyResult<PyMapping<'_>> {
79        let mapping = self.mapping_unchecked();
80        if mapping.check() {
81            Ok(mapping)
82        } else {
83            Err(vm.new_type_error(format!("{} is not a mapping object", self.class())))
84        }
85    }
86}
87
88#[derive(Copy, Clone)]
89pub struct PyMapping<'a> {
90    pub obj: &'a PyObject,
91}
92
93unsafe impl Traverse for PyMapping<'_> {
94    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
95        self.obj.traverse(tracer_fn)
96    }
97}
98
99impl AsRef<PyObject> for PyMapping<'_> {
100    #[inline(always)]
101    fn as_ref(&self) -> &PyObject {
102        self.obj
103    }
104}
105
106impl PyMapping<'_> {
107    #[inline]
108    pub fn slots(&self) -> &PyMappingSlots {
109        &self.obj.class().slots.as_mapping
110    }
111
112    #[inline]
113    pub fn check(&self) -> bool {
114        self.slots().has_subscript()
115    }
116
117    pub fn length_opt(self, vm: &VirtualMachine) -> Option<PyResult<usize>> {
118        self.slots().length.load().map(|f| f(self, vm))
119    }
120
121    pub fn length(self, vm: &VirtualMachine) -> PyResult<usize> {
122        self.length_opt(vm).ok_or_else(|| {
123            vm.new_type_error(format!(
124                "object of type '{}' has no len() or not a mapping",
125                self.obj.class()
126            ))
127        })?
128    }
129
130    pub fn subscript(self, needle: &impl AsObject, vm: &VirtualMachine) -> PyResult {
131        self._subscript(needle.as_object(), vm)
132    }
133
134    pub fn ass_subscript(
135        self,
136        needle: &impl AsObject,
137        value: Option<PyObjectRef>,
138        vm: &VirtualMachine,
139    ) -> PyResult<()> {
140        self._ass_subscript(needle.as_object(), value, vm)
141    }
142
143    fn _subscript(self, needle: &PyObject, vm: &VirtualMachine) -> PyResult {
144        let f =
145            self.slots().subscript.load().ok_or_else(|| {
146                vm.new_type_error(format!("{} is not a mapping", self.obj.class()))
147            })?;
148        f(self, needle, vm)
149    }
150
151    fn _ass_subscript(
152        self,
153        needle: &PyObject,
154        value: Option<PyObjectRef>,
155        vm: &VirtualMachine,
156    ) -> PyResult<()> {
157        let f = self.slots().ass_subscript.load().ok_or_else(|| {
158            vm.new_type_error(format!(
159                "'{}' object does not support item assignment",
160                self.obj.class()
161            ))
162        })?;
163        f(self, needle, value, vm)
164    }
165
166    pub fn keys(self, vm: &VirtualMachine) -> PyResult {
167        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
168            PyDictKeys::new(dict.to_owned()).to_pyresult(vm)
169        } else {
170            self.method_output_as_list(identifier!(vm, keys), vm)
171        }
172    }
173
174    pub fn values(self, vm: &VirtualMachine) -> PyResult {
175        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
176            PyDictValues::new(dict.to_owned()).to_pyresult(vm)
177        } else {
178            self.method_output_as_list(identifier!(vm, values), vm)
179        }
180    }
181
182    pub fn items(self, vm: &VirtualMachine) -> PyResult {
183        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
184            PyDictItems::new(dict.to_owned()).to_pyresult(vm)
185        } else {
186            self.method_output_as_list(identifier!(vm, items), vm)
187        }
188    }
189
190    fn method_output_as_list(
191        self,
192        method_name: &'static PyStrInterned,
193        vm: &VirtualMachine,
194    ) -> PyResult {
195        let meth_output = vm.call_method(self.obj, method_name.as_str(), ())?;
196        if meth_output.is(vm.ctx.types.list_type) {
197            return Ok(meth_output);
198        }
199
200        let iter = meth_output.get_iter(vm).map_err(|_| {
201            vm.new_type_error(format!(
202                "{}.{}() returned a non-iterable (type {})",
203                self.obj.class(),
204                method_name.as_str(),
205                meth_output.class()
206            ))
207        })?;
208
209        // TODO
210        // PySequence::from(&iter).list(vm).map(|x| x.into())
211        vm.ctx.new_list(iter.try_to_value(vm)?).to_pyresult(vm)
212    }
213}