rustydht_lib/storage/
buckets.rs

1use crate::common::Id;
2use std::time::Instant;
3
4/// Anything that implements this trait can be stored in Buckets
5pub trait Bucketable {
6    fn get_id(&self) -> Id;
7
8    fn get_first_seen(&self) -> Instant;
9}
10
11/// This data structure implements the bucket system described in [BEP0005](http://bittorrent.org/beps/bep_0005.html) (more or less).
12#[derive(Clone)]
13pub struct Buckets<T: Bucketable> {
14    our_id: Id,
15    buckets: Vec<Vec<T>>,
16
17    k: usize,
18}
19
20impl<T: Bucketable> Buckets<T> {
21    pub fn new(our_id: Id, k: usize) -> Buckets<T> {
22        let mut to_ret = Buckets {
23            our_id,
24            buckets: Vec::with_capacity(32),
25            k,
26        };
27
28        to_ret.buckets.push(Vec::new());
29
30        to_ret
31    }
32
33    pub fn add(&mut self, item: T, chump_list: Option<&mut Vec<T>>) {
34        // Never add our own node!
35        if item.get_id() == self.our_id {
36            return;
37        }
38
39        let dest_bucket_idx = self.get_dest_bucket_idx(&item);
40        self.buckets[dest_bucket_idx].push(item);
41        self.handle_bucket_overflow(dest_bucket_idx, chump_list);
42    }
43
44    pub fn clear(&mut self) {
45        self.buckets.clear();
46        self.buckets.push(Vec::with_capacity(2 * self.k));
47    }
48
49    pub fn contains(&self, id: &Id) -> bool {
50        let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
51        if let Some(bucket) = self.buckets.get(dest_bucket_idx) {
52            for item in bucket.iter() {
53                if item.get_id() == *id {
54                    return true;
55                }
56            }
57        }
58        false
59    }
60
61    pub fn count(&self) -> usize {
62        let mut count = 0;
63        for bucket in &self.buckets {
64            count += bucket.len();
65        }
66
67        count
68    }
69
70    pub fn count_buckets(&self) -> usize {
71        self.buckets.len()
72    }
73
74    pub fn get_mut(&mut self, id: &Id) -> Option<&mut T> {
75        let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
76        if let Some(bucket) = self.buckets.get_mut(dest_bucket_idx) {
77            for item in bucket.iter_mut() {
78                if item.get_id() == *id {
79                    return Some(item);
80                }
81            }
82        }
83        None
84    }
85
86    /// Get the `k` nearest nodes/items stored in the buckets
87    ///
88    /// The returned vector is sorted by distance, from nearest to farthest.
89    pub fn get_nearest_nodes(&self, id: &Id, exclude: Option<&Id>) -> Vec<&T> {
90        let mut all: Vec<&T> = self
91            .values()
92            .iter()
93            .filter(|item| exclude.is_none() || *exclude.unwrap() != item.get_id())
94            .copied()
95            .collect();
96
97        all.sort_unstable_by(|a, b| {
98            let a_dist = a.get_id().xor(id);
99            let b_dist = b.get_id().xor(id);
100            a_dist.partial_cmp(&b_dist).unwrap()
101        });
102
103        all.truncate(self.k);
104
105        all
106    }
107
108    pub fn retain<F>(&mut self, mut f: F)
109    where
110        F: FnMut(&T) -> bool,
111    {
112        for bucket in &mut self.buckets {
113            bucket.retain(|item| f(item));
114        }
115    }
116
117    pub fn remove(&mut self, id: &Id) -> Option<T> {
118        let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
119        if let Some(bucket) = self.buckets.get_mut(dest_bucket_idx) {
120            for i in 0..bucket.len() {
121                if bucket[i].get_id() == *id {
122                    return Some(bucket.swap_remove(i));
123                }
124            }
125        }
126        None
127    }
128
129    pub fn set_id(&mut self, new_id: Id) {
130        self.clear();
131        self.our_id = new_id;
132    }
133
134    pub fn values(&self) -> Vec<&T> {
135        let mut to_ret = Vec::new();
136        for bucket in &self.buckets {
137            for item in bucket {
138                to_ret.push(item);
139            }
140        }
141        to_ret
142    }
143
144    fn get_dest_bucket_idx(&self, item: &T) -> usize {
145        self.get_dest_bucket_idx_for_id(&item.get_id())
146    }
147
148    fn get_dest_bucket_idx_for_id(&self, id: &Id) -> usize {
149        std::cmp::min(self.buckets.len() - 1, self.our_id.matching_prefix_bits(id))
150    }
151
152    fn handle_bucket_overflow(
153        &mut self,
154        mut bucket_index: usize,
155        mut chump_list: Option<&mut Vec<T>>,
156    ) {
157        while bucket_index < self.buckets.len() {
158            // Is the bucket over capacity?
159            if self.buckets[bucket_index].len() > self.k {
160                // Is this the "deepest" bucket?
161                // If so, add a new one since we're over capacity
162                if bucket_index == self.buckets.len() - 1 {
163                    self.buckets.push(Vec::with_capacity(2 * self.k));
164                }
165
166                // (Hopefully) move some nodes out of this bucket into the next one
167                for i in (0..self.buckets[bucket_index].len()).rev() {
168                    let ideal_bucket_idx = self.get_dest_bucket_idx(&self.buckets[bucket_index][i]);
169
170                    // This Node belongs in another bucket. Move it.
171                    if ideal_bucket_idx != bucket_index {
172                        let node = self.buckets[bucket_index].swap_remove(i);
173                        self.buckets[ideal_bucket_idx].push(node);
174                    }
175                }
176
177                // Sort by oldest. Move the newest excess to the chump list
178                if self.buckets[bucket_index].len() > self.k {
179                    self.buckets[bucket_index].sort_unstable_by_key(|a| a.get_first_seen());
180                    let mut remainder = self.buckets[bucket_index].split_off(self.k);
181
182                    if let Some(chump_list) = &mut chump_list {
183                        chump_list.append(&mut remainder);
184                    }
185                }
186            }
187            bucket_index += 1
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use rand::prelude::*;
196    extern crate rand_chacha;
197
198    struct TestWrapper {
199        id: Id,
200        first_seen: Instant,
201    }
202
203    impl TestWrapper {
204        pub fn new(id: Id, first_seen: Option<Instant>) -> TestWrapper {
205            let fs = if let Some(first_seen) = first_seen {
206                first_seen
207            } else {
208                Instant::now()
209            };
210
211            TestWrapper { id, first_seen: fs }
212        }
213    }
214
215    impl Bucketable for TestWrapper {
216        fn get_id(&self) -> Id {
217            self.id
218        }
219
220        fn get_first_seen(&self) -> Instant {
221            self.first_seen
222        }
223    }
224
225    impl std::fmt::Debug for TestWrapper {
226        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227            self.get_id().fmt(f)
228        }
229    }
230
231    /// Tests that items stay in the correct buckets as the number of buckets grows and that each bucket only contains the correct number of items
232    #[test]
233    fn test_correct_bucket() {
234        let id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
235        let mut storage = Buckets::new(id, 8);
236
237        // Create RNG with static seed to ensure tests are reproducible
238        let mut rng = Box::new(rand_chacha::ChaCha8Rng::seed_from_u64(50));
239
240        for _ in 0..2000 {
241            let node_id = Id::from_random(&mut rng);
242            storage.add(TestWrapper::new(node_id, None), None);
243        }
244
245        for i in 0..storage.buckets.len() {
246            assert!(storage.buckets[i].len() <= 8);
247            for wrapper in &storage.buckets[i] {
248                assert_eq!(
249                    i,
250                    std::cmp::min(
251                        storage.our_id.matching_prefix_bits(&wrapper.get_id()),
252                        storage.buckets.len() - 1
253                    )
254                );
255            }
256        }
257    }
258
259    /// Tests that we can add and remove an item from the buckets.
260    #[test]
261    fn test_add_remove() {
262        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
263        let mut storage = Buckets::new(our_id, 8);
264
265        let their_id = Id::from_hex("0000000000000000000000000000000000000001").unwrap();
266        storage.add(TestWrapper::new(their_id, None), None);
267
268        assert_eq!(storage.count(), 1);
269        assert!(storage.get_mut(&their_id).is_some());
270
271        assert!(storage.remove(&their_id).is_some());
272        assert!(storage.remove(&their_id).is_none());
273        assert_eq!(storage.count(), 0);
274    }
275
276    /// Tests that we only store a max of k items that have nothing in common with our id
277    #[test]
278    fn test_nothing_in_common() {
279        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
280        let mut storage = Buckets::new(our_id, 8);
281
282        // First 8 should be added
283        storage.add(
284            TestWrapper::new(
285                Id::from_hex("f000000000000000000000000000000000000000").unwrap(),
286                None,
287            ),
288            None,
289        );
290        storage.add(
291            TestWrapper::new(
292                Id::from_hex("f000000000000000000000000000000000000001").unwrap(),
293                None,
294            ),
295            None,
296        );
297        storage.add(
298            TestWrapper::new(
299                Id::from_hex("f000000000000000000000000000000000000010").unwrap(),
300                None,
301            ),
302            None,
303        );
304        storage.add(
305            TestWrapper::new(
306                Id::from_hex("f000000000000000000000000000000000000011").unwrap(),
307                None,
308            ),
309            None,
310        );
311        storage.add(
312            TestWrapper::new(
313                Id::from_hex("f000000000000000000000000000000000000100").unwrap(),
314                None,
315            ),
316            None,
317        );
318        storage.add(
319            TestWrapper::new(
320                Id::from_hex("f000000000000000000000000000000000000101").unwrap(),
321                None,
322            ),
323            None,
324        );
325        storage.add(
326            TestWrapper::new(
327                Id::from_hex("f000000000000000000000000000000000000110").unwrap(),
328                None,
329            ),
330            None,
331        );
332        storage.add(
333            TestWrapper::new(
334                Id::from_hex("f000000000000000000000000000000000000111").unwrap(),
335                None,
336            ),
337            None,
338        );
339        assert_eq!(storage.buckets[0].len(), 8);
340
341        // This one should not be added
342        storage.add(
343            TestWrapper::new(
344                Id::from_hex("f000000000000000000000000000000000001000").unwrap(),
345                None,
346            ),
347            None,
348        );
349        assert_eq!(storage.buckets[0].len(), 8);
350        assert!(storage
351            .get_mut(&Id::from_hex("f000000000000000000000000000000000001000").unwrap())
352            .is_none());
353    }
354
355    /// Test that get_nearest works
356    #[test]
357    fn test_get_nearest_nodes() {
358        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
359        let mut storage = Buckets::new(our_id, 8);
360
361        storage.add(
362            TestWrapper::new(
363                Id::from_hex("0000000000000000000000000000000000000001").unwrap(),
364                None,
365            ),
366            None,
367        );
368        storage.add(
369            TestWrapper::new(
370                Id::from_hex("0000000000000000000000000000000000000010").unwrap(),
371                None,
372            ),
373            None,
374        );
375        storage.add(
376            TestWrapper::new(
377                Id::from_hex("0000000000000000000000000000000000000011").unwrap(),
378                None,
379            ),
380            None,
381        );
382        storage.add(
383            TestWrapper::new(
384                Id::from_hex("0000000000000000000000000000000000000100").unwrap(),
385                None,
386            ),
387            None,
388        );
389        storage.add(
390            TestWrapper::new(
391                Id::from_hex("0000000000000000000000000000000000000101").unwrap(),
392                None,
393            ),
394            None,
395        );
396        storage.add(
397            TestWrapper::new(
398                Id::from_hex("0000000000000000000000000000000000000110").unwrap(),
399                None,
400            ),
401            None,
402        );
403        storage.add(
404            TestWrapper::new(
405                Id::from_hex("0000000000000000000000000000000000000111").unwrap(),
406                None,
407            ),
408            None,
409        );
410        storage.add(
411            TestWrapper::new(
412                Id::from_hex("0000000000000000000000000000000000001000").unwrap(),
413                None,
414            ),
415            None,
416        );
417        storage.add(
418            TestWrapper::new(
419                Id::from_hex("0000000000000000000000000000000000001001").unwrap(),
420                None,
421            ),
422            None,
423        );
424
425        let nearest = storage.get_nearest_nodes(
426            &Id::from_hex("ffffffffffffffffffffffffffffffffffffffff").unwrap(),
427            None,
428        );
429        assert_eq!(nearest.len(), 8);
430        assert_eq!(
431            nearest[0].get_id(),
432            Id::from_hex("0000000000000000000000000000000000001001").unwrap()
433        );
434
435        let nearest = storage.get_nearest_nodes(
436            &Id::from_hex("0000000000000000000000000000000000000000").unwrap(),
437            None,
438        );
439        assert_eq!(nearest.len(), 8);
440        assert_eq!(
441            nearest[0].get_id(),
442            Id::from_hex("0000000000000000000000000000000000000001").unwrap()
443        );
444    }
445
446    #[test]
447    fn test_get_nearest_nodes2() {
448        let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
449        let mut storage = Buckets::new(our_id, 8);
450
451        storage.add(
452            TestWrapper::new(
453                Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap(),
454                None,
455            ),
456            None,
457        );
458        storage.add(
459            TestWrapper::new(
460                Id::from_hex("00000000000000000000fada4cd3cf6225373cb7").unwrap(),
461                None,
462            ),
463            None,
464        );
465
466        let nearest = storage.get_nearest_nodes(
467            &Id::from_hex("5fcb695a07ad50be46f1fada4cd3cf6225373cb7").unwrap(),
468            None,
469        );
470        assert_eq!(nearest.len(), 2);
471        assert_eq!(
472            nearest[0].get_id(),
473            Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap()
474        );
475
476        let nearest = storage.get_nearest_nodes(
477            &Id::from_hex("0000000000000000000000000000000000000000").unwrap(),
478            None,
479        );
480        assert_eq!(nearest.len(), 2);
481        assert_eq!(
482            nearest[0].get_id(),
483            Id::from_hex("00000000000000000000fada4cd3cf6225373cb7").unwrap()
484        );
485
486        let nearest = storage.get_nearest_nodes(
487            &Id::from_hex("ffffffffffffffffffffffffffffffffffffffff").unwrap(),
488            None,
489        );
490        assert_eq!(nearest.len(), 2);
491        assert_eq!(
492            nearest[0].get_id(),
493            Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap()
494        );
495    }
496}