1use crate::types::VectorId;
7use ndarray::{Array1, Array2};
8use rand::Rng;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub trait NeuralHash {
14 fn encode(&self, vector: &[f32]) -> Vec<u8>;
16
17 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
19
20 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
22}
23
24#[derive(Clone, Serialize, Deserialize)]
26pub struct DeepHashEmbedding {
27 projections: Vec<Array2<f32>>,
29 biases: Vec<Array1<f32>>,
31 output_bits: usize,
33 input_dims: usize,
35}
36
37impl DeepHashEmbedding {
38 pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
40 let mut rng = rand::thread_rng();
41 let mut projections = Vec::new();
42 let mut biases = Vec::new();
43
44 let mut layer_dims = vec![input_dims];
45 layer_dims.extend(&hidden_dims);
46 layer_dims.push(output_bits);
47
48 for i in 0..layer_dims.len() - 1 {
50 let in_dim = layer_dims[i];
51 let out_dim = layer_dims[i + 1];
52
53 let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
54 let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
55 rng.gen::<f32>() * 2.0 * scale - scale
56 });
57
58 let bias = Array1::zeros(out_dim);
59
60 projections.push(proj);
61 biases.push(bias);
62 }
63
64 Self {
65 projections,
66 biases,
67 output_bits,
68 input_dims,
69 }
70 }
71
72 fn forward(&self, input: &[f32]) -> Vec<f32> {
74 let mut activations = Array1::from_vec(input.to_vec());
75
76 for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
77 activations = proj.dot(&activations) + bias;
79
80 if proj.nrows() != self.output_bits {
82 activations.mapv_inplace(|x| x.max(0.0));
83 }
84 }
85
86 activations.to_vec()
87 }
88
89 pub fn train(
91 &mut self,
92 positive_pairs: &[(Vec<f32>, Vec<f32>)],
93 negative_pairs: &[(Vec<f32>, Vec<f32>)],
94 learning_rate: f32,
95 epochs: usize,
96 ) {
97 for _ in 0..epochs {
100 for (a, b) in positive_pairs {
102 let code_a = self.encode(a);
103 let code_b = self.encode(b);
104 let dist = self.hamming_distance(&code_a, &code_b);
105
106 if dist as f32 > self.output_bits as f32 * 0.3 {
108 self.update_weights(a, b, learning_rate, true);
109 }
110 }
111
112 for (a, b) in negative_pairs {
114 let code_a = self.encode(a);
115 let code_b = self.encode(b);
116 let dist = self.hamming_distance(&code_a, &code_b);
117
118 if (dist as f32) < self.output_bits as f32 * 0.6 {
120 self.update_weights(a, b, learning_rate, false);
121 }
122 }
123 }
124 }
125
126 fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
127 let direction = if attract { 1.0 } else { -1.0 };
129
130 if let Some(last_proj) = self.projections.last_mut() {
132 let a_arr = Array1::from_vec(a.to_vec());
133 let b_arr = Array1::from_vec(b.to_vec());
134
135 for i in 0..last_proj.nrows() {
136 for j in 0..last_proj.ncols() {
137 let grad = direction * lr * (a_arr[j] - b_arr[j]);
138 last_proj[[i, j]] += grad * 0.001; }
140 }
141 }
142 }
143
144 pub fn dimensions(&self) -> (usize, usize) {
146 (self.input_dims, self.output_bits)
147 }
148}
149
150impl NeuralHash for DeepHashEmbedding {
151 fn encode(&self, vector: &[f32]) -> Vec<u8> {
152 if vector.len() != self.input_dims {
153 return vec![0; self.output_bits.div_ceil(8)];
154 }
155
156 let logits = self.forward(vector);
157
158 let mut bits = vec![0u8; self.output_bits.div_ceil(8)];
160
161 for (i, &logit) in logits.iter().enumerate() {
162 if logit > 0.0 {
163 let byte_idx = i / 8;
164 let bit_idx = i % 8;
165 bits[byte_idx] |= 1 << bit_idx;
166 }
167 }
168
169 bits
170 }
171
172 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
173 code_a
174 .iter()
175 .zip(code_b.iter())
176 .map(|(a, b)| (a ^ b).count_ones())
177 .sum()
178 }
179
180 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
181 let normalized_dist = hamming_dist as f32 / code_bits as f32;
183 1.0 - 2.0 * normalized_dist
184 }
185}
186
187#[derive(Clone, Serialize, Deserialize)]
189pub struct SimpleLSH {
190 projections: Array2<f32>,
192 num_bits: usize,
194}
195
196impl SimpleLSH {
197 pub fn new(input_dims: usize, num_bits: usize) -> Self {
199 let mut rng = rand::thread_rng();
200
201 let projections =
203 Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
204
205 Self {
206 projections,
207 num_bits,
208 }
209 }
210}
211
212impl NeuralHash for SimpleLSH {
213 fn encode(&self, vector: &[f32]) -> Vec<u8> {
214 let input = Array1::from_vec(vector.to_vec());
215 let projections = self.projections.dot(&input);
216
217 let mut bits = vec![0u8; self.num_bits.div_ceil(8)];
218
219 for (i, &val) in projections.iter().enumerate() {
220 if val > 0.0 {
221 let byte_idx = i / 8;
222 let bit_idx = i % 8;
223 bits[byte_idx] |= 1 << bit_idx;
224 }
225 }
226
227 bits
228 }
229
230 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
231 code_a
232 .iter()
233 .zip(code_b.iter())
234 .map(|(a, b)| (a ^ b).count_ones())
235 .sum()
236 }
237
238 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
239 let normalized_dist = hamming_dist as f32 / code_bits as f32;
240 1.0 - 2.0 * normalized_dist
241 }
242}
243
244pub struct HashIndex<H: NeuralHash + Clone> {
246 hasher: H,
248 tables: HashMap<Vec<u8>, Vec<VectorId>>,
250 vectors: HashMap<VectorId, Vec<f32>>,
252 code_bits: usize,
254}
255
256impl<H: NeuralHash + Clone> HashIndex<H> {
257 pub fn new(hasher: H, code_bits: usize) -> Self {
259 Self {
260 hasher,
261 tables: HashMap::new(),
262 vectors: HashMap::new(),
263 code_bits,
264 }
265 }
266
267 pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
269 let code = self.hasher.encode(&vector);
270
271 self.tables.entry(code).or_default().push(id.clone());
272
273 self.vectors.insert(id, vector);
274 }
275
276 pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
278 let query_code = self.hasher.encode(query);
279
280 let mut candidates = Vec::new();
281
282 for (code, ids) in &self.tables {
284 let hamming = self.hasher.hamming_distance(&query_code, code);
285
286 if hamming <= max_hamming {
287 for id in ids {
288 if let Some(vec) = self.vectors.get(id) {
289 let similarity = cosine_similarity(query, vec);
290 candidates.push((id.clone(), similarity));
291 }
292 }
293 }
294 }
295
296 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
298 candidates.truncate(k);
299 candidates
300 }
301
302 pub fn compression_ratio(&self) -> f32 {
304 if self.vectors.is_empty() {
305 return 0.0;
306 }
307
308 let original_size: usize = self
309 .vectors
310 .values()
311 .map(|v| v.len() * std::mem::size_of::<f32>())
312 .sum();
313
314 let compressed_size = self.tables.len() * self.code_bits.div_ceil(8);
315
316 original_size as f32 / compressed_size as f32
317 }
318
319 pub fn stats(&self) -> HashIndexStats {
321 let buckets = self.tables.len();
322 let total_vectors = self.vectors.len();
323 let avg_bucket_size = if buckets > 0 {
324 total_vectors as f32 / buckets as f32
325 } else {
326 0.0
327 };
328
329 HashIndexStats {
330 total_vectors,
331 num_buckets: buckets,
332 avg_bucket_size,
333 compression_ratio: self.compression_ratio(),
334 }
335 }
336}
337
338#[derive(Debug, Clone)]
340pub struct HashIndexStats {
341 pub total_vectors: usize,
342 pub num_buckets: usize,
343 pub avg_bucket_size: f32,
344 pub compression_ratio: f32,
345}
346
347fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
348 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
349 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
350 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
351
352 if norm_a > 0.0 && norm_b > 0.0 {
353 dot / (norm_a * norm_b)
354 } else {
355 0.0
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_deep_hash_encoding() {
365 let hash = DeepHashEmbedding::new(4, vec![8], 16);
366 let vector = vec![0.1, 0.2, 0.3, 0.4];
367
368 let code = hash.encode(&vector);
369 assert_eq!(code.len(), 2); }
371
372 #[test]
373 fn test_hamming_distance() {
374 let hash = DeepHashEmbedding::new(2, vec![], 8);
375
376 let code_a = vec![0b10101010];
377 let code_b = vec![0b11001100];
378
379 let dist = hash.hamming_distance(&code_a, &code_b);
380 assert_eq!(dist, 4); }
382
383 #[test]
384 fn test_lsh_encoding() {
385 let lsh = SimpleLSH::new(4, 16);
386 let vector = vec![1.0, 2.0, 3.0, 4.0];
387
388 let code = lsh.encode(&vector);
389 assert_eq!(code.len(), 2);
390
391 let code2 = lsh.encode(&vector);
393 assert_eq!(code, code2);
394 }
395
396 #[test]
397 fn test_hash_index() {
398 let lsh = SimpleLSH::new(3, 8);
399 let mut index = HashIndex::new(lsh, 8);
400
401 index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
403 index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
404 index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
405
406 let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
408
409 assert!(!results.is_empty());
410 let stats = index.stats();
411 assert_eq!(stats.total_vectors, 3);
412 }
413
414 #[test]
415 fn test_compression_ratio() {
416 let lsh = SimpleLSH::new(128, 32); let mut index = HashIndex::new(lsh, 32);
418
419 for i in 0..10 {
420 let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
421 index.insert(i.to_string(), vec);
422 }
423
424 let ratio = index.compression_ratio();
425 assert!(ratio > 1.0); }
427}