Skip to main content

superbit/
hash.rs

1use ndarray::Array1;
2use rand::Rng;
3use rand_distr::StandardNormal;
4
5/// A random-projection hash family for one hash table.
6///
7/// Uses sign-of-random-projection (SimHash / hyperplane LSH) to map vectors
8/// to bit signatures. Each bit corresponds to the sign of the dot product
9/// with a random Gaussian vector.
10#[derive(Debug, Clone)]
11#[cfg_attr(
12    feature = "persistence",
13    derive(serde::Serialize, serde::Deserialize)
14)]
15pub struct RandomProjectionHasher {
16    projections: Vec<Array1<f32>>,
17    num_hashes: usize,
18}
19
20impl RandomProjectionHasher {
21    /// Create a new hasher with `num_hashes` random projection vectors of dimension `dim`.
22    pub fn new(dim: usize, num_hashes: usize, rng: &mut impl Rng) -> Self {
23        let projections = (0..num_hashes)
24            .map(|_| {
25                let v: Vec<f32> = (0..dim).map(|_| rng.sample(StandardNormal)).collect();
26                Array1::from_vec(v)
27            })
28            .collect();
29        Self {
30            projections,
31            num_hashes,
32        }
33    }
34
35    /// Compute the hash key for a vector, along with margin information for multi-probe.
36    ///
37    /// Returns `(hash_key, margins)` where margins is a vec of `(bit_index, |dot_product|)`
38    /// sorted by ascending margin (most uncertain bits first).
39    pub fn hash_vector(&self, vector: &ndarray::ArrayView1<f32>) -> (u64, Vec<(usize, f32)>) {
40        let mut hash: u64 = 0;
41        let mut margins: Vec<(usize, f32)> = Vec::with_capacity(self.num_hashes);
42
43        for (i, proj) in self.projections.iter().enumerate() {
44            let dot = vector.dot(proj);
45            if dot >= 0.0 {
46                hash |= 1u64 << i;
47            }
48            margins.push((i, dot.abs()));
49        }
50
51        margins.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
52        (hash, margins)
53    }
54
55    /// Compute just the hash key (fast path, no margin data).
56    pub fn hash_vector_fast(&self, vector: &ndarray::ArrayView1<f32>) -> u64 {
57        let mut hash: u64 = 0;
58        for (i, proj) in self.projections.iter().enumerate() {
59            if vector.dot(proj) >= 0.0 {
60                hash |= 1u64 << i;
61            }
62        }
63        hash
64    }
65
66    /// Number of hash functions (bits in the signature).
67    pub fn num_hashes(&self) -> usize {
68        self.num_hashes
69    }
70}
71
72/// Generate multi-probe hash keys by flipping the most uncertain bits.
73///
74/// Given the base hash and margin info (sorted ascending by uncertainty),
75/// produces the base key plus `num_probes` perturbed keys.
76pub fn multi_probe_keys(
77    base_hash: u64,
78    margins: &[(usize, f32)],
79    num_probes: usize,
80) -> Vec<u64> {
81    let mut keys = Vec::with_capacity(1 + num_probes);
82    keys.push(base_hash);
83
84    for &(bit_idx, _) in margins.iter().take(num_probes) {
85        keys.push(base_hash ^ (1u64 << bit_idx));
86    }
87
88    keys
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use ndarray::array;
95    use rand::SeedableRng;
96    use rand::rngs::StdRng;
97
98    #[test]
99    fn test_deterministic_hash() {
100        let mut rng = StdRng::seed_from_u64(42);
101        let hasher = RandomProjectionHasher::new(4, 8, &mut rng);
102        let v = array![1.0, 2.0, 3.0, 4.0];
103        let h1 = hasher.hash_vector_fast(&v.view());
104        let h2 = hasher.hash_vector_fast(&v.view());
105        assert_eq!(h1, h2);
106    }
107
108    #[test]
109    fn test_similar_vectors_likely_same_hash() {
110        let mut rng = StdRng::seed_from_u64(42);
111        let hasher = RandomProjectionHasher::new(4, 4, &mut rng);
112        let v1 = array![1.0, 2.0, 3.0, 4.0];
113        let v2 = array![1.01, 2.01, 3.01, 4.01];
114        let h1 = hasher.hash_vector_fast(&v1.view());
115        let h2 = hasher.hash_vector_fast(&v2.view());
116        // Very similar vectors should often (but not always) hash the same
117        // With only 4 bits, probability is high
118        assert_eq!(h1, h2);
119    }
120
121    #[test]
122    fn test_multi_probe_keys() {
123        let base = 0b1010u64;
124        let margins = vec![(0, 0.1), (2, 0.5), (1, 0.8), (3, 1.2)];
125        let keys = multi_probe_keys(base, &margins, 2);
126        assert_eq!(keys.len(), 3);
127        assert_eq!(keys[0], 0b1010); // base
128        assert_eq!(keys[1], 0b1011); // flip bit 0
129        assert_eq!(keys[2], 0b1110); // flip bit 2
130    }
131}