1use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13struct DistanceFn {
15 metric: DistanceMetric,
16}
17
18impl DistanceFn {
19 fn new(metric: DistanceMetric) -> Self {
20 Self { metric }
21 }
22}
23
24impl Distance<f32> for DistanceFn {
25 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26 distance(a, b, self.metric).unwrap_or(f32::MAX).max(0.0)
31 }
32}
33
34pub struct HnswIndex {
36 inner: Arc<RwLock<HnswInner>>,
37 config: HnswConfig,
38 metric: DistanceMetric,
39 dimensions: usize,
40}
41
42struct HnswInner {
43 hnsw: Hnsw<'static, f32, DistanceFn>,
44 vectors: DashMap<VectorId, Vec<f32>>,
45 id_to_idx: DashMap<VectorId, usize>,
46 idx_to_id: DashMap<usize, VectorId>,
47 next_idx: usize,
48}
49
50#[derive(Encode, Decode, Clone)]
52pub struct HnswState {
53 vectors: Vec<(String, Vec<f32>)>,
54 id_to_idx: Vec<(String, usize)>,
55 idx_to_id: Vec<(usize, String)>,
56 next_idx: usize,
57 config: SerializableHnswConfig,
58 dimensions: usize,
59 metric: SerializableDistanceMetric,
60}
61
62#[derive(Encode, Decode, Clone)]
63struct SerializableHnswConfig {
64 m: usize,
65 ef_construction: usize,
66 ef_search: usize,
67 max_elements: usize,
68}
69
70#[derive(Encode, Decode, Clone, Copy)]
71enum SerializableDistanceMetric {
72 Euclidean,
73 Cosine,
74 DotProduct,
75 Manhattan,
76}
77
78impl From<DistanceMetric> for SerializableDistanceMetric {
79 fn from(metric: DistanceMetric) -> Self {
80 match metric {
81 DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
82 DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
83 DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
84 DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
85 }
86 }
87}
88
89impl From<SerializableDistanceMetric> for DistanceMetric {
90 fn from(metric: SerializableDistanceMetric) -> Self {
91 match metric {
92 SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
93 SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
94 SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
95 SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
96 }
97 }
98}
99
100impl HnswIndex {
101 pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
103 let distance_fn = DistanceFn::new(metric);
104
105 let hnsw = Hnsw::<f32, DistanceFn>::new(
107 config.m,
108 config.max_elements,
109 dimensions,
110 config.ef_construction,
111 distance_fn,
112 );
113
114 Ok(Self {
115 inner: Arc::new(RwLock::new(HnswInner {
116 hnsw,
117 vectors: DashMap::new(),
118 id_to_idx: DashMap::new(),
119 idx_to_id: DashMap::new(),
120 next_idx: 0,
121 })),
122 config,
123 metric,
124 dimensions,
125 })
126 }
127
128 pub fn config(&self) -> &HnswConfig {
130 &self.config
131 }
132
133 pub fn set_ef_search(&mut self, ef_search: usize) {
138 self.config.ef_search = ef_search;
139 }
140
141 pub fn serialize(&self) -> Result<Vec<u8>> {
143 let inner = self.inner.read();
144
145 let state = HnswState {
146 vectors: inner
147 .vectors
148 .iter()
149 .map(|entry| (entry.key().clone(), entry.value().clone()))
150 .collect(),
151 id_to_idx: inner
152 .id_to_idx
153 .iter()
154 .map(|entry| (entry.key().clone(), *entry.value()))
155 .collect(),
156 idx_to_id: inner
157 .idx_to_id
158 .iter()
159 .map(|entry| (*entry.key(), entry.value().clone()))
160 .collect(),
161 next_idx: inner.next_idx,
162 config: SerializableHnswConfig {
163 m: self.config.m,
164 ef_construction: self.config.ef_construction,
165 ef_search: self.config.ef_search,
166 max_elements: self.config.max_elements,
167 },
168 dimensions: self.dimensions,
169 metric: self.metric.into(),
170 };
171
172 bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
173 RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
174 })
175 }
176
177 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
179 let (state, _): (HnswState, usize) =
180 bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
181 RuvectorError::SerializationError(format!(
182 "Failed to deserialize HNSW index: {}",
183 e
184 ))
185 })?;
186
187 let config = HnswConfig {
188 m: state.config.m,
189 ef_construction: state.config.ef_construction,
190 ef_search: state.config.ef_search,
191 max_elements: state.config.max_elements,
192 };
193
194 let dimensions = state.dimensions;
195 let metric: DistanceMetric = state.metric.into();
196
197 let distance_fn = DistanceFn::new(metric);
198 let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
199 config.m,
200 config.max_elements,
201 dimensions,
202 config.ef_construction,
203 distance_fn,
204 );
205
206 let vectors_lookup: std::collections::HashMap<&str, &Vec<f32>> = state
209 .vectors
210 .iter()
211 .map(|(id, v)| (id.as_str(), v))
212 .collect();
213
214 let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
215 let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
216
217 let mut sorted_entries: Vec<_> = idx_to_id
219 .iter()
220 .map(|e| (*e.key(), e.value().clone()))
221 .collect();
222 sorted_entries.sort_unstable_by_key(|(idx, _)| *idx);
223
224 for (idx, id) in &sorted_entries {
225 if let Some(vector) = vectors_lookup.get(id.as_str()) {
226 hnsw.insert_data(vector, *idx);
227 }
228 }
229
230 let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
231
232 Ok(Self {
233 inner: Arc::new(RwLock::new(HnswInner {
234 hnsw,
235 vectors: vectors_map,
236 id_to_idx,
237 idx_to_id,
238 next_idx: state.next_idx,
239 })),
240 config,
241 metric,
242 dimensions,
243 })
244 }
245
246 pub fn search_with_ef(
252 &self,
253 query: &[f32],
254 k: usize,
255 ef_search: usize,
256 ) -> Result<Vec<SearchResult>> {
257 if query.len() != self.dimensions {
258 return Err(RuvectorError::DimensionMismatch {
259 expected: self.dimensions,
260 actual: query.len(),
261 });
262 }
263
264 if k == 0 {
265 return Ok(vec![]);
266 }
267
268 let inner = self.inner.read();
269
270 if inner.vectors.is_empty() {
275 return Ok(vec![]);
276 }
277
278 let effective_ef = ef_search.max(k);
280
281 let neighbors = inner.hnsw.search(query, k, effective_ef);
283
284 let mut results: Vec<SearchResult> = neighbors
285 .into_iter()
286 .filter_map(|neighbor| {
287 inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
288 id: id.clone(),
289 score: neighbor.distance,
290 vector: None,
291 metadata: None,
292 })
293 })
294 .collect();
295
296 results.sort_unstable_by(|a, b| {
298 a.score
299 .partial_cmp(&b.score)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 });
302
303 Ok(results)
304 }
305}
306
307impl VectorIndex for HnswIndex {
308 fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
309 if vector.len() != self.dimensions {
310 return Err(RuvectorError::DimensionMismatch {
311 expected: self.dimensions,
312 actual: vector.len(),
313 });
314 }
315
316 let mut inner = self.inner.write();
317 let idx = inner.next_idx;
318 inner.next_idx += 1;
319
320 inner.hnsw.insert_data(&vector, idx);
322
323 inner.vectors.insert(id.clone(), vector);
325 inner.id_to_idx.insert(id.clone(), idx);
326 inner.idx_to_id.insert(idx, id);
327
328 Ok(())
329 }
330
331 fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
332 for (_, vector) in &entries {
334 if vector.len() != self.dimensions {
335 return Err(RuvectorError::DimensionMismatch {
336 expected: self.dimensions,
337 actual: vector.len(),
338 });
339 }
340 }
341
342 let mut inner = self.inner.write();
343
344 let data_with_ids: Vec<_> = entries
347 .iter()
348 .enumerate()
349 .map(|(i, (id, vector))| {
350 let idx = inner.next_idx + i;
351 (id.clone(), idx, vector.clone())
352 })
353 .collect();
354
355 inner.next_idx += entries.len();
357
358 for (_id, idx, vector) in &data_with_ids {
362 inner.hnsw.insert_data(vector, *idx);
363 }
364
365 for (id, idx, vector) in data_with_ids {
367 inner.vectors.insert(id.clone(), vector);
368 inner.id_to_idx.insert(id.clone(), idx);
369 inner.idx_to_id.insert(idx, id);
370 }
371
372 Ok(())
373 }
374
375 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
376 self.search_with_ef(query, k, self.config.ef_search)
378 }
379
380 fn remove(&mut self, id: &VectorId) -> Result<bool> {
381 let inner = self.inner.write();
382
383 let removed = inner.vectors.remove(id).is_some();
387
388 if removed {
389 if let Some((_, idx)) = inner.id_to_idx.remove(id) {
390 inner.idx_to_id.remove(&idx);
391 }
392 }
393
394 Ok(removed)
395 }
396
397 fn len(&self) -> usize {
398 self.inner.read().vectors.len()
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
407 use rand::Rng;
408 let mut rng = rand::thread_rng();
409
410 (0..count)
411 .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
412 .collect()
413 }
414
415 fn normalize_vector(v: &[f32]) -> Vec<f32> {
416 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
417 if norm > 0.0 {
418 v.iter().map(|x| x / norm).collect()
419 } else {
420 v.to_vec()
421 }
422 }
423
424 #[test]
425 fn test_hnsw_index_creation() -> Result<()> {
426 let config = HnswConfig::default();
427 let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
428 assert_eq!(index.len(), 0);
429 Ok(())
430 }
431
432 #[test]
433 fn test_hnsw_insert_and_search() -> Result<()> {
434 let config = HnswConfig {
435 m: 16,
436 ef_construction: 100,
437 ef_search: 50,
438 max_elements: 1000,
439 };
440
441 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
442
443 let vectors = generate_random_vectors(100, 128);
445 for (i, vector) in vectors.iter().enumerate() {
446 let normalized = normalize_vector(vector);
447 index.add(format!("vec_{}", i), normalized)?;
448 }
449
450 assert_eq!(index.len(), 100);
451
452 let query = normalize_vector(&vectors[0]);
454 let results = index.search(&query, 10)?;
455
456 assert!(!results.is_empty());
457 assert_eq!(results[0].id, "vec_0");
458
459 Ok(())
460 }
461
462 #[test]
463 fn test_hnsw_batch_insert() -> Result<()> {
464 let config = HnswConfig::default();
465 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
466
467 let vectors = generate_random_vectors(100, 128);
468 let entries: Vec<_> = vectors
469 .iter()
470 .enumerate()
471 .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
472 .collect();
473
474 index.add_batch(entries)?;
475 assert_eq!(index.len(), 100);
476
477 Ok(())
478 }
479
480 #[test]
481 fn test_hnsw_serialization() -> Result<()> {
482 let config = HnswConfig {
483 m: 16,
484 ef_construction: 100,
485 ef_search: 50,
486 max_elements: 1000,
487 };
488
489 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
490
491 let vectors = generate_random_vectors(50, 128);
493 for (i, vector) in vectors.iter().enumerate() {
494 let normalized = normalize_vector(vector);
495 index.add(format!("vec_{}", i), normalized)?;
496 }
497
498 let bytes = index.serialize()?;
500
501 let restored_index = HnswIndex::deserialize(&bytes)?;
503
504 assert_eq!(restored_index.len(), 50);
505
506 let query = normalize_vector(&vectors[0]);
508 let results = restored_index.search(&query, 5)?;
509
510 assert!(!results.is_empty());
511
512 Ok(())
513 }
514
515 #[test]
516 fn test_dimension_mismatch() -> Result<()> {
517 let config = HnswConfig::default();
518 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
519
520 let result = index.add("test".to_string(), vec![1.0; 64]);
521 assert!(result.is_err());
522
523 Ok(())
524 }
525}