s_tree/
lib.rs

1use std::collections::btree_map::Range;
2use std::collections::BTreeMap;
3use std::iter::Chain;
4use std::ops::Bound::Excluded;
5use std::ops::Bound::Included;
6use std::ops::Bound::Unbounded;
7
8#[non_exhaustive]
9pub struct STree<K, V>(BTreeMap<K, V>);
10
11impl<K: Ord, V> STree<K, V> {
12    /// Makes a empty tree
13    pub fn new() -> Self {
14        Self(BTreeMap::new())
15    }
16
17    /// Returns a reference to the value corresponding to the key
18    pub fn get(&self, key: K) -> Option<&V> {
19        self.0.get(&key)
20    }
21
22    /// Inserts a key-value pair into the tree
23    ///
24    /// If the tree did not have this key present, `None` is returned
25    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
26        self.0.insert(key, value)
27    }
28
29    /// Removes a key from the tree, returning the value at the key if the key
30    /// was previously in the tree
31    pub fn remove(&mut self, key: K) -> Option<V> {
32        self.0.remove(&key)
33    }
34}
35
36impl<V> STree<u64, V> {
37    /// Finds the best element by proximity to the target within common bits threshold
38    pub fn find_best_single(&self, target: u64, common_bits: u32) -> Option<(u64, &V)> {
39        let mut elements_iterator = self.find(target, common_bits);
40
41        let mut best_element = match elements_iterator.next() {
42            Some((&k, v)) => (k, v),
43            None => {
44                return None;
45            }
46        };
47        let mut best_element_diff = (best_element.0 as i128 - target as i128).abs();
48
49        for (&k, v) in elements_iterator {
50            let element_diff = (k as i128 - target as i128).abs();
51            if element_diff < best_element_diff {
52                best_element = (k, v);
53                best_element_diff = element_diff;
54            }
55        }
56
57        Some(best_element)
58    }
59
60    /// Finds elements within common bits threshold and returns them sorted by proximity to the
61    /// target, the first one being the closest one
62    pub fn find_best_sorted(&self, target: u64, common_bits: u32) -> Vec<(u64, &V)> {
63        let mut sorted_elements: Vec<(i128, u64, &V)> = self
64            .find(target, common_bits)
65            .map(|(&k, v)| ((k as i128 - target as i128).abs(), k, v))
66            .collect();
67        sorted_elements.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
68        sorted_elements
69            .into_iter()
70            .map(|(_, k, v)| (k, v))
71            .collect()
72    }
73
74    fn find(&self, target: u64, common_bits: u32) -> Chain<Range<u64, V>, Range<u64, V>> {
75        let deviation = 2u64.pow(64 - common_bits) / 2;
76        let start = target.wrapping_sub(deviation);
77        let end = target.wrapping_add(deviation);
78
79        if start > end {
80            // Wrapping range over the end of the map key space
81            self.0
82                .range((Included(start), Unbounded))
83                .chain(self.0.range((Unbounded, Included(end))))
84        } else {
85            // Wrapping is not necessary here, but having the same type allows things to compile
86            self.0
87                .range((Included(start), Excluded(target)))
88                .chain(self.0.range((Included(target), Included(end))))
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn basic_test() {
99        let mut tree = STree::<u64, ()>::new();
100
101        let num_1 = 10u64;
102        let num_2 = 15u64;
103        let num_3 = u64::max_value() - 5;
104        let num_4 = 999u64;
105        let num_x = u64::max_value() - 999;
106        let num_y = 2000u64;
107        let num_z = 12u64;
108        let num_w = 0u64;
109
110        assert!(tree.insert(num_1, ()).is_none());
111        assert!(tree.insert(num_2, ()).is_none());
112        assert!(tree.insert(num_3, ()).is_none());
113        assert!(tree.insert(num_4, ()).is_none());
114
115        assert!(tree.get(num_1).is_some());
116        assert!(tree.get(num_2).is_some());
117        assert!(tree.get(num_3).is_some());
118        assert!(tree.get(num_4).is_some());
119
120        assert_eq!(tree.find_best_single(num_1, 64 - 1), Some((num_1, &())));
121        assert_eq!(tree.find_best_single(num_2, 64 - 1), Some((num_2, &())));
122        assert_eq!(tree.find_best_single(num_x, 64 - 1), None);
123        assert_eq!(tree.find_best_single(num_x, 64 - 11), Some((num_3, &())));
124        assert_eq!(tree.find_best_single(num_y, 64 - 1), None);
125        assert_eq!(tree.find_best_single(num_y, 64 - 11), Some((num_4, &())));
126        assert_eq!(tree.find_best_single(num_z, 64 - 2), Some((num_1, &())));
127        assert_eq!(tree.find_best_single(num_w, 64 - 3), None);
128        assert_eq!(tree.find_best_single(num_w, 64 - 4), Some((num_3, &())));
129
130        assert_eq!(tree.find_best_sorted(num_1, 64 - 1).len(), 1);
131        assert_eq!(tree.find_best_sorted(num_1, 64 - 4).len(), 2);
132        assert_eq!(
133            tree.find_best_sorted(num_1, 64 - 4),
134            vec![(num_1, &()), (num_2, &())]
135        );
136        assert_eq!(tree.find_best_sorted(num_1, 64 - 10).len(), 3);
137        assert_eq!(
138            tree.find_best_sorted(num_1, 64 - 10),
139            vec![(num_1, &()), (num_2, &()), (num_3, &())]
140        );
141    }
142}