Skip to main content

rustpython_vm/builtins/
object.rs

1use super::{PyDictRef, PyList, PyStr, PyStrRef, PyType, PyTypeRef, PyUtf8StrRef};
2use crate::common::hash::PyHash;
3use crate::types::PyTypeFlags;
4use crate::{
5    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
6    class::PyClassImpl,
7    convert::ToPyResult,
8    function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue},
9    types::{Constructor, Initializer, PyComparisonOp},
10};
11use itertools::Itertools;
12
13/// object()
14/// --
15///
16/// The base class of the class hierarchy.
17///
18/// When called, it accepts no arguments and returns a new featureless
19/// instance that has no instance attributes and cannot be given any.
20#[pyclass(module = false, name = "object")]
21#[derive(Debug)]
22pub struct PyBaseObject;
23
24impl PyPayload for PyBaseObject {
25    #[inline]
26    fn class(ctx: &Context) -> &'static Py<PyType> {
27        ctx.types.object_type
28    }
29}
30
31impl Constructor for PyBaseObject {
32    type Args = FuncArgs;
33
34    // = object_new
35    fn slot_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
36        if !args.args.is_empty() || !args.kwargs.is_empty() {
37            // Check if type's __new__ != object.__new__
38            let tp_new = cls.get_attr(identifier!(vm, __new__));
39            let object_new = vm.ctx.types.object_type.get_attr(identifier!(vm, __new__));
40
41            if let (Some(tp_new), Some(object_new)) = (tp_new, object_new) {
42                if !tp_new.is(&object_new) {
43                    // Type has its own __new__, so object.__new__ is being called
44                    // with excess args. This is the first error case in CPython
45                    return Err(vm.new_type_error(
46                        "object.__new__() takes exactly one argument (the type to instantiate)",
47                    ));
48                }
49
50                // If we reach here, tp_new == object_new
51                // Now check if type's __init__ == object.__init__
52                let tp_init = cls.get_attr(identifier!(vm, __init__));
53                let object_init = vm.ctx.types.object_type.get_attr(identifier!(vm, __init__));
54
55                if let (Some(tp_init), Some(object_init)) = (tp_init, object_init)
56                    && tp_init.is(&object_init)
57                {
58                    // Both __new__ and __init__ are object's versions,
59                    // so the type accepts no arguments
60                    return Err(vm.new_type_error(format!("{}() takes no arguments", cls.name())));
61                }
62                // If tp_init != object_init, then the type has custom __init__
63                // which might accept arguments, so we allow it
64            }
65        }
66
67        // Ensure that all abstract methods are implemented before instantiating instance.
68        if let Some(abs_methods) = cls.get_attr(identifier!(vm, __abstractmethods__))
69            && let Some(unimplemented_abstract_method_count) = abs_methods.length_opt(vm)
70        {
71            let methods: Vec<PyUtf8StrRef> = abs_methods.try_to_value(vm)?;
72            let methods: String = Itertools::intersperse(
73                methods.iter().map(|name| name.as_str().to_owned()),
74                "', '".to_owned(),
75            )
76            .collect();
77
78            let unimplemented_abstract_method_count = unimplemented_abstract_method_count?;
79            let name = cls.name().to_string();
80
81            match unimplemented_abstract_method_count {
82                0 => {}
83                1 => {
84                    return Err(vm.new_type_error(format!(
85                        "class {name} without an implementation for abstract method '{methods}'"
86                    )));
87                }
88                2.. => {
89                    return Err(vm.new_type_error(format!(
90                        "class {name} without an implementation for abstract methods '{methods}'"
91                    )));
92                }
93                // TODO: remove `allow` when redox build doesn't complain about it
94                #[allow(unreachable_patterns)]
95                _ => unreachable!(),
96            }
97        }
98
99        generic_alloc(cls, 0, vm)
100    }
101
102    fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
103        unimplemented!("use slot_new")
104    }
105}
106
107pub(crate) fn generic_alloc(cls: PyTypeRef, _nitems: usize, vm: &VirtualMachine) -> PyResult {
108    // Only create dict if the class has HAS_DICT flag (i.e., __slots__ was not defined
109    // or __dict__ is in __slots__)
110    let dict = if cls
111        .slots
112        .flags
113        .has_feature(crate::types::PyTypeFlags::HAS_DICT)
114    {
115        Some(vm.ctx.new_dict())
116    } else {
117        None
118    };
119    Ok(crate::PyRef::new_ref(PyBaseObject, cls, dict).into())
120}
121
122impl Initializer for PyBaseObject {
123    type Args = FuncArgs;
124
125    // object_init: excess_args validation
126    fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
127        if args.is_empty() {
128            return Ok(());
129        }
130
131        let typ = zelf.class();
132        let object_type = &vm.ctx.types.object_type;
133
134        let typ_init = typ.slots.init.load().map(|f| f as usize);
135        let object_init = object_type.slots.init.load().map(|f| f as usize);
136
137        // if (type->tp_init != object_init) → first error
138        if typ_init != object_init {
139            return Err(vm.new_type_error(
140                "object.__init__() takes exactly one argument (the instance to initialize)",
141            ));
142        }
143
144        // if (type->tp_new == object_new) → second error
145        if let (Some(typ_new), Some(object_new)) = (
146            typ.get_attr(identifier!(vm, __new__)),
147            object_type.get_attr(identifier!(vm, __new__)),
148        ) && typ_new.is(&object_new)
149        {
150            return Err(vm.new_type_error(format!(
151                "{}.__init__() takes exactly one argument (the instance to initialize)",
152                typ.name()
153            )));
154        }
155
156        // Both conditions false → OK (e.g., tuple, dict with custom __new__)
157        Ok(())
158    }
159
160    fn init(_zelf: PyRef<Self>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> {
161        unreachable!("slot_init is defined")
162    }
163}
164
165// TODO: implement _PyType_GetSlotNames properly
166fn type_slot_names(typ: &Py<PyType>, vm: &VirtualMachine) -> PyResult<Option<super::PyListRef>> {
167    // let attributes = typ.attributes.read();
168    // if let Some(slot_names) = attributes.get(identifier!(vm.ctx, __slotnames__)) {
169    //     return match_class!(match slot_names.clone() {
170    //         l @ super::PyList => Ok(Some(l)),
171    //         _n @ super::PyNone => Ok(None),
172    //         _ => Err(vm.new_type_error(format!(
173    //             "{:.200}.__slotnames__ should be a list or None, not {:.200}",
174    //             typ.name(),
175    //             slot_names.class().name()
176    //         ))),
177    //     });
178    // }
179
180    let copyreg = vm.import("copyreg", 0)?;
181    let copyreg_slotnames = copyreg.get_attr("_slotnames", vm)?;
182    let slot_names = copyreg_slotnames.call((typ.to_owned(),), vm)?;
183    let result = match_class!(match slot_names {
184        l @ super::PyList => Some(l),
185        _n @ super::PyNone => None,
186        _ => return Err(vm.new_type_error("copyreg._slotnames didn't return a list or None")),
187    });
188    Ok(result)
189}
190
191// object_getstate_default
192fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
193    // Check itemsize
194    if required && obj.class().slots.itemsize > 0 {
195        return Err(vm.new_type_error(format!("cannot pickle {:.200} objects", obj.class().name())));
196    }
197
198    let state = if obj.dict().is_none_or(|d| d.is_empty()) {
199        vm.ctx.none()
200    } else {
201        // let state = object_get_dict(obj.clone(), obj.ctx()).unwrap();
202        let Some(state) = obj.dict() else {
203            return Ok(vm.ctx.none());
204        };
205        state.into()
206    };
207
208    let slot_names =
209        type_slot_names(obj.class(), vm).map_err(|_| vm.new_type_error("cannot pickle object"))?;
210
211    if required {
212        // Start with PyBaseObject_Type's basicsize
213        let mut basicsize = vm.ctx.types.object_type.slots.basicsize;
214
215        // Add __dict__ size if type has dict
216        if obj.class().slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
217            basicsize += core::mem::size_of::<PyObjectRef>();
218        }
219
220        // Add __weakref__ size if type has weakref support
221        let has_weakref = if let Some(ref ext) = obj.class().heaptype_ext {
222            match &ext.slots {
223                None => true, // Heap type without __slots__ has automatic weakref
224                Some(slots) => slots.iter().any(|s| s.as_bytes() == b"__weakref__"),
225            }
226        } else {
227            let weakref_name = vm.ctx.intern_str("__weakref__");
228            obj.class().attributes.read().contains_key(weakref_name)
229        };
230        if has_weakref {
231            basicsize += core::mem::size_of::<PyObjectRef>();
232        }
233
234        // Add slots size
235        if let Some(ref slot_names) = slot_names {
236            basicsize += core::mem::size_of::<PyObjectRef>() * slot_names.__len__();
237        }
238
239        // Fail if actual type's basicsize > expected basicsize
240        if obj.class().slots.basicsize > basicsize {
241            return Err(vm.new_type_error(format!("cannot pickle '{}' object", obj.class().name())));
242        }
243    }
244
245    if let Some(slot_names) = slot_names {
246        let slot_names_len = slot_names.__len__();
247        if slot_names_len > 0 {
248            let slots = vm.ctx.new_dict();
249            for i in 0..slot_names_len {
250                let borrowed_names = slot_names.borrow_vec();
251                // Check if slotnames changed during iteration
252                if borrowed_names.len() != slot_names_len {
253                    return Err(vm.new_runtime_error("__slotnames__ changed size during iteration"));
254                }
255                let name = borrowed_names[i].downcast_ref::<PyStr>().unwrap();
256                let Ok(value) = obj.get_attr(name, vm) else {
257                    continue;
258                };
259                slots.set_item(name.as_wtf8(), value, vm).unwrap();
260            }
261
262            if !slots.is_empty() {
263                return (state, slots).to_pyresult(vm);
264            }
265        }
266    }
267
268    Ok(state)
269}
270
271// object_getstate
272// fn object_getstate(
273//     obj: &PyObject,
274//     required: bool,
275//     vm: &VirtualMachine,
276// ) -> PyResult {
277//     let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
278//     if vm.is_none(&getstate) {
279//         return Ok(None);
280//     }
281
282//     let getstate = match getstate.downcast_exact::<PyNativeFunction>(vm) {
283//         Ok(getstate)
284//             if getstate
285//                 .get_self()
286//                 .map_or(false, |self_obj| self_obj.is(obj))
287//                 && std::ptr::addr_eq(
288//                     getstate.as_func() as *const _,
289//                     &PyBaseObject::__getstate__ as &dyn crate::function::PyNativeFn as *const _,
290//                 ) =>
291//         {
292//             return object_getstate_default(obj, required, vm);
293//         }
294//         Ok(getstate) => getstate.into_pyref().into(),
295//         Err(getstate) => getstate,
296//     };
297//     getstate.call((), vm)
298// }
299
300#[pyclass(with(Constructor, Initializer), flags(BASETYPE))]
301impl PyBaseObject {
302    #[pymethod(raw)]
303    fn __getstate__(vm: &VirtualMachine, args: FuncArgs) -> PyResult {
304        let (zelf,): (PyObjectRef,) = args.bind(vm)?;
305        object_getstate_default(&zelf, false, vm)
306    }
307
308    #[pyslot]
309    fn slot_richcompare(
310        zelf: &PyObject,
311        other: &PyObject,
312        op: PyComparisonOp,
313        vm: &VirtualMachine,
314    ) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
315        Self::cmp(zelf, other, op, vm).map(Either::B)
316    }
317
318    #[inline(always)]
319    fn cmp(
320        zelf: &PyObject,
321        other: &PyObject,
322        op: PyComparisonOp,
323        vm: &VirtualMachine,
324    ) -> PyResult<PyComparisonValue> {
325        let res = match op {
326            PyComparisonOp::Eq => {
327                if zelf.is(other) {
328                    PyComparisonValue::Implemented(true)
329                } else {
330                    PyComparisonValue::NotImplemented
331                }
332            }
333            PyComparisonOp::Ne => {
334                let cmp = zelf.class().slots.richcompare.load().unwrap();
335                let value = match cmp(zelf, other, PyComparisonOp::Eq, vm)? {
336                    Either::A(obj) => PyArithmeticValue::from_object(vm, obj)
337                        .map(|obj| obj.try_to_bool(vm))
338                        .transpose()?,
339                    Either::B(value) => value,
340                };
341                value.map(|v| !v)
342            }
343            _ => PyComparisonValue::NotImplemented,
344        };
345        Ok(res)
346    }
347
348    /// Implement setattr(self, name, value).
349    #[pymethod]
350    fn __setattr__(
351        obj: PyObjectRef,
352        name: PyStrRef,
353        value: PyObjectRef,
354        vm: &VirtualMachine,
355    ) -> PyResult<()> {
356        obj.generic_setattr(&name, PySetterValue::Assign(value), vm)
357    }
358
359    /// Implement delattr(self, name).
360    #[pymethod]
361    fn __delattr__(obj: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
362        obj.generic_setattr(&name, PySetterValue::Delete, vm)
363    }
364
365    #[pyslot]
366    pub(crate) fn slot_setattro(
367        obj: &PyObject,
368        attr_name: &Py<PyStr>,
369        value: PySetterValue,
370        vm: &VirtualMachine,
371    ) -> PyResult<()> {
372        obj.generic_setattr(attr_name, value, vm)
373    }
374
375    /// Return str(self).
376    #[pyslot]
377    fn slot_str(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
378        // FIXME: try tp_repr first and fallback to object.__repr__
379        zelf.repr(vm)
380    }
381
382    #[pyslot]
383    fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyStrRef> {
384        let class = zelf.class();
385        match (
386            class
387                .__qualname__(vm)
388                .downcast_ref::<PyStr>()
389                .map(|n| n.as_wtf8()),
390            class
391                .__module__(vm)
392                .downcast_ref::<PyStr>()
393                .map(|m| m.as_wtf8()),
394        ) {
395            (None, _) => Err(vm.new_type_error("Unknown qualified name")),
396            (Some(qualname), Some(module)) if module != "builtins" => Ok(PyStr::from(format!(
397                "<{}.{} object at {:#x}>",
398                module,
399                qualname,
400                zelf.get_id()
401            ))
402            .into_ref(&vm.ctx)),
403            _ => Ok(PyStr::from(format!(
404                "<{} object at {:#x}>",
405                class.slot_name(),
406                zelf.get_id()
407            ))
408            .into_ref(&vm.ctx)),
409        }
410    }
411
412    #[pyclassmethod]
413    fn __subclasshook__(_args: FuncArgs, vm: &VirtualMachine) -> PyObjectRef {
414        vm.ctx.not_implemented()
415    }
416
417    #[pyclassmethod]
418    fn __init_subclass__(_cls: PyTypeRef) {}
419
420    #[pymethod]
421    pub fn __dir__(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyList> {
422        obj.dir(vm)
423    }
424
425    #[pymethod]
426    fn __format__(
427        obj: PyObjectRef,
428        format_spec: PyStrRef,
429        vm: &VirtualMachine,
430    ) -> PyResult<PyStrRef> {
431        if !format_spec.is_empty() {
432            return Err(vm.new_type_error(format!(
433                "unsupported format string passed to {}.__format__",
434                obj.class().name()
435            )));
436        }
437        obj.str(vm)
438    }
439
440    #[pygetset]
441    fn __class__(obj: PyObjectRef) -> PyTypeRef {
442        obj.class().to_owned()
443    }
444
445    #[pygetset(setter)]
446    fn set___class__(
447        instance: PyObjectRef,
448        value: PyObjectRef,
449        vm: &VirtualMachine,
450    ) -> PyResult<()> {
451        match value.downcast::<PyType>() {
452            Ok(cls) => {
453                let current_cls = instance.class();
454                let both_module = current_cls.fast_issubclass(vm.ctx.types.module_type)
455                    && cls.fast_issubclass(vm.ctx.types.module_type);
456                let both_mutable = !current_cls
457                    .slots
458                    .flags
459                    .has_feature(PyTypeFlags::IMMUTABLETYPE)
460                    && !cls.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE);
461                // FIXME(#1979) cls instances might have a payload
462                if both_mutable || both_module {
463                    let has_dict =
464                        |typ: &Py<PyType>| typ.slots.flags.has_feature(PyTypeFlags::HAS_DICT);
465                    let has_weakref =
466                        |typ: &Py<PyType>| typ.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF);
467                    // Compare slots tuples
468                    let slots_equal = match (
469                        current_cls
470                            .heaptype_ext
471                            .as_ref()
472                            .and_then(|e| e.slots.as_ref()),
473                        cls.heaptype_ext.as_ref().and_then(|e| e.slots.as_ref()),
474                    ) {
475                        (Some(a), Some(b)) => {
476                            a.len() == b.len()
477                                && a.iter()
478                                    .zip(b.iter())
479                                    .all(|(x, y)| x.as_wtf8() == y.as_wtf8())
480                        }
481                        (None, None) => true,
482                        _ => false,
483                    };
484                    if current_cls.slots.basicsize != cls.slots.basicsize
485                        || !slots_equal
486                        || has_dict(current_cls) != has_dict(&cls)
487                        || has_weakref(current_cls) != has_weakref(&cls)
488                        || current_cls.slots.member_count != cls.slots.member_count
489                    {
490                        return Err(vm.new_type_error(format!(
491                            "__class__ assignment: '{}' object layout differs from '{}'",
492                            cls.name(),
493                            current_cls.name()
494                        )));
495                    }
496                    instance.set_class(cls, vm);
497                    Ok(())
498                } else {
499                    Err(vm.new_type_error(
500                        "__class__ assignment only supported for mutable types or ModuleType subclasses",
501                    ))
502                }
503            }
504            Err(value) => {
505                let value_class = value.class();
506                let type_repr = &value_class.name();
507                Err(vm.new_type_error(format!(
508                    "__class__ must be set to a class, not '{type_repr}' object"
509                )))
510            }
511        }
512    }
513
514    /// Return getattr(self, name).
515    #[pyslot]
516    pub(crate) fn getattro(obj: &PyObject, name: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
517        vm_trace!("object.__getattribute__({:?}, {:?})", obj, name);
518        obj.as_object().generic_getattr(name, vm)
519    }
520
521    #[pymethod]
522    fn __getattribute__(obj: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult {
523        Self::getattro(&obj, &name, vm)
524    }
525
526    #[pymethod]
527    fn __reduce__(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
528        common_reduce(obj, 0, vm)
529    }
530
531    #[pymethod]
532    fn __reduce_ex__(obj: PyObjectRef, proto: usize, vm: &VirtualMachine) -> PyResult {
533        let __reduce__ = identifier!(vm, __reduce__);
534        if let Some(reduce) = vm.get_attribute_opt(obj.clone(), __reduce__)? {
535            let object_reduce = vm.ctx.types.object_type.get_attr(__reduce__).unwrap();
536            let typ_obj: PyObjectRef = obj.class().to_owned().into();
537            let class_reduce = typ_obj.get_attr(__reduce__, vm)?;
538            if !class_reduce.is(&object_reduce) {
539                return reduce.call((), vm);
540            }
541        }
542        common_reduce(obj, proto, vm)
543    }
544
545    #[pyslot]
546    fn slot_hash(zelf: &PyObject, _vm: &VirtualMachine) -> PyResult<PyHash> {
547        Ok(zelf.get_id() as _)
548    }
549
550    #[pymethod]
551    fn __sizeof__(zelf: PyObjectRef) -> usize {
552        zelf.class().slots.basicsize
553    }
554}
555
556pub fn object_get_dict(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyDictRef> {
557    obj.dict()
558        .ok_or_else(|| vm.new_attribute_error("This object has no __dict__"))
559}
560pub fn object_set_dict(obj: PyObjectRef, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
561    obj.set_dict(dict)
562        .map_err(|_| vm.new_attribute_error("This object has no __dict__"))
563}
564
565pub fn init(ctx: &'static Context) {
566    // Manually set alloc/init slots - derive macro doesn't generate extend_slots
567    // for trait impl that overrides #[pyslot] method
568    ctx.types.object_type.slots.alloc.store(Some(generic_alloc));
569    ctx.types
570        .object_type
571        .slots
572        .init
573        .store(Some(<PyBaseObject as Initializer>::slot_init));
574    PyBaseObject::extend_class(ctx, ctx.types.object_type);
575}
576
577/// Get arguments for __new__ from __getnewargs_ex__ or __getnewargs__
578/// Returns (args, kwargs) tuple where either can be None
579fn get_new_arguments(
580    obj: &PyObject,
581    vm: &VirtualMachine,
582) -> PyResult<(Option<super::PyTupleRef>, Option<super::PyDictRef>)> {
583    // First try __getnewargs_ex__
584    if let Some(getnewargs_ex) = vm.get_special_method(obj, identifier!(vm, __getnewargs_ex__))? {
585        let newargs = getnewargs_ex.invoke((), vm)?;
586
587        let newargs_tuple: PyRef<super::PyTuple> = newargs.downcast().map_err(|obj| {
588            vm.new_type_error(format!(
589                "__getnewargs_ex__ should return a tuple, not '{}'",
590                obj.class().name()
591            ))
592        })?;
593
594        if newargs_tuple.len() != 2 {
595            return Err(vm.new_value_error(format!(
596                "__getnewargs_ex__ should return a tuple of length 2, not {}",
597                newargs_tuple.len()
598            )));
599        }
600
601        let args = newargs_tuple.as_slice()[0].clone();
602        let kwargs = newargs_tuple.as_slice()[1].clone();
603
604        let args_tuple: PyRef<super::PyTuple> = args.downcast().map_err(|obj| {
605            vm.new_type_error(format!(
606                "first item of the tuple returned by __getnewargs_ex__ must be a tuple, not '{}'",
607                obj.class().name()
608            ))
609        })?;
610
611        let kwargs_dict: PyRef<super::PyDict> = kwargs.downcast().map_err(|obj| {
612            vm.new_type_error(format!(
613                "second item of the tuple returned by __getnewargs_ex__ must be a dict, not '{}'",
614                obj.class().name()
615            ))
616        })?;
617
618        return Ok((Some(args_tuple), Some(kwargs_dict)));
619    }
620
621    // Fall back to __getnewargs__
622    if let Some(getnewargs) = vm.get_special_method(obj, identifier!(vm, __getnewargs__))? {
623        let args = getnewargs.invoke((), vm)?;
624
625        let args_tuple: PyRef<super::PyTuple> = args.downcast().map_err(|obj| {
626            vm.new_type_error(format!(
627                "__getnewargs__ should return a tuple, not '{}'",
628                obj.class().name()
629            ))
630        })?;
631
632        return Ok((Some(args_tuple), None));
633    }
634
635    // No __getnewargs_ex__ or __getnewargs__
636    Ok((None, None))
637}
638
639/// Check if __getstate__ is overridden by comparing with object.__getstate__
640fn is_getstate_overridden(obj: &PyObject, vm: &VirtualMachine) -> bool {
641    let obj_cls = obj.class();
642    let object_type = vm.ctx.types.object_type;
643
644    // If the class is object itself, not overridden
645    if obj_cls.is(object_type) {
646        return false;
647    }
648
649    // Check if __getstate__ in the MRO comes from object or elsewhere
650    // If the type has its own __getstate__, it's overridden
651    if let Some(getstate) = obj_cls.get_attr(identifier!(vm, __getstate__))
652        && let Some(obj_getstate) = object_type.get_attr(identifier!(vm, __getstate__))
653    {
654        return !getstate.is(&obj_getstate);
655    }
656    false
657}
658
659/// object_getstate - calls __getstate__ method or default implementation
660fn object_getstate(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
661    // If __getstate__ is not overridden, use the default implementation with required flag
662    if !is_getstate_overridden(obj, vm) {
663        return object_getstate_default(obj, required, vm);
664    }
665
666    // __getstate__ is overridden, call it without required
667    let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
668    getstate.call((), vm)
669}
670
671/// Get list items iterator if obj is a list (or subclass), None iterator otherwise
672fn get_items_iter(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, PyObjectRef)> {
673    let listitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.list_type) {
674        obj.get_iter(vm)?.into()
675    } else {
676        vm.ctx.none()
677    };
678
679    let dictitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.dict_type) {
680        let items = vm.call_method(obj, "items", ())?;
681        items.get_iter(vm)?.into()
682    } else {
683        vm.ctx.none()
684    };
685
686    Ok((listitems, dictitems))
687}
688
689/// reduce_newobj - creates reduce tuple for protocol >= 2
690fn reduce_newobj(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
691    // Check if type has tp_new
692    let cls = obj.class();
693    if cls.slots.new.load().is_none() {
694        return Err(vm.new_type_error(format!("cannot pickle '{}' object", cls.name())));
695    }
696
697    let (args, kwargs) = get_new_arguments(&obj, vm)?;
698
699    let copyreg = vm.import("copyreg", 0)?;
700
701    let has_args = args.is_some();
702
703    let (newobj, newargs): (PyObjectRef, PyObjectRef) = if kwargs.is_none()
704        || kwargs.as_ref().is_some_and(|k| k.is_empty())
705    {
706        // Use copyreg.__newobj__
707        let newobj = copyreg.get_attr("__newobj__", vm)?;
708
709        let args_vec: Vec<PyObjectRef> = args.map(|a| a.as_slice().to_vec()).unwrap_or_default();
710
711        // Create (cls, *args) tuple
712        let mut newargs_vec: Vec<PyObjectRef> = vec![cls.to_owned().into()];
713        newargs_vec.extend(args_vec);
714        let newargs = vm.ctx.new_tuple(newargs_vec);
715
716        (newobj, newargs.into())
717    } else {
718        // args == NULL with non-empty kwargs is BadInternalCall
719        let Some(args) = args else {
720            return Err(vm.new_system_error("bad internal call"));
721        };
722        // Use copyreg.__newobj_ex__
723        let newobj = copyreg.get_attr("__newobj_ex__", vm)?;
724        let args_tuple: PyObjectRef = args.into();
725        let kwargs_dict: PyObjectRef = kwargs
726            .map(|k| k.into())
727            .unwrap_or_else(|| vm.ctx.new_dict().into());
728
729        let newargs = vm
730            .ctx
731            .new_tuple(vec![cls.to_owned().into(), args_tuple, kwargs_dict]);
732        (newobj, newargs.into())
733    };
734
735    // Determine if state is required
736    // required = !(has_args || is_list || is_dict)
737    let is_list = obj.fast_isinstance(vm.ctx.types.list_type);
738    let is_dict = obj.fast_isinstance(vm.ctx.types.dict_type);
739    let required = !(has_args || is_list || is_dict);
740
741    let state = object_getstate(&obj, required, vm)?;
742
743    let (listitems, dictitems) = get_items_iter(&obj, vm)?;
744
745    let result = vm
746        .ctx
747        .new_tuple(vec![newobj, newargs, state, listitems, dictitems]);
748    Ok(result.into())
749}
750
751fn common_reduce(obj: PyObjectRef, proto: usize, vm: &VirtualMachine) -> PyResult {
752    if proto >= 2 {
753        reduce_newobj(obj, vm)
754    } else {
755        let copyreg = vm.import("copyreg", 0)?;
756        let reduce_ex = copyreg.get_attr("_reduce_ex", vm)?;
757        reduce_ex.call((obj, proto), vm)
758    }
759}