toolbox_rs/
count_min_sketch.rs

1use xxhash_rust::xxh3::xxh3_128_with_seed;
2
3use crate::as_bytes::AsBytes;
4
5fn optimal_k(delta: f64) -> usize {
6    (1. / delta).ln().ceil() as usize
7}
8
9fn optimal_m(epsilon: f64) -> usize {
10    (std::f64::consts::E / epsilon).ceil() as usize
11}
12
13fn get_seed() -> u64 {
14    rand::Rng::random::<u64>(&mut rand::rng())
15}
16
17pub struct CountMinSketch {
18    seed: u64,
19    counter: Vec<Vec<u32>>,
20    m: usize,
21    k: usize,
22    len: usize,
23}
24
25impl CountMinSketch {
26    pub fn new(delta: f64, epsilon: f64) -> Self {
27        let seed = get_seed();
28
29        let m = optimal_m(delta);
30        let k = optimal_k(epsilon);
31
32        let counter = vec![vec![0u32; m]; k];
33
34        Self {
35            seed,
36            counter,
37            m,
38            k,
39            len: 0,
40        }
41    }
42}
43
44impl CountMinSketch {
45    fn hash_pair<K: Eq + AsBytes>(&self, key: &K) -> (u64, u64) {
46        let hash = xxh3_128_with_seed(key.as_bytes(), self.seed);
47        (hash as u64, (hash >> 64) as u64)
48    }
49
50    fn get_buckets<K: Eq + AsBytes>(&self, key: &K) -> Vec<usize> {
51        let (hash1, hash2) = self.hash_pair(key);
52        let mut bucket_indices = Vec::with_capacity(self.k);
53        if self.k == 1 {
54            let index = hash1 % self.m as u64;
55            bucket_indices.push(index as usize);
56        } else {
57            (0..self.k as u64).for_each(|i| {
58                let hash = hash1.wrapping_add(i.wrapping_mul(hash2));
59                let index = hash % self.m as u64;
60                bucket_indices.push(index as usize);
61            });
62        }
63        bucket_indices
64    }
65
66    pub fn insert<K: Eq + AsBytes>(&mut self, key: &K) {
67        let indices = self.get_buckets(key);
68        indices.iter().enumerate().for_each(|(k, &b)| {
69            self.counter[k][b] = self.counter[k][b].saturating_add(1);
70        });
71        self.len += 1;
72    }
73
74    pub fn estimate<K: Eq + AsBytes>(&self, key: &K) -> u32 {
75        let indices = self.get_buckets(key);
76        indices
77            .iter()
78            .enumerate()
79            .map(|(k, b)| self.counter[k][*b])
80            .fold(u32::MAX, |a, b| a.min(b))
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::CountMinSketch;
87
88    #[test]
89    fn insert_check_1m() {
90        let mut sketch = CountMinSketch::new(0.01, 0.2);
91
92        for _ in 0..1_000_000 {
93            sketch.insert(&"key");
94        }
95
96        assert_eq!(sketch.estimate(&"key"), 1_000_000);
97        assert_eq!(sketch.estimate(&"blah"), 0);
98    }
99}