Skip to main content

ruvector_core/index/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) index implementation
2
3use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13/// Distance function wrapper for hnsw_rs
14struct DistanceFn {
15    metric: DistanceMetric,
16}
17
18impl DistanceFn {
19    fn new(metric: DistanceMetric) -> Self {
20        Self { metric }
21    }
22}
23
24impl Distance<f32> for DistanceFn {
25    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26        distance(a, b, self.metric).unwrap_or(f32::MAX)
27    }
28}
29
30/// HNSW index wrapper
31pub struct HnswIndex {
32    inner: Arc<RwLock<HnswInner>>,
33    config: HnswConfig,
34    metric: DistanceMetric,
35    dimensions: usize,
36}
37
38struct HnswInner {
39    hnsw: Hnsw<'static, f32, DistanceFn>,
40    vectors: DashMap<VectorId, Vec<f32>>,
41    id_to_idx: DashMap<VectorId, usize>,
42    idx_to_id: DashMap<usize, VectorId>,
43    next_idx: usize,
44}
45
46/// Serializable HNSW index state
47#[derive(Encode, Decode, Clone)]
48pub struct HnswState {
49    vectors: Vec<(String, Vec<f32>)>,
50    id_to_idx: Vec<(String, usize)>,
51    idx_to_id: Vec<(usize, String)>,
52    next_idx: usize,
53    config: SerializableHnswConfig,
54    dimensions: usize,
55    metric: SerializableDistanceMetric,
56}
57
58#[derive(Encode, Decode, Clone)]
59struct SerializableHnswConfig {
60    m: usize,
61    ef_construction: usize,
62    ef_search: usize,
63    max_elements: usize,
64}
65
66#[derive(Encode, Decode, Clone, Copy)]
67enum SerializableDistanceMetric {
68    Euclidean,
69    Cosine,
70    DotProduct,
71    Manhattan,
72}
73
74impl From<DistanceMetric> for SerializableDistanceMetric {
75    fn from(metric: DistanceMetric) -> Self {
76        match metric {
77            DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
78            DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
79            DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
80            DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
81        }
82    }
83}
84
85impl From<SerializableDistanceMetric> for DistanceMetric {
86    fn from(metric: SerializableDistanceMetric) -> Self {
87        match metric {
88            SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
89            SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
90            SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
91            SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
92        }
93    }
94}
95
96impl HnswIndex {
97    /// Create a new HNSW index
98    pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
99        let distance_fn = DistanceFn::new(metric);
100
101        // Create HNSW with configured parameters
102        let hnsw = Hnsw::<f32, DistanceFn>::new(
103            config.m,
104            config.max_elements,
105            dimensions,
106            config.ef_construction,
107            distance_fn,
108        );
109
110        Ok(Self {
111            inner: Arc::new(RwLock::new(HnswInner {
112                hnsw,
113                vectors: DashMap::new(),
114                id_to_idx: DashMap::new(),
115                idx_to_id: DashMap::new(),
116                next_idx: 0,
117            })),
118            config,
119            metric,
120            dimensions,
121        })
122    }
123
124    /// Get configuration
125    pub fn config(&self) -> &HnswConfig {
126        &self.config
127    }
128
129    /// Set efSearch parameter for query-time accuracy tuning
130    pub fn set_ef_search(&mut self, _ef_search: usize) {
131        // Note: hnsw_rs controls ef_search via the search method's knbn parameter
132        // We store it in config and use it in search_with_ef
133    }
134
135    /// Serialize the index to bytes using bincode
136    pub fn serialize(&self) -> Result<Vec<u8>> {
137        let inner = self.inner.read();
138
139        let state = HnswState {
140            vectors: inner
141                .vectors
142                .iter()
143                .map(|entry| (entry.key().clone(), entry.value().clone()))
144                .collect(),
145            id_to_idx: inner
146                .id_to_idx
147                .iter()
148                .map(|entry| (entry.key().clone(), *entry.value()))
149                .collect(),
150            idx_to_id: inner
151                .idx_to_id
152                .iter()
153                .map(|entry| (*entry.key(), entry.value().clone()))
154                .collect(),
155            next_idx: inner.next_idx,
156            config: SerializableHnswConfig {
157                m: self.config.m,
158                ef_construction: self.config.ef_construction,
159                ef_search: self.config.ef_search,
160                max_elements: self.config.max_elements,
161            },
162            dimensions: self.dimensions,
163            metric: self.metric.into(),
164        };
165
166        bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
167            RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
168        })
169    }
170
171    /// Deserialize the index from bytes using bincode
172    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
173        let (state, _): (HnswState, usize) =
174            bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
175                RuvectorError::SerializationError(format!(
176                    "Failed to deserialize HNSW index: {}",
177                    e
178                ))
179            })?;
180
181        let config = HnswConfig {
182            m: state.config.m,
183            ef_construction: state.config.ef_construction,
184            ef_search: state.config.ef_search,
185            max_elements: state.config.max_elements,
186        };
187
188        let dimensions = state.dimensions;
189        let metric: DistanceMetric = state.metric.into();
190
191        let distance_fn = DistanceFn::new(metric);
192        let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
193            config.m,
194            config.max_elements,
195            dimensions,
196            config.ef_construction,
197            distance_fn,
198        );
199
200        // Rebuild the index by inserting all vectors
201        let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
202        let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
203
204        // Insert vectors into HNSW in order
205        for entry in idx_to_id.iter() {
206            let idx = *entry.key();
207            let id = entry.value();
208            if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
209                // Use insert_data method with slice and idx
210                hnsw.insert_data(&vector.1, idx);
211            }
212        }
213
214        let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
215
216        Ok(Self {
217            inner: Arc::new(RwLock::new(HnswInner {
218                hnsw,
219                vectors: vectors_map,
220                id_to_idx,
221                idx_to_id,
222                next_idx: state.next_idx,
223            })),
224            config,
225            metric,
226            dimensions,
227        })
228    }
229
230    /// Search with custom efSearch parameter
231    pub fn search_with_ef(
232        &self,
233        query: &[f32],
234        k: usize,
235        ef_search: usize,
236    ) -> Result<Vec<SearchResult>> {
237        if query.len() != self.dimensions {
238            return Err(RuvectorError::DimensionMismatch {
239                expected: self.dimensions,
240                actual: query.len(),
241            });
242        }
243
244        let inner = self.inner.read();
245
246        // Use HNSW search with custom ef parameter (knbn)
247        let neighbors = inner.hnsw.search(query, k, ef_search);
248
249        Ok(neighbors
250            .into_iter()
251            .filter_map(|neighbor| {
252                inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
253                    id: id.clone(),
254                    score: neighbor.distance,
255                    vector: None,
256                    metadata: None,
257                })
258            })
259            .collect())
260    }
261}
262
263impl VectorIndex for HnswIndex {
264    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
265        if vector.len() != self.dimensions {
266            return Err(RuvectorError::DimensionMismatch {
267                expected: self.dimensions,
268                actual: vector.len(),
269            });
270        }
271
272        let mut inner = self.inner.write();
273        let idx = inner.next_idx;
274        inner.next_idx += 1;
275
276        // Insert into HNSW graph using insert_data
277        inner.hnsw.insert_data(&vector, idx);
278
279        // Store mappings
280        inner.vectors.insert(id.clone(), vector);
281        inner.id_to_idx.insert(id.clone(), idx);
282        inner.idx_to_id.insert(idx, id);
283
284        Ok(())
285    }
286
287    fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
288        // Validate all dimensions first
289        for (_, vector) in &entries {
290            if vector.len() != self.dimensions {
291                return Err(RuvectorError::DimensionMismatch {
292                    expected: self.dimensions,
293                    actual: vector.len(),
294                });
295            }
296        }
297
298        let mut inner = self.inner.write();
299
300        // Prepare batch data for insertion
301        // First, assign indices and collect vector data
302        let data_with_ids: Vec<_> = entries
303            .iter()
304            .enumerate()
305            .map(|(i, (id, vector))| {
306                let idx = inner.next_idx + i;
307                (id.clone(), idx, vector.clone())
308            })
309            .collect();
310
311        // Update next_idx
312        inner.next_idx += entries.len();
313
314        // Insert into HNSW sequentially
315        // Note: Using sequential insertion to avoid Send requirements with RwLock guard
316        // For large batches, consider restructuring to use hnsw_rs parallel_insert
317        for (_id, idx, vector) in &data_with_ids {
318            inner.hnsw.insert_data(vector, *idx);
319        }
320
321        // Store mappings
322        for (id, idx, vector) in data_with_ids {
323            inner.vectors.insert(id.clone(), vector);
324            inner.id_to_idx.insert(id.clone(), idx);
325            inner.idx_to_id.insert(idx, id);
326        }
327
328        Ok(())
329    }
330
331    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
332        // Use configured ef_search
333        self.search_with_ef(query, k, self.config.ef_search)
334    }
335
336    fn remove(&mut self, id: &VectorId) -> Result<bool> {
337        let inner = self.inner.write();
338
339        // Note: hnsw_rs doesn't support direct deletion
340        // We remove from our mappings but the graph structure remains
341        // This is a known limitation of HNSW
342        let removed = inner.vectors.remove(id).is_some();
343
344        if removed {
345            if let Some((_, idx)) = inner.id_to_idx.remove(id) {
346                inner.idx_to_id.remove(&idx);
347            }
348        }
349
350        Ok(removed)
351    }
352
353    fn len(&self) -> usize {
354        self.inner.read().vectors.len()
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
363        use rand::Rng;
364        let mut rng = rand::thread_rng();
365
366        (0..count)
367            .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
368            .collect()
369    }
370
371    fn normalize_vector(v: &[f32]) -> Vec<f32> {
372        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
373        if norm > 0.0 {
374            v.iter().map(|x| x / norm).collect()
375        } else {
376            v.to_vec()
377        }
378    }
379
380    #[test]
381    fn test_hnsw_index_creation() -> Result<()> {
382        let config = HnswConfig::default();
383        let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
384        assert_eq!(index.len(), 0);
385        Ok(())
386    }
387
388    #[test]
389    fn test_hnsw_insert_and_search() -> Result<()> {
390        let config = HnswConfig {
391            m: 16,
392            ef_construction: 100,
393            ef_search: 50,
394            max_elements: 1000,
395        };
396
397        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
398
399        // Insert a few vectors
400        let vectors = generate_random_vectors(100, 128);
401        for (i, vector) in vectors.iter().enumerate() {
402            let normalized = normalize_vector(vector);
403            index.add(format!("vec_{}", i), normalized)?;
404        }
405
406        assert_eq!(index.len(), 100);
407
408        // Search for the first vector
409        let query = normalize_vector(&vectors[0]);
410        let results = index.search(&query, 10)?;
411
412        assert!(!results.is_empty());
413        assert_eq!(results[0].id, "vec_0");
414
415        Ok(())
416    }
417
418    #[test]
419    fn test_hnsw_batch_insert() -> Result<()> {
420        let config = HnswConfig::default();
421        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
422
423        let vectors = generate_random_vectors(100, 128);
424        let entries: Vec<_> = vectors
425            .iter()
426            .enumerate()
427            .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
428            .collect();
429
430        index.add_batch(entries)?;
431        assert_eq!(index.len(), 100);
432
433        Ok(())
434    }
435
436    #[test]
437    fn test_hnsw_serialization() -> Result<()> {
438        let config = HnswConfig {
439            m: 16,
440            ef_construction: 100,
441            ef_search: 50,
442            max_elements: 1000,
443        };
444
445        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
446
447        // Insert vectors
448        let vectors = generate_random_vectors(50, 128);
449        for (i, vector) in vectors.iter().enumerate() {
450            let normalized = normalize_vector(vector);
451            index.add(format!("vec_{}", i), normalized)?;
452        }
453
454        // Serialize
455        let bytes = index.serialize()?;
456
457        // Deserialize
458        let restored_index = HnswIndex::deserialize(&bytes)?;
459
460        assert_eq!(restored_index.len(), 50);
461
462        // Test search on restored index
463        let query = normalize_vector(&vectors[0]);
464        let results = restored_index.search(&query, 5)?;
465
466        assert!(!results.is_empty());
467
468        Ok(())
469    }
470
471    #[test]
472    fn test_dimension_mismatch() -> Result<()> {
473        let config = HnswConfig::default();
474        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
475
476        let result = index.add("test".to_string(), vec![1.0; 64]);
477        assert!(result.is_err());
478
479        Ok(())
480    }
481}