rustpython_vm/protocol/
mapping.rs

1use crate::{
2    builtins::{
3        dict::{PyDictItems, PyDictKeys, PyDictValues},
4        type_::PointerSlot,
5        PyDict, PyStrInterned,
6    },
7    convert::ToPyResult,
8    object::{Traverse, TraverseFn},
9    AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine,
10};
11use crossbeam_utils::atomic::AtomicCell;
12
13// Mapping protocol
14// https://docs.python.org/3/c-api/mapping.html
15
16impl PyObject {
17    pub fn to_mapping(&self) -> PyMapping<'_> {
18        PyMapping::from(self)
19    }
20}
21
22#[allow(clippy::type_complexity)]
23#[derive(Default)]
24pub struct PyMappingMethods {
25    pub length: AtomicCell<Option<fn(PyMapping, &VirtualMachine) -> PyResult<usize>>>,
26    pub subscript: AtomicCell<Option<fn(PyMapping, &PyObject, &VirtualMachine) -> PyResult>>,
27    pub ass_subscript: AtomicCell<
28        Option<fn(PyMapping, &PyObject, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>>,
29    >,
30}
31
32impl std::fmt::Debug for PyMappingMethods {
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        write!(f, "mapping methods")
35    }
36}
37
38impl PyMappingMethods {
39    fn check(&self) -> bool {
40        self.subscript.load().is_some()
41    }
42
43    #[allow(clippy::declare_interior_mutable_const)]
44    pub const NOT_IMPLEMENTED: PyMappingMethods = PyMappingMethods {
45        length: AtomicCell::new(None),
46        subscript: AtomicCell::new(None),
47        ass_subscript: AtomicCell::new(None),
48    };
49}
50
51impl<'a> From<&'a PyObject> for PyMapping<'a> {
52    fn from(obj: &'a PyObject) -> Self {
53        static GLOBAL_NOT_IMPLEMENTED: PyMappingMethods = PyMappingMethods::NOT_IMPLEMENTED;
54        let methods = Self::find_methods(obj)
55            .map_or(&GLOBAL_NOT_IMPLEMENTED, |x| unsafe { x.borrow_static() });
56        Self { obj, methods }
57    }
58}
59
60#[derive(Copy, Clone)]
61pub struct PyMapping<'a> {
62    pub obj: &'a PyObject,
63    pub methods: &'static PyMappingMethods,
64}
65
66unsafe impl Traverse for PyMapping<'_> {
67    fn traverse(&self, tracer_fn: &mut TraverseFn) {
68        self.obj.traverse(tracer_fn)
69    }
70}
71
72impl AsRef<PyObject> for PyMapping<'_> {
73    #[inline(always)]
74    fn as_ref(&self) -> &PyObject {
75        self.obj
76    }
77}
78
79impl<'a> PyMapping<'a> {
80    pub fn try_protocol(obj: &'a PyObject, vm: &VirtualMachine) -> PyResult<Self> {
81        if let Some(methods) = Self::find_methods(obj) {
82            if methods.as_ref().check() {
83                return Ok(Self {
84                    obj,
85                    methods: unsafe { methods.borrow_static() },
86                });
87            }
88        }
89
90        Err(vm.new_type_error(format!("{} is not a mapping object", obj.class())))
91    }
92}
93
94impl PyMapping<'_> {
95    // PyMapping::Check
96    #[inline]
97    pub fn check(obj: &PyObject) -> bool {
98        Self::find_methods(obj).map_or(false, |x| x.as_ref().check())
99    }
100
101    pub fn find_methods(obj: &PyObject) -> Option<PointerSlot<PyMappingMethods>> {
102        obj.class().mro_find_map(|cls| cls.slots.as_mapping.load())
103    }
104
105    pub fn length_opt(self, vm: &VirtualMachine) -> Option<PyResult<usize>> {
106        self.methods.length.load().map(|f| f(self, vm))
107    }
108
109    pub fn length(self, vm: &VirtualMachine) -> PyResult<usize> {
110        self.length_opt(vm).ok_or_else(|| {
111            vm.new_type_error(format!(
112                "object of type '{}' has no len() or not a mapping",
113                self.obj.class()
114            ))
115        })?
116    }
117
118    pub fn subscript(self, needle: &impl AsObject, vm: &VirtualMachine) -> PyResult {
119        self._subscript(needle.as_object(), vm)
120    }
121
122    pub fn ass_subscript(
123        self,
124        needle: &impl AsObject,
125        value: Option<PyObjectRef>,
126        vm: &VirtualMachine,
127    ) -> PyResult<()> {
128        self._ass_subscript(needle.as_object(), value, vm)
129    }
130
131    fn _subscript(self, needle: &PyObject, vm: &VirtualMachine) -> PyResult {
132        let f =
133            self.methods.subscript.load().ok_or_else(|| {
134                vm.new_type_error(format!("{} is not a mapping", self.obj.class()))
135            })?;
136        f(self, needle, vm)
137    }
138
139    fn _ass_subscript(
140        self,
141        needle: &PyObject,
142        value: Option<PyObjectRef>,
143        vm: &VirtualMachine,
144    ) -> PyResult<()> {
145        let f = self.methods.ass_subscript.load().ok_or_else(|| {
146            vm.new_type_error(format!(
147                "'{}' object does not support item assignment",
148                self.obj.class()
149            ))
150        })?;
151        f(self, needle, value, vm)
152    }
153
154    pub fn keys(self, vm: &VirtualMachine) -> PyResult {
155        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
156            PyDictKeys::new(dict.to_owned()).to_pyresult(vm)
157        } else {
158            self.method_output_as_list(identifier!(vm, keys), vm)
159        }
160    }
161
162    pub fn values(self, vm: &VirtualMachine) -> PyResult {
163        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
164            PyDictValues::new(dict.to_owned()).to_pyresult(vm)
165        } else {
166            self.method_output_as_list(identifier!(vm, values), vm)
167        }
168    }
169
170    pub fn items(self, vm: &VirtualMachine) -> PyResult {
171        if let Some(dict) = self.obj.downcast_ref_if_exact::<PyDict>(vm) {
172            PyDictItems::new(dict.to_owned()).to_pyresult(vm)
173        } else {
174            self.method_output_as_list(identifier!(vm, items), vm)
175        }
176    }
177
178    fn method_output_as_list(
179        self,
180        method_name: &'static PyStrInterned,
181        vm: &VirtualMachine,
182    ) -> PyResult {
183        let meth_output = vm.call_method(self.obj, method_name.as_str(), ())?;
184        if meth_output.is(vm.ctx.types.list_type) {
185            return Ok(meth_output);
186        }
187
188        let iter = meth_output.get_iter(vm).map_err(|_| {
189            vm.new_type_error(format!(
190                "{}.{}() returned a non-iterable (type {})",
191                self.obj.class(),
192                method_name.as_str(),
193                meth_output.class()
194            ))
195        })?;
196
197        // TODO
198        // PySequence::from(&iter).list(vm).map(|x| x.into())
199        vm.ctx.new_list(iter.try_to_value(vm)?).to_pyresult(vm)
200    }
201}