1use alloc::collections::btree_map::Entry;
4use alloc::collections::BTreeMap;
5#[cfg(feature = "serialize-borsh")]
6use alloc::{format, string::ToString};
7#[cfg(feature = "serialize-borsh")]
8use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
9#[cfg(feature = "serialize-serde")]
10use serde::{Deserialize, Serialize};
11
12use super::calculate_map_and_set_indices;
13use super::macros::*;
14use super::storage;
15use super::IndexSet;
16
17#[cfg(feature = "serialize-borsh")]
18mod borsh_deserialize {
19 use alloc::vec::Vec;
20
21 use super::*;
22
23 pub fn from<R, S>(reader: &mut R) -> Result<BTreeMap<usize, S>, borsh::io::Error>
25 where
26 R: borsh::io::Read,
27 S: borsh::de::BorshDeserialize,
28 {
29 let bit_sets: Vec<(usize, S)> = borsh::BorshDeserialize::deserialize_reader(reader)?;
30 for window in bit_sets.windows(2) {
31 let &[(a, _), (b, _)] = window else {
32 unreachable!()
33 };
34 if a > b {
35 return Err(borsh::io::Error::new(
36 borsh::io::ErrorKind::Other,
37 "BTreeIndexSet should have been sorted",
38 ));
39 }
40 }
41 Ok(bit_sets.into_iter().collect())
42 }
43}
44
45#[derive(Default, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
47#[cfg_attr(
48 feature = "serialize-borsh",
49 derive(BorshSerialize, BorshDeserialize, BorshSchema)
50)]
51#[cfg_attr(feature = "serialize-serde", derive(Serialize, Deserialize))]
52#[repr(transparent)]
53pub struct BTreeIndexSet<S = u64> {
54 #[cfg_attr(
60 feature = "serialize-borsh",
61 borsh(deserialize_with = "borsh_deserialize::from")
62 )]
63 bit_sets: BTreeMap<usize, S>,
64}
65
66impl<S> BTreeIndexSet<S> {
67 pub const fn new() -> Self {
69 Self {
70 bit_sets: BTreeMap::new(),
71 }
72 }
73
74 #[inline]
82 pub fn with_capacity(_capacity: usize) -> Self {
83 Self::new()
84 }
85}
86
87impl<S: storage::Storage> IndexSet for BTreeIndexSet<S> {
88 #[inline]
89 fn len(&self) -> usize {
90 self.bit_sets
91 .values()
92 .map(|set| set.num_of_high_bits())
93 .sum::<usize>()
94 }
95
96 #[inline]
97 fn is_empty(&self) -> bool {
98 self.bit_sets.is_empty()
99 }
100
101 fn insert(&mut self, index: usize) {
102 let (map_index, bit_set_index) = calculate_map_and_set_indices::<S>(index);
103 let set = self.bit_sets.entry(map_index).or_insert(S::ZERO);
104 *set |= S::from_usize(1 << bit_set_index);
105 }
106
107 fn remove(&mut self, index: usize) {
108 let (map_index, bit_set_index) = calculate_map_and_set_indices::<S>(index);
109 let entry = self.bit_sets.entry(map_index).and_modify(|set| {
110 *set &= !S::from_usize(1 << bit_set_index);
111 });
112 match entry {
113 Entry::Occupied(e) if *e.get() == S::ZERO => {
114 e.remove();
115 }
116 _ => {}
117 }
118 }
119
120 fn contains(&self, index: usize) -> bool {
121 let (map_index, bit_set_index) = calculate_map_and_set_indices::<S>(index);
122 self.bit_sets
123 .get(&map_index)
124 .map(|&set| set & S::from_usize(1 << bit_set_index) != S::ZERO)
125 .unwrap_or(false)
126 }
127
128 #[inline]
129 fn iter(&self) -> impl Iterator<Item = usize> + '_ {
130 self.bit_sets.iter().flat_map(|(&map_index, &set)| {
131 (0..S::WIDTH).filter_map(move |bit_set_index| {
132 let is_bit_set = (set & S::from_usize(1 << bit_set_index)) != S::ZERO;
133 if is_bit_set {
134 Some(map_index * S::WIDTH + bit_set_index)
135 } else {
136 None
137 }
138 })
139 })
140 }
141
142 #[inline]
143 fn union(&mut self, other: &BTreeIndexSet<S>) {
144 for (&map_index, &other_set) in other.bit_sets.iter() {
145 let set = self.bit_sets.entry(map_index).or_insert(S::ZERO);
146 *set |= other_set;
147 }
148 }
149}
150
151index_set_impl_from!(crate::btree::BTreeIndexSet);
152index_set_impl_from_iterator!(crate::btree::BTreeIndexSet);
153index_set_impl_extend!(crate::btree::BTreeIndexSet);
154index_set_tests!(crate::btree::BTreeIndexSet);