1use ndarray::Array1;
2use rand::Rng;
3use rand_distr::StandardNormal;
4
5#[derive(Debug, Clone)]
11#[cfg_attr(
12 feature = "persistence",
13 derive(serde::Serialize, serde::Deserialize)
14)]
15pub struct RandomProjectionHasher {
16 projections: Vec<Array1<f32>>,
17 num_hashes: usize,
18}
19
20impl RandomProjectionHasher {
21 pub fn new(dim: usize, num_hashes: usize, rng: &mut impl Rng) -> Self {
23 let projections = (0..num_hashes)
24 .map(|_| {
25 let v: Vec<f32> = (0..dim).map(|_| rng.sample(StandardNormal)).collect();
26 Array1::from_vec(v)
27 })
28 .collect();
29 Self {
30 projections,
31 num_hashes,
32 }
33 }
34
35 pub fn hash_vector(&self, vector: &ndarray::ArrayView1<f32>) -> (u64, Vec<(usize, f32)>) {
40 let mut hash: u64 = 0;
41 let mut margins: Vec<(usize, f32)> = Vec::with_capacity(self.num_hashes);
42
43 for (i, proj) in self.projections.iter().enumerate() {
44 let dot = vector.dot(proj);
45 if dot >= 0.0 {
46 hash |= 1u64 << i;
47 }
48 margins.push((i, dot.abs()));
49 }
50
51 margins.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
52 (hash, margins)
53 }
54
55 pub fn hash_vector_fast(&self, vector: &ndarray::ArrayView1<f32>) -> u64 {
57 let mut hash: u64 = 0;
58 for (i, proj) in self.projections.iter().enumerate() {
59 if vector.dot(proj) >= 0.0 {
60 hash |= 1u64 << i;
61 }
62 }
63 hash
64 }
65
66 pub fn num_hashes(&self) -> usize {
68 self.num_hashes
69 }
70}
71
72pub fn multi_probe_keys(
77 base_hash: u64,
78 margins: &[(usize, f32)],
79 num_probes: usize,
80) -> Vec<u64> {
81 let mut keys = Vec::with_capacity(1 + num_probes);
82 keys.push(base_hash);
83
84 for &(bit_idx, _) in margins.iter().take(num_probes) {
85 keys.push(base_hash ^ (1u64 << bit_idx));
86 }
87
88 keys
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use ndarray::array;
95 use rand::SeedableRng;
96 use rand::rngs::StdRng;
97
98 #[test]
99 fn test_deterministic_hash() {
100 let mut rng = StdRng::seed_from_u64(42);
101 let hasher = RandomProjectionHasher::new(4, 8, &mut rng);
102 let v = array![1.0, 2.0, 3.0, 4.0];
103 let h1 = hasher.hash_vector_fast(&v.view());
104 let h2 = hasher.hash_vector_fast(&v.view());
105 assert_eq!(h1, h2);
106 }
107
108 #[test]
109 fn test_similar_vectors_likely_same_hash() {
110 let mut rng = StdRng::seed_from_u64(42);
111 let hasher = RandomProjectionHasher::new(4, 4, &mut rng);
112 let v1 = array![1.0, 2.0, 3.0, 4.0];
113 let v2 = array![1.01, 2.01, 3.01, 4.01];
114 let h1 = hasher.hash_vector_fast(&v1.view());
115 let h2 = hasher.hash_vector_fast(&v2.view());
116 assert_eq!(h1, h2);
119 }
120
121 #[test]
122 fn test_multi_probe_keys() {
123 let base = 0b1010u64;
124 let margins = vec![(0, 0.1), (2, 0.5), (1, 0.8), (3, 1.2)];
125 let keys = multi_probe_keys(base, &margins, 2);
126 assert_eq!(keys.len(), 3);
127 assert_eq!(keys[0], 0b1010); assert_eq!(keys[1], 0b1011); assert_eq!(keys[2], 0b1110); }
131}