1use serde::{Deserialize, Serialize};
24use std::{
25 borrow::Borrow, cmp::max, convert::TryFrom, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops
26};
27use twox_hash::XxHash;
28
29use super::f64_to_usize;
30use crate::traits::{Intersect, IntersectPlusUnionIsPlus, New, UnionAssign};
31
32#[derive(Serialize, Deserialize)]
38#[serde(bound(
39 serialize = "C: Serialize, <C as New>::Config: Serialize",
40 deserialize = "C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
41))]
42pub struct CountMinSketch<K: ?Sized, C: New> {
43 counters: Vec<Vec<C>>,
44 offsets: Vec<usize>, mask: usize,
46 k_num: usize,
47 config: <C as New>::Config,
48 marker: PhantomData<fn(K)>,
49}
50
51impl<K: ?Sized, C> CountMinSketch<K, C>
52where
53 K: Hash,
54 C: New + for<'a> UnionAssign<&'a C> + Intersect,
55{
56 pub fn new(probability: f64, tolerance: f64, config: C::Config) -> Self {
58 let width = Self::optimal_width(tolerance);
59 let k_num = Self::optimal_k_num(probability);
60 let counters: Vec<Vec<C>> = (0..k_num)
61 .map(|_| (0..width).map(|_| C::new(&config)).collect())
62 .collect();
63 let offsets = vec![0; k_num];
64 Self {
65 counters,
66 offsets,
67 mask: Self::mask(width),
68 k_num,
69 config,
70 marker: PhantomData,
71 }
72 }
73
74 pub fn push<Q: ?Sized, V: ?Sized>(&mut self, key: &Q, value: &V) -> C
76 where
77 Q: Hash,
78 K: Borrow<Q>,
79 C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
80 {
81 if !<C as IntersectPlusUnionIsPlus>::VAL {
82 let offsets = self.offsets(key);
83 self.offsets
84 .iter_mut()
85 .zip(offsets)
86 .for_each(|(offset, offset_new)| {
87 *offset = offset_new;
88 });
89 let mut lowest = C::intersect(
90 self.offsets
91 .iter()
92 .enumerate()
93 .map(|(k_i, &offset)| &self.counters[k_i][offset]),
94 )
95 .unwrap();
96 lowest += value;
97 self.counters
98 .iter_mut()
99 .zip(self.offsets.iter())
100 .for_each(|(counters, &offset)| {
101 counters[offset].union_assign(&lowest);
102 });
103 lowest
104 } else {
105 let offsets = self.offsets(key);
106 C::intersect(
107 self.counters
108 .iter_mut()
109 .zip(offsets)
110 .map(|(counters, offset)| {
111 counters[offset] += value;
112 &counters[offset]
113 }),
114 )
115 .unwrap()
116 }
117 }
118
119 pub fn union_assign<Q: ?Sized>(&mut self, key: &Q, value: &C)
121 where
122 Q: Hash,
123 K: Borrow<Q>,
124 {
125 let offsets = self.offsets(key);
126 self.counters
127 .iter_mut()
128 .zip(offsets)
129 .for_each(|(counters, offset)| {
130 counters[offset].union_assign(value);
131 })
132 }
133
134 pub fn get<Q: ?Sized>(&self, key: &Q) -> C
136 where
137 Q: Hash,
138 K: Borrow<Q>,
139 {
140 C::intersect(
141 self.counters
142 .iter()
143 .zip(self.offsets(key))
144 .map(|(counters, offset)| &counters[offset]),
145 )
146 .unwrap()
147 }
148
149 pub fn clear(&mut self) {
159 let config = &self.config;
160 self.counters
161 .iter_mut()
162 .flat_map(|x| x.iter_mut())
163 .for_each(|counter| {
164 *counter = C::new(config);
165 })
166 }
167
168 fn optimal_width(tolerance: f64) -> usize {
169 let e = tolerance;
170 let width = f64_to_usize((2.0 / e).round());
171 max(2, width)
172 .checked_next_power_of_two()
173 .expect("Width would be way too large")
174 }
175
176 fn mask(width: usize) -> usize {
177 assert!(width > 1);
178 assert_eq!(width & (width - 1), 0);
179 width - 1
180 }
181
182 fn optimal_k_num(probability: f64) -> usize {
183 max(
184 1,
185 f64_to_usize(((1.0 - probability).ln() / 0.5_f64.ln()).floor()),
186 )
187 }
188
189 fn offsets<Q: ?Sized>(&self, key: &Q) -> impl Iterator<Item = usize>
190 where
191 Q: Hash,
192 K: Borrow<Q>,
193 {
194 let mask = self.mask;
195 hashes(key).map(move |hash| usize::try_from(hash & u64::try_from(mask).unwrap()).unwrap())
196 }
197}
198
199fn hashes<Q: ?Sized>(key: &Q) -> impl Iterator<Item = u64>
200where
201 Q: Hash,
202{
203 #[allow(missing_copy_implementations, missing_debug_implementations)]
204 struct X(XxHash);
205 impl Iterator for X {
206 type Item = u64;
207 fn next(&mut self) -> Option<Self::Item> {
208 let ret = self.0.finish();
209 self.0.write(&[123]);
210 Some(ret)
211 }
212 }
213 let mut hasher = XxHash::default();
214 key.hash(&mut hasher);
215 X(hasher)
216}
217
218impl<K: ?Sized, C: New + Clone> Clone for CountMinSketch<K, C> {
219 fn clone(&self) -> Self {
220 Self {
221 counters: self.counters.clone(),
222 offsets: vec![0; self.offsets.len()],
223 mask: self.mask,
224 k_num: self.k_num,
225 config: self.config.clone(),
226 marker: PhantomData,
227 }
228 }
229}
230impl<K: ?Sized, C: New> fmt::Debug for CountMinSketch<K, C> {
231 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
232 fmt.debug_struct("CountMinSketch")
233 .finish()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 type CountMinSketch8<K> = super::CountMinSketch<K, u8>;
241 type CountMinSketch16<K> = super::CountMinSketch<K, u16>;
242 type CountMinSketch64<K> = super::CountMinSketch<K, u64>;
243
244 #[ignore] #[test]
246 #[should_panic]
247 fn test_overflow() {
248 let mut cms = CountMinSketch8::<&str>::new(0.95, 10.0 / 100.0, ());
249 for _ in 0..300 {
250 let _ = cms.push("key", &1);
251 }
252 }
254
255 #[test]
256 fn test_increment() {
257 let mut cms = CountMinSketch16::<&str>::new(0.95, 10.0 / 100.0, ());
258 for _ in 0..300 {
259 let _ = cms.push("key", &1);
260 }
261 assert_eq!(cms.get("key"), 300);
262 }
263
264 #[test]
265 fn test_increment_multi() {
266 let mut cms = CountMinSketch64::<u64>::new(0.99, 2.0 / 100.0, ());
267 for i in 0..1_000_000 {
268 let _ = cms.push(&(i % 100), &1);
269 }
270 for key in 0..100 {
271 assert!(cms.get(&key) >= 9_000);
272 }
273 }
278}