1use std::ops::RangeBounds;
4
5use crate::config::Config;
6use crate::error::Result;
7use crate::key::Key;
8use crate::map::{Guard, LearnedMap, MapRef};
9
10#[derive(Debug)]
16pub struct LearnedSet<K: Key> {
17 inner: LearnedMap<K, ()>,
18}
19
20pub struct SetRef<'a, K: Key> {
22 inner: MapRef<'a, K, ()>,
23}
24
25impl<K: Key> SetRef<'_, K> {
26 pub fn insert(&self, key: K) -> bool {
28 self.inner.insert(key, ())
29 }
30
31 pub fn remove(&self, key: &K) -> bool {
33 self.inner.remove(key)
34 }
35
36 pub fn contains(&self, key: &K) -> bool {
38 self.inner.contains_key(key)
39 }
40
41 pub fn len(&self) -> usize {
46 self.inner.len()
47 }
48
49 pub fn is_empty(&self) -> bool {
53 self.inner.is_empty()
54 }
55
56 pub fn range<R: RangeBounds<K>>(&self, range: R) -> impl Iterator<Item = &K> {
58 self.inner.range(range).map(|(k, ())| k)
59 }
60
61 pub fn first(&self) -> Option<&K> {
63 self.inner.first_key_value().map(|(k, ())| k)
64 }
65
66 pub fn last(&self) -> Option<&K> {
68 self.inner.last_key_value().map(|(k, ())| k)
69 }
70}
71
72impl<K: Key> LearnedSet<K> {
73 pub fn new() -> Self {
75 Self {
76 inner: LearnedMap::new(),
77 }
78 }
79
80 pub fn with_config(config: Config) -> Self {
82 Self {
83 inner: LearnedMap::with_config(config),
84 }
85 }
86
87 pub fn bulk_load(keys: &[K]) -> Result<Self> {
96 let pairs: Vec<(K, ())> = keys.iter().map(|k| (k.clone(), ())).collect();
97 Ok(Self {
98 inner: LearnedMap::bulk_load_dedup(&pairs)?,
99 })
100 }
101
102 pub fn guard(&self) -> Guard {
104 self.inner.guard()
105 }
106
107 pub fn pin(&self) -> SetRef<'_, K> {
109 SetRef {
110 inner: self.inner.pin(),
111 }
112 }
113
114 pub fn insert(&self, key: K, guard: &Guard) -> bool {
116 self.inner.insert(key, (), guard)
117 }
118
119 pub fn remove(&self, key: &K, guard: &Guard) -> bool {
121 self.inner.remove(key, guard)
122 }
123
124 pub fn contains(&self, key: &K, guard: &Guard) -> bool {
126 self.inner.contains_key(key, guard)
127 }
128
129 pub fn len(&self) -> usize {
134 self.inner.len()
135 }
136
137 pub fn is_empty(&self) -> bool {
141 self.inner.is_empty()
142 }
143
144 pub fn range<'g, R: RangeBounds<K>>(
146 &self,
147 range: R,
148 guard: &'g Guard,
149 ) -> impl Iterator<Item = &'g K> {
150 self.inner.range(range, guard).map(|(k, ())| k)
151 }
152
153 pub fn first<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
155 self.inner.first_key_value(guard).map(|(k, ())| k)
156 }
157
158 pub fn last<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
160 self.inner.last_key_value(guard).map(|(k, ())| k)
161 }
162}
163
164#[cfg(feature = "serde")]
165impl<K> serde::Serialize for LearnedSet<K>
166where
167 K: Key + serde::Serialize,
168{
169 fn serialize<S: serde::Serializer>(
170 &self,
171 serializer: S,
172 ) -> std::result::Result<S::Ok, S::Error> {
173 use serde::ser::SerializeSeq;
174
175 let guard = self.guard();
176 let len = self.len();
177 let mut seq = serializer.serialize_seq(Some(len))?;
178 for (k, ()) in self.inner.iter(&guard) {
179 seq.serialize_element(k)?;
180 }
181 seq.end()
182 }
183}
184
185#[cfg(feature = "serde")]
186impl<'de, K> serde::Deserialize<'de> for LearnedSet<K>
187where
188 K: Key + serde::Deserialize<'de>,
189{
190 fn deserialize<D: serde::Deserializer<'de>>(
191 deserializer: D,
192 ) -> std::result::Result<Self, D::Error> {
193 let keys: Vec<K> = Vec::deserialize(deserializer)?;
194 if keys.is_empty() {
195 return Ok(Self::new());
196 }
197 Self::bulk_load(&keys).map_err(serde::de::Error::custom)
198 }
199}
200
201impl<K: Key> Default for LearnedSet<K> {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl<K: Key> FromIterator<K> for LearnedSet<K> {
208 fn from_iter<I: IntoIterator<Item = K>>(iter: I) -> Self {
209 let set = Self::new();
210 let guard = set.guard();
211 for k in iter {
212 set.insert(k, &guard);
213 }
214 set
215 }
216}
217
218impl<K: Key> Extend<K> for LearnedSet<K> {
219 fn extend<I: IntoIterator<Item = K>>(&mut self, iter: I) {
220 let guard = self.guard();
221 for k in iter {
222 self.insert(k, &guard);
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn basic_set_ops() {
233 let set = LearnedSet::new();
234 let g = set.guard();
235 assert!(set.insert(1u64, &g));
236 assert!(set.insert(2, &g));
237 assert!(!set.insert(1, &g)); assert_eq!(set.len(), 2);
239 assert!(set.contains(&1, &g));
240 assert!(set.remove(&1, &g));
241 assert!(!set.contains(&1, &g));
242 assert_eq!(set.len(), 1);
243 }
244
245 #[test]
246 fn from_iterator() {
247 let set: LearnedSet<u64> = vec![3, 1, 2].into_iter().collect();
248 let g = set.guard();
249 assert_eq!(set.len(), 3);
250 assert!(set.contains(&1, &g));
251 assert!(set.contains(&2, &g));
252 assert!(set.contains(&3, &g));
253 }
254
255 #[test]
256 fn bulk_load_set() {
257 let keys: Vec<u64> = (0..100).collect();
258 let set = LearnedSet::bulk_load(&keys).unwrap();
259 let g = set.guard();
260 assert_eq!(set.len(), 100);
261 for k in &keys {
262 assert!(set.contains(k, &g));
263 }
264 }
265
266 #[test]
267 fn bulk_load_deduplicates() {
268 let keys: Vec<u64> = vec![1, 1, 2, 3, 3, 3, 4, 5];
269 let set = LearnedSet::bulk_load(&keys).unwrap();
270 let g = set.guard();
271 assert_eq!(set.len(), 5);
272 for k in 1..=5u64 {
273 assert!(set.contains(&k, &g), "key {k} missing after dedup");
274 }
275 }
276
277 #[test]
278 fn set_ref_convenience() {
279 let set = LearnedSet::new();
280 let s = set.pin();
281 assert!(s.insert(10u64));
282 assert!(s.insert(20));
283 assert!(!s.insert(10));
284 assert_eq!(s.len(), 2);
285 assert!(s.contains(&10));
286 assert!(s.remove(&10));
287 assert!(!s.contains(&10));
288 assert_eq!(s.len(), 1);
289 }
290}