1use crate::common::Id;
2use std::time::Instant;
3
4pub trait Bucketable {
6 fn get_id(&self) -> Id;
7
8 fn get_first_seen(&self) -> Instant;
9}
10
11#[derive(Clone)]
13pub struct Buckets<T: Bucketable> {
14 our_id: Id,
15 buckets: Vec<Vec<T>>,
16
17 k: usize,
18}
19
20impl<T: Bucketable> Buckets<T> {
21 pub fn new(our_id: Id, k: usize) -> Buckets<T> {
22 let mut to_ret = Buckets {
23 our_id,
24 buckets: Vec::with_capacity(32),
25 k,
26 };
27
28 to_ret.buckets.push(Vec::new());
29
30 to_ret
31 }
32
33 pub fn add(&mut self, item: T, chump_list: Option<&mut Vec<T>>) {
34 if item.get_id() == self.our_id {
36 return;
37 }
38
39 let dest_bucket_idx = self.get_dest_bucket_idx(&item);
40 self.buckets[dest_bucket_idx].push(item);
41 self.handle_bucket_overflow(dest_bucket_idx, chump_list);
42 }
43
44 pub fn clear(&mut self) {
45 self.buckets.clear();
46 self.buckets.push(Vec::with_capacity(2 * self.k));
47 }
48
49 pub fn contains(&self, id: &Id) -> bool {
50 let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
51 if let Some(bucket) = self.buckets.get(dest_bucket_idx) {
52 for item in bucket.iter() {
53 if item.get_id() == *id {
54 return true;
55 }
56 }
57 }
58 false
59 }
60
61 pub fn count(&self) -> usize {
62 let mut count = 0;
63 for bucket in &self.buckets {
64 count += bucket.len();
65 }
66
67 count
68 }
69
70 pub fn count_buckets(&self) -> usize {
71 self.buckets.len()
72 }
73
74 pub fn get_mut(&mut self, id: &Id) -> Option<&mut T> {
75 let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
76 if let Some(bucket) = self.buckets.get_mut(dest_bucket_idx) {
77 for item in bucket.iter_mut() {
78 if item.get_id() == *id {
79 return Some(item);
80 }
81 }
82 }
83 None
84 }
85
86 pub fn get_nearest_nodes(&self, id: &Id, exclude: Option<&Id>) -> Vec<&T> {
90 let mut all: Vec<&T> = self
91 .values()
92 .iter()
93 .filter(|item| exclude.is_none() || *exclude.unwrap() != item.get_id())
94 .copied()
95 .collect();
96
97 all.sort_unstable_by(|a, b| {
98 let a_dist = a.get_id().xor(id);
99 let b_dist = b.get_id().xor(id);
100 a_dist.partial_cmp(&b_dist).unwrap()
101 });
102
103 all.truncate(self.k);
104
105 all
106 }
107
108 pub fn retain<F>(&mut self, mut f: F)
109 where
110 F: FnMut(&T) -> bool,
111 {
112 for bucket in &mut self.buckets {
113 bucket.retain(|item| f(item));
114 }
115 }
116
117 pub fn remove(&mut self, id: &Id) -> Option<T> {
118 let dest_bucket_idx = self.get_dest_bucket_idx_for_id(id);
119 if let Some(bucket) = self.buckets.get_mut(dest_bucket_idx) {
120 for i in 0..bucket.len() {
121 if bucket[i].get_id() == *id {
122 return Some(bucket.swap_remove(i));
123 }
124 }
125 }
126 None
127 }
128
129 pub fn set_id(&mut self, new_id: Id) {
130 self.clear();
131 self.our_id = new_id;
132 }
133
134 pub fn values(&self) -> Vec<&T> {
135 let mut to_ret = Vec::new();
136 for bucket in &self.buckets {
137 for item in bucket {
138 to_ret.push(item);
139 }
140 }
141 to_ret
142 }
143
144 fn get_dest_bucket_idx(&self, item: &T) -> usize {
145 self.get_dest_bucket_idx_for_id(&item.get_id())
146 }
147
148 fn get_dest_bucket_idx_for_id(&self, id: &Id) -> usize {
149 std::cmp::min(self.buckets.len() - 1, self.our_id.matching_prefix_bits(id))
150 }
151
152 fn handle_bucket_overflow(
153 &mut self,
154 mut bucket_index: usize,
155 mut chump_list: Option<&mut Vec<T>>,
156 ) {
157 while bucket_index < self.buckets.len() {
158 if self.buckets[bucket_index].len() > self.k {
160 if bucket_index == self.buckets.len() - 1 {
163 self.buckets.push(Vec::with_capacity(2 * self.k));
164 }
165
166 for i in (0..self.buckets[bucket_index].len()).rev() {
168 let ideal_bucket_idx = self.get_dest_bucket_idx(&self.buckets[bucket_index][i]);
169
170 if ideal_bucket_idx != bucket_index {
172 let node = self.buckets[bucket_index].swap_remove(i);
173 self.buckets[ideal_bucket_idx].push(node);
174 }
175 }
176
177 if self.buckets[bucket_index].len() > self.k {
179 self.buckets[bucket_index].sort_unstable_by_key(|a| a.get_first_seen());
180 let mut remainder = self.buckets[bucket_index].split_off(self.k);
181
182 if let Some(chump_list) = &mut chump_list {
183 chump_list.append(&mut remainder);
184 }
185 }
186 }
187 bucket_index += 1
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use rand::prelude::*;
196 extern crate rand_chacha;
197
198 struct TestWrapper {
199 id: Id,
200 first_seen: Instant,
201 }
202
203 impl TestWrapper {
204 pub fn new(id: Id, first_seen: Option<Instant>) -> TestWrapper {
205 let fs = if let Some(first_seen) = first_seen {
206 first_seen
207 } else {
208 Instant::now()
209 };
210
211 TestWrapper { id, first_seen: fs }
212 }
213 }
214
215 impl Bucketable for TestWrapper {
216 fn get_id(&self) -> Id {
217 self.id
218 }
219
220 fn get_first_seen(&self) -> Instant {
221 self.first_seen
222 }
223 }
224
225 impl std::fmt::Debug for TestWrapper {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 self.get_id().fmt(f)
228 }
229 }
230
231 #[test]
233 fn test_correct_bucket() {
234 let id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
235 let mut storage = Buckets::new(id, 8);
236
237 let mut rng = Box::new(rand_chacha::ChaCha8Rng::seed_from_u64(50));
239
240 for _ in 0..2000 {
241 let node_id = Id::from_random(&mut rng);
242 storage.add(TestWrapper::new(node_id, None), None);
243 }
244
245 for i in 0..storage.buckets.len() {
246 assert!(storage.buckets[i].len() <= 8);
247 for wrapper in &storage.buckets[i] {
248 assert_eq!(
249 i,
250 std::cmp::min(
251 storage.our_id.matching_prefix_bits(&wrapper.get_id()),
252 storage.buckets.len() - 1
253 )
254 );
255 }
256 }
257 }
258
259 #[test]
261 fn test_add_remove() {
262 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
263 let mut storage = Buckets::new(our_id, 8);
264
265 let their_id = Id::from_hex("0000000000000000000000000000000000000001").unwrap();
266 storage.add(TestWrapper::new(their_id, None), None);
267
268 assert_eq!(storage.count(), 1);
269 assert!(storage.get_mut(&their_id).is_some());
270
271 assert!(storage.remove(&their_id).is_some());
272 assert!(storage.remove(&their_id).is_none());
273 assert_eq!(storage.count(), 0);
274 }
275
276 #[test]
278 fn test_nothing_in_common() {
279 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
280 let mut storage = Buckets::new(our_id, 8);
281
282 storage.add(
284 TestWrapper::new(
285 Id::from_hex("f000000000000000000000000000000000000000").unwrap(),
286 None,
287 ),
288 None,
289 );
290 storage.add(
291 TestWrapper::new(
292 Id::from_hex("f000000000000000000000000000000000000001").unwrap(),
293 None,
294 ),
295 None,
296 );
297 storage.add(
298 TestWrapper::new(
299 Id::from_hex("f000000000000000000000000000000000000010").unwrap(),
300 None,
301 ),
302 None,
303 );
304 storage.add(
305 TestWrapper::new(
306 Id::from_hex("f000000000000000000000000000000000000011").unwrap(),
307 None,
308 ),
309 None,
310 );
311 storage.add(
312 TestWrapper::new(
313 Id::from_hex("f000000000000000000000000000000000000100").unwrap(),
314 None,
315 ),
316 None,
317 );
318 storage.add(
319 TestWrapper::new(
320 Id::from_hex("f000000000000000000000000000000000000101").unwrap(),
321 None,
322 ),
323 None,
324 );
325 storage.add(
326 TestWrapper::new(
327 Id::from_hex("f000000000000000000000000000000000000110").unwrap(),
328 None,
329 ),
330 None,
331 );
332 storage.add(
333 TestWrapper::new(
334 Id::from_hex("f000000000000000000000000000000000000111").unwrap(),
335 None,
336 ),
337 None,
338 );
339 assert_eq!(storage.buckets[0].len(), 8);
340
341 storage.add(
343 TestWrapper::new(
344 Id::from_hex("f000000000000000000000000000000000001000").unwrap(),
345 None,
346 ),
347 None,
348 );
349 assert_eq!(storage.buckets[0].len(), 8);
350 assert!(storage
351 .get_mut(&Id::from_hex("f000000000000000000000000000000000001000").unwrap())
352 .is_none());
353 }
354
355 #[test]
357 fn test_get_nearest_nodes() {
358 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
359 let mut storage = Buckets::new(our_id, 8);
360
361 storage.add(
362 TestWrapper::new(
363 Id::from_hex("0000000000000000000000000000000000000001").unwrap(),
364 None,
365 ),
366 None,
367 );
368 storage.add(
369 TestWrapper::new(
370 Id::from_hex("0000000000000000000000000000000000000010").unwrap(),
371 None,
372 ),
373 None,
374 );
375 storage.add(
376 TestWrapper::new(
377 Id::from_hex("0000000000000000000000000000000000000011").unwrap(),
378 None,
379 ),
380 None,
381 );
382 storage.add(
383 TestWrapper::new(
384 Id::from_hex("0000000000000000000000000000000000000100").unwrap(),
385 None,
386 ),
387 None,
388 );
389 storage.add(
390 TestWrapper::new(
391 Id::from_hex("0000000000000000000000000000000000000101").unwrap(),
392 None,
393 ),
394 None,
395 );
396 storage.add(
397 TestWrapper::new(
398 Id::from_hex("0000000000000000000000000000000000000110").unwrap(),
399 None,
400 ),
401 None,
402 );
403 storage.add(
404 TestWrapper::new(
405 Id::from_hex("0000000000000000000000000000000000000111").unwrap(),
406 None,
407 ),
408 None,
409 );
410 storage.add(
411 TestWrapper::new(
412 Id::from_hex("0000000000000000000000000000000000001000").unwrap(),
413 None,
414 ),
415 None,
416 );
417 storage.add(
418 TestWrapper::new(
419 Id::from_hex("0000000000000000000000000000000000001001").unwrap(),
420 None,
421 ),
422 None,
423 );
424
425 let nearest = storage.get_nearest_nodes(
426 &Id::from_hex("ffffffffffffffffffffffffffffffffffffffff").unwrap(),
427 None,
428 );
429 assert_eq!(nearest.len(), 8);
430 assert_eq!(
431 nearest[0].get_id(),
432 Id::from_hex("0000000000000000000000000000000000001001").unwrap()
433 );
434
435 let nearest = storage.get_nearest_nodes(
436 &Id::from_hex("0000000000000000000000000000000000000000").unwrap(),
437 None,
438 );
439 assert_eq!(nearest.len(), 8);
440 assert_eq!(
441 nearest[0].get_id(),
442 Id::from_hex("0000000000000000000000000000000000000001").unwrap()
443 );
444 }
445
446 #[test]
447 fn test_get_nearest_nodes2() {
448 let our_id = Id::from_hex("0000000000000000000000000000000000000000").unwrap();
449 let mut storage = Buckets::new(our_id, 8);
450
451 storage.add(
452 TestWrapper::new(
453 Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap(),
454 None,
455 ),
456 None,
457 );
458 storage.add(
459 TestWrapper::new(
460 Id::from_hex("00000000000000000000fada4cd3cf6225373cb7").unwrap(),
461 None,
462 ),
463 None,
464 );
465
466 let nearest = storage.get_nearest_nodes(
467 &Id::from_hex("5fcb695a07ad50be46f1fada4cd3cf6225373cb7").unwrap(),
468 None,
469 );
470 assert_eq!(nearest.len(), 2);
471 assert_eq!(
472 nearest[0].get_id(),
473 Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap()
474 );
475
476 let nearest = storage.get_nearest_nodes(
477 &Id::from_hex("0000000000000000000000000000000000000000").unwrap(),
478 None,
479 );
480 assert_eq!(nearest.len(), 2);
481 assert_eq!(
482 nearest[0].get_id(),
483 Id::from_hex("00000000000000000000fada4cd3cf6225373cb7").unwrap()
484 );
485
486 let nearest = storage.get_nearest_nodes(
487 &Id::from_hex("ffffffffffffffffffffffffffffffffffffffff").unwrap(),
488 None,
489 );
490 assert_eq!(nearest.len(), 2);
491 assert_eq!(
492 nearest[0].get_id(),
493 Id::from_hex("5fcb695a07ad50be46f100000000000000000000").unwrap()
494 );
495 }
496}