1use std::collections::btree_map::Range;
2use std::collections::BTreeMap;
3use std::iter::Chain;
4use std::ops::Bound::Excluded;
5use std::ops::Bound::Included;
6use std::ops::Bound::Unbounded;
7
8#[non_exhaustive]
9pub struct STree<K, V>(BTreeMap<K, V>);
10
11impl<K: Ord, V> STree<K, V> {
12 pub fn new() -> Self {
14 Self(BTreeMap::new())
15 }
16
17 pub fn get(&self, key: K) -> Option<&V> {
19 self.0.get(&key)
20 }
21
22 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
26 self.0.insert(key, value)
27 }
28
29 pub fn remove(&mut self, key: K) -> Option<V> {
32 self.0.remove(&key)
33 }
34}
35
36impl<V> STree<u64, V> {
37 pub fn find_best_single(&self, target: u64, common_bits: u32) -> Option<(u64, &V)> {
39 let mut elements_iterator = self.find(target, common_bits);
40
41 let mut best_element = match elements_iterator.next() {
42 Some((&k, v)) => (k, v),
43 None => {
44 return None;
45 }
46 };
47 let mut best_element_diff = (best_element.0 as i128 - target as i128).abs();
48
49 for (&k, v) in elements_iterator {
50 let element_diff = (k as i128 - target as i128).abs();
51 if element_diff < best_element_diff {
52 best_element = (k, v);
53 best_element_diff = element_diff;
54 }
55 }
56
57 Some(best_element)
58 }
59
60 pub fn find_best_sorted(&self, target: u64, common_bits: u32) -> Vec<(u64, &V)> {
63 let mut sorted_elements: Vec<(i128, u64, &V)> = self
64 .find(target, common_bits)
65 .map(|(&k, v)| ((k as i128 - target as i128).abs(), k, v))
66 .collect();
67 sorted_elements.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
68 sorted_elements
69 .into_iter()
70 .map(|(_, k, v)| (k, v))
71 .collect()
72 }
73
74 fn find(&self, target: u64, common_bits: u32) -> Chain<Range<u64, V>, Range<u64, V>> {
75 let deviation = 2u64.pow(64 - common_bits) / 2;
76 let start = target.wrapping_sub(deviation);
77 let end = target.wrapping_add(deviation);
78
79 if start > end {
80 self.0
82 .range((Included(start), Unbounded))
83 .chain(self.0.range((Unbounded, Included(end))))
84 } else {
85 self.0
87 .range((Included(start), Excluded(target)))
88 .chain(self.0.range((Included(target), Included(end))))
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn basic_test() {
99 let mut tree = STree::<u64, ()>::new();
100
101 let num_1 = 10u64;
102 let num_2 = 15u64;
103 let num_3 = u64::max_value() - 5;
104 let num_4 = 999u64;
105 let num_x = u64::max_value() - 999;
106 let num_y = 2000u64;
107 let num_z = 12u64;
108 let num_w = 0u64;
109
110 assert!(tree.insert(num_1, ()).is_none());
111 assert!(tree.insert(num_2, ()).is_none());
112 assert!(tree.insert(num_3, ()).is_none());
113 assert!(tree.insert(num_4, ()).is_none());
114
115 assert!(tree.get(num_1).is_some());
116 assert!(tree.get(num_2).is_some());
117 assert!(tree.get(num_3).is_some());
118 assert!(tree.get(num_4).is_some());
119
120 assert_eq!(tree.find_best_single(num_1, 64 - 1), Some((num_1, &())));
121 assert_eq!(tree.find_best_single(num_2, 64 - 1), Some((num_2, &())));
122 assert_eq!(tree.find_best_single(num_x, 64 - 1), None);
123 assert_eq!(tree.find_best_single(num_x, 64 - 11), Some((num_3, &())));
124 assert_eq!(tree.find_best_single(num_y, 64 - 1), None);
125 assert_eq!(tree.find_best_single(num_y, 64 - 11), Some((num_4, &())));
126 assert_eq!(tree.find_best_single(num_z, 64 - 2), Some((num_1, &())));
127 assert_eq!(tree.find_best_single(num_w, 64 - 3), None);
128 assert_eq!(tree.find_best_single(num_w, 64 - 4), Some((num_3, &())));
129
130 assert_eq!(tree.find_best_sorted(num_1, 64 - 1).len(), 1);
131 assert_eq!(tree.find_best_sorted(num_1, 64 - 4).len(), 2);
132 assert_eq!(
133 tree.find_best_sorted(num_1, 64 - 4),
134 vec![(num_1, &()), (num_2, &())]
135 );
136 assert_eq!(tree.find_best_sorted(num_1, 64 - 10).len(), 3);
137 assert_eq!(
138 tree.find_best_sorted(num_1, 64 - 10),
139 vec![(num_1, &()), (num_2, &()), (num_3, &())]
140 );
141 }
142}