Skip to main content

scry_index/
set.rs

1//! A sorted set backed by a learned index.
2
3use std::ops::RangeBounds;
4
5use crate::config::Config;
6use crate::error::Result;
7use crate::key::Key;
8use crate::map::{Guard, LearnedMap, MapRef};
9
10/// A sorted set backed by a learned index.
11///
12/// This is a thin wrapper around [`LearnedMap<K, ()>`].
13///
14/// All operations take `&self` and are safe to call from multiple threads.
15#[derive(Debug)]
16pub struct LearnedSet<K: Key> {
17    inner: LearnedMap<K, ()>,
18}
19
20/// A convenience handle that bundles a set reference with an epoch guard.
21pub struct SetRef<'a, K: Key> {
22    inner: MapRef<'a, K, ()>,
23}
24
25impl<K: Key> SetRef<'_, K> {
26    /// Insert a key. Returns `true` if the key was newly inserted.
27    pub fn insert(&self, key: K) -> bool {
28        self.inner.insert(key, ())
29    }
30
31    /// Remove a key. Returns `true` if the key was present.
32    pub fn remove(&self, key: &K) -> bool {
33        self.inner.remove(key)
34    }
35
36    /// Check whether the set contains a key.
37    pub fn contains(&self, key: &K) -> bool {
38        self.inner.contains_key(key)
39    }
40
41    /// Return the approximate number of elements in the set.
42    ///
43    /// See [`LearnedMap::len`](crate::LearnedMap::len) for details on
44    /// relaxed-atomic staleness under concurrency.
45    pub fn len(&self) -> usize {
46        self.inner.len()
47    }
48
49    /// Return `true` if the set is empty.
50    ///
51    /// Subject to the same relaxed-atomic staleness as [`len`](Self::len).
52    pub fn is_empty(&self) -> bool {
53        self.inner.is_empty()
54    }
55
56    /// Return an iterator over keys within the given range, in ascending order.
57    pub fn range<R: RangeBounds<K>>(&self, range: R) -> impl Iterator<Item = &K> {
58        self.inner.range(range).map(|(k, ())| k)
59    }
60
61    /// Return the first (minimum) key.
62    pub fn first(&self) -> Option<&K> {
63        self.inner.first_key_value().map(|(k, ())| k)
64    }
65
66    /// Return the last (maximum) key.
67    pub fn last(&self) -> Option<&K> {
68        self.inner.last_key_value().map(|(k, ())| k)
69    }
70}
71
72impl<K: Key> LearnedSet<K> {
73    /// Create a new empty set.
74    pub fn new() -> Self {
75        Self {
76            inner: LearnedMap::new(),
77        }
78    }
79
80    /// Create a new set with the given configuration.
81    pub fn with_config(config: Config) -> Self {
82        Self {
83            inner: LearnedMap::with_config(config),
84        }
85    }
86
87    /// Create a set from sorted keys, deduplicating any repeated keys.
88    ///
89    /// Keys must be in ascending order but duplicates are allowed and will
90    /// be silently removed (sets are idempotent by definition).
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if `keys` is empty (after dedup) or not sorted.
95    pub fn bulk_load(keys: &[K]) -> Result<Self> {
96        let pairs: Vec<(K, ())> = keys.iter().map(|k| (k.clone(), ())).collect();
97        Ok(Self {
98            inner: LearnedMap::bulk_load_dedup(&pairs)?,
99        })
100    }
101
102    /// Acquire an epoch guard.
103    pub fn guard(&self) -> Guard {
104        self.inner.guard()
105    }
106
107    /// Pin the current epoch and return a [`SetRef`] convenience handle.
108    pub fn pin(&self) -> SetRef<'_, K> {
109        SetRef {
110            inner: self.inner.pin(),
111        }
112    }
113
114    /// Insert a key. Returns `true` if the key was newly inserted.
115    pub fn insert(&self, key: K, guard: &Guard) -> bool {
116        self.inner.insert(key, (), guard)
117    }
118
119    /// Remove a key. Returns `true` if the key was present.
120    pub fn remove(&self, key: &K, guard: &Guard) -> bool {
121        self.inner.remove(key, guard)
122    }
123
124    /// Check whether the set contains a key.
125    pub fn contains(&self, key: &K, guard: &Guard) -> bool {
126        self.inner.contains_key(key, guard)
127    }
128
129    /// Return the approximate number of elements in the set.
130    ///
131    /// See [`LearnedMap::len`](crate::LearnedMap::len) for details on
132    /// relaxed-atomic staleness under concurrency.
133    pub fn len(&self) -> usize {
134        self.inner.len()
135    }
136
137    /// Return `true` if the set is empty.
138    ///
139    /// Subject to the same relaxed-atomic staleness as [`len`](Self::len).
140    pub fn is_empty(&self) -> bool {
141        self.inner.is_empty()
142    }
143
144    /// Return an iterator over keys within the given range, in ascending order.
145    pub fn range<'g, R: RangeBounds<K>>(
146        &self,
147        range: R,
148        guard: &'g Guard,
149    ) -> impl Iterator<Item = &'g K> {
150        self.inner.range(range, guard).map(|(k, ())| k)
151    }
152
153    /// Return the first (minimum) key.
154    pub fn first<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
155        self.inner.first_key_value(guard).map(|(k, ())| k)
156    }
157
158    /// Return the last (maximum) key.
159    pub fn last<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
160        self.inner.last_key_value(guard).map(|(k, ())| k)
161    }
162}
163
164#[cfg(feature = "serde")]
165impl<K> serde::Serialize for LearnedSet<K>
166where
167    K: Key + serde::Serialize,
168{
169    fn serialize<S: serde::Serializer>(
170        &self,
171        serializer: S,
172    ) -> std::result::Result<S::Ok, S::Error> {
173        use serde::ser::SerializeSeq;
174
175        let guard = self.guard();
176        let len = self.len();
177        let mut seq = serializer.serialize_seq(Some(len))?;
178        for (k, ()) in self.inner.iter(&guard) {
179            seq.serialize_element(k)?;
180        }
181        seq.end()
182    }
183}
184
185#[cfg(feature = "serde")]
186impl<'de, K> serde::Deserialize<'de> for LearnedSet<K>
187where
188    K: Key + serde::Deserialize<'de>,
189{
190    fn deserialize<D: serde::Deserializer<'de>>(
191        deserializer: D,
192    ) -> std::result::Result<Self, D::Error> {
193        let keys: Vec<K> = Vec::deserialize(deserializer)?;
194        if keys.is_empty() {
195            return Ok(Self::new());
196        }
197        Self::bulk_load(&keys).map_err(serde::de::Error::custom)
198    }
199}
200
201impl<K: Key> Default for LearnedSet<K> {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl<K: Key> FromIterator<K> for LearnedSet<K> {
208    fn from_iter<I: IntoIterator<Item = K>>(iter: I) -> Self {
209        let set = Self::new();
210        let guard = set.guard();
211        for k in iter {
212            set.insert(k, &guard);
213        }
214        set
215    }
216}
217
218impl<K: Key> Extend<K> for LearnedSet<K> {
219    fn extend<I: IntoIterator<Item = K>>(&mut self, iter: I) {
220        let guard = self.guard();
221        for k in iter {
222            self.insert(k, &guard);
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn basic_set_ops() {
233        let set = LearnedSet::new();
234        let g = set.guard();
235        assert!(set.insert(1u64, &g));
236        assert!(set.insert(2, &g));
237        assert!(!set.insert(1, &g)); // duplicate
238        assert_eq!(set.len(), 2);
239        assert!(set.contains(&1, &g));
240        assert!(set.remove(&1, &g));
241        assert!(!set.contains(&1, &g));
242        assert_eq!(set.len(), 1);
243    }
244
245    #[test]
246    fn from_iterator() {
247        let set: LearnedSet<u64> = vec![3, 1, 2].into_iter().collect();
248        let g = set.guard();
249        assert_eq!(set.len(), 3);
250        assert!(set.contains(&1, &g));
251        assert!(set.contains(&2, &g));
252        assert!(set.contains(&3, &g));
253    }
254
255    #[test]
256    fn bulk_load_set() {
257        let keys: Vec<u64> = (0..100).collect();
258        let set = LearnedSet::bulk_load(&keys).unwrap();
259        let g = set.guard();
260        assert_eq!(set.len(), 100);
261        for k in &keys {
262            assert!(set.contains(k, &g));
263        }
264    }
265
266    #[test]
267    fn bulk_load_deduplicates() {
268        let keys: Vec<u64> = vec![1, 1, 2, 3, 3, 3, 4, 5];
269        let set = LearnedSet::bulk_load(&keys).unwrap();
270        let g = set.guard();
271        assert_eq!(set.len(), 5);
272        for k in 1..=5u64 {
273            assert!(set.contains(&k, &g), "key {k} missing after dedup");
274        }
275    }
276
277    #[test]
278    fn set_ref_convenience() {
279        let set = LearnedSet::new();
280        let s = set.pin();
281        assert!(s.insert(10u64));
282        assert!(s.insert(20));
283        assert!(!s.insert(10));
284        assert_eq!(s.len(), 2);
285        assert!(s.contains(&10));
286        assert!(s.remove(&10));
287        assert!(!s.contains(&10));
288        assert_eq!(s.len(), 1);
289    }
290}