1use scirs2_core::random::rngs::StdRng;
7use scirs2_core::random::{Rng, SeedableRng};
8use std::collections::hash_map::DefaultHasher;
9use std::f64::consts::LN_2;
10use std::hash::{Hash, Hasher};
11
12pub struct BloomFilter {
14 bit_array: Vec<bool>,
15 size: usize,
16 hash_functions: usize,
17 inserted_count: usize,
18}
19
20impl BloomFilter {
21 pub fn new(capacity: usize, false_positive_rate: f64) -> Self {
23 let size = Self::optimal_size(capacity, false_positive_rate);
24 let hash_functions = Self::optimal_hash_functions(size, capacity);
25
26 Self {
27 bit_array: vec![false; size],
28 size,
29 hash_functions,
30 inserted_count: 0,
31 }
32 }
33
34 pub fn with_parameters(size: usize, hash_functions: usize) -> Self {
36 Self {
37 bit_array: vec![false; size],
38 size,
39 hash_functions,
40 inserted_count: 0,
41 }
42 }
43
44 fn optimal_size(capacity: usize, false_positive_rate: f64) -> usize {
45 let ln2_sq = LN_2 * LN_2;
46 (-(capacity as f64) * false_positive_rate.ln() / ln2_sq).ceil() as usize
47 }
48
49 fn optimal_hash_functions(size: usize, capacity: usize) -> usize {
50 ((size as f64 / capacity as f64) * LN_2).ceil() as usize
51 }
52
53 fn hash_values<T: Hash>(&self, item: &T) -> Vec<usize> {
54 let mut hashes = Vec::with_capacity(self.hash_functions);
55
56 for i in 0..self.hash_functions {
57 let mut hasher = DefaultHasher::new();
58 item.hash(&mut hasher);
59 i.hash(&mut hasher);
60 hashes.push((hasher.finish() as usize) % self.size);
61 }
62
63 hashes
64 }
65
66 pub fn insert<T: Hash>(&mut self, item: &T) {
68 let hashes = self.hash_values(item);
69 for hash in hashes {
70 self.bit_array[hash] = true;
71 }
72 self.inserted_count += 1;
73 }
74
75 pub fn contains<T: Hash>(&self, item: &T) -> bool {
77 let hashes = self.hash_values(item);
78 hashes.iter().all(|&hash| self.bit_array[hash])
79 }
80
81 pub fn false_positive_probability(&self) -> f64 {
83 let bits_set = self.bit_array.iter().filter(|&&bit| bit).count() as f64;
84 let ratio = bits_set / self.size as f64;
85 ratio.powf(self.hash_functions as f64)
86 }
87
88 pub fn len(&self) -> usize {
90 self.inserted_count
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.inserted_count == 0
96 }
97
98 pub fn clear(&mut self) {
100 self.bit_array.fill(false);
101 self.inserted_count = 0;
102 }
103
104 pub fn stats(&self) -> BloomFilterStats {
106 let bits_set = self.bit_array.iter().filter(|&&bit| bit).count();
107 BloomFilterStats {
108 size: self.size,
109 hash_functions: self.hash_functions,
110 inserted_count: self.inserted_count,
111 bits_set,
112 load_factor: bits_set as f64 / self.size as f64,
113 false_positive_probability: self.false_positive_probability(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct BloomFilterStats {
120 pub size: usize,
121 pub hash_functions: usize,
122 pub inserted_count: usize,
123 pub bits_set: usize,
124 pub load_factor: f64,
125 pub false_positive_probability: f64,
126}
127
128pub struct CountMinSketch {
130 counts: Vec<Vec<u32>>,
131 width: usize,
132 depth: usize,
133 total_count: u64,
134}
135
136impl CountMinSketch {
137 pub fn new(width: usize, depth: usize) -> Self {
139 Self {
140 counts: vec![vec![0; width]; depth],
141 width,
142 depth,
143 total_count: 0,
144 }
145 }
146
147 pub fn with_bounds(epsilon: f64, delta: f64) -> Self {
149 let width = (std::f64::consts::E / epsilon).ceil() as usize;
150 let depth = (1.0 / delta).ln().ceil() as usize;
151 Self::new(width, depth)
152 }
153
154 fn hash_values<T: Hash>(&self, item: &T) -> Vec<usize> {
155 let mut hashes = Vec::with_capacity(self.depth);
156
157 for i in 0..self.depth {
158 let mut hasher = DefaultHasher::new();
159 item.hash(&mut hasher);
160 i.hash(&mut hasher);
161 hashes.push((hasher.finish() as usize) % self.width);
162 }
163
164 hashes
165 }
166
167 pub fn add<T: Hash>(&mut self, item: &T, count: u32) {
169 let hashes = self.hash_values(item);
170 for (i, &hash) in hashes.iter().enumerate() {
171 self.counts[i][hash] = self.counts[i][hash].saturating_add(count);
172 }
173 self.total_count += count as u64;
174 }
175
176 pub fn increment<T: Hash>(&mut self, item: &T) {
178 self.add(item, 1);
179 }
180
181 pub fn estimate<T: Hash>(&self, item: &T) -> u32 {
183 let hashes = self.hash_values(item);
184 hashes
185 .iter()
186 .enumerate()
187 .map(|(i, &hash)| self.counts[i][hash])
188 .min()
189 .unwrap_or(0)
190 }
191
192 pub fn total_count(&self) -> u64 {
194 self.total_count
195 }
196
197 pub fn clear(&mut self) {
199 for row in &mut self.counts {
200 row.fill(0);
201 }
202 self.total_count = 0;
203 }
204
205 pub fn stats(&self) -> CountMinSketchStats {
207 let max_count = self
208 .counts
209 .iter()
210 .flat_map(|row| row.iter())
211 .max()
212 .copied()
213 .unwrap_or(0);
214
215 let avg_count = if self.width * self.depth > 0 {
216 self.total_count as f64 / (self.width * self.depth) as f64
217 } else {
218 0.0
219 };
220
221 CountMinSketchStats {
222 width: self.width,
223 depth: self.depth,
224 total_count: self.total_count,
225 max_count,
226 avg_count,
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
232pub struct CountMinSketchStats {
233 pub width: usize,
234 pub depth: usize,
235 pub total_count: u64,
236 pub max_count: u32,
237 pub avg_count: f64,
238}
239
240pub struct HyperLogLog {
242 buckets: Vec<u8>,
243 bucket_count: usize,
244 alpha: f64,
245}
246
247impl HyperLogLog {
248 pub fn new(precision: u8) -> Self {
250 assert!(
251 (4..=16).contains(&precision),
252 "Precision must be between 4 and 16"
253 );
254
255 let bucket_count = 1 << precision;
256 let alpha = Self::calculate_alpha(bucket_count);
257
258 Self {
259 buckets: vec![0; bucket_count],
260 bucket_count,
261 alpha,
262 }
263 }
264
265 fn calculate_alpha(bucket_count: usize) -> f64 {
266 match bucket_count {
267 16 => 0.673,
268 32 => 0.697,
269 64 => 0.709,
270 _ => 0.7213 / (1.0 + 1.079 / bucket_count as f64),
271 }
272 }
273
274 fn hash_value<T: Hash>(&self, item: &T) -> u64 {
275 let mut hasher = DefaultHasher::new();
276 item.hash(&mut hasher);
277 hasher.finish()
278 }
279
280 fn leading_zeros(mut value: u64) -> u8 {
281 if value == 0 {
282 return 64;
283 }
284
285 let mut count = 0;
286 while (value & 0x8000000000000000) == 0 {
287 count += 1;
288 value <<= 1;
289 }
290 count
291 }
292
293 pub fn add<T: Hash>(&mut self, item: &T) {
295 let hash = self.hash_value(item);
296 let bucket_bits = 64 - (self.bucket_count as f64).log2() as u8;
297 let bucket = (hash >> bucket_bits) as usize;
298 let leading_zeros = Self::leading_zeros(hash << (64 - bucket_bits)) + 1;
299
300 if leading_zeros > self.buckets[bucket] {
301 self.buckets[bucket] = leading_zeros;
302 }
303 }
304
305 pub fn cardinality(&self) -> f64 {
307 let sum: f64 = self
308 .buckets
309 .iter()
310 .map(|&bucket| 2.0_f64.powf(-(bucket as f64)))
311 .sum();
312
313 let raw_estimate = self.alpha * (self.bucket_count as f64).powi(2) / sum;
314
315 if raw_estimate <= 2.5 * self.bucket_count as f64 {
317 let zero_buckets = self.buckets.iter().filter(|&&bucket| bucket == 0).count();
319 if zero_buckets != 0 {
320 return (self.bucket_count as f64)
321 * (self.bucket_count as f64 / zero_buckets as f64).ln();
322 }
323 } else if raw_estimate <= (1.0 / 30.0) * (1u64 << 32) as f64 {
324 return raw_estimate;
326 }
327
328 -((1u64 << 32) as f64) * (1.0 - raw_estimate / ((1u64 << 32) as f64)).ln()
330 }
331
332 pub fn merge(&mut self, other: &HyperLogLog) {
334 assert_eq!(
335 self.bucket_count, other.bucket_count,
336 "Cannot merge HyperLogLogs with different precisions"
337 );
338
339 for i in 0..self.bucket_count {
340 self.buckets[i] = self.buckets[i].max(other.buckets[i]);
341 }
342 }
343
344 pub fn clear(&mut self) {
346 self.buckets.fill(0);
347 }
348
349 pub fn stats(&self) -> HyperLogLogStats {
351 let max_bucket = *self.buckets.iter().max().unwrap_or(&0);
352 let zero_buckets = self.buckets.iter().filter(|&&bucket| bucket == 0).count();
353 let avg_bucket =
354 self.buckets.iter().map(|&b| b as f64).sum::<f64>() / self.bucket_count as f64;
355
356 HyperLogLogStats {
357 bucket_count: self.bucket_count,
358 cardinality: self.cardinality(),
359 max_bucket,
360 zero_buckets,
361 avg_bucket,
362 }
363 }
364}
365
366#[derive(Debug, Clone)]
367pub struct HyperLogLogStats {
368 pub bucket_count: usize,
369 pub cardinality: f64,
370 pub max_bucket: u8,
371 pub zero_buckets: usize,
372 pub avg_bucket: f64,
373}
374
375pub struct MinHash {
377 hashes: Vec<u64>,
378 hash_functions: usize,
379}
380
381impl MinHash {
382 pub fn new(hash_functions: usize) -> Self {
384 Self {
385 hashes: vec![u64::MAX; hash_functions],
386 hash_functions,
387 }
388 }
389
390 fn hash_values<T: Hash>(&self, item: &T) -> Vec<u64> {
391 let mut hashes = Vec::with_capacity(self.hash_functions);
392
393 for i in 0..self.hash_functions {
394 let mut hasher = DefaultHasher::new();
395 item.hash(&mut hasher);
396 i.hash(&mut hasher);
397 hashes.push(hasher.finish());
398 }
399
400 hashes
401 }
402
403 pub fn add<T: Hash>(&mut self, item: &T) {
405 let item_hashes = self.hash_values(item);
406
407 for (i, &hash) in item_hashes.iter().enumerate() {
408 if hash < self.hashes[i] {
409 self.hashes[i] = hash;
410 }
411 }
412 }
413
414 pub fn jaccard_similarity(&self, other: &MinHash) -> f64 {
416 assert_eq!(
417 self.hash_functions, other.hash_functions,
418 "MinHash objects must have the same number of hash functions"
419 );
420
421 let matches = self
422 .hashes
423 .iter()
424 .zip(other.hashes.iter())
425 .filter(|(&a, &b)| a == b)
426 .count();
427
428 matches as f64 / self.hash_functions as f64
429 }
430
431 pub fn clear(&mut self) {
433 self.hashes.fill(u64::MAX);
434 }
435
436 pub fn stats(&self) -> MinHashStats {
438 let initialized_hashes = self.hashes.iter().filter(|&&h| h != u64::MAX).count();
439
440 MinHashStats {
441 hash_functions: self.hash_functions,
442 initialized_hashes,
443 completion_ratio: initialized_hashes as f64 / self.hash_functions as f64,
444 }
445 }
446}
447
448#[derive(Debug, Clone)]
449pub struct MinHashStats {
450 pub hash_functions: usize,
451 pub initialized_hashes: usize,
452 pub completion_ratio: f64,
453}
454
455pub struct LSHash {
457 hash_tables: Vec<Vec<Vec<usize>>>,
458 projections: Vec<Vec<f64>>,
459 table_count: usize,
460 dimension: usize,
461 bucket_width: f64,
462}
463
464impl LSHash {
465 pub fn new(dimension: usize, table_count: usize, bucket_width: f64) -> Self {
467 let mut projections = Vec::with_capacity(table_count);
468 let mut rng = StdRng::seed_from_u64(42);
469
470 for _ in 0..table_count {
471 let mut projection = Vec::with_capacity(dimension);
472 for _ in 0..dimension {
473 projection.push(rng.gen::<f64>() * 2.0 - 1.0); }
475 projections.push(projection);
476 }
477
478 Self {
479 hash_tables: vec![Vec::new(); table_count],
480 projections,
481 table_count,
482 dimension,
483 bucket_width,
484 }
485 }
486
487 fn hash_vector(&self, vector: &[f64], table_idx: usize) -> i32 {
488 let dot_product: f64 = vector
489 .iter()
490 .zip(self.projections[table_idx].iter())
491 .map(|(&v, &p)| v * p)
492 .sum();
493
494 (dot_product / self.bucket_width).floor() as i32
495 }
496
497 pub fn add(&mut self, vector: &[f64], data_idx: usize) {
499 assert_eq!(
500 vector.len(),
501 self.dimension,
502 "Vector dimension must match LSH dimension"
503 );
504
505 for table_idx in 0..self.table_count {
506 let hash = self.hash_vector(vector, table_idx);
507
508 if hash >= 0 {
510 let bucket_idx = hash as usize;
511 if self.hash_tables[table_idx].len() <= bucket_idx {
513 self.hash_tables[table_idx].resize(bucket_idx + 1, Vec::new());
514 }
515 self.hash_tables[table_idx][bucket_idx].push(data_idx);
516 }
517 }
518 }
519
520 pub fn query(&self, vector: &[f64]) -> Vec<usize> {
522 assert_eq!(
523 vector.len(),
524 self.dimension,
525 "Vector dimension must match LSH dimension"
526 );
527
528 let mut candidates = std::collections::HashSet::new();
529
530 for table_idx in 0..self.table_count {
531 let hash = self.hash_vector(vector, table_idx);
532
533 if hash >= 0 && (hash as usize) < self.hash_tables[table_idx].len() {
534 for &candidate in &self.hash_tables[table_idx][hash as usize] {
535 candidates.insert(candidate);
536 }
537 }
538 }
539
540 candidates.into_iter().collect()
541 }
542
543 pub fn clear(&mut self) {
545 for table in &mut self.hash_tables {
546 table.clear();
547 }
548 }
549
550 pub fn stats(&self) -> LSHashStats {
552 let total_entries: usize = self
553 .hash_tables
554 .iter()
555 .flat_map(|table| table.iter())
556 .map(|bucket| bucket.len())
557 .sum();
558
559 let non_empty_buckets: usize = self
560 .hash_tables
561 .iter()
562 .flat_map(|table| table.iter())
563 .filter(|bucket| !bucket.is_empty())
564 .count();
565
566 let total_buckets: usize = self.hash_tables.iter().map(|table| table.len()).sum();
567
568 LSHashStats {
569 table_count: self.table_count,
570 dimension: self.dimension,
571 bucket_width: self.bucket_width,
572 total_entries,
573 total_buckets,
574 non_empty_buckets,
575 load_factor: if total_buckets > 0 {
576 non_empty_buckets as f64 / total_buckets as f64
577 } else {
578 0.0
579 },
580 }
581 }
582}
583
584#[derive(Debug, Clone)]
585pub struct LSHashStats {
586 pub table_count: usize,
587 pub dimension: usize,
588 pub bucket_width: f64,
589 pub total_entries: usize,
590 pub total_buckets: usize,
591 pub non_empty_buckets: usize,
592 pub load_factor: f64,
593}
594
595#[allow(non_snake_case)]
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use std::collections::HashSet;
600
601 #[test]
602 fn test_bloom_filter() {
603 let mut filter = BloomFilter::new(1000, 0.01);
604
605 filter.insert(&"hello");
607 filter.insert(&"world");
608 filter.insert(&42);
609
610 assert!(filter.contains(&"hello"));
612 assert!(filter.contains(&"world"));
613 assert!(filter.contains(&42));
614 assert!(!filter.contains(&"not_inserted"));
615
616 assert_eq!(filter.len(), 3);
617
618 let stats = filter.stats();
619 assert!(stats.false_positive_probability < 0.1);
620 }
621
622 #[test]
623 fn test_count_min_sketch() {
624 let mut sketch = CountMinSketch::new(100, 5);
625
626 sketch.increment(&"apple");
628 sketch.increment(&"apple");
629 sketch.add(&"banana", 3);
630 sketch.increment(&"cherry");
631
632 assert!(sketch.estimate(&"apple") >= 2);
634 assert!(sketch.estimate(&"banana") >= 3);
635 assert!(sketch.estimate(&"cherry") >= 1);
636 assert_eq!(sketch.estimate(&"not_added"), 0);
637
638 assert_eq!(sketch.total_count(), 6);
639 }
640
641 #[test]
642 fn test_hyperloglog() {
643 let mut hll = HyperLogLog::new(8);
644
645 for i in 0..1000 {
647 hll.add(&i);
648 }
649
650 let cardinality = hll.cardinality();
651 assert!(cardinality > 800.0 && cardinality < 1200.0);
653
654 let mut hll2 = HyperLogLog::new(8);
656 for i in 500..1500 {
657 hll2.add(&i);
658 }
659
660 hll.merge(&hll2);
661 let merged_cardinality = hll.cardinality();
662 assert!(merged_cardinality > cardinality);
663 }
664
665 #[test]
666 fn test_minhash() {
667 let mut mh1 = MinHash::new(128);
668 let mut mh2 = MinHash::new(128);
669
670 let set1: HashSet<i32> = (0..100).collect();
672 let set2: HashSet<i32> = (50..150).collect();
673
674 for item in &set1 {
675 mh1.add(item);
676 }
677
678 for item in &set2 {
679 mh2.add(item);
680 }
681
682 let similarity = mh1.jaccard_similarity(&mh2);
683
684 assert!(similarity > 0.2 && similarity < 0.5);
687 }
688
689 #[test]
690 fn test_lsh() {
691 let mut lsh = LSHash::new(3, 5, 1.0);
692
693 lsh.add(&[1.0, 2.0, 3.0], 0);
695 lsh.add(&[1.1, 2.1, 3.1], 1);
696 lsh.add(&[5.0, 6.0, 7.0], 2);
697
698 let candidates = lsh.query(&[1.05, 2.05, 3.05]);
700
701 println!("LSH candidates: {:?}", candidates);
705
706 let stats = lsh.stats();
707 assert!(stats.table_count == 5);
710 assert!(stats.dimension == 3);
711 }
712}