toolbox_rs/
count_min_sketch.rs1use xxhash_rust::xxh3::xxh3_128_with_seed;
2
3use crate::as_bytes::AsBytes;
4
5fn optimal_k(delta: f64) -> usize {
6 (1. / delta).ln().ceil() as usize
7}
8
9fn optimal_m(epsilon: f64) -> usize {
10 (std::f64::consts::E / epsilon).ceil() as usize
11}
12
13fn get_seed() -> u64 {
14 rand::Rng::random::<u64>(&mut rand::rng())
15}
16
17pub struct CountMinSketch {
18 seed: u64,
19 counter: Vec<Vec<u32>>,
20 m: usize,
21 k: usize,
22 len: usize,
23}
24
25impl CountMinSketch {
26 pub fn new(delta: f64, epsilon: f64) -> Self {
27 let seed = get_seed();
28
29 let m = optimal_m(delta);
30 let k = optimal_k(epsilon);
31
32 let counter = vec![vec![0u32; m]; k];
33
34 Self {
35 seed,
36 counter,
37 m,
38 k,
39 len: 0,
40 }
41 }
42}
43
44impl CountMinSketch {
45 fn hash_pair<K: Eq + AsBytes>(&self, key: &K) -> (u64, u64) {
46 let hash = xxh3_128_with_seed(key.as_bytes(), self.seed);
47 (hash as u64, (hash >> 64) as u64)
48 }
49
50 fn get_buckets<K: Eq + AsBytes>(&self, key: &K) -> Vec<usize> {
51 let (hash1, hash2) = self.hash_pair(key);
52 let mut bucket_indices = Vec::with_capacity(self.k);
53 if self.k == 1 {
54 let index = hash1 % self.m as u64;
55 bucket_indices.push(index as usize);
56 } else {
57 (0..self.k as u64).for_each(|i| {
58 let hash = hash1.wrapping_add(i.wrapping_mul(hash2));
59 let index = hash % self.m as u64;
60 bucket_indices.push(index as usize);
61 });
62 }
63 bucket_indices
64 }
65
66 pub fn insert<K: Eq + AsBytes>(&mut self, key: &K) {
67 let indices = self.get_buckets(key);
68 indices.iter().enumerate().for_each(|(k, &b)| {
69 self.counter[k][b] = self.counter[k][b].saturating_add(1);
70 });
71 self.len += 1;
72 }
73
74 pub fn estimate<K: Eq + AsBytes>(&self, key: &K) -> u32 {
75 let indices = self.get_buckets(key);
76 indices
77 .iter()
78 .enumerate()
79 .map(|(k, b)| self.counter[k][*b])
80 .fold(u32::MAX, |a, b| a.min(b))
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::CountMinSketch;
87
88 #[test]
89 fn insert_check_1m() {
90 let mut sketch = CountMinSketch::new(0.01, 0.2);
91
92 for _ in 0..1_000_000 {
93 sketch.insert(&"key");
94 }
95
96 assert_eq!(sketch.estimate(&"key"), 1_000_000);
97 assert_eq!(sketch.estimate(&"blah"), 0);
98 }
99}