rustpython_vm/builtins/
set.rs

1/*
2 * Builtin set type with a sequence of unique items.
3 */
4use super::{
5    builtins_iter, IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef,
6    PyType, PyTypeRef,
7};
8use crate::{
9    atomic_func,
10    class::PyClassImpl,
11    common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc},
12    convert::ToPyResult,
13    dictdatatype::{self, DictSize},
14    function::{ArgIterable, OptionalArg, PosArgs, PyArithmeticValue, PyComparisonValue},
15    protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
16    recursion::ReprGuard,
17    types::AsNumber,
18    types::{
19        AsSequence, Comparable, Constructor, DefaultConstructor, Hashable, Initializer, IterNext,
20        Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible,
21    },
22    utils::collection_repr,
23    vm::VirtualMachine,
24    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
25};
26use once_cell::sync::Lazy;
27use std::{fmt, ops::Deref};
28
29pub type SetContentType = dictdatatype::Dict<()>;
30
31#[pyclass(module = false, name = "set", unhashable = true, traverse)]
32#[derive(Default)]
33pub struct PySet {
34    pub(super) inner: PySetInner,
35}
36
37impl PySet {
38    pub fn new_ref(ctx: &Context) -> PyRef<Self> {
39        // Initialized empty, as calling __hash__ is required for adding each object to the set
40        // which requires a VM context - this is done in the set code itself.
41        PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None)
42    }
43
44    pub fn elements(&self) -> Vec<PyObjectRef> {
45        self.inner.elements()
46    }
47
48    fn fold_op(
49        &self,
50        others: impl std::iter::Iterator<Item = ArgIterable>,
51        op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
52        vm: &VirtualMachine,
53    ) -> PyResult<Self> {
54        Ok(Self {
55            inner: self.inner.fold_op(others, op, vm)?,
56        })
57    }
58
59    fn op(
60        &self,
61        other: AnySet,
62        op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
63        vm: &VirtualMachine,
64    ) -> PyResult<Self> {
65        Ok(Self {
66            inner: self
67                .inner
68                .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
69        })
70    }
71}
72
73#[pyclass(module = false, name = "frozenset", unhashable = true)]
74#[derive(Default)]
75pub struct PyFrozenSet {
76    inner: PySetInner,
77}
78
79impl PyFrozenSet {
80    // Also used by ssl.rs windows.
81    pub fn from_iter(
82        vm: &VirtualMachine,
83        it: impl IntoIterator<Item = PyObjectRef>,
84    ) -> PyResult<Self> {
85        let inner = PySetInner::default();
86        for elem in it {
87            inner.add(elem, vm)?;
88        }
89        // FIXME: empty set check
90        Ok(Self { inner })
91    }
92
93    pub fn elements(&self) -> Vec<PyObjectRef> {
94        self.inner.elements()
95    }
96
97    fn fold_op(
98        &self,
99        others: impl std::iter::Iterator<Item = ArgIterable>,
100        op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
101        vm: &VirtualMachine,
102    ) -> PyResult<Self> {
103        Ok(Self {
104            inner: self.inner.fold_op(others, op, vm)?,
105        })
106    }
107
108    fn op(
109        &self,
110        other: AnySet,
111        op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
112        vm: &VirtualMachine,
113    ) -> PyResult<Self> {
114        Ok(Self {
115            inner: self
116                .inner
117                .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
118        })
119    }
120}
121
122impl fmt::Debug for PySet {
123    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124        // TODO: implement more detailed, non-recursive Debug formatter
125        f.write_str("set")
126    }
127}
128
129impl fmt::Debug for PyFrozenSet {
130    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131        // TODO: implement more detailed, non-recursive Debug formatter
132        f.write_str("PyFrozenSet ")?;
133        f.debug_set().entries(self.elements().iter()).finish()
134    }
135}
136
137impl PyPayload for PySet {
138    fn class(ctx: &Context) -> &'static Py<PyType> {
139        ctx.types.set_type
140    }
141}
142
143impl PyPayload for PyFrozenSet {
144    fn class(ctx: &Context) -> &'static Py<PyType> {
145        ctx.types.frozenset_type
146    }
147}
148
149#[derive(Default, Clone)]
150pub(super) struct PySetInner {
151    content: PyRc<SetContentType>,
152}
153
154unsafe impl crate::object::Traverse for PySetInner {
155    fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn) {
156        // FIXME(discord9): Rc means shared ref, so should it be traced?
157        self.content.traverse(tracer_fn)
158    }
159}
160
161impl PySetInner {
162    pub(super) fn from_iter<T>(iter: T, vm: &VirtualMachine) -> PyResult<Self>
163    where
164        T: IntoIterator<Item = PyResult<PyObjectRef>>,
165    {
166        let set = PySetInner::default();
167        for item in iter {
168            set.add(item?, vm)?;
169        }
170        Ok(set)
171    }
172
173    fn fold_op<O>(
174        &self,
175        others: impl std::iter::Iterator<Item = O>,
176        op: fn(&Self, O, &VirtualMachine) -> PyResult<Self>,
177        vm: &VirtualMachine,
178    ) -> PyResult<Self> {
179        let mut res = self.copy();
180        for other in others {
181            res = op(&res, other, vm)?;
182        }
183        Ok(res)
184    }
185
186    fn len(&self) -> usize {
187        self.content.len()
188    }
189
190    fn sizeof(&self) -> usize {
191        self.content.sizeof()
192    }
193
194    fn copy(&self) -> PySetInner {
195        PySetInner {
196            content: PyRc::new((*self.content).clone()),
197        }
198    }
199
200    fn contains(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
201        self.retry_op_with_frozenset(needle, vm, |needle, vm| self.content.contains(vm, needle))
202    }
203
204    fn compare(
205        &self,
206        other: &PySetInner,
207        op: PyComparisonOp,
208        vm: &VirtualMachine,
209    ) -> PyResult<bool> {
210        if op == PyComparisonOp::Ne {
211            return self.compare(other, PyComparisonOp::Eq, vm).map(|eq| !eq);
212        }
213        if !op.eval_ord(self.len().cmp(&other.len())) {
214            return Ok(false);
215        }
216        let (superset, subset) = if matches!(op, PyComparisonOp::Lt | PyComparisonOp::Le) {
217            (other, self)
218        } else {
219            (self, other)
220        };
221        for key in subset.elements() {
222            if !superset.contains(&key, vm)? {
223                return Ok(false);
224            }
225        }
226        Ok(true)
227    }
228
229    pub(super) fn union(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySetInner> {
230        let set = self.clone();
231        for item in other.iter(vm)? {
232            set.add(item?, vm)?;
233        }
234
235        Ok(set)
236    }
237
238    pub(super) fn intersection(
239        &self,
240        other: ArgIterable,
241        vm: &VirtualMachine,
242    ) -> PyResult<PySetInner> {
243        let set = PySetInner::default();
244        for item in other.iter(vm)? {
245            let obj = item?;
246            if self.contains(&obj, vm)? {
247                set.add(obj, vm)?;
248            }
249        }
250        Ok(set)
251    }
252
253    pub(super) fn difference(
254        &self,
255        other: ArgIterable,
256        vm: &VirtualMachine,
257    ) -> PyResult<PySetInner> {
258        let set = self.copy();
259        for item in other.iter(vm)? {
260            set.content.delete_if_exists(vm, &*item?)?;
261        }
262        Ok(set)
263    }
264
265    pub(super) fn symmetric_difference(
266        &self,
267        other: ArgIterable,
268        vm: &VirtualMachine,
269    ) -> PyResult<PySetInner> {
270        let new_inner = self.clone();
271
272        // We want to remove duplicates in other
273        let other_set = Self::from_iter(other.iter(vm)?, vm)?;
274
275        for item in other_set.elements() {
276            new_inner.content.delete_or_insert(vm, &item, ())?
277        }
278
279        Ok(new_inner)
280    }
281
282    fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
283        for item in other.iter(vm)? {
284            if !self.contains(&*item?, vm)? {
285                return Ok(false);
286            }
287        }
288        Ok(true)
289    }
290
291    fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
292        let other_set = PySetInner::from_iter(other.iter(vm)?, vm)?;
293        self.compare(&other_set, PyComparisonOp::Le, vm)
294    }
295
296    pub(super) fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
297        for item in other.iter(vm)? {
298            if self.contains(&*item?, vm)? {
299                return Ok(false);
300            }
301        }
302        Ok(true)
303    }
304
305    fn iter(&self) -> PySetIterator {
306        PySetIterator {
307            size: self.content.size(),
308            internal: PyMutex::new(PositionIterInternal::new(self.content.clone(), 0)),
309        }
310    }
311
312    fn repr(&self, class_name: Option<&str>, vm: &VirtualMachine) -> PyResult<String> {
313        collection_repr(class_name, "{", "}", self.elements().iter(), vm)
314    }
315
316    fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
317        self.content.insert(vm, &*item, ())
318    }
319
320    fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
321        self.retry_op_with_frozenset(&item, vm, |item, vm| self.content.delete(vm, item))
322    }
323
324    fn discard(&self, item: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
325        self.retry_op_with_frozenset(item, vm, |item, vm| self.content.delete_if_exists(vm, item))
326    }
327
328    fn clear(&self) {
329        self.content.clear()
330    }
331
332    fn elements(&self) -> Vec<PyObjectRef> {
333        self.content.keys()
334    }
335
336    fn pop(&self, vm: &VirtualMachine) -> PyResult {
337        // TODO: should be pop_front, but that requires rearranging every index
338        if let Some((key, _)) = self.content.pop_back() {
339            Ok(key)
340        } else {
341            let err_msg = vm.ctx.new_str(ascii!("pop from an empty set")).into();
342            Err(vm.new_key_error(err_msg))
343        }
344    }
345
346    fn update(
347        &self,
348        others: impl std::iter::Iterator<Item = ArgIterable>,
349        vm: &VirtualMachine,
350    ) -> PyResult<()> {
351        for iterable in others {
352            for item in iterable.iter(vm)? {
353                self.add(item?, vm)?;
354            }
355        }
356        Ok(())
357    }
358
359    fn update_internal(&self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
360        // check AnySet
361        if let Ok(any_set) = AnySet::try_from_object(vm, iterable.to_owned()) {
362            self.merge_set(any_set, vm)
363        // check Dict
364        } else if let Ok(dict) = iterable.to_owned().downcast_exact::<PyDict>(vm) {
365            self.merge_dict(dict.into_pyref(), vm)
366        } else {
367            // add iterable that is not AnySet or Dict
368            for item in iterable.try_into_value::<ArgIterable>(vm)?.iter(vm)? {
369                self.add(item?, vm)?;
370            }
371            Ok(())
372        }
373    }
374
375    fn merge_set(&self, any_set: AnySet, vm: &VirtualMachine) -> PyResult<()> {
376        for item in any_set.as_inner().elements() {
377            self.add(item, vm)?;
378        }
379        Ok(())
380    }
381
382    fn merge_dict(&self, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
383        for (key, _value) in dict {
384            self.add(key, vm)?;
385        }
386        Ok(())
387    }
388
389    fn intersection_update(
390        &self,
391        others: impl std::iter::Iterator<Item = ArgIterable>,
392        vm: &VirtualMachine,
393    ) -> PyResult<()> {
394        let mut temp_inner = self.copy();
395        self.clear();
396        for iterable in others {
397            for item in iterable.iter(vm)? {
398                let obj = item?;
399                if temp_inner.contains(&obj, vm)? {
400                    self.add(obj, vm)?;
401                }
402            }
403            temp_inner = self.copy()
404        }
405        Ok(())
406    }
407
408    fn difference_update(
409        &self,
410        others: impl std::iter::Iterator<Item = ArgIterable>,
411        vm: &VirtualMachine,
412    ) -> PyResult<()> {
413        for iterable in others {
414            for item in iterable.iter(vm)? {
415                self.content.delete_if_exists(vm, &*item?)?;
416            }
417        }
418        Ok(())
419    }
420
421    fn symmetric_difference_update(
422        &self,
423        others: impl std::iter::Iterator<Item = ArgIterable>,
424        vm: &VirtualMachine,
425    ) -> PyResult<()> {
426        for iterable in others {
427            // We want to remove duplicates in iterable
428            let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?;
429            for item in iterable_set.elements() {
430                self.content.delete_or_insert(vm, &item, ())?;
431            }
432        }
433        Ok(())
434    }
435
436    fn hash(&self, vm: &VirtualMachine) -> PyResult<PyHash> {
437        // Work to increase the bit dispersion for closely spaced hash values.
438        // This is important because some use cases have many combinations of a
439        // small number of elements with nearby hashes so that many distinct
440        // combinations collapse to only a handful of distinct hash values.
441        fn _shuffle_bits(h: u64) -> u64 {
442            ((h ^ 89869747) ^ (h.wrapping_shl(16))).wrapping_mul(3644798167)
443        }
444        // Factor in the number of active entries
445        let mut hash: u64 = (self.elements().len() as u64 + 1).wrapping_mul(1927868237);
446        // Xor-in shuffled bits from every entry's hash field because xor is
447        // commutative and a frozenset hash should be independent of order.
448        for element in self.elements().iter() {
449            hash ^= _shuffle_bits(element.hash(vm)? as u64);
450        }
451        // Disperse patterns arising in nested frozensets
452        hash ^= (hash >> 11) ^ (hash >> 25);
453        hash = hash.wrapping_mul(69069).wrapping_add(907133923);
454        // -1 is reserved as an error code
455        if hash == u64::MAX {
456            hash = 590923713;
457        }
458        Ok(hash as PyHash)
459    }
460
461    // Run operation, on failure, if item is a set/set subclass, convert it
462    // into a frozenset and try the operation again. Propagates original error
463    // on failure to convert and restores item in KeyError on failure (remove).
464    fn retry_op_with_frozenset<T, F>(
465        &self,
466        item: &PyObject,
467        vm: &VirtualMachine,
468        op: F,
469    ) -> PyResult<T>
470    where
471        F: Fn(&PyObject, &VirtualMachine) -> PyResult<T>,
472    {
473        op(item, vm).or_else(|original_err| {
474            item.payload_if_subclass::<PySet>(vm)
475                // Keep original error around.
476                .ok_or(original_err)
477                .and_then(|set| {
478                    op(
479                        &PyFrozenSet {
480                            inner: set.inner.copy(),
481                        }
482                        .into_pyobject(vm),
483                        vm,
484                    )
485                    // If operation raised KeyError, report original set (set.remove)
486                    .map_err(|op_err| {
487                        if op_err.fast_isinstance(vm.ctx.exceptions.key_error) {
488                            vm.new_key_error(item.to_owned())
489                        } else {
490                            op_err
491                        }
492                    })
493                })
494        })
495    }
496}
497
498fn extract_set(obj: &PyObject) -> Option<&PySetInner> {
499    match_class!(match obj {
500        ref set @ PySet => Some(&set.inner),
501        ref frozen @ PyFrozenSet => Some(&frozen.inner),
502        _ => None,
503    })
504}
505
506fn reduce_set(
507    zelf: &PyObject,
508    vm: &VirtualMachine,
509) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
510    Ok((
511        zelf.class().to_owned(),
512        vm.new_tuple((extract_set(zelf)
513            .unwrap_or(&PySetInner::default())
514            .elements(),)),
515        zelf.dict(),
516    ))
517}
518
519#[pyclass(
520    with(
521        Constructor,
522        Initializer,
523        AsSequence,
524        Comparable,
525        Iterable,
526        AsNumber,
527        Representable
528    ),
529    flags(BASETYPE)
530)]
531impl PySet {
532    #[pymethod(magic)]
533    fn len(&self) -> usize {
534        self.inner.len()
535    }
536
537    #[pymethod(magic)]
538    fn sizeof(&self) -> usize {
539        std::mem::size_of::<Self>() + self.inner.sizeof()
540    }
541
542    #[pymethod]
543    fn copy(&self) -> Self {
544        Self {
545            inner: self.inner.copy(),
546        }
547    }
548
549    #[pymethod(magic)]
550    fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
551        self.inner.contains(&needle, vm)
552    }
553
554    #[pymethod]
555    fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
556        self.fold_op(others.into_iter(), PySetInner::union, vm)
557    }
558
559    #[pymethod]
560    fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
561        self.fold_op(others.into_iter(), PySetInner::intersection, vm)
562    }
563
564    #[pymethod]
565    fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
566        self.fold_op(others.into_iter(), PySetInner::difference, vm)
567    }
568
569    #[pymethod]
570    fn symmetric_difference(
571        &self,
572        others: PosArgs<ArgIterable>,
573        vm: &VirtualMachine,
574    ) -> PyResult<Self> {
575        self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
576    }
577
578    #[pymethod]
579    fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
580        self.inner.issubset(other, vm)
581    }
582
583    #[pymethod]
584    fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
585        self.inner.issuperset(other, vm)
586    }
587
588    #[pymethod]
589    fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
590        self.inner.isdisjoint(other, vm)
591    }
592
593    #[pymethod(name = "__ror__")]
594    #[pymethod(magic)]
595    fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
596        if let Ok(other) = AnySet::try_from_object(vm, other) {
597            Ok(PyArithmeticValue::Implemented(self.op(
598                other,
599                PySetInner::union,
600                vm,
601            )?))
602        } else {
603            Ok(PyArithmeticValue::NotImplemented)
604        }
605    }
606
607    #[pymethod(name = "__rand__")]
608    #[pymethod(magic)]
609    fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
610        if let Ok(other) = AnySet::try_from_object(vm, other) {
611            Ok(PyArithmeticValue::Implemented(self.op(
612                other,
613                PySetInner::intersection,
614                vm,
615            )?))
616        } else {
617            Ok(PyArithmeticValue::NotImplemented)
618        }
619    }
620
621    #[pymethod(magic)]
622    fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
623        if let Ok(other) = AnySet::try_from_object(vm, other) {
624            Ok(PyArithmeticValue::Implemented(self.op(
625                other,
626                PySetInner::difference,
627                vm,
628            )?))
629        } else {
630            Ok(PyArithmeticValue::NotImplemented)
631        }
632    }
633
634    #[pymethod(magic)]
635    fn rsub(
636        zelf: PyRef<Self>,
637        other: PyObjectRef,
638        vm: &VirtualMachine,
639    ) -> PyResult<PyArithmeticValue<Self>> {
640        if let Ok(other) = AnySet::try_from_object(vm, other) {
641            Ok(PyArithmeticValue::Implemented(Self {
642                inner: other
643                    .as_inner()
644                    .difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
645            }))
646        } else {
647            Ok(PyArithmeticValue::NotImplemented)
648        }
649    }
650
651    #[pymethod(name = "__rxor__")]
652    #[pymethod(magic)]
653    fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
654        if let Ok(other) = AnySet::try_from_object(vm, other) {
655            Ok(PyArithmeticValue::Implemented(self.op(
656                other,
657                PySetInner::symmetric_difference,
658                vm,
659            )?))
660        } else {
661            Ok(PyArithmeticValue::NotImplemented)
662        }
663    }
664
665    #[pymethod]
666    pub fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
667        self.inner.add(item, vm)?;
668        Ok(())
669    }
670
671    #[pymethod]
672    fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
673        self.inner.remove(item, vm)
674    }
675
676    #[pymethod]
677    fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
678        self.inner.discard(&item, vm)?;
679        Ok(())
680    }
681
682    #[pymethod]
683    fn clear(&self) {
684        self.inner.clear()
685    }
686
687    #[pymethod]
688    fn pop(&self, vm: &VirtualMachine) -> PyResult {
689        self.inner.pop(vm)
690    }
691
692    #[pymethod(magic)]
693    fn ior(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
694        zelf.inner.update(set.into_iterable_iter(vm)?, vm)?;
695        Ok(zelf)
696    }
697
698    #[pymethod]
699    fn update(&self, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
700        for iterable in others {
701            self.inner.update_internal(iterable, vm)?;
702        }
703        Ok(())
704    }
705
706    #[pymethod]
707    fn intersection_update(
708        &self,
709        others: PosArgs<ArgIterable>,
710        vm: &VirtualMachine,
711    ) -> PyResult<()> {
712        self.inner.intersection_update(others.into_iter(), vm)?;
713        Ok(())
714    }
715
716    #[pymethod(magic)]
717    fn iand(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
718        zelf.inner
719            .intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
720        Ok(zelf)
721    }
722
723    #[pymethod]
724    fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
725        self.inner.difference_update(others.into_iter(), vm)?;
726        Ok(())
727    }
728
729    #[pymethod(magic)]
730    fn isub(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
731        zelf.inner
732            .difference_update(set.into_iterable_iter(vm)?, vm)?;
733        Ok(zelf)
734    }
735
736    #[pymethod]
737    fn symmetric_difference_update(
738        &self,
739        others: PosArgs<ArgIterable>,
740        vm: &VirtualMachine,
741    ) -> PyResult<()> {
742        self.inner
743            .symmetric_difference_update(others.into_iter(), vm)?;
744        Ok(())
745    }
746
747    #[pymethod(magic)]
748    fn ixor(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
749        zelf.inner
750            .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
751        Ok(zelf)
752    }
753
754    #[pymethod(magic)]
755    fn reduce(
756        zelf: PyRef<Self>,
757        vm: &VirtualMachine,
758    ) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
759        reduce_set(zelf.as_ref(), vm)
760    }
761
762    #[pyclassmethod(magic)]
763    fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
764        PyGenericAlias::new(cls, args, vm)
765    }
766}
767
768impl DefaultConstructor for PySet {}
769
770impl Initializer for PySet {
771    type Args = OptionalArg<PyObjectRef>;
772
773    fn init(zelf: PyRef<Self>, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
774        if zelf.len() > 0 {
775            zelf.clear();
776        }
777        if let OptionalArg::Present(it) = iterable {
778            zelf.update(PosArgs::new(vec![it]), vm)?;
779        }
780        Ok(())
781    }
782}
783
784impl AsSequence for PySet {
785    fn as_sequence() -> &'static PySequenceMethods {
786        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
787            length: atomic_func!(|seq, _vm| Ok(PySet::sequence_downcast(seq).len())),
788            contains: atomic_func!(|seq, needle, vm| PySet::sequence_downcast(seq)
789                .inner
790                .contains(needle, vm)),
791            ..PySequenceMethods::NOT_IMPLEMENTED
792        });
793        &AS_SEQUENCE
794    }
795}
796
797impl Comparable for PySet {
798    fn cmp(
799        zelf: &crate::Py<Self>,
800        other: &PyObject,
801        op: PyComparisonOp,
802        vm: &VirtualMachine,
803    ) -> PyResult<PyComparisonValue> {
804        extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
805            Ok(zelf.inner.compare(other, op, vm)?.into())
806        })
807    }
808}
809
810impl Iterable for PySet {
811    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
812        Ok(zelf.inner.iter().into_pyobject(vm))
813    }
814}
815
816impl AsNumber for PySet {
817    fn as_number() -> &'static PyNumberMethods {
818        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
819            subtract: Some(|a, b, vm| {
820                if let Some(a) = a.downcast_ref::<PySet>() {
821                    a.sub(b.to_owned(), vm).to_pyresult(vm)
822                } else {
823                    Ok(vm.ctx.not_implemented())
824                }
825            }),
826            and: Some(|a, b, vm| {
827                if let Some(a) = a.downcast_ref::<PySet>() {
828                    a.and(b.to_owned(), vm).to_pyresult(vm)
829                } else {
830                    Ok(vm.ctx.not_implemented())
831                }
832            }),
833            xor: Some(|a, b, vm| {
834                if let Some(a) = a.downcast_ref::<PySet>() {
835                    a.xor(b.to_owned(), vm).to_pyresult(vm)
836                } else {
837                    Ok(vm.ctx.not_implemented())
838                }
839            }),
840            or: Some(|a, b, vm| {
841                if let Some(a) = a.downcast_ref::<PySet>() {
842                    a.or(b.to_owned(), vm).to_pyresult(vm)
843                } else {
844                    Ok(vm.ctx.not_implemented())
845                }
846            }),
847            inplace_subtract: Some(|a, b, vm| {
848                if let Some(a) = a.downcast_ref::<PySet>() {
849                    PySet::isub(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
850                        .to_pyresult(vm)
851                } else {
852                    Ok(vm.ctx.not_implemented())
853                }
854            }),
855            inplace_and: Some(|a, b, vm| {
856                if let Some(a) = a.downcast_ref::<PySet>() {
857                    PySet::iand(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
858                        .to_pyresult(vm)
859                } else {
860                    Ok(vm.ctx.not_implemented())
861                }
862            }),
863            inplace_xor: Some(|a, b, vm| {
864                if let Some(a) = a.downcast_ref::<PySet>() {
865                    PySet::ixor(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
866                        .to_pyresult(vm)
867                } else {
868                    Ok(vm.ctx.not_implemented())
869                }
870            }),
871            inplace_or: Some(|a, b, vm| {
872                if let Some(a) = a.downcast_ref::<PySet>() {
873                    PySet::ior(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
874                        .to_pyresult(vm)
875                } else {
876                    Ok(vm.ctx.not_implemented())
877                }
878            }),
879            ..PyNumberMethods::NOT_IMPLEMENTED
880        };
881        &AS_NUMBER
882    }
883}
884
885impl Representable for PySet {
886    #[inline]
887    fn repr_str(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
888        let class = zelf.class();
889        let borrowed_name = class.name();
890        let class_name = borrowed_name.deref();
891        let s = if zelf.inner.len() == 0 {
892            format!("{class_name}()")
893        } else if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
894            let name = if class_name != "set" {
895                Some(class_name)
896            } else {
897                None
898            };
899            zelf.inner.repr(name, vm)?
900        } else {
901            format!("{class_name}(...)")
902        };
903        Ok(s)
904    }
905}
906
907impl Constructor for PyFrozenSet {
908    type Args = OptionalArg<PyObjectRef>;
909
910    fn py_new(cls: PyTypeRef, iterable: Self::Args, vm: &VirtualMachine) -> PyResult {
911        let elements = if let OptionalArg::Present(iterable) = iterable {
912            let iterable = if cls.is(vm.ctx.types.frozenset_type) {
913                match iterable.downcast_exact::<Self>(vm) {
914                    Ok(fs) => return Ok(fs.into_pyref().into()),
915                    Err(iterable) => iterable,
916                }
917            } else {
918                iterable
919            };
920            iterable.try_to_value(vm)?
921        } else {
922            vec![]
923        };
924
925        // Return empty fs if iterable passed is empty and only for exact fs types.
926        if elements.is_empty() && cls.is(vm.ctx.types.frozenset_type) {
927            Ok(vm.ctx.empty_frozenset.clone().into())
928        } else {
929            Self::from_iter(vm, elements)
930                .and_then(|o| o.into_ref_with_type(vm, cls).map(Into::into))
931        }
932    }
933}
934
935#[pyclass(
936    flags(BASETYPE),
937    with(
938        Constructor,
939        AsSequence,
940        Hashable,
941        Comparable,
942        Iterable,
943        AsNumber,
944        Representable
945    )
946)]
947impl PyFrozenSet {
948    #[pymethod(magic)]
949    fn len(&self) -> usize {
950        self.inner.len()
951    }
952
953    #[pymethod(magic)]
954    fn sizeof(&self) -> usize {
955        std::mem::size_of::<Self>() + self.inner.sizeof()
956    }
957
958    #[pymethod]
959    fn copy(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyRef<Self> {
960        if zelf.class().is(vm.ctx.types.frozenset_type) {
961            zelf
962        } else {
963            Self {
964                inner: zelf.inner.copy(),
965            }
966            .into_ref(&vm.ctx)
967        }
968    }
969
970    #[pymethod(magic)]
971    fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
972        self.inner.contains(&needle, vm)
973    }
974
975    #[pymethod]
976    fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
977        self.fold_op(others.into_iter(), PySetInner::union, vm)
978    }
979
980    #[pymethod]
981    fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
982        self.fold_op(others.into_iter(), PySetInner::intersection, vm)
983    }
984
985    #[pymethod]
986    fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
987        self.fold_op(others.into_iter(), PySetInner::difference, vm)
988    }
989
990    #[pymethod]
991    fn symmetric_difference(
992        &self,
993        others: PosArgs<ArgIterable>,
994        vm: &VirtualMachine,
995    ) -> PyResult<Self> {
996        self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
997    }
998
999    #[pymethod]
1000    fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1001        self.inner.issubset(other, vm)
1002    }
1003
1004    #[pymethod]
1005    fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1006        self.inner.issuperset(other, vm)
1007    }
1008
1009    #[pymethod]
1010    fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1011        self.inner.isdisjoint(other, vm)
1012    }
1013
1014    #[pymethod(name = "__ror__")]
1015    #[pymethod(magic)]
1016    fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1017        if let Ok(set) = AnySet::try_from_object(vm, other) {
1018            Ok(PyArithmeticValue::Implemented(self.op(
1019                set,
1020                PySetInner::union,
1021                vm,
1022            )?))
1023        } else {
1024            Ok(PyArithmeticValue::NotImplemented)
1025        }
1026    }
1027
1028    #[pymethod(name = "__rand__")]
1029    #[pymethod(magic)]
1030    fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1031        if let Ok(other) = AnySet::try_from_object(vm, other) {
1032            Ok(PyArithmeticValue::Implemented(self.op(
1033                other,
1034                PySetInner::intersection,
1035                vm,
1036            )?))
1037        } else {
1038            Ok(PyArithmeticValue::NotImplemented)
1039        }
1040    }
1041
1042    #[pymethod(magic)]
1043    fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1044        if let Ok(other) = AnySet::try_from_object(vm, other) {
1045            Ok(PyArithmeticValue::Implemented(self.op(
1046                other,
1047                PySetInner::difference,
1048                vm,
1049            )?))
1050        } else {
1051            Ok(PyArithmeticValue::NotImplemented)
1052        }
1053    }
1054
1055    #[pymethod(magic)]
1056    fn rsub(
1057        zelf: PyRef<Self>,
1058        other: PyObjectRef,
1059        vm: &VirtualMachine,
1060    ) -> PyResult<PyArithmeticValue<Self>> {
1061        if let Ok(other) = AnySet::try_from_object(vm, other) {
1062            Ok(PyArithmeticValue::Implemented(Self {
1063                inner: other
1064                    .as_inner()
1065                    .difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
1066            }))
1067        } else {
1068            Ok(PyArithmeticValue::NotImplemented)
1069        }
1070    }
1071
1072    #[pymethod(name = "__rxor__")]
1073    #[pymethod(magic)]
1074    fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1075        if let Ok(other) = AnySet::try_from_object(vm, other) {
1076            Ok(PyArithmeticValue::Implemented(self.op(
1077                other,
1078                PySetInner::symmetric_difference,
1079                vm,
1080            )?))
1081        } else {
1082            Ok(PyArithmeticValue::NotImplemented)
1083        }
1084    }
1085
1086    #[pymethod(magic)]
1087    fn reduce(
1088        zelf: PyRef<Self>,
1089        vm: &VirtualMachine,
1090    ) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
1091        reduce_set(zelf.as_ref(), vm)
1092    }
1093
1094    #[pyclassmethod(magic)]
1095    fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
1096        PyGenericAlias::new(cls, args, vm)
1097    }
1098}
1099
1100impl AsSequence for PyFrozenSet {
1101    fn as_sequence() -> &'static PySequenceMethods {
1102        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
1103            length: atomic_func!(|seq, _vm| Ok(PyFrozenSet::sequence_downcast(seq).len())),
1104            contains: atomic_func!(|seq, needle, vm| PyFrozenSet::sequence_downcast(seq)
1105                .inner
1106                .contains(needle, vm)),
1107            ..PySequenceMethods::NOT_IMPLEMENTED
1108        });
1109        &AS_SEQUENCE
1110    }
1111}
1112
1113impl Hashable for PyFrozenSet {
1114    #[inline]
1115    fn hash(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
1116        zelf.inner.hash(vm)
1117    }
1118}
1119
1120impl Comparable for PyFrozenSet {
1121    fn cmp(
1122        zelf: &crate::Py<Self>,
1123        other: &PyObject,
1124        op: PyComparisonOp,
1125        vm: &VirtualMachine,
1126    ) -> PyResult<PyComparisonValue> {
1127        extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
1128            Ok(zelf.inner.compare(other, op, vm)?.into())
1129        })
1130    }
1131}
1132
1133impl Iterable for PyFrozenSet {
1134    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
1135        Ok(zelf.inner.iter().into_pyobject(vm))
1136    }
1137}
1138
1139impl AsNumber for PyFrozenSet {
1140    fn as_number() -> &'static PyNumberMethods {
1141        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1142            subtract: Some(|a, b, vm| {
1143                if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1144                    a.sub(b.to_owned(), vm).to_pyresult(vm)
1145                } else {
1146                    Ok(vm.ctx.not_implemented())
1147                }
1148            }),
1149            and: Some(|a, b, vm| {
1150                if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1151                    a.and(b.to_owned(), vm).to_pyresult(vm)
1152                } else {
1153                    Ok(vm.ctx.not_implemented())
1154                }
1155            }),
1156            xor: Some(|a, b, vm| {
1157                if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1158                    a.xor(b.to_owned(), vm).to_pyresult(vm)
1159                } else {
1160                    Ok(vm.ctx.not_implemented())
1161                }
1162            }),
1163            or: Some(|a, b, vm| {
1164                if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1165                    a.or(b.to_owned(), vm).to_pyresult(vm)
1166                } else {
1167                    Ok(vm.ctx.not_implemented())
1168                }
1169            }),
1170            ..PyNumberMethods::NOT_IMPLEMENTED
1171        };
1172        &AS_NUMBER
1173    }
1174}
1175
1176impl Representable for PyFrozenSet {
1177    #[inline]
1178    fn repr_str(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
1179        let inner = &zelf.inner;
1180        let class = zelf.class();
1181        let class_name = class.name();
1182        let s = if inner.len() == 0 {
1183            format!("{class_name}()")
1184        } else if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
1185            inner.repr(Some(&class_name), vm)?
1186        } else {
1187            format!("{class_name}(...)")
1188        };
1189        Ok(s)
1190    }
1191}
1192
1193struct AnySet {
1194    object: PyObjectRef,
1195}
1196
1197impl AnySet {
1198    fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
1199        self.object.try_into_value(vm)
1200    }
1201
1202    fn into_iterable_iter(
1203        self,
1204        vm: &VirtualMachine,
1205    ) -> PyResult<impl std::iter::Iterator<Item = ArgIterable>> {
1206        Ok(std::iter::once(self.into_iterable(vm)?))
1207    }
1208
1209    fn as_inner(&self) -> &PySetInner {
1210        match_class!(match self.object.as_object() {
1211            ref set @ PySet => &set.inner,
1212            ref frozen @ PyFrozenSet => &frozen.inner,
1213            _ => unreachable!("AnySet is always PySet or PyFrozenSet"), // should not be called.
1214        })
1215    }
1216}
1217
1218impl TryFromObject for AnySet {
1219    fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
1220        let class = obj.class();
1221        if class.fast_issubclass(vm.ctx.types.set_type)
1222            || class.fast_issubclass(vm.ctx.types.frozenset_type)
1223        {
1224            Ok(AnySet { object: obj })
1225        } else {
1226            Err(vm.new_type_error(format!("{class} is not a subtype of set or frozenset")))
1227        }
1228    }
1229}
1230
1231#[pyclass(module = false, name = "set_iterator")]
1232pub(crate) struct PySetIterator {
1233    size: DictSize,
1234    internal: PyMutex<PositionIterInternal<PyRc<SetContentType>>>,
1235}
1236
1237impl fmt::Debug for PySetIterator {
1238    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1239        // TODO: implement more detailed, non-recursive Debug formatter
1240        f.write_str("set_iterator")
1241    }
1242}
1243
1244impl PyPayload for PySetIterator {
1245    fn class(ctx: &Context) -> &'static Py<PyType> {
1246        ctx.types.set_iterator_type
1247    }
1248}
1249
1250#[pyclass(with(Unconstructible, IterNext, Iterable))]
1251impl PySetIterator {
1252    #[pymethod(magic)]
1253    fn length_hint(&self) -> usize {
1254        self.internal.lock().length_hint(|_| self.size.entries_size)
1255    }
1256
1257    #[pymethod(magic)]
1258    fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> {
1259        let internal = zelf.internal.lock();
1260        Ok((
1261            builtins_iter(vm).to_owned(),
1262            (vm.ctx
1263                .new_list(match &internal.status {
1264                    IterStatus::Exhausted => vec![],
1265                    IterStatus::Active(dict) => {
1266                        dict.keys().into_iter().skip(internal.position).collect()
1267                    }
1268                })
1269                .into(),),
1270        ))
1271    }
1272}
1273impl Unconstructible for PySetIterator {}
1274
1275impl SelfIter for PySetIterator {}
1276impl IterNext for PySetIterator {
1277    fn next(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
1278        let mut internal = zelf.internal.lock();
1279        let next = if let IterStatus::Active(dict) = &internal.status {
1280            if dict.has_changed_size(&zelf.size) {
1281                internal.status = IterStatus::Exhausted;
1282                return Err(vm.new_runtime_error("set changed size during iteration".to_owned()));
1283            }
1284            match dict.next_entry(internal.position) {
1285                Some((position, key, _)) => {
1286                    internal.position = position;
1287                    PyIterReturn::Return(key)
1288                }
1289                None => {
1290                    internal.status = IterStatus::Exhausted;
1291                    PyIterReturn::StopIteration(None)
1292                }
1293            }
1294        } else {
1295            PyIterReturn::StopIteration(None)
1296        };
1297        Ok(next)
1298    }
1299}
1300
1301pub fn init(context: &Context) {
1302    PySet::extend_class(context, context.types.set_type);
1303    PyFrozenSet::extend_class(context, context.types.frozenset_type);
1304    PySetIterator::extend_class(context, context.types.set_iterator_type);
1305}