1use super::{
5 builtins_iter, IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef,
6 PyType, PyTypeRef,
7};
8use crate::{
9 atomic_func,
10 class::PyClassImpl,
11 common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc},
12 convert::ToPyResult,
13 dictdatatype::{self, DictSize},
14 function::{ArgIterable, OptionalArg, PosArgs, PyArithmeticValue, PyComparisonValue},
15 protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
16 recursion::ReprGuard,
17 types::AsNumber,
18 types::{
19 AsSequence, Comparable, Constructor, DefaultConstructor, Hashable, Initializer, IterNext,
20 Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible,
21 },
22 utils::collection_repr,
23 vm::VirtualMachine,
24 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
25};
26use once_cell::sync::Lazy;
27use std::{fmt, ops::Deref};
28
29pub type SetContentType = dictdatatype::Dict<()>;
30
31#[pyclass(module = false, name = "set", unhashable = true, traverse)]
32#[derive(Default)]
33pub struct PySet {
34 pub(super) inner: PySetInner,
35}
36
37impl PySet {
38 pub fn new_ref(ctx: &Context) -> PyRef<Self> {
39 PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None)
42 }
43
44 pub fn elements(&self) -> Vec<PyObjectRef> {
45 self.inner.elements()
46 }
47
48 fn fold_op(
49 &self,
50 others: impl std::iter::Iterator<Item = ArgIterable>,
51 op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
52 vm: &VirtualMachine,
53 ) -> PyResult<Self> {
54 Ok(Self {
55 inner: self.inner.fold_op(others, op, vm)?,
56 })
57 }
58
59 fn op(
60 &self,
61 other: AnySet,
62 op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
63 vm: &VirtualMachine,
64 ) -> PyResult<Self> {
65 Ok(Self {
66 inner: self
67 .inner
68 .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
69 })
70 }
71}
72
73#[pyclass(module = false, name = "frozenset", unhashable = true)]
74#[derive(Default)]
75pub struct PyFrozenSet {
76 inner: PySetInner,
77}
78
79impl PyFrozenSet {
80 pub fn from_iter(
82 vm: &VirtualMachine,
83 it: impl IntoIterator<Item = PyObjectRef>,
84 ) -> PyResult<Self> {
85 let inner = PySetInner::default();
86 for elem in it {
87 inner.add(elem, vm)?;
88 }
89 Ok(Self { inner })
91 }
92
93 pub fn elements(&self) -> Vec<PyObjectRef> {
94 self.inner.elements()
95 }
96
97 fn fold_op(
98 &self,
99 others: impl std::iter::Iterator<Item = ArgIterable>,
100 op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
101 vm: &VirtualMachine,
102 ) -> PyResult<Self> {
103 Ok(Self {
104 inner: self.inner.fold_op(others, op, vm)?,
105 })
106 }
107
108 fn op(
109 &self,
110 other: AnySet,
111 op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
112 vm: &VirtualMachine,
113 ) -> PyResult<Self> {
114 Ok(Self {
115 inner: self
116 .inner
117 .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
118 })
119 }
120}
121
122impl fmt::Debug for PySet {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 f.write_str("set")
126 }
127}
128
129impl fmt::Debug for PyFrozenSet {
130 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131 f.write_str("PyFrozenSet ")?;
133 f.debug_set().entries(self.elements().iter()).finish()
134 }
135}
136
137impl PyPayload for PySet {
138 fn class(ctx: &Context) -> &'static Py<PyType> {
139 ctx.types.set_type
140 }
141}
142
143impl PyPayload for PyFrozenSet {
144 fn class(ctx: &Context) -> &'static Py<PyType> {
145 ctx.types.frozenset_type
146 }
147}
148
149#[derive(Default, Clone)]
150pub(super) struct PySetInner {
151 content: PyRc<SetContentType>,
152}
153
154unsafe impl crate::object::Traverse for PySetInner {
155 fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn) {
156 self.content.traverse(tracer_fn)
158 }
159}
160
161impl PySetInner {
162 pub(super) fn from_iter<T>(iter: T, vm: &VirtualMachine) -> PyResult<Self>
163 where
164 T: IntoIterator<Item = PyResult<PyObjectRef>>,
165 {
166 let set = PySetInner::default();
167 for item in iter {
168 set.add(item?, vm)?;
169 }
170 Ok(set)
171 }
172
173 fn fold_op<O>(
174 &self,
175 others: impl std::iter::Iterator<Item = O>,
176 op: fn(&Self, O, &VirtualMachine) -> PyResult<Self>,
177 vm: &VirtualMachine,
178 ) -> PyResult<Self> {
179 let mut res = self.copy();
180 for other in others {
181 res = op(&res, other, vm)?;
182 }
183 Ok(res)
184 }
185
186 fn len(&self) -> usize {
187 self.content.len()
188 }
189
190 fn sizeof(&self) -> usize {
191 self.content.sizeof()
192 }
193
194 fn copy(&self) -> PySetInner {
195 PySetInner {
196 content: PyRc::new((*self.content).clone()),
197 }
198 }
199
200 fn contains(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
201 self.retry_op_with_frozenset(needle, vm, |needle, vm| self.content.contains(vm, needle))
202 }
203
204 fn compare(
205 &self,
206 other: &PySetInner,
207 op: PyComparisonOp,
208 vm: &VirtualMachine,
209 ) -> PyResult<bool> {
210 if op == PyComparisonOp::Ne {
211 return self.compare(other, PyComparisonOp::Eq, vm).map(|eq| !eq);
212 }
213 if !op.eval_ord(self.len().cmp(&other.len())) {
214 return Ok(false);
215 }
216 let (superset, subset) = if matches!(op, PyComparisonOp::Lt | PyComparisonOp::Le) {
217 (other, self)
218 } else {
219 (self, other)
220 };
221 for key in subset.elements() {
222 if !superset.contains(&key, vm)? {
223 return Ok(false);
224 }
225 }
226 Ok(true)
227 }
228
229 pub(super) fn union(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<PySetInner> {
230 let set = self.clone();
231 for item in other.iter(vm)? {
232 set.add(item?, vm)?;
233 }
234
235 Ok(set)
236 }
237
238 pub(super) fn intersection(
239 &self,
240 other: ArgIterable,
241 vm: &VirtualMachine,
242 ) -> PyResult<PySetInner> {
243 let set = PySetInner::default();
244 for item in other.iter(vm)? {
245 let obj = item?;
246 if self.contains(&obj, vm)? {
247 set.add(obj, vm)?;
248 }
249 }
250 Ok(set)
251 }
252
253 pub(super) fn difference(
254 &self,
255 other: ArgIterable,
256 vm: &VirtualMachine,
257 ) -> PyResult<PySetInner> {
258 let set = self.copy();
259 for item in other.iter(vm)? {
260 set.content.delete_if_exists(vm, &*item?)?;
261 }
262 Ok(set)
263 }
264
265 pub(super) fn symmetric_difference(
266 &self,
267 other: ArgIterable,
268 vm: &VirtualMachine,
269 ) -> PyResult<PySetInner> {
270 let new_inner = self.clone();
271
272 let other_set = Self::from_iter(other.iter(vm)?, vm)?;
274
275 for item in other_set.elements() {
276 new_inner.content.delete_or_insert(vm, &item, ())?
277 }
278
279 Ok(new_inner)
280 }
281
282 fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
283 for item in other.iter(vm)? {
284 if !self.contains(&*item?, vm)? {
285 return Ok(false);
286 }
287 }
288 Ok(true)
289 }
290
291 fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
292 let other_set = PySetInner::from_iter(other.iter(vm)?, vm)?;
293 self.compare(&other_set, PyComparisonOp::Le, vm)
294 }
295
296 pub(super) fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
297 for item in other.iter(vm)? {
298 if self.contains(&*item?, vm)? {
299 return Ok(false);
300 }
301 }
302 Ok(true)
303 }
304
305 fn iter(&self) -> PySetIterator {
306 PySetIterator {
307 size: self.content.size(),
308 internal: PyMutex::new(PositionIterInternal::new(self.content.clone(), 0)),
309 }
310 }
311
312 fn repr(&self, class_name: Option<&str>, vm: &VirtualMachine) -> PyResult<String> {
313 collection_repr(class_name, "{", "}", self.elements().iter(), vm)
314 }
315
316 fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
317 self.content.insert(vm, &*item, ())
318 }
319
320 fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
321 self.retry_op_with_frozenset(&item, vm, |item, vm| self.content.delete(vm, item))
322 }
323
324 fn discard(&self, item: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
325 self.retry_op_with_frozenset(item, vm, |item, vm| self.content.delete_if_exists(vm, item))
326 }
327
328 fn clear(&self) {
329 self.content.clear()
330 }
331
332 fn elements(&self) -> Vec<PyObjectRef> {
333 self.content.keys()
334 }
335
336 fn pop(&self, vm: &VirtualMachine) -> PyResult {
337 if let Some((key, _)) = self.content.pop_back() {
339 Ok(key)
340 } else {
341 let err_msg = vm.ctx.new_str(ascii!("pop from an empty set")).into();
342 Err(vm.new_key_error(err_msg))
343 }
344 }
345
346 fn update(
347 &self,
348 others: impl std::iter::Iterator<Item = ArgIterable>,
349 vm: &VirtualMachine,
350 ) -> PyResult<()> {
351 for iterable in others {
352 for item in iterable.iter(vm)? {
353 self.add(item?, vm)?;
354 }
355 }
356 Ok(())
357 }
358
359 fn update_internal(&self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
360 if let Ok(any_set) = AnySet::try_from_object(vm, iterable.to_owned()) {
362 self.merge_set(any_set, vm)
363 } else if let Ok(dict) = iterable.to_owned().downcast_exact::<PyDict>(vm) {
365 self.merge_dict(dict.into_pyref(), vm)
366 } else {
367 for item in iterable.try_into_value::<ArgIterable>(vm)?.iter(vm)? {
369 self.add(item?, vm)?;
370 }
371 Ok(())
372 }
373 }
374
375 fn merge_set(&self, any_set: AnySet, vm: &VirtualMachine) -> PyResult<()> {
376 for item in any_set.as_inner().elements() {
377 self.add(item, vm)?;
378 }
379 Ok(())
380 }
381
382 fn merge_dict(&self, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
383 for (key, _value) in dict {
384 self.add(key, vm)?;
385 }
386 Ok(())
387 }
388
389 fn intersection_update(
390 &self,
391 others: impl std::iter::Iterator<Item = ArgIterable>,
392 vm: &VirtualMachine,
393 ) -> PyResult<()> {
394 let mut temp_inner = self.copy();
395 self.clear();
396 for iterable in others {
397 for item in iterable.iter(vm)? {
398 let obj = item?;
399 if temp_inner.contains(&obj, vm)? {
400 self.add(obj, vm)?;
401 }
402 }
403 temp_inner = self.copy()
404 }
405 Ok(())
406 }
407
408 fn difference_update(
409 &self,
410 others: impl std::iter::Iterator<Item = ArgIterable>,
411 vm: &VirtualMachine,
412 ) -> PyResult<()> {
413 for iterable in others {
414 for item in iterable.iter(vm)? {
415 self.content.delete_if_exists(vm, &*item?)?;
416 }
417 }
418 Ok(())
419 }
420
421 fn symmetric_difference_update(
422 &self,
423 others: impl std::iter::Iterator<Item = ArgIterable>,
424 vm: &VirtualMachine,
425 ) -> PyResult<()> {
426 for iterable in others {
427 let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?;
429 for item in iterable_set.elements() {
430 self.content.delete_or_insert(vm, &item, ())?;
431 }
432 }
433 Ok(())
434 }
435
436 fn hash(&self, vm: &VirtualMachine) -> PyResult<PyHash> {
437 fn _shuffle_bits(h: u64) -> u64 {
442 ((h ^ 89869747) ^ (h.wrapping_shl(16))).wrapping_mul(3644798167)
443 }
444 let mut hash: u64 = (self.elements().len() as u64 + 1).wrapping_mul(1927868237);
446 for element in self.elements().iter() {
449 hash ^= _shuffle_bits(element.hash(vm)? as u64);
450 }
451 hash ^= (hash >> 11) ^ (hash >> 25);
453 hash = hash.wrapping_mul(69069).wrapping_add(907133923);
454 if hash == u64::MAX {
456 hash = 590923713;
457 }
458 Ok(hash as PyHash)
459 }
460
461 fn retry_op_with_frozenset<T, F>(
465 &self,
466 item: &PyObject,
467 vm: &VirtualMachine,
468 op: F,
469 ) -> PyResult<T>
470 where
471 F: Fn(&PyObject, &VirtualMachine) -> PyResult<T>,
472 {
473 op(item, vm).or_else(|original_err| {
474 item.payload_if_subclass::<PySet>(vm)
475 .ok_or(original_err)
477 .and_then(|set| {
478 op(
479 &PyFrozenSet {
480 inner: set.inner.copy(),
481 }
482 .into_pyobject(vm),
483 vm,
484 )
485 .map_err(|op_err| {
487 if op_err.fast_isinstance(vm.ctx.exceptions.key_error) {
488 vm.new_key_error(item.to_owned())
489 } else {
490 op_err
491 }
492 })
493 })
494 })
495 }
496}
497
498fn extract_set(obj: &PyObject) -> Option<&PySetInner> {
499 match_class!(match obj {
500 ref set @ PySet => Some(&set.inner),
501 ref frozen @ PyFrozenSet => Some(&frozen.inner),
502 _ => None,
503 })
504}
505
506fn reduce_set(
507 zelf: &PyObject,
508 vm: &VirtualMachine,
509) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
510 Ok((
511 zelf.class().to_owned(),
512 vm.new_tuple((extract_set(zelf)
513 .unwrap_or(&PySetInner::default())
514 .elements(),)),
515 zelf.dict(),
516 ))
517}
518
519#[pyclass(
520 with(
521 Constructor,
522 Initializer,
523 AsSequence,
524 Comparable,
525 Iterable,
526 AsNumber,
527 Representable
528 ),
529 flags(BASETYPE)
530)]
531impl PySet {
532 #[pymethod(magic)]
533 fn len(&self) -> usize {
534 self.inner.len()
535 }
536
537 #[pymethod(magic)]
538 fn sizeof(&self) -> usize {
539 std::mem::size_of::<Self>() + self.inner.sizeof()
540 }
541
542 #[pymethod]
543 fn copy(&self) -> Self {
544 Self {
545 inner: self.inner.copy(),
546 }
547 }
548
549 #[pymethod(magic)]
550 fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
551 self.inner.contains(&needle, vm)
552 }
553
554 #[pymethod]
555 fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
556 self.fold_op(others.into_iter(), PySetInner::union, vm)
557 }
558
559 #[pymethod]
560 fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
561 self.fold_op(others.into_iter(), PySetInner::intersection, vm)
562 }
563
564 #[pymethod]
565 fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
566 self.fold_op(others.into_iter(), PySetInner::difference, vm)
567 }
568
569 #[pymethod]
570 fn symmetric_difference(
571 &self,
572 others: PosArgs<ArgIterable>,
573 vm: &VirtualMachine,
574 ) -> PyResult<Self> {
575 self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
576 }
577
578 #[pymethod]
579 fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
580 self.inner.issubset(other, vm)
581 }
582
583 #[pymethod]
584 fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
585 self.inner.issuperset(other, vm)
586 }
587
588 #[pymethod]
589 fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
590 self.inner.isdisjoint(other, vm)
591 }
592
593 #[pymethod(name = "__ror__")]
594 #[pymethod(magic)]
595 fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
596 if let Ok(other) = AnySet::try_from_object(vm, other) {
597 Ok(PyArithmeticValue::Implemented(self.op(
598 other,
599 PySetInner::union,
600 vm,
601 )?))
602 } else {
603 Ok(PyArithmeticValue::NotImplemented)
604 }
605 }
606
607 #[pymethod(name = "__rand__")]
608 #[pymethod(magic)]
609 fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
610 if let Ok(other) = AnySet::try_from_object(vm, other) {
611 Ok(PyArithmeticValue::Implemented(self.op(
612 other,
613 PySetInner::intersection,
614 vm,
615 )?))
616 } else {
617 Ok(PyArithmeticValue::NotImplemented)
618 }
619 }
620
621 #[pymethod(magic)]
622 fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
623 if let Ok(other) = AnySet::try_from_object(vm, other) {
624 Ok(PyArithmeticValue::Implemented(self.op(
625 other,
626 PySetInner::difference,
627 vm,
628 )?))
629 } else {
630 Ok(PyArithmeticValue::NotImplemented)
631 }
632 }
633
634 #[pymethod(magic)]
635 fn rsub(
636 zelf: PyRef<Self>,
637 other: PyObjectRef,
638 vm: &VirtualMachine,
639 ) -> PyResult<PyArithmeticValue<Self>> {
640 if let Ok(other) = AnySet::try_from_object(vm, other) {
641 Ok(PyArithmeticValue::Implemented(Self {
642 inner: other
643 .as_inner()
644 .difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
645 }))
646 } else {
647 Ok(PyArithmeticValue::NotImplemented)
648 }
649 }
650
651 #[pymethod(name = "__rxor__")]
652 #[pymethod(magic)]
653 fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
654 if let Ok(other) = AnySet::try_from_object(vm, other) {
655 Ok(PyArithmeticValue::Implemented(self.op(
656 other,
657 PySetInner::symmetric_difference,
658 vm,
659 )?))
660 } else {
661 Ok(PyArithmeticValue::NotImplemented)
662 }
663 }
664
665 #[pymethod]
666 pub fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
667 self.inner.add(item, vm)?;
668 Ok(())
669 }
670
671 #[pymethod]
672 fn remove(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
673 self.inner.remove(item, vm)
674 }
675
676 #[pymethod]
677 fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
678 self.inner.discard(&item, vm)?;
679 Ok(())
680 }
681
682 #[pymethod]
683 fn clear(&self) {
684 self.inner.clear()
685 }
686
687 #[pymethod]
688 fn pop(&self, vm: &VirtualMachine) -> PyResult {
689 self.inner.pop(vm)
690 }
691
692 #[pymethod(magic)]
693 fn ior(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
694 zelf.inner.update(set.into_iterable_iter(vm)?, vm)?;
695 Ok(zelf)
696 }
697
698 #[pymethod]
699 fn update(&self, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
700 for iterable in others {
701 self.inner.update_internal(iterable, vm)?;
702 }
703 Ok(())
704 }
705
706 #[pymethod]
707 fn intersection_update(
708 &self,
709 others: PosArgs<ArgIterable>,
710 vm: &VirtualMachine,
711 ) -> PyResult<()> {
712 self.inner.intersection_update(others.into_iter(), vm)?;
713 Ok(())
714 }
715
716 #[pymethod(magic)]
717 fn iand(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
718 zelf.inner
719 .intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
720 Ok(zelf)
721 }
722
723 #[pymethod]
724 fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
725 self.inner.difference_update(others.into_iter(), vm)?;
726 Ok(())
727 }
728
729 #[pymethod(magic)]
730 fn isub(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
731 zelf.inner
732 .difference_update(set.into_iterable_iter(vm)?, vm)?;
733 Ok(zelf)
734 }
735
736 #[pymethod]
737 fn symmetric_difference_update(
738 &self,
739 others: PosArgs<ArgIterable>,
740 vm: &VirtualMachine,
741 ) -> PyResult<()> {
742 self.inner
743 .symmetric_difference_update(others.into_iter(), vm)?;
744 Ok(())
745 }
746
747 #[pymethod(magic)]
748 fn ixor(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
749 zelf.inner
750 .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
751 Ok(zelf)
752 }
753
754 #[pymethod(magic)]
755 fn reduce(
756 zelf: PyRef<Self>,
757 vm: &VirtualMachine,
758 ) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
759 reduce_set(zelf.as_ref(), vm)
760 }
761
762 #[pyclassmethod(magic)]
763 fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
764 PyGenericAlias::new(cls, args, vm)
765 }
766}
767
768impl DefaultConstructor for PySet {}
769
770impl Initializer for PySet {
771 type Args = OptionalArg<PyObjectRef>;
772
773 fn init(zelf: PyRef<Self>, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
774 if zelf.len() > 0 {
775 zelf.clear();
776 }
777 if let OptionalArg::Present(it) = iterable {
778 zelf.update(PosArgs::new(vec![it]), vm)?;
779 }
780 Ok(())
781 }
782}
783
784impl AsSequence for PySet {
785 fn as_sequence() -> &'static PySequenceMethods {
786 static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
787 length: atomic_func!(|seq, _vm| Ok(PySet::sequence_downcast(seq).len())),
788 contains: atomic_func!(|seq, needle, vm| PySet::sequence_downcast(seq)
789 .inner
790 .contains(needle, vm)),
791 ..PySequenceMethods::NOT_IMPLEMENTED
792 });
793 &AS_SEQUENCE
794 }
795}
796
797impl Comparable for PySet {
798 fn cmp(
799 zelf: &crate::Py<Self>,
800 other: &PyObject,
801 op: PyComparisonOp,
802 vm: &VirtualMachine,
803 ) -> PyResult<PyComparisonValue> {
804 extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
805 Ok(zelf.inner.compare(other, op, vm)?.into())
806 })
807 }
808}
809
810impl Iterable for PySet {
811 fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
812 Ok(zelf.inner.iter().into_pyobject(vm))
813 }
814}
815
816impl AsNumber for PySet {
817 fn as_number() -> &'static PyNumberMethods {
818 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
819 subtract: Some(|a, b, vm| {
820 if let Some(a) = a.downcast_ref::<PySet>() {
821 a.sub(b.to_owned(), vm).to_pyresult(vm)
822 } else {
823 Ok(vm.ctx.not_implemented())
824 }
825 }),
826 and: Some(|a, b, vm| {
827 if let Some(a) = a.downcast_ref::<PySet>() {
828 a.and(b.to_owned(), vm).to_pyresult(vm)
829 } else {
830 Ok(vm.ctx.not_implemented())
831 }
832 }),
833 xor: Some(|a, b, vm| {
834 if let Some(a) = a.downcast_ref::<PySet>() {
835 a.xor(b.to_owned(), vm).to_pyresult(vm)
836 } else {
837 Ok(vm.ctx.not_implemented())
838 }
839 }),
840 or: Some(|a, b, vm| {
841 if let Some(a) = a.downcast_ref::<PySet>() {
842 a.or(b.to_owned(), vm).to_pyresult(vm)
843 } else {
844 Ok(vm.ctx.not_implemented())
845 }
846 }),
847 inplace_subtract: Some(|a, b, vm| {
848 if let Some(a) = a.downcast_ref::<PySet>() {
849 PySet::isub(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
850 .to_pyresult(vm)
851 } else {
852 Ok(vm.ctx.not_implemented())
853 }
854 }),
855 inplace_and: Some(|a, b, vm| {
856 if let Some(a) = a.downcast_ref::<PySet>() {
857 PySet::iand(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
858 .to_pyresult(vm)
859 } else {
860 Ok(vm.ctx.not_implemented())
861 }
862 }),
863 inplace_xor: Some(|a, b, vm| {
864 if let Some(a) = a.downcast_ref::<PySet>() {
865 PySet::ixor(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
866 .to_pyresult(vm)
867 } else {
868 Ok(vm.ctx.not_implemented())
869 }
870 }),
871 inplace_or: Some(|a, b, vm| {
872 if let Some(a) = a.downcast_ref::<PySet>() {
873 PySet::ior(a.to_owned(), AnySet::try_from_object(vm, b.to_owned())?, vm)
874 .to_pyresult(vm)
875 } else {
876 Ok(vm.ctx.not_implemented())
877 }
878 }),
879 ..PyNumberMethods::NOT_IMPLEMENTED
880 };
881 &AS_NUMBER
882 }
883}
884
885impl Representable for PySet {
886 #[inline]
887 fn repr_str(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
888 let class = zelf.class();
889 let borrowed_name = class.name();
890 let class_name = borrowed_name.deref();
891 let s = if zelf.inner.len() == 0 {
892 format!("{class_name}()")
893 } else if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
894 let name = if class_name != "set" {
895 Some(class_name)
896 } else {
897 None
898 };
899 zelf.inner.repr(name, vm)?
900 } else {
901 format!("{class_name}(...)")
902 };
903 Ok(s)
904 }
905}
906
907impl Constructor for PyFrozenSet {
908 type Args = OptionalArg<PyObjectRef>;
909
910 fn py_new(cls: PyTypeRef, iterable: Self::Args, vm: &VirtualMachine) -> PyResult {
911 let elements = if let OptionalArg::Present(iterable) = iterable {
912 let iterable = if cls.is(vm.ctx.types.frozenset_type) {
913 match iterable.downcast_exact::<Self>(vm) {
914 Ok(fs) => return Ok(fs.into_pyref().into()),
915 Err(iterable) => iterable,
916 }
917 } else {
918 iterable
919 };
920 iterable.try_to_value(vm)?
921 } else {
922 vec![]
923 };
924
925 if elements.is_empty() && cls.is(vm.ctx.types.frozenset_type) {
927 Ok(vm.ctx.empty_frozenset.clone().into())
928 } else {
929 Self::from_iter(vm, elements)
930 .and_then(|o| o.into_ref_with_type(vm, cls).map(Into::into))
931 }
932 }
933}
934
935#[pyclass(
936 flags(BASETYPE),
937 with(
938 Constructor,
939 AsSequence,
940 Hashable,
941 Comparable,
942 Iterable,
943 AsNumber,
944 Representable
945 )
946)]
947impl PyFrozenSet {
948 #[pymethod(magic)]
949 fn len(&self) -> usize {
950 self.inner.len()
951 }
952
953 #[pymethod(magic)]
954 fn sizeof(&self) -> usize {
955 std::mem::size_of::<Self>() + self.inner.sizeof()
956 }
957
958 #[pymethod]
959 fn copy(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyRef<Self> {
960 if zelf.class().is(vm.ctx.types.frozenset_type) {
961 zelf
962 } else {
963 Self {
964 inner: zelf.inner.copy(),
965 }
966 .into_ref(&vm.ctx)
967 }
968 }
969
970 #[pymethod(magic)]
971 fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
972 self.inner.contains(&needle, vm)
973 }
974
975 #[pymethod]
976 fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
977 self.fold_op(others.into_iter(), PySetInner::union, vm)
978 }
979
980 #[pymethod]
981 fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
982 self.fold_op(others.into_iter(), PySetInner::intersection, vm)
983 }
984
985 #[pymethod]
986 fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
987 self.fold_op(others.into_iter(), PySetInner::difference, vm)
988 }
989
990 #[pymethod]
991 fn symmetric_difference(
992 &self,
993 others: PosArgs<ArgIterable>,
994 vm: &VirtualMachine,
995 ) -> PyResult<Self> {
996 self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
997 }
998
999 #[pymethod]
1000 fn issubset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1001 self.inner.issubset(other, vm)
1002 }
1003
1004 #[pymethod]
1005 fn issuperset(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1006 self.inner.issuperset(other, vm)
1007 }
1008
1009 #[pymethod]
1010 fn isdisjoint(&self, other: ArgIterable, vm: &VirtualMachine) -> PyResult<bool> {
1011 self.inner.isdisjoint(other, vm)
1012 }
1013
1014 #[pymethod(name = "__ror__")]
1015 #[pymethod(magic)]
1016 fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1017 if let Ok(set) = AnySet::try_from_object(vm, other) {
1018 Ok(PyArithmeticValue::Implemented(self.op(
1019 set,
1020 PySetInner::union,
1021 vm,
1022 )?))
1023 } else {
1024 Ok(PyArithmeticValue::NotImplemented)
1025 }
1026 }
1027
1028 #[pymethod(name = "__rand__")]
1029 #[pymethod(magic)]
1030 fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1031 if let Ok(other) = AnySet::try_from_object(vm, other) {
1032 Ok(PyArithmeticValue::Implemented(self.op(
1033 other,
1034 PySetInner::intersection,
1035 vm,
1036 )?))
1037 } else {
1038 Ok(PyArithmeticValue::NotImplemented)
1039 }
1040 }
1041
1042 #[pymethod(magic)]
1043 fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1044 if let Ok(other) = AnySet::try_from_object(vm, other) {
1045 Ok(PyArithmeticValue::Implemented(self.op(
1046 other,
1047 PySetInner::difference,
1048 vm,
1049 )?))
1050 } else {
1051 Ok(PyArithmeticValue::NotImplemented)
1052 }
1053 }
1054
1055 #[pymethod(magic)]
1056 fn rsub(
1057 zelf: PyRef<Self>,
1058 other: PyObjectRef,
1059 vm: &VirtualMachine,
1060 ) -> PyResult<PyArithmeticValue<Self>> {
1061 if let Ok(other) = AnySet::try_from_object(vm, other) {
1062 Ok(PyArithmeticValue::Implemented(Self {
1063 inner: other
1064 .as_inner()
1065 .difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
1066 }))
1067 } else {
1068 Ok(PyArithmeticValue::NotImplemented)
1069 }
1070 }
1071
1072 #[pymethod(name = "__rxor__")]
1073 #[pymethod(magic)]
1074 fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1075 if let Ok(other) = AnySet::try_from_object(vm, other) {
1076 Ok(PyArithmeticValue::Implemented(self.op(
1077 other,
1078 PySetInner::symmetric_difference,
1079 vm,
1080 )?))
1081 } else {
1082 Ok(PyArithmeticValue::NotImplemented)
1083 }
1084 }
1085
1086 #[pymethod(magic)]
1087 fn reduce(
1088 zelf: PyRef<Self>,
1089 vm: &VirtualMachine,
1090 ) -> PyResult<(PyTypeRef, PyTupleRef, Option<PyDictRef>)> {
1091 reduce_set(zelf.as_ref(), vm)
1092 }
1093
1094 #[pyclassmethod(magic)]
1095 fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
1096 PyGenericAlias::new(cls, args, vm)
1097 }
1098}
1099
1100impl AsSequence for PyFrozenSet {
1101 fn as_sequence() -> &'static PySequenceMethods {
1102 static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
1103 length: atomic_func!(|seq, _vm| Ok(PyFrozenSet::sequence_downcast(seq).len())),
1104 contains: atomic_func!(|seq, needle, vm| PyFrozenSet::sequence_downcast(seq)
1105 .inner
1106 .contains(needle, vm)),
1107 ..PySequenceMethods::NOT_IMPLEMENTED
1108 });
1109 &AS_SEQUENCE
1110 }
1111}
1112
1113impl Hashable for PyFrozenSet {
1114 #[inline]
1115 fn hash(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
1116 zelf.inner.hash(vm)
1117 }
1118}
1119
1120impl Comparable for PyFrozenSet {
1121 fn cmp(
1122 zelf: &crate::Py<Self>,
1123 other: &PyObject,
1124 op: PyComparisonOp,
1125 vm: &VirtualMachine,
1126 ) -> PyResult<PyComparisonValue> {
1127 extract_set(other).map_or(Ok(PyComparisonValue::NotImplemented), |other| {
1128 Ok(zelf.inner.compare(other, op, vm)?.into())
1129 })
1130 }
1131}
1132
1133impl Iterable for PyFrozenSet {
1134 fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
1135 Ok(zelf.inner.iter().into_pyobject(vm))
1136 }
1137}
1138
1139impl AsNumber for PyFrozenSet {
1140 fn as_number() -> &'static PyNumberMethods {
1141 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1142 subtract: Some(|a, b, vm| {
1143 if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1144 a.sub(b.to_owned(), vm).to_pyresult(vm)
1145 } else {
1146 Ok(vm.ctx.not_implemented())
1147 }
1148 }),
1149 and: Some(|a, b, vm| {
1150 if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1151 a.and(b.to_owned(), vm).to_pyresult(vm)
1152 } else {
1153 Ok(vm.ctx.not_implemented())
1154 }
1155 }),
1156 xor: Some(|a, b, vm| {
1157 if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1158 a.xor(b.to_owned(), vm).to_pyresult(vm)
1159 } else {
1160 Ok(vm.ctx.not_implemented())
1161 }
1162 }),
1163 or: Some(|a, b, vm| {
1164 if let Some(a) = a.downcast_ref::<PyFrozenSet>() {
1165 a.or(b.to_owned(), vm).to_pyresult(vm)
1166 } else {
1167 Ok(vm.ctx.not_implemented())
1168 }
1169 }),
1170 ..PyNumberMethods::NOT_IMPLEMENTED
1171 };
1172 &AS_NUMBER
1173 }
1174}
1175
1176impl Representable for PyFrozenSet {
1177 #[inline]
1178 fn repr_str(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
1179 let inner = &zelf.inner;
1180 let class = zelf.class();
1181 let class_name = class.name();
1182 let s = if inner.len() == 0 {
1183 format!("{class_name}()")
1184 } else if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) {
1185 inner.repr(Some(&class_name), vm)?
1186 } else {
1187 format!("{class_name}(...)")
1188 };
1189 Ok(s)
1190 }
1191}
1192
1193struct AnySet {
1194 object: PyObjectRef,
1195}
1196
1197impl AnySet {
1198 fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
1199 self.object.try_into_value(vm)
1200 }
1201
1202 fn into_iterable_iter(
1203 self,
1204 vm: &VirtualMachine,
1205 ) -> PyResult<impl std::iter::Iterator<Item = ArgIterable>> {
1206 Ok(std::iter::once(self.into_iterable(vm)?))
1207 }
1208
1209 fn as_inner(&self) -> &PySetInner {
1210 match_class!(match self.object.as_object() {
1211 ref set @ PySet => &set.inner,
1212 ref frozen @ PyFrozenSet => &frozen.inner,
1213 _ => unreachable!("AnySet is always PySet or PyFrozenSet"), })
1215 }
1216}
1217
1218impl TryFromObject for AnySet {
1219 fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
1220 let class = obj.class();
1221 if class.fast_issubclass(vm.ctx.types.set_type)
1222 || class.fast_issubclass(vm.ctx.types.frozenset_type)
1223 {
1224 Ok(AnySet { object: obj })
1225 } else {
1226 Err(vm.new_type_error(format!("{class} is not a subtype of set or frozenset")))
1227 }
1228 }
1229}
1230
1231#[pyclass(module = false, name = "set_iterator")]
1232pub(crate) struct PySetIterator {
1233 size: DictSize,
1234 internal: PyMutex<PositionIterInternal<PyRc<SetContentType>>>,
1235}
1236
1237impl fmt::Debug for PySetIterator {
1238 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1239 f.write_str("set_iterator")
1241 }
1242}
1243
1244impl PyPayload for PySetIterator {
1245 fn class(ctx: &Context) -> &'static Py<PyType> {
1246 ctx.types.set_iterator_type
1247 }
1248}
1249
1250#[pyclass(with(Unconstructible, IterNext, Iterable))]
1251impl PySetIterator {
1252 #[pymethod(magic)]
1253 fn length_hint(&self) -> usize {
1254 self.internal.lock().length_hint(|_| self.size.entries_size)
1255 }
1256
1257 #[pymethod(magic)]
1258 fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> {
1259 let internal = zelf.internal.lock();
1260 Ok((
1261 builtins_iter(vm).to_owned(),
1262 (vm.ctx
1263 .new_list(match &internal.status {
1264 IterStatus::Exhausted => vec![],
1265 IterStatus::Active(dict) => {
1266 dict.keys().into_iter().skip(internal.position).collect()
1267 }
1268 })
1269 .into(),),
1270 ))
1271 }
1272}
1273impl Unconstructible for PySetIterator {}
1274
1275impl SelfIter for PySetIterator {}
1276impl IterNext for PySetIterator {
1277 fn next(zelf: &crate::Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
1278 let mut internal = zelf.internal.lock();
1279 let next = if let IterStatus::Active(dict) = &internal.status {
1280 if dict.has_changed_size(&zelf.size) {
1281 internal.status = IterStatus::Exhausted;
1282 return Err(vm.new_runtime_error("set changed size during iteration".to_owned()));
1283 }
1284 match dict.next_entry(internal.position) {
1285 Some((position, key, _)) => {
1286 internal.position = position;
1287 PyIterReturn::Return(key)
1288 }
1289 None => {
1290 internal.status = IterStatus::Exhausted;
1291 PyIterReturn::StopIteration(None)
1292 }
1293 }
1294 } else {
1295 PyIterReturn::StopIteration(None)
1296 };
1297 Ok(next)
1298 }
1299}
1300
1301pub fn init(context: &Context) {
1302 PySet::extend_class(context, context.types.set_type);
1303 PyFrozenSet::extend_class(context, context.types.frozenset_type);
1304 PySetIterator::extend_class(context, context.types.set_iterator_type);
1305}