sklears_utils/
probabilistic.rs

1//! Probabilistic data structures for efficient approximate algorithms
2//!
3//! This module provides memory-efficient probabilistic data structures commonly used
4//! in machine learning and big data applications for approximate computations.
5
6use scirs2_core::random::rngs::StdRng;
7use scirs2_core::random::{Rng, SeedableRng};
8use std::collections::hash_map::DefaultHasher;
9use std::f64::consts::LN_2;
10use std::hash::{Hash, Hasher};
11
12/// Bloom filter for membership testing with false positives
13pub struct BloomFilter {
14    bit_array: Vec<bool>,
15    size: usize,
16    hash_functions: usize,
17    inserted_count: usize,
18}
19
20impl BloomFilter {
21    /// Create a new Bloom filter with optimal parameters for given capacity and false positive rate
22    pub fn new(capacity: usize, false_positive_rate: f64) -> Self {
23        let size = Self::optimal_size(capacity, false_positive_rate);
24        let hash_functions = Self::optimal_hash_functions(size, capacity);
25
26        Self {
27            bit_array: vec![false; size],
28            size,
29            hash_functions,
30            inserted_count: 0,
31        }
32    }
33
34    /// Create a new Bloom filter with explicit parameters
35    pub fn with_parameters(size: usize, hash_functions: usize) -> Self {
36        Self {
37            bit_array: vec![false; size],
38            size,
39            hash_functions,
40            inserted_count: 0,
41        }
42    }
43
44    fn optimal_size(capacity: usize, false_positive_rate: f64) -> usize {
45        let ln2_sq = LN_2 * LN_2;
46        (-(capacity as f64) * false_positive_rate.ln() / ln2_sq).ceil() as usize
47    }
48
49    fn optimal_hash_functions(size: usize, capacity: usize) -> usize {
50        ((size as f64 / capacity as f64) * LN_2).ceil() as usize
51    }
52
53    fn hash_values<T: Hash>(&self, item: &T) -> Vec<usize> {
54        let mut hashes = Vec::with_capacity(self.hash_functions);
55
56        for i in 0..self.hash_functions {
57            let mut hasher = DefaultHasher::new();
58            item.hash(&mut hasher);
59            i.hash(&mut hasher);
60            hashes.push((hasher.finish() as usize) % self.size);
61        }
62
63        hashes
64    }
65
66    /// Insert an item into the filter
67    pub fn insert<T: Hash>(&mut self, item: &T) {
68        let hashes = self.hash_values(item);
69        for hash in hashes {
70            self.bit_array[hash] = true;
71        }
72        self.inserted_count += 1;
73    }
74
75    /// Test if an item might be in the filter (no false negatives, possible false positives)
76    pub fn contains<T: Hash>(&self, item: &T) -> bool {
77        let hashes = self.hash_values(item);
78        hashes.iter().all(|&hash| self.bit_array[hash])
79    }
80
81    /// Get the current false positive probability
82    pub fn false_positive_probability(&self) -> f64 {
83        let bits_set = self.bit_array.iter().filter(|&&bit| bit).count() as f64;
84        let ratio = bits_set / self.size as f64;
85        ratio.powf(self.hash_functions as f64)
86    }
87
88    /// Get the number of items inserted
89    pub fn len(&self) -> usize {
90        self.inserted_count
91    }
92
93    /// Check if the filter is empty
94    pub fn is_empty(&self) -> bool {
95        self.inserted_count == 0
96    }
97
98    /// Clear the filter
99    pub fn clear(&mut self) {
100        self.bit_array.fill(false);
101        self.inserted_count = 0;
102    }
103
104    /// Get filter statistics
105    pub fn stats(&self) -> BloomFilterStats {
106        let bits_set = self.bit_array.iter().filter(|&&bit| bit).count();
107        BloomFilterStats {
108            size: self.size,
109            hash_functions: self.hash_functions,
110            inserted_count: self.inserted_count,
111            bits_set,
112            load_factor: bits_set as f64 / self.size as f64,
113            false_positive_probability: self.false_positive_probability(),
114        }
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct BloomFilterStats {
120    pub size: usize,
121    pub hash_functions: usize,
122    pub inserted_count: usize,
123    pub bits_set: usize,
124    pub load_factor: f64,
125    pub false_positive_probability: f64,
126}
127
128/// Count-Min Sketch for frequency estimation
129pub struct CountMinSketch {
130    counts: Vec<Vec<u32>>,
131    width: usize,
132    depth: usize,
133    total_count: u64,
134}
135
136impl CountMinSketch {
137    /// Create a new Count-Min Sketch with specified dimensions
138    pub fn new(width: usize, depth: usize) -> Self {
139        Self {
140            counts: vec![vec![0; width]; depth],
141            width,
142            depth,
143            total_count: 0,
144        }
145    }
146
147    /// Create a Count-Min Sketch with optimal parameters for given error bounds
148    pub fn with_bounds(epsilon: f64, delta: f64) -> Self {
149        let width = (std::f64::consts::E / epsilon).ceil() as usize;
150        let depth = (1.0 / delta).ln().ceil() as usize;
151        Self::new(width, depth)
152    }
153
154    fn hash_values<T: Hash>(&self, item: &T) -> Vec<usize> {
155        let mut hashes = Vec::with_capacity(self.depth);
156
157        for i in 0..self.depth {
158            let mut hasher = DefaultHasher::new();
159            item.hash(&mut hasher);
160            i.hash(&mut hasher);
161            hashes.push((hasher.finish() as usize) % self.width);
162        }
163
164        hashes
165    }
166
167    /// Add count occurrences of an item
168    pub fn add<T: Hash>(&mut self, item: &T, count: u32) {
169        let hashes = self.hash_values(item);
170        for (i, &hash) in hashes.iter().enumerate() {
171            self.counts[i][hash] = self.counts[i][hash].saturating_add(count);
172        }
173        self.total_count += count as u64;
174    }
175
176    /// Increment the count of an item by 1
177    pub fn increment<T: Hash>(&mut self, item: &T) {
178        self.add(item, 1);
179    }
180
181    /// Estimate the frequency of an item
182    pub fn estimate<T: Hash>(&self, item: &T) -> u32 {
183        let hashes = self.hash_values(item);
184        hashes
185            .iter()
186            .enumerate()
187            .map(|(i, &hash)| self.counts[i][hash])
188            .min()
189            .unwrap_or(0)
190    }
191
192    /// Get the total count of all items
193    pub fn total_count(&self) -> u64 {
194        self.total_count
195    }
196
197    /// Clear the sketch
198    pub fn clear(&mut self) {
199        for row in &mut self.counts {
200            row.fill(0);
201        }
202        self.total_count = 0;
203    }
204
205    /// Get sketch statistics
206    pub fn stats(&self) -> CountMinSketchStats {
207        let max_count = self
208            .counts
209            .iter()
210            .flat_map(|row| row.iter())
211            .max()
212            .copied()
213            .unwrap_or(0);
214
215        let avg_count = if self.width * self.depth > 0 {
216            self.total_count as f64 / (self.width * self.depth) as f64
217        } else {
218            0.0
219        };
220
221        CountMinSketchStats {
222            width: self.width,
223            depth: self.depth,
224            total_count: self.total_count,
225            max_count,
226            avg_count,
227        }
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct CountMinSketchStats {
233    pub width: usize,
234    pub depth: usize,
235    pub total_count: u64,
236    pub max_count: u32,
237    pub avg_count: f64,
238}
239
240/// HyperLogLog for cardinality estimation
241pub struct HyperLogLog {
242    buckets: Vec<u8>,
243    bucket_count: usize,
244    alpha: f64,
245}
246
247impl HyperLogLog {
248    /// Create a new HyperLogLog with the specified precision (4-16)
249    pub fn new(precision: u8) -> Self {
250        assert!(
251            (4..=16).contains(&precision),
252            "Precision must be between 4 and 16"
253        );
254
255        let bucket_count = 1 << precision;
256        let alpha = Self::calculate_alpha(bucket_count);
257
258        Self {
259            buckets: vec![0; bucket_count],
260            bucket_count,
261            alpha,
262        }
263    }
264
265    fn calculate_alpha(bucket_count: usize) -> f64 {
266        match bucket_count {
267            16 => 0.673,
268            32 => 0.697,
269            64 => 0.709,
270            _ => 0.7213 / (1.0 + 1.079 / bucket_count as f64),
271        }
272    }
273
274    fn hash_value<T: Hash>(&self, item: &T) -> u64 {
275        let mut hasher = DefaultHasher::new();
276        item.hash(&mut hasher);
277        hasher.finish()
278    }
279
280    fn leading_zeros(mut value: u64) -> u8 {
281        if value == 0 {
282            return 64;
283        }
284
285        let mut count = 0;
286        while (value & 0x8000000000000000) == 0 {
287            count += 1;
288            value <<= 1;
289        }
290        count
291    }
292
293    /// Add an item to the HyperLogLog
294    pub fn add<T: Hash>(&mut self, item: &T) {
295        let hash = self.hash_value(item);
296        let bucket_bits = 64 - (self.bucket_count as f64).log2() as u8;
297        let bucket = (hash >> bucket_bits) as usize;
298        let leading_zeros = Self::leading_zeros(hash << (64 - bucket_bits)) + 1;
299
300        if leading_zeros > self.buckets[bucket] {
301            self.buckets[bucket] = leading_zeros;
302        }
303    }
304
305    /// Estimate the cardinality
306    pub fn cardinality(&self) -> f64 {
307        let sum: f64 = self
308            .buckets
309            .iter()
310            .map(|&bucket| 2.0_f64.powf(-(bucket as f64)))
311            .sum();
312
313        let raw_estimate = self.alpha * (self.bucket_count as f64).powi(2) / sum;
314
315        // Apply bias correction for different ranges
316        if raw_estimate <= 2.5 * self.bucket_count as f64 {
317            // Small range correction
318            let zero_buckets = self.buckets.iter().filter(|&&bucket| bucket == 0).count();
319            if zero_buckets != 0 {
320                return (self.bucket_count as f64)
321                    * (self.bucket_count as f64 / zero_buckets as f64).ln();
322            }
323        } else if raw_estimate <= (1.0 / 30.0) * (1u64 << 32) as f64 {
324            // Intermediate range - no correction
325            return raw_estimate;
326        }
327
328        // Large range correction
329        -((1u64 << 32) as f64) * (1.0 - raw_estimate / ((1u64 << 32) as f64)).ln()
330    }
331
332    /// Merge another HyperLogLog into this one
333    pub fn merge(&mut self, other: &HyperLogLog) {
334        assert_eq!(
335            self.bucket_count, other.bucket_count,
336            "Cannot merge HyperLogLogs with different precisions"
337        );
338
339        for i in 0..self.bucket_count {
340            self.buckets[i] = self.buckets[i].max(other.buckets[i]);
341        }
342    }
343
344    /// Clear the HyperLogLog
345    pub fn clear(&mut self) {
346        self.buckets.fill(0);
347    }
348
349    /// Get HyperLogLog statistics
350    pub fn stats(&self) -> HyperLogLogStats {
351        let max_bucket = *self.buckets.iter().max().unwrap_or(&0);
352        let zero_buckets = self.buckets.iter().filter(|&&bucket| bucket == 0).count();
353        let avg_bucket =
354            self.buckets.iter().map(|&b| b as f64).sum::<f64>() / self.bucket_count as f64;
355
356        HyperLogLogStats {
357            bucket_count: self.bucket_count,
358            cardinality: self.cardinality(),
359            max_bucket,
360            zero_buckets,
361            avg_bucket,
362        }
363    }
364}
365
366#[derive(Debug, Clone)]
367pub struct HyperLogLogStats {
368    pub bucket_count: usize,
369    pub cardinality: f64,
370    pub max_bucket: u8,
371    pub zero_buckets: usize,
372    pub avg_bucket: f64,
373}
374
375/// MinHash for similarity estimation
376pub struct MinHash {
377    hashes: Vec<u64>,
378    hash_functions: usize,
379}
380
381impl MinHash {
382    /// Create a new MinHash with specified number of hash functions
383    pub fn new(hash_functions: usize) -> Self {
384        Self {
385            hashes: vec![u64::MAX; hash_functions],
386            hash_functions,
387        }
388    }
389
390    fn hash_values<T: Hash>(&self, item: &T) -> Vec<u64> {
391        let mut hashes = Vec::with_capacity(self.hash_functions);
392
393        for i in 0..self.hash_functions {
394            let mut hasher = DefaultHasher::new();
395            item.hash(&mut hasher);
396            i.hash(&mut hasher);
397            hashes.push(hasher.finish());
398        }
399
400        hashes
401    }
402
403    /// Add an item to the MinHash
404    pub fn add<T: Hash>(&mut self, item: &T) {
405        let item_hashes = self.hash_values(item);
406
407        for (i, &hash) in item_hashes.iter().enumerate() {
408            if hash < self.hashes[i] {
409                self.hashes[i] = hash;
410            }
411        }
412    }
413
414    /// Estimate Jaccard similarity with another MinHash
415    pub fn jaccard_similarity(&self, other: &MinHash) -> f64 {
416        assert_eq!(
417            self.hash_functions, other.hash_functions,
418            "MinHash objects must have the same number of hash functions"
419        );
420
421        let matches = self
422            .hashes
423            .iter()
424            .zip(other.hashes.iter())
425            .filter(|(&a, &b)| a == b)
426            .count();
427
428        matches as f64 / self.hash_functions as f64
429    }
430
431    /// Clear the MinHash
432    pub fn clear(&mut self) {
433        self.hashes.fill(u64::MAX);
434    }
435
436    /// Get MinHash statistics
437    pub fn stats(&self) -> MinHashStats {
438        let initialized_hashes = self.hashes.iter().filter(|&&h| h != u64::MAX).count();
439
440        MinHashStats {
441            hash_functions: self.hash_functions,
442            initialized_hashes,
443            completion_ratio: initialized_hashes as f64 / self.hash_functions as f64,
444        }
445    }
446}
447
448#[derive(Debug, Clone)]
449pub struct MinHashStats {
450    pub hash_functions: usize,
451    pub initialized_hashes: usize,
452    pub completion_ratio: f64,
453}
454
455/// Locality-Sensitive Hashing for approximate nearest neighbor search
456pub struct LSHash {
457    hash_tables: Vec<Vec<Vec<usize>>>,
458    projections: Vec<Vec<f64>>,
459    table_count: usize,
460    dimension: usize,
461    bucket_width: f64,
462}
463
464impl LSHash {
465    /// Create a new LSH with specified parameters
466    pub fn new(dimension: usize, table_count: usize, bucket_width: f64) -> Self {
467        let mut projections = Vec::with_capacity(table_count);
468        let mut rng = StdRng::seed_from_u64(42);
469
470        for _ in 0..table_count {
471            let mut projection = Vec::with_capacity(dimension);
472            for _ in 0..dimension {
473                projection.push(rng.gen::<f64>() * 2.0 - 1.0); // Random values between -1 and 1
474            }
475            projections.push(projection);
476        }
477
478        Self {
479            hash_tables: vec![Vec::new(); table_count],
480            projections,
481            table_count,
482            dimension,
483            bucket_width,
484        }
485    }
486
487    fn hash_vector(&self, vector: &[f64], table_idx: usize) -> i32 {
488        let dot_product: f64 = vector
489            .iter()
490            .zip(self.projections[table_idx].iter())
491            .map(|(&v, &p)| v * p)
492            .sum();
493
494        (dot_product / self.bucket_width).floor() as i32
495    }
496
497    /// Add a vector with associated data index
498    pub fn add(&mut self, vector: &[f64], data_idx: usize) {
499        assert_eq!(
500            vector.len(),
501            self.dimension,
502            "Vector dimension must match LSH dimension"
503        );
504
505        for table_idx in 0..self.table_count {
506            let hash = self.hash_vector(vector, table_idx);
507
508            // Ensure the hash is non-negative and resize table if needed
509            if hash >= 0 {
510                let bucket_idx = hash as usize;
511                // Resize table if needed
512                if self.hash_tables[table_idx].len() <= bucket_idx {
513                    self.hash_tables[table_idx].resize(bucket_idx + 1, Vec::new());
514                }
515                self.hash_tables[table_idx][bucket_idx].push(data_idx);
516            }
517        }
518    }
519
520    /// Query for approximate nearest neighbors
521    pub fn query(&self, vector: &[f64]) -> Vec<usize> {
522        assert_eq!(
523            vector.len(),
524            self.dimension,
525            "Vector dimension must match LSH dimension"
526        );
527
528        let mut candidates = std::collections::HashSet::new();
529
530        for table_idx in 0..self.table_count {
531            let hash = self.hash_vector(vector, table_idx);
532
533            if hash >= 0 && (hash as usize) < self.hash_tables[table_idx].len() {
534                for &candidate in &self.hash_tables[table_idx][hash as usize] {
535                    candidates.insert(candidate);
536                }
537            }
538        }
539
540        candidates.into_iter().collect()
541    }
542
543    /// Clear all hash tables
544    pub fn clear(&mut self) {
545        for table in &mut self.hash_tables {
546            table.clear();
547        }
548    }
549
550    /// Get LSH statistics
551    pub fn stats(&self) -> LSHashStats {
552        let total_entries: usize = self
553            .hash_tables
554            .iter()
555            .flat_map(|table| table.iter())
556            .map(|bucket| bucket.len())
557            .sum();
558
559        let non_empty_buckets: usize = self
560            .hash_tables
561            .iter()
562            .flat_map(|table| table.iter())
563            .filter(|bucket| !bucket.is_empty())
564            .count();
565
566        let total_buckets: usize = self.hash_tables.iter().map(|table| table.len()).sum();
567
568        LSHashStats {
569            table_count: self.table_count,
570            dimension: self.dimension,
571            bucket_width: self.bucket_width,
572            total_entries,
573            total_buckets,
574            non_empty_buckets,
575            load_factor: if total_buckets > 0 {
576                non_empty_buckets as f64 / total_buckets as f64
577            } else {
578                0.0
579            },
580        }
581    }
582}
583
584#[derive(Debug, Clone)]
585pub struct LSHashStats {
586    pub table_count: usize,
587    pub dimension: usize,
588    pub bucket_width: f64,
589    pub total_entries: usize,
590    pub total_buckets: usize,
591    pub non_empty_buckets: usize,
592    pub load_factor: f64,
593}
594
595#[allow(non_snake_case)]
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use std::collections::HashSet;
600
601    #[test]
602    fn test_bloom_filter() {
603        let mut filter = BloomFilter::new(1000, 0.01);
604
605        // Insert some items
606        filter.insert(&"hello");
607        filter.insert(&"world");
608        filter.insert(&42);
609
610        // Test membership
611        assert!(filter.contains(&"hello"));
612        assert!(filter.contains(&"world"));
613        assert!(filter.contains(&42));
614        assert!(!filter.contains(&"not_inserted"));
615
616        assert_eq!(filter.len(), 3);
617
618        let stats = filter.stats();
619        assert!(stats.false_positive_probability < 0.1);
620    }
621
622    #[test]
623    fn test_count_min_sketch() {
624        let mut sketch = CountMinSketch::new(100, 5);
625
626        // Add some items
627        sketch.increment(&"apple");
628        sketch.increment(&"apple");
629        sketch.add(&"banana", 3);
630        sketch.increment(&"cherry");
631
632        // Test estimates
633        assert!(sketch.estimate(&"apple") >= 2);
634        assert!(sketch.estimate(&"banana") >= 3);
635        assert!(sketch.estimate(&"cherry") >= 1);
636        assert_eq!(sketch.estimate(&"not_added"), 0);
637
638        assert_eq!(sketch.total_count(), 6);
639    }
640
641    #[test]
642    fn test_hyperloglog() {
643        let mut hll = HyperLogLog::new(8);
644
645        // Add many unique items
646        for i in 0..1000 {
647            hll.add(&i);
648        }
649
650        let cardinality = hll.cardinality();
651        // HyperLogLog should estimate around 1000 with some error
652        assert!(cardinality > 800.0 && cardinality < 1200.0);
653
654        // Test merge
655        let mut hll2 = HyperLogLog::new(8);
656        for i in 500..1500 {
657            hll2.add(&i);
658        }
659
660        hll.merge(&hll2);
661        let merged_cardinality = hll.cardinality();
662        assert!(merged_cardinality > cardinality);
663    }
664
665    #[test]
666    fn test_minhash() {
667        let mut mh1 = MinHash::new(128);
668        let mut mh2 = MinHash::new(128);
669
670        // Create two sets with some overlap
671        let set1: HashSet<i32> = (0..100).collect();
672        let set2: HashSet<i32> = (50..150).collect();
673
674        for item in &set1 {
675            mh1.add(item);
676        }
677
678        for item in &set2 {
679            mh2.add(item);
680        }
681
682        let similarity = mh1.jaccard_similarity(&mh2);
683
684        // The actual Jaccard similarity is 50/150 = 0.33
685        // MinHash should approximate this
686        assert!(similarity > 0.2 && similarity < 0.5);
687    }
688
689    #[test]
690    fn test_lsh() {
691        let mut lsh = LSHash::new(3, 5, 1.0);
692
693        // Add some vectors
694        lsh.add(&[1.0, 2.0, 3.0], 0);
695        lsh.add(&[1.1, 2.1, 3.1], 1);
696        lsh.add(&[5.0, 6.0, 7.0], 2);
697
698        // Query for similar vectors
699        let candidates = lsh.query(&[1.05, 2.05, 3.05]);
700
701        // The test should complete quickly
702        // We'll just verify that we get some results back (could be empty or non-empty)
703        // depending on the random projections
704        println!("LSH candidates: {:?}", candidates);
705
706        let stats = lsh.stats();
707        // The exact number of entries depends on the random projections and may vary
708        // We just check that the basic functionality works
709        assert!(stats.table_count == 5);
710        assert!(stats.dimension == 3);
711    }
712}