Skip to main content

rustpython_vm/builtins/
weakproxy.rs

1use super::{PyStr, PyStrRef, PyType, PyTypeRef, PyWeak};
2use crate::common::lock::LazyLock;
3use crate::{
4    Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
5    class::PyClassImpl,
6    common::hash::PyHash,
7    function::{OptionalArg, PyComparisonValue, PySetterValue},
8    protocol::{PyIter, PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
9    stdlib::builtins::reversed,
10    types::{
11        AsMapping, AsNumber, AsSequence, Comparable, Constructor, GetAttr, Hashable, IterNext,
12        Iterable, PyComparisonOp, Representable, SetAttr,
13    },
14};
15
16#[pyclass(module = false, name = "weakproxy", unhashable = true, traverse)]
17#[derive(Debug)]
18pub struct PyWeakProxy {
19    weak: PyRef<PyWeak>,
20}
21
22impl PyPayload for PyWeakProxy {
23    #[inline]
24    fn class(ctx: &Context) -> &'static Py<PyType> {
25        ctx.types.weakproxy_type
26    }
27}
28
29#[derive(FromArgs)]
30pub struct WeakProxyNewArgs {
31    #[pyarg(positional)]
32    referent: PyObjectRef,
33    #[pyarg(positional, optional)]
34    callback: OptionalArg<PyObjectRef>,
35}
36
37impl Constructor for PyWeakProxy {
38    type Args = WeakProxyNewArgs;
39
40    fn py_new(
41        _cls: &Py<PyType>,
42        Self::Args { referent, callback }: Self::Args,
43        vm: &VirtualMachine,
44    ) -> PyResult<Self> {
45        // using an internal subclass as the class prevents us from getting the generic weakref,
46        // which would mess up the weakref count
47        let weak_cls = WEAK_SUBCLASS.get_or_init(|| {
48            vm.ctx.new_class(
49                None,
50                "__weakproxy",
51                vm.ctx.types.weakref_type.to_owned(),
52                super::PyWeak::make_slots(),
53            )
54        });
55        // TODO: PyWeakProxy should use the same payload as PyWeak
56        Ok(Self {
57            weak: referent.downgrade_with_typ(callback.into_option(), weak_cls.clone(), vm)?,
58        })
59    }
60}
61
62crate::common::static_cell! {
63    static WEAK_SUBCLASS: PyTypeRef;
64}
65
66#[pyclass(with(
67    GetAttr,
68    SetAttr,
69    Constructor,
70    Comparable,
71    AsNumber,
72    AsSequence,
73    AsMapping,
74    Representable,
75    IterNext
76))]
77impl PyWeakProxy {
78    fn try_upgrade(&self, vm: &VirtualMachine) -> PyResult {
79        self.weak.upgrade().ok_or_else(|| new_reference_error(vm))
80    }
81
82    #[pymethod]
83    fn __str__(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
84        zelf.try_upgrade(vm)?.str(vm)
85    }
86
87    fn len(&self, vm: &VirtualMachine) -> PyResult<usize> {
88        self.try_upgrade(vm)?.length(vm)
89    }
90
91    #[pymethod]
92    fn __bytes__(&self, vm: &VirtualMachine) -> PyResult {
93        self.try_upgrade(vm)?.bytes(vm)
94    }
95
96    #[pymethod]
97    fn __reversed__(&self, vm: &VirtualMachine) -> PyResult {
98        let obj = self.try_upgrade(vm)?;
99        reversed(obj, vm)
100    }
101    fn __contains__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
102        self.try_upgrade(vm)?
103            .sequence_unchecked()
104            .contains(&needle, vm)
105    }
106
107    fn getitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
108        let obj = self.try_upgrade(vm)?;
109        obj.get_item(&*needle, vm)
110    }
111
112    fn setitem(
113        &self,
114        needle: PyObjectRef,
115        value: PyObjectRef,
116        vm: &VirtualMachine,
117    ) -> PyResult<()> {
118        let obj = self.try_upgrade(vm)?;
119        obj.set_item(&*needle, value, vm)
120    }
121
122    fn delitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
123        let obj = self.try_upgrade(vm)?;
124        obj.del_item(&*needle, vm)
125    }
126}
127
128impl Iterable for PyWeakProxy {
129    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
130        let obj = zelf.try_upgrade(vm)?;
131        Ok(obj.get_iter(vm)?.into())
132    }
133}
134
135impl IterNext for PyWeakProxy {
136    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
137        let obj = zelf.try_upgrade(vm)?;
138        if obj.class().slots.iternext.load().is_none() {
139            return Err(vm.new_type_error("Weakref proxy referenced a non-iterator".to_owned()));
140        }
141        PyIter::new(obj).next(vm)
142    }
143}
144
145fn new_reference_error(vm: &VirtualMachine) -> PyRef<super::PyBaseException> {
146    vm.new_exception_msg(
147        vm.ctx.exceptions.reference_error.to_owned(),
148        "weakly-referenced object no longer exists".into(),
149    )
150}
151
152impl GetAttr for PyWeakProxy {
153    // TODO: callbacks
154    fn getattro(zelf: &Py<Self>, name: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
155        let obj = zelf.try_upgrade(vm)?;
156        obj.get_attr(name, vm)
157    }
158}
159
160impl SetAttr for PyWeakProxy {
161    fn setattro(
162        zelf: &Py<Self>,
163        attr_name: &Py<PyStr>,
164        value: PySetterValue,
165        vm: &VirtualMachine,
166    ) -> PyResult<()> {
167        let obj = zelf.try_upgrade(vm)?;
168        obj.call_set_attr(vm, attr_name, value)
169    }
170}
171
172fn proxy_upgrade(obj: &PyObject, vm: &VirtualMachine) -> PyResult {
173    obj.downcast_ref::<PyWeakProxy>()
174        .expect("proxy_upgrade called on non-PyWeakProxy object")
175        .try_upgrade(vm)
176}
177
178fn proxy_upgrade_opt(obj: &PyObject, vm: &VirtualMachine) -> PyResult<Option<PyObjectRef>> {
179    match obj.downcast_ref::<PyWeakProxy>() {
180        Some(proxy) => Ok(Some(proxy.try_upgrade(vm)?)),
181        None => Ok(None),
182    }
183}
184
185fn proxy_unary_op(
186    obj: &PyObject,
187    vm: &VirtualMachine,
188    op: fn(&VirtualMachine, &PyObject) -> PyResult,
189) -> PyResult {
190    let upgraded = proxy_upgrade(obj, vm)?;
191    op(vm, &upgraded)
192}
193
194macro_rules! proxy_unary_slot {
195    ($vm_method:ident) => {
196        Some(|number, vm| proxy_unary_op(number.obj, vm, |vm, obj| vm.$vm_method(obj)))
197    };
198}
199
200fn proxy_binary_op(
201    a: &PyObject,
202    b: &PyObject,
203    vm: &VirtualMachine,
204    op: fn(&VirtualMachine, &PyObject, &PyObject) -> PyResult,
205) -> PyResult {
206    let a_up = proxy_upgrade_opt(a, vm)?;
207    let b_up = proxy_upgrade_opt(b, vm)?;
208    let a_ref = a_up.as_deref().unwrap_or(a);
209    let b_ref = b_up.as_deref().unwrap_or(b);
210    op(vm, a_ref, b_ref)
211}
212
213macro_rules! proxy_binary_slot {
214    ($vm_method:ident) => {
215        Some(|a, b, vm| proxy_binary_op(a, b, vm, |vm, a, b| vm.$vm_method(a, b)))
216    };
217}
218
219fn proxy_ternary_op(
220    a: &PyObject,
221    b: &PyObject,
222    c: &PyObject,
223    vm: &VirtualMachine,
224    op: fn(&VirtualMachine, &PyObject, &PyObject, &PyObject) -> PyResult,
225) -> PyResult {
226    let a_up = proxy_upgrade_opt(a, vm)?;
227    let b_up = proxy_upgrade_opt(b, vm)?;
228    let c_up = proxy_upgrade_opt(c, vm)?;
229    let a_ref = a_up.as_deref().unwrap_or(a);
230    let b_ref = b_up.as_deref().unwrap_or(b);
231    let c_ref = c_up.as_deref().unwrap_or(c);
232    op(vm, a_ref, b_ref, c_ref)
233}
234
235macro_rules! proxy_ternary_slot {
236    ($vm_method:ident) => {
237        Some(|a, b, c, vm| proxy_ternary_op(a, b, c, vm, |vm, a, b, c| vm.$vm_method(a, b, c)))
238    };
239}
240
241impl AsNumber for PyWeakProxy {
242    fn as_number() -> &'static PyNumberMethods {
243        static AS_NUMBER: LazyLock<PyNumberMethods> = LazyLock::new(|| PyNumberMethods {
244            boolean: Some(|number, vm| {
245                let obj = proxy_upgrade(number.obj, vm)?;
246                obj.is_true(vm)
247            }),
248            int: Some(|number, vm| {
249                let obj = proxy_upgrade(number.obj, vm)?;
250                obj.try_int(vm).map(Into::into)
251            }),
252            float: Some(|number, vm| {
253                let obj = proxy_upgrade(number.obj, vm)?;
254                obj.try_float(vm).map(Into::into)
255            }),
256            index: Some(|number, vm| {
257                let obj = proxy_upgrade(number.obj, vm)?;
258                obj.try_index(vm).map(Into::into)
259            }),
260            negative: proxy_unary_slot!(_neg),
261            positive: proxy_unary_slot!(_pos),
262            absolute: proxy_unary_slot!(_abs),
263            invert: proxy_unary_slot!(_invert),
264            add: proxy_binary_slot!(_add),
265            subtract: proxy_binary_slot!(_sub),
266            multiply: proxy_binary_slot!(_mul),
267            remainder: proxy_binary_slot!(_mod),
268            divmod: proxy_binary_slot!(_divmod),
269            lshift: proxy_binary_slot!(_lshift),
270            rshift: proxy_binary_slot!(_rshift),
271            and: proxy_binary_slot!(_and),
272            xor: proxy_binary_slot!(_xor),
273            or: proxy_binary_slot!(_or),
274            floor_divide: proxy_binary_slot!(_floordiv),
275            true_divide: proxy_binary_slot!(_truediv),
276            matrix_multiply: proxy_binary_slot!(_matmul),
277            inplace_add: proxy_binary_slot!(_iadd),
278            inplace_subtract: proxy_binary_slot!(_isub),
279            inplace_multiply: proxy_binary_slot!(_imul),
280            inplace_remainder: proxy_binary_slot!(_imod),
281            inplace_lshift: proxy_binary_slot!(_ilshift),
282            inplace_rshift: proxy_binary_slot!(_irshift),
283            inplace_and: proxy_binary_slot!(_iand),
284            inplace_xor: proxy_binary_slot!(_ixor),
285            inplace_or: proxy_binary_slot!(_ior),
286            inplace_floor_divide: proxy_binary_slot!(_ifloordiv),
287            inplace_true_divide: proxy_binary_slot!(_itruediv),
288            inplace_matrix_multiply: proxy_binary_slot!(_imatmul),
289            power: proxy_ternary_slot!(_pow),
290            inplace_power: proxy_ternary_slot!(_ipow),
291        });
292        &AS_NUMBER
293    }
294}
295
296impl Comparable for PyWeakProxy {
297    fn cmp(
298        zelf: &Py<Self>,
299        other: &PyObject,
300        op: PyComparisonOp,
301        vm: &VirtualMachine,
302    ) -> PyResult<PyComparisonValue> {
303        let obj = zelf.try_upgrade(vm)?;
304        Ok(PyComparisonValue::Implemented(
305            obj.rich_compare_bool(other, op, vm)?,
306        ))
307    }
308}
309
310impl AsSequence for PyWeakProxy {
311    fn as_sequence() -> &'static PySequenceMethods {
312        static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
313            length: atomic_func!(|seq, vm| PyWeakProxy::sequence_downcast(seq).len(vm)),
314            contains: atomic_func!(|seq, needle, vm| {
315                PyWeakProxy::sequence_downcast(seq).__contains__(needle.to_owned(), vm)
316            }),
317            ..PySequenceMethods::NOT_IMPLEMENTED
318        });
319        &AS_SEQUENCE
320    }
321}
322
323impl AsMapping for PyWeakProxy {
324    fn as_mapping() -> &'static PyMappingMethods {
325        static AS_MAPPING: PyMappingMethods = PyMappingMethods {
326            length: atomic_func!(|mapping, vm| PyWeakProxy::mapping_downcast(mapping).len(vm)),
327            subscript: atomic_func!(|mapping, needle, vm| {
328                PyWeakProxy::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
329            }),
330            ass_subscript: atomic_func!(|mapping, needle, value, vm| {
331                let zelf = PyWeakProxy::mapping_downcast(mapping);
332                if let Some(value) = value {
333                    zelf.setitem(needle.to_owned(), value, vm)
334                } else {
335                    zelf.delitem(needle.to_owned(), vm)
336                }
337            }),
338        };
339        &AS_MAPPING
340    }
341}
342
343impl Representable for PyWeakProxy {
344    #[inline]
345    fn repr(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyStrRef> {
346        zelf.try_upgrade(vm)?.repr(vm)
347    }
348
349    #[cold]
350    fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
351        unreachable!("use repr instead")
352    }
353}
354
355pub fn init(context: &'static Context) {
356    PyWeakProxy::extend_class(context, context.types.weakproxy_type);
357}
358
359impl Hashable for PyWeakProxy {
360    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
361        zelf.try_upgrade(vm)?.hash(vm)
362    }
363}