1use crate::error::{Result, RuvectorError};
7use crate::types::VectorId;
8use ndarray::{Array1, Array2};
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13pub trait NeuralHash {
15 fn encode(&self, vector: &[f32]) -> Vec<u8>;
17
18 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
20
21 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
23}
24
25#[derive(Clone, Serialize, Deserialize)]
27pub struct DeepHashEmbedding {
28 projections: Vec<Array2<f32>>,
30 biases: Vec<Array1<f32>>,
32 output_bits: usize,
34 input_dims: usize,
36}
37
38impl DeepHashEmbedding {
39 pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
41 let mut rng = rand::thread_rng();
42 let mut projections = Vec::new();
43 let mut biases = Vec::new();
44
45 let mut layer_dims = vec![input_dims];
46 layer_dims.extend(&hidden_dims);
47 layer_dims.push(output_bits);
48
49 for i in 0..layer_dims.len() - 1 {
51 let in_dim = layer_dims[i];
52 let out_dim = layer_dims[i + 1];
53
54 let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
55 let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
56 rng.gen::<f32>() * 2.0 * scale - scale
57 });
58
59 let bias = Array1::zeros(out_dim);
60
61 projections.push(proj);
62 biases.push(bias);
63 }
64
65 Self {
66 projections,
67 biases,
68 output_bits,
69 input_dims,
70 }
71 }
72
73 fn forward(&self, input: &[f32]) -> Vec<f32> {
75 let mut activations = Array1::from_vec(input.to_vec());
76
77 for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
78 activations = proj.dot(&activations) + bias;
80
81 if proj.nrows() != self.output_bits {
83 activations.mapv_inplace(|x| x.max(0.0));
84 }
85 }
86
87 activations.to_vec()
88 }
89
90 pub fn train(
92 &mut self,
93 positive_pairs: &[(Vec<f32>, Vec<f32>)],
94 negative_pairs: &[(Vec<f32>, Vec<f32>)],
95 learning_rate: f32,
96 epochs: usize,
97 ) {
98 for _ in 0..epochs {
101 for (a, b) in positive_pairs {
103 let code_a = self.encode(a);
104 let code_b = self.encode(b);
105 let dist = self.hamming_distance(&code_a, &code_b);
106
107 if dist as f32 > self.output_bits as f32 * 0.3 {
109 self.update_weights(a, b, learning_rate, true);
110 }
111 }
112
113 for (a, b) in negative_pairs {
115 let code_a = self.encode(a);
116 let code_b = self.encode(b);
117 let dist = self.hamming_distance(&code_a, &code_b);
118
119 if (dist as f32) < self.output_bits as f32 * 0.6 {
121 self.update_weights(a, b, learning_rate, false);
122 }
123 }
124 }
125 }
126
127 fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
128 let direction = if attract { 1.0 } else { -1.0 };
130
131 if let Some(last_proj) = self.projections.last_mut() {
133 let a_arr = Array1::from_vec(a.to_vec());
134 let b_arr = Array1::from_vec(b.to_vec());
135
136 for i in 0..last_proj.nrows() {
137 for j in 0..last_proj.ncols() {
138 let grad = direction * lr * (a_arr[j] - b_arr[j]);
139 last_proj[[i, j]] += grad * 0.001; }
141 }
142 }
143 }
144
145 pub fn dimensions(&self) -> (usize, usize) {
147 (self.input_dims, self.output_bits)
148 }
149}
150
151impl NeuralHash for DeepHashEmbedding {
152 fn encode(&self, vector: &[f32]) -> Vec<u8> {
153 if vector.len() != self.input_dims {
154 return vec![0; (self.output_bits + 7) / 8];
155 }
156
157 let logits = self.forward(vector);
158
159 let mut bits = vec![0u8; (self.output_bits + 7) / 8];
161
162 for (i, &logit) in logits.iter().enumerate() {
163 if logit > 0.0 {
164 let byte_idx = i / 8;
165 let bit_idx = i % 8;
166 bits[byte_idx] |= 1 << bit_idx;
167 }
168 }
169
170 bits
171 }
172
173 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
174 code_a
175 .iter()
176 .zip(code_b.iter())
177 .map(|(a, b)| (a ^ b).count_ones())
178 .sum()
179 }
180
181 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
182 let normalized_dist = hamming_dist as f32 / code_bits as f32;
184 1.0 - 2.0 * normalized_dist
185 }
186}
187
188#[derive(Clone, Serialize, Deserialize)]
190pub struct SimpleLSH {
191 projections: Array2<f32>,
193 num_bits: usize,
195}
196
197impl SimpleLSH {
198 pub fn new(input_dims: usize, num_bits: usize) -> Self {
200 let mut rng = rand::thread_rng();
201
202 let projections =
204 Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
205
206 Self {
207 projections,
208 num_bits,
209 }
210 }
211}
212
213impl NeuralHash for SimpleLSH {
214 fn encode(&self, vector: &[f32]) -> Vec<u8> {
215 let input = Array1::from_vec(vector.to_vec());
216 let projections = self.projections.dot(&input);
217
218 let mut bits = vec![0u8; (self.num_bits + 7) / 8];
219
220 for (i, &val) in projections.iter().enumerate() {
221 if val > 0.0 {
222 let byte_idx = i / 8;
223 let bit_idx = i % 8;
224 bits[byte_idx] |= 1 << bit_idx;
225 }
226 }
227
228 bits
229 }
230
231 fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
232 code_a
233 .iter()
234 .zip(code_b.iter())
235 .map(|(a, b)| (a ^ b).count_ones())
236 .sum()
237 }
238
239 fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
240 let normalized_dist = hamming_dist as f32 / code_bits as f32;
241 1.0 - 2.0 * normalized_dist
242 }
243}
244
245pub struct HashIndex<H: NeuralHash + Clone> {
247 hasher: H,
249 tables: HashMap<Vec<u8>, Vec<VectorId>>,
251 vectors: HashMap<VectorId, Vec<f32>>,
253 code_bits: usize,
255}
256
257impl<H: NeuralHash + Clone> HashIndex<H> {
258 pub fn new(hasher: H, code_bits: usize) -> Self {
260 Self {
261 hasher,
262 tables: HashMap::new(),
263 vectors: HashMap::new(),
264 code_bits,
265 }
266 }
267
268 pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
270 let code = self.hasher.encode(&vector);
271
272 self.tables
273 .entry(code)
274 .or_insert_with(Vec::new)
275 .push(id.clone());
276
277 self.vectors.insert(id, vector);
278 }
279
280 pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
282 let query_code = self.hasher.encode(query);
283
284 let mut candidates = Vec::new();
285
286 for (code, ids) in &self.tables {
288 let hamming = self.hasher.hamming_distance(&query_code, code);
289
290 if hamming <= max_hamming {
291 for id in ids {
292 if let Some(vec) = self.vectors.get(id) {
293 let similarity = cosine_similarity(query, vec);
294 candidates.push((id.clone(), similarity));
295 }
296 }
297 }
298 }
299
300 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
302 candidates.truncate(k);
303 candidates
304 }
305
306 pub fn compression_ratio(&self) -> f32 {
308 if self.vectors.is_empty() {
309 return 0.0;
310 }
311
312 let original_size: usize = self
313 .vectors
314 .values()
315 .map(|v| v.len() * std::mem::size_of::<f32>())
316 .sum();
317
318 let compressed_size = self.tables.len() * ((self.code_bits + 7) / 8);
319
320 original_size as f32 / compressed_size as f32
321 }
322
323 pub fn stats(&self) -> HashIndexStats {
325 let buckets = self.tables.len();
326 let total_vectors = self.vectors.len();
327 let avg_bucket_size = if buckets > 0 {
328 total_vectors as f32 / buckets as f32
329 } else {
330 0.0
331 };
332
333 HashIndexStats {
334 total_vectors,
335 num_buckets: buckets,
336 avg_bucket_size,
337 compression_ratio: self.compression_ratio(),
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
344pub struct HashIndexStats {
345 pub total_vectors: usize,
346 pub num_buckets: usize,
347 pub avg_bucket_size: f32,
348 pub compression_ratio: f32,
349}
350
351fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
352 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
353 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
354 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
355
356 if norm_a > 0.0 && norm_b > 0.0 {
357 dot / (norm_a * norm_b)
358 } else {
359 0.0
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_deep_hash_encoding() {
369 let hash = DeepHashEmbedding::new(4, vec![8], 16);
370 let vector = vec![0.1, 0.2, 0.3, 0.4];
371
372 let code = hash.encode(&vector);
373 assert_eq!(code.len(), 2); }
375
376 #[test]
377 fn test_hamming_distance() {
378 let hash = DeepHashEmbedding::new(2, vec![], 8);
379
380 let code_a = vec![0b10101010];
381 let code_b = vec![0b11001100];
382
383 let dist = hash.hamming_distance(&code_a, &code_b);
384 assert_eq!(dist, 4); }
386
387 #[test]
388 fn test_lsh_encoding() {
389 let lsh = SimpleLSH::new(4, 16);
390 let vector = vec![1.0, 2.0, 3.0, 4.0];
391
392 let code = lsh.encode(&vector);
393 assert_eq!(code.len(), 2);
394
395 let code2 = lsh.encode(&vector);
397 assert_eq!(code, code2);
398 }
399
400 #[test]
401 fn test_hash_index() {
402 let lsh = SimpleLSH::new(3, 8);
403 let mut index = HashIndex::new(lsh, 8);
404
405 index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
407 index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
408 index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
409
410 let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
412
413 assert!(!results.is_empty());
414 let stats = index.stats();
415 assert_eq!(stats.total_vectors, 3);
416 }
417
418 #[test]
419 fn test_compression_ratio() {
420 let lsh = SimpleLSH::new(128, 32); let mut index = HashIndex::new(lsh, 32);
422
423 for i in 0..10 {
424 let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
425 index.insert(i.to_string(), vec);
426 }
427
428 let ratio = index.compression_ratio();
429 assert!(ratio > 1.0); }
431}