ruvector_core/index/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) index implementation
2
3use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13/// Distance function wrapper for hnsw_rs
14struct DistanceFn {
15    metric: DistanceMetric,
16}
17
18impl DistanceFn {
19    fn new(metric: DistanceMetric) -> Self {
20        Self { metric }
21    }
22}
23
24impl Distance<f32> for DistanceFn {
25    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26        distance(a, b, self.metric).unwrap_or(f32::MAX)
27    }
28}
29
30/// HNSW index wrapper
31pub struct HnswIndex {
32    inner: Arc<RwLock<HnswInner>>,
33    config: HnswConfig,
34    metric: DistanceMetric,
35    dimensions: usize,
36}
37
38struct HnswInner {
39    hnsw: Hnsw<'static, f32, DistanceFn>,
40    vectors: DashMap<VectorId, Vec<f32>>,
41    id_to_idx: DashMap<VectorId, usize>,
42    idx_to_id: DashMap<usize, VectorId>,
43    next_idx: usize,
44}
45
46/// Serializable HNSW index state
47#[derive(Encode, Decode, Clone)]
48pub struct HnswState {
49    vectors: Vec<(String, Vec<f32>)>,
50    id_to_idx: Vec<(String, usize)>,
51    idx_to_id: Vec<(usize, String)>,
52    next_idx: usize,
53    config: SerializableHnswConfig,
54    dimensions: usize,
55    metric: SerializableDistanceMetric,
56}
57
58#[derive(Encode, Decode, Clone)]
59struct SerializableHnswConfig {
60    m: usize,
61    ef_construction: usize,
62    ef_search: usize,
63    max_elements: usize,
64}
65
66#[derive(Encode, Decode, Clone, Copy)]
67enum SerializableDistanceMetric {
68    Euclidean,
69    Cosine,
70    DotProduct,
71    Manhattan,
72}
73
74impl From<DistanceMetric> for SerializableDistanceMetric {
75    fn from(metric: DistanceMetric) -> Self {
76        match metric {
77            DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
78            DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
79            DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
80            DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
81        }
82    }
83}
84
85impl From<SerializableDistanceMetric> for DistanceMetric {
86    fn from(metric: SerializableDistanceMetric) -> Self {
87        match metric {
88            SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
89            SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
90            SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
91            SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
92        }
93    }
94}
95
96impl HnswIndex {
97    /// Create a new HNSW index
98    pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
99        let distance_fn = DistanceFn::new(metric);
100
101        // Create HNSW with configured parameters
102        let hnsw = Hnsw::<f32, DistanceFn>::new(
103            config.m,
104            config.max_elements,
105            dimensions,
106            config.ef_construction,
107            distance_fn,
108        );
109
110        Ok(Self {
111            inner: Arc::new(RwLock::new(HnswInner {
112                hnsw,
113                vectors: DashMap::new(),
114                id_to_idx: DashMap::new(),
115                idx_to_id: DashMap::new(),
116                next_idx: 0,
117            })),
118            config,
119            metric,
120            dimensions,
121        })
122    }
123
124    /// Get configuration
125    pub fn config(&self) -> &HnswConfig {
126        &self.config
127    }
128
129    /// Set efSearch parameter for query-time accuracy tuning
130    pub fn set_ef_search(&mut self, _ef_search: usize) {
131        // Note: hnsw_rs controls ef_search via the search method's knbn parameter
132        // We store it in config and use it in search_with_ef
133    }
134
135    /// Serialize the index to bytes using bincode
136    pub fn serialize(&self) -> Result<Vec<u8>> {
137        let inner = self.inner.read();
138
139        let state = HnswState {
140            vectors: inner
141                .vectors
142                .iter()
143                .map(|entry| (entry.key().clone(), entry.value().clone()))
144                .collect(),
145            id_to_idx: inner
146                .id_to_idx
147                .iter()
148                .map(|entry| (entry.key().clone(), *entry.value()))
149                .collect(),
150            idx_to_id: inner
151                .idx_to_id
152                .iter()
153                .map(|entry| (*entry.key(), entry.value().clone()))
154                .collect(),
155            next_idx: inner.next_idx,
156            config: SerializableHnswConfig {
157                m: self.config.m,
158                ef_construction: self.config.ef_construction,
159                ef_search: self.config.ef_search,
160                max_elements: self.config.max_elements,
161            },
162            dimensions: self.dimensions,
163            metric: self.metric.into(),
164        };
165
166        bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
167            RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
168        })
169    }
170
171    /// Deserialize the index from bytes using bincode
172    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
173        let (state, _): (HnswState, usize) =
174            bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
175                RuvectorError::SerializationError(format!(
176                    "Failed to deserialize HNSW index: {}",
177                    e
178                ))
179            })?;
180
181        let config = HnswConfig {
182            m: state.config.m,
183            ef_construction: state.config.ef_construction,
184            ef_search: state.config.ef_search,
185            max_elements: state.config.max_elements,
186        };
187
188        let dimensions = state.dimensions;
189        let metric: DistanceMetric = state.metric.into();
190
191        let distance_fn = DistanceFn::new(metric);
192        let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
193            config.m,
194            config.max_elements,
195            dimensions,
196            config.ef_construction,
197            distance_fn,
198        );
199
200        // Rebuild the index by inserting all vectors
201        let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
202        let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
203
204        // Insert vectors into HNSW in order
205        for entry in idx_to_id.iter() {
206            let idx = *entry.key();
207            let id = entry.value();
208            if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
209                // Use insert_data method with slice and idx
210                hnsw.insert_data(&vector.1, idx);
211            }
212        }
213
214        let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
215
216        Ok(Self {
217            inner: Arc::new(RwLock::new(HnswInner {
218                hnsw,
219                vectors: vectors_map,
220                id_to_idx,
221                idx_to_id,
222                next_idx: state.next_idx,
223            })),
224            config,
225            metric,
226            dimensions,
227        })
228    }
229
230    /// Search with custom efSearch parameter
231    pub fn search_with_ef(
232        &self,
233        query: &[f32],
234        k: usize,
235        ef_search: usize,
236    ) -> Result<Vec<SearchResult>> {
237        if query.len() != self.dimensions {
238            return Err(RuvectorError::DimensionMismatch {
239                expected: self.dimensions,
240                actual: query.len(),
241            });
242        }
243
244        let inner = self.inner.read();
245
246        // Use HNSW search with custom ef parameter (knbn)
247        let neighbors = inner.hnsw.search(query, k, ef_search);
248
249        Ok(neighbors
250            .into_iter()
251            .filter_map(|neighbor| {
252                inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
253                    id: id.clone(),
254                    score: neighbor.distance,
255                    vector: None,
256                    metadata: None,
257                })
258            })
259            .collect())
260    }
261}
262
263impl VectorIndex for HnswIndex {
264    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
265        if vector.len() != self.dimensions {
266            return Err(RuvectorError::DimensionMismatch {
267                expected: self.dimensions,
268                actual: vector.len(),
269            });
270        }
271
272        let mut inner = self.inner.write();
273        let idx = inner.next_idx;
274        inner.next_idx += 1;
275
276        // Insert into HNSW graph using insert_data
277        inner.hnsw.insert_data(&vector, idx);
278
279        // Store mappings
280        inner.vectors.insert(id.clone(), vector);
281        inner.id_to_idx.insert(id.clone(), idx);
282        inner.idx_to_id.insert(idx, id);
283
284        Ok(())
285    }
286
287    fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
288        // Validate all dimensions first
289        for (_, vector) in &entries {
290            if vector.len() != self.dimensions {
291                return Err(RuvectorError::DimensionMismatch {
292                    expected: self.dimensions,
293                    actual: vector.len(),
294                });
295            }
296        }
297
298        let mut inner = self.inner.write();
299
300        // Prepare batch data for parallel insertion
301        use rayon::prelude::*;
302
303        // First, assign indices and collect vector data
304        let data_with_ids: Vec<_> = entries
305            .iter()
306            .enumerate()
307            .map(|(i, (id, vector))| {
308                let idx = inner.next_idx + i;
309                (id.clone(), idx, vector.clone())
310            })
311            .collect();
312
313        // Update next_idx
314        inner.next_idx += entries.len();
315
316        // Insert into HNSW sequentially
317        // Note: Using sequential insertion to avoid Send requirements with RwLock guard
318        // For large batches, consider restructuring to use hnsw_rs parallel_insert
319        for (_id, idx, vector) in &data_with_ids {
320            inner.hnsw.insert_data(vector, *idx);
321        }
322
323        // Store mappings
324        for (id, idx, vector) in data_with_ids {
325            inner.vectors.insert(id.clone(), vector);
326            inner.id_to_idx.insert(id.clone(), idx);
327            inner.idx_to_id.insert(idx, id);
328        }
329
330        Ok(())
331    }
332
333    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
334        // Use configured ef_search
335        self.search_with_ef(query, k, self.config.ef_search)
336    }
337
338    fn remove(&mut self, id: &VectorId) -> Result<bool> {
339        let mut inner = self.inner.write();
340
341        // Note: hnsw_rs doesn't support direct deletion
342        // We remove from our mappings but the graph structure remains
343        // This is a known limitation of HNSW
344        let removed = inner.vectors.remove(id).is_some();
345
346        if removed {
347            if let Some((_, idx)) = inner.id_to_idx.remove(id) {
348                inner.idx_to_id.remove(&idx);
349            }
350        }
351
352        Ok(removed)
353    }
354
355    fn len(&self) -> usize {
356        self.inner.read().vectors.len()
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
365        use rand::Rng;
366        let mut rng = rand::thread_rng();
367
368        (0..count)
369            .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
370            .collect()
371    }
372
373    fn normalize_vector(v: &[f32]) -> Vec<f32> {
374        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
375        if norm > 0.0 {
376            v.iter().map(|x| x / norm).collect()
377        } else {
378            v.to_vec()
379        }
380    }
381
382    #[test]
383    fn test_hnsw_index_creation() -> Result<()> {
384        let config = HnswConfig::default();
385        let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
386        assert_eq!(index.len(), 0);
387        Ok(())
388    }
389
390    #[test]
391    fn test_hnsw_insert_and_search() -> Result<()> {
392        let config = HnswConfig {
393            m: 16,
394            ef_construction: 100,
395            ef_search: 50,
396            max_elements: 1000,
397        };
398
399        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
400
401        // Insert a few vectors
402        let vectors = generate_random_vectors(100, 128);
403        for (i, vector) in vectors.iter().enumerate() {
404            let normalized = normalize_vector(vector);
405            index.add(format!("vec_{}", i), normalized)?;
406        }
407
408        assert_eq!(index.len(), 100);
409
410        // Search for the first vector
411        let query = normalize_vector(&vectors[0]);
412        let results = index.search(&query, 10)?;
413
414        assert!(!results.is_empty());
415        assert_eq!(results[0].id, "vec_0");
416
417        Ok(())
418    }
419
420    #[test]
421    fn test_hnsw_batch_insert() -> Result<()> {
422        let config = HnswConfig::default();
423        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
424
425        let vectors = generate_random_vectors(100, 128);
426        let entries: Vec<_> = vectors
427            .iter()
428            .enumerate()
429            .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
430            .collect();
431
432        index.add_batch(entries)?;
433        assert_eq!(index.len(), 100);
434
435        Ok(())
436    }
437
438    #[test]
439    fn test_hnsw_serialization() -> Result<()> {
440        let config = HnswConfig {
441            m: 16,
442            ef_construction: 100,
443            ef_search: 50,
444            max_elements: 1000,
445        };
446
447        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
448
449        // Insert vectors
450        let vectors = generate_random_vectors(50, 128);
451        for (i, vector) in vectors.iter().enumerate() {
452            let normalized = normalize_vector(vector);
453            index.add(format!("vec_{}", i), normalized)?;
454        }
455
456        // Serialize
457        let bytes = index.serialize()?;
458
459        // Deserialize
460        let restored_index = HnswIndex::deserialize(&bytes)?;
461
462        assert_eq!(restored_index.len(), 50);
463
464        // Test search on restored index
465        let query = normalize_vector(&vectors[0]);
466        let results = restored_index.search(&query, 5)?;
467
468        assert!(!results.is_empty());
469
470        Ok(())
471    }
472
473    #[test]
474    fn test_dimension_mismatch() -> Result<()> {
475        let config = HnswConfig::default();
476        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
477
478        let result = index.add("test".to_string(), vec![1.0; 64]);
479        assert!(result.is_err());
480
481        Ok(())
482    }
483}