Skip to main content

ruvector_core/index/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) index implementation
2
3use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13/// Distance function wrapper for hnsw_rs
14struct DistanceFn {
15    metric: DistanceMetric,
16}
17
18impl DistanceFn {
19    fn new(metric: DistanceMetric) -> Self {
20        Self { metric }
21    }
22}
23
24impl Distance<f32> for DistanceFn {
25    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26        // hnsw_rs asserts `dist_to_ref >= 0` in its search loop.  Clamp any
27        // tiny negative values caused by floating-point rounding (e.g. cosine
28        // distance between two nearly-identical normalised vectors can be
29        // marginally below zero).  f32::MAX is the safe sentinel for errors.
30        distance(a, b, self.metric).unwrap_or(f32::MAX).max(0.0)
31    }
32}
33
34/// HNSW index wrapper
35pub struct HnswIndex {
36    inner: Arc<RwLock<HnswInner>>,
37    config: HnswConfig,
38    metric: DistanceMetric,
39    dimensions: usize,
40}
41
42struct HnswInner {
43    hnsw: Hnsw<'static, f32, DistanceFn>,
44    vectors: DashMap<VectorId, Vec<f32>>,
45    id_to_idx: DashMap<VectorId, usize>,
46    idx_to_id: DashMap<usize, VectorId>,
47    next_idx: usize,
48}
49
50/// Serializable HNSW index state
51#[derive(Encode, Decode, Clone)]
52pub struct HnswState {
53    vectors: Vec<(String, Vec<f32>)>,
54    id_to_idx: Vec<(String, usize)>,
55    idx_to_id: Vec<(usize, String)>,
56    next_idx: usize,
57    config: SerializableHnswConfig,
58    dimensions: usize,
59    metric: SerializableDistanceMetric,
60}
61
62#[derive(Encode, Decode, Clone)]
63struct SerializableHnswConfig {
64    m: usize,
65    ef_construction: usize,
66    ef_search: usize,
67    max_elements: usize,
68}
69
70#[derive(Encode, Decode, Clone, Copy)]
71enum SerializableDistanceMetric {
72    Euclidean,
73    Cosine,
74    DotProduct,
75    Manhattan,
76}
77
78impl From<DistanceMetric> for SerializableDistanceMetric {
79    fn from(metric: DistanceMetric) -> Self {
80        match metric {
81            DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
82            DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
83            DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
84            DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
85        }
86    }
87}
88
89impl From<SerializableDistanceMetric> for DistanceMetric {
90    fn from(metric: SerializableDistanceMetric) -> Self {
91        match metric {
92            SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
93            SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
94            SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
95            SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
96        }
97    }
98}
99
100impl HnswIndex {
101    /// Create a new HNSW index
102    pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
103        let distance_fn = DistanceFn::new(metric);
104
105        // Create HNSW with configured parameters
106        let hnsw = Hnsw::<f32, DistanceFn>::new(
107            config.m,
108            config.max_elements,
109            dimensions,
110            config.ef_construction,
111            distance_fn,
112        );
113
114        Ok(Self {
115            inner: Arc::new(RwLock::new(HnswInner {
116                hnsw,
117                vectors: DashMap::new(),
118                id_to_idx: DashMap::new(),
119                idx_to_id: DashMap::new(),
120                next_idx: 0,
121            })),
122            config,
123            metric,
124            dimensions,
125        })
126    }
127
128    /// Get configuration
129    pub fn config(&self) -> &HnswConfig {
130        &self.config
131    }
132
133    /// Set efSearch parameter for query-time accuracy tuning.
134    ///
135    /// Higher values increase recall at the cost of search latency.
136    /// Typical range: 50–500. Must be >= k for meaningful results.
137    pub fn set_ef_search(&mut self, ef_search: usize) {
138        self.config.ef_search = ef_search;
139    }
140
141    /// Serialize the index to bytes using bincode
142    pub fn serialize(&self) -> Result<Vec<u8>> {
143        let inner = self.inner.read();
144
145        let state = HnswState {
146            vectors: inner
147                .vectors
148                .iter()
149                .map(|entry| (entry.key().clone(), entry.value().clone()))
150                .collect(),
151            id_to_idx: inner
152                .id_to_idx
153                .iter()
154                .map(|entry| (entry.key().clone(), *entry.value()))
155                .collect(),
156            idx_to_id: inner
157                .idx_to_id
158                .iter()
159                .map(|entry| (*entry.key(), entry.value().clone()))
160                .collect(),
161            next_idx: inner.next_idx,
162            config: SerializableHnswConfig {
163                m: self.config.m,
164                ef_construction: self.config.ef_construction,
165                ef_search: self.config.ef_search,
166                max_elements: self.config.max_elements,
167            },
168            dimensions: self.dimensions,
169            metric: self.metric.into(),
170        };
171
172        bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
173            RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
174        })
175    }
176
177    /// Deserialize the index from bytes using bincode
178    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
179        let (state, _): (HnswState, usize) =
180            bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
181                RuvectorError::SerializationError(format!(
182                    "Failed to deserialize HNSW index: {}",
183                    e
184                ))
185            })?;
186
187        let config = HnswConfig {
188            m: state.config.m,
189            ef_construction: state.config.ef_construction,
190            ef_search: state.config.ef_search,
191            max_elements: state.config.max_elements,
192        };
193
194        let dimensions = state.dimensions;
195        let metric: DistanceMetric = state.metric.into();
196
197        let distance_fn = DistanceFn::new(metric);
198        let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
199            config.m,
200            config.max_elements,
201            dimensions,
202            config.ef_construction,
203            distance_fn,
204        );
205
206        // Rebuild the index by inserting all vectors.
207        // Build a HashMap first to avoid O(n^2) linear search in the loop below.
208        let vectors_lookup: std::collections::HashMap<&str, &Vec<f32>> = state
209            .vectors
210            .iter()
211            .map(|(id, v)| (id.as_str(), v))
212            .collect();
213
214        let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
215        let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
216
217        // Insert vectors into HNSW in index order for deterministic reconstruction.
218        let mut sorted_entries: Vec<_> = idx_to_id
219            .iter()
220            .map(|e| (*e.key(), e.value().clone()))
221            .collect();
222        sorted_entries.sort_unstable_by_key(|(idx, _)| *idx);
223
224        for (idx, id) in &sorted_entries {
225            if let Some(vector) = vectors_lookup.get(id.as_str()) {
226                hnsw.insert_data(vector, *idx);
227            }
228        }
229
230        let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
231
232        Ok(Self {
233            inner: Arc::new(RwLock::new(HnswInner {
234                hnsw,
235                vectors: vectors_map,
236                id_to_idx,
237                idx_to_id,
238                next_idx: state.next_idx,
239            })),
240            config,
241            metric,
242            dimensions,
243        })
244    }
245
246    /// Search with custom efSearch parameter.
247    ///
248    /// `ef_search` must be >= `k`; values smaller than `k` are clamped to `k`
249    /// to avoid silent under-recall.  Results are returned sorted by ascending
250    /// distance (closest first).
251    pub fn search_with_ef(
252        &self,
253        query: &[f32],
254        k: usize,
255        ef_search: usize,
256    ) -> Result<Vec<SearchResult>> {
257        if query.len() != self.dimensions {
258            return Err(RuvectorError::DimensionMismatch {
259                expected: self.dimensions,
260                actual: query.len(),
261            });
262        }
263
264        if k == 0 {
265            return Ok(vec![]);
266        }
267
268        let inner = self.inner.read();
269
270        // hnsw_rs panics in its BinaryHeap traversal when the index is empty
271        // or contains only a single element (the candidate/return-point loop
272        // calls .peek().unwrap() without an emptiness guard).  Return early
273        // to surface a clean error instead of an assertion panic.
274        if inner.vectors.is_empty() {
275            return Ok(vec![]);
276        }
277
278        // ef_search < k causes hnsw_rs to return fewer than k candidates; clamp.
279        let effective_ef = ef_search.max(k);
280
281        // Use HNSW search with custom ef parameter (knbn)
282        let neighbors = inner.hnsw.search(query, k, effective_ef);
283
284        let mut results: Vec<SearchResult> = neighbors
285            .into_iter()
286            .filter_map(|neighbor| {
287                inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
288                    id: id.clone(),
289                    score: neighbor.distance,
290                    vector: None,
291                    metadata: None,
292                })
293            })
294            .collect();
295
296        // hnsw_rs does not guarantee sort order — ensure ascending distance.
297        results.sort_unstable_by(|a, b| {
298            a.score
299                .partial_cmp(&b.score)
300                .unwrap_or(std::cmp::Ordering::Equal)
301        });
302
303        Ok(results)
304    }
305}
306
307impl VectorIndex for HnswIndex {
308    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
309        if vector.len() != self.dimensions {
310            return Err(RuvectorError::DimensionMismatch {
311                expected: self.dimensions,
312                actual: vector.len(),
313            });
314        }
315
316        let mut inner = self.inner.write();
317        let idx = inner.next_idx;
318        inner.next_idx += 1;
319
320        // Insert into HNSW graph using insert_data
321        inner.hnsw.insert_data(&vector, idx);
322
323        // Store mappings
324        inner.vectors.insert(id.clone(), vector);
325        inner.id_to_idx.insert(id.clone(), idx);
326        inner.idx_to_id.insert(idx, id);
327
328        Ok(())
329    }
330
331    fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
332        // Validate all dimensions first
333        for (_, vector) in &entries {
334            if vector.len() != self.dimensions {
335                return Err(RuvectorError::DimensionMismatch {
336                    expected: self.dimensions,
337                    actual: vector.len(),
338                });
339            }
340        }
341
342        let mut inner = self.inner.write();
343
344        // Prepare batch data for insertion
345        // First, assign indices and collect vector data
346        let data_with_ids: Vec<_> = entries
347            .iter()
348            .enumerate()
349            .map(|(i, (id, vector))| {
350                let idx = inner.next_idx + i;
351                (id.clone(), idx, vector.clone())
352            })
353            .collect();
354
355        // Update next_idx
356        inner.next_idx += entries.len();
357
358        // Insert into HNSW sequentially
359        // Note: Using sequential insertion to avoid Send requirements with RwLock guard
360        // For large batches, consider restructuring to use hnsw_rs parallel_insert
361        for (_id, idx, vector) in &data_with_ids {
362            inner.hnsw.insert_data(vector, *idx);
363        }
364
365        // Store mappings
366        for (id, idx, vector) in data_with_ids {
367            inner.vectors.insert(id.clone(), vector);
368            inner.id_to_idx.insert(id.clone(), idx);
369            inner.idx_to_id.insert(idx, id);
370        }
371
372        Ok(())
373    }
374
375    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
376        // Use configured ef_search
377        self.search_with_ef(query, k, self.config.ef_search)
378    }
379
380    fn remove(&mut self, id: &VectorId) -> Result<bool> {
381        let inner = self.inner.write();
382
383        // Note: hnsw_rs doesn't support direct deletion
384        // We remove from our mappings but the graph structure remains
385        // This is a known limitation of HNSW
386        let removed = inner.vectors.remove(id).is_some();
387
388        if removed {
389            if let Some((_, idx)) = inner.id_to_idx.remove(id) {
390                inner.idx_to_id.remove(&idx);
391            }
392        }
393
394        Ok(removed)
395    }
396
397    fn len(&self) -> usize {
398        self.inner.read().vectors.len()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
407        use rand::Rng;
408        let mut rng = rand::thread_rng();
409
410        (0..count)
411            .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
412            .collect()
413    }
414
415    fn normalize_vector(v: &[f32]) -> Vec<f32> {
416        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
417        if norm > 0.0 {
418            v.iter().map(|x| x / norm).collect()
419        } else {
420            v.to_vec()
421        }
422    }
423
424    #[test]
425    fn test_hnsw_index_creation() -> Result<()> {
426        let config = HnswConfig::default();
427        let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
428        assert_eq!(index.len(), 0);
429        Ok(())
430    }
431
432    #[test]
433    fn test_hnsw_insert_and_search() -> Result<()> {
434        let config = HnswConfig {
435            m: 16,
436            ef_construction: 100,
437            ef_search: 50,
438            max_elements: 1000,
439        };
440
441        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
442
443        // Insert a few vectors
444        let vectors = generate_random_vectors(100, 128);
445        for (i, vector) in vectors.iter().enumerate() {
446            let normalized = normalize_vector(vector);
447            index.add(format!("vec_{}", i), normalized)?;
448        }
449
450        assert_eq!(index.len(), 100);
451
452        // Search for the first vector
453        let query = normalize_vector(&vectors[0]);
454        let results = index.search(&query, 10)?;
455
456        assert!(!results.is_empty());
457        assert_eq!(results[0].id, "vec_0");
458
459        Ok(())
460    }
461
462    #[test]
463    fn test_hnsw_batch_insert() -> Result<()> {
464        let config = HnswConfig::default();
465        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
466
467        let vectors = generate_random_vectors(100, 128);
468        let entries: Vec<_> = vectors
469            .iter()
470            .enumerate()
471            .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
472            .collect();
473
474        index.add_batch(entries)?;
475        assert_eq!(index.len(), 100);
476
477        Ok(())
478    }
479
480    #[test]
481    fn test_hnsw_serialization() -> Result<()> {
482        let config = HnswConfig {
483            m: 16,
484            ef_construction: 100,
485            ef_search: 50,
486            max_elements: 1000,
487        };
488
489        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
490
491        // Insert vectors
492        let vectors = generate_random_vectors(50, 128);
493        for (i, vector) in vectors.iter().enumerate() {
494            let normalized = normalize_vector(vector);
495            index.add(format!("vec_{}", i), normalized)?;
496        }
497
498        // Serialize
499        let bytes = index.serialize()?;
500
501        // Deserialize
502        let restored_index = HnswIndex::deserialize(&bytes)?;
503
504        assert_eq!(restored_index.len(), 50);
505
506        // Test search on restored index
507        let query = normalize_vector(&vectors[0]);
508        let results = restored_index.search(&query, 5)?;
509
510        assert!(!results.is_empty());
511
512        Ok(())
513    }
514
515    #[test]
516    fn test_dimension_mismatch() -> Result<()> {
517        let config = HnswConfig::default();
518        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
519
520        let result = index.add("test".to_string(), vec![1.0; 64]);
521        assert!(result.is_err());
522
523        Ok(())
524    }
525}