ruvector_core/index/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) index implementation
2
3use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13/// Distance function wrapper for hnsw_rs
14struct DistanceFn {
15    metric: DistanceMetric,
16}
17
18impl DistanceFn {
19    fn new(metric: DistanceMetric) -> Self {
20        Self { metric }
21    }
22}
23
24impl Distance<f32> for DistanceFn {
25    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26        distance(a, b, self.metric).unwrap_or(f32::MAX)
27    }
28}
29
30/// HNSW index wrapper
31pub struct HnswIndex {
32    inner: Arc<RwLock<HnswInner>>,
33    config: HnswConfig,
34    metric: DistanceMetric,
35    dimensions: usize,
36}
37
38struct HnswInner {
39    hnsw: Hnsw<'static, f32, DistanceFn>,
40    vectors: DashMap<VectorId, Vec<f32>>,
41    id_to_idx: DashMap<VectorId, usize>,
42    idx_to_id: DashMap<usize, VectorId>,
43    next_idx: usize,
44}
45
46/// Serializable HNSW index state
47#[derive(Encode, Decode, Clone)]
48pub struct HnswState {
49    vectors: Vec<(String, Vec<f32>)>,
50    id_to_idx: Vec<(String, usize)>,
51    idx_to_id: Vec<(usize, String)>,
52    next_idx: usize,
53    config: SerializableHnswConfig,
54    dimensions: usize,
55    metric: SerializableDistanceMetric,
56}
57
58#[derive(Encode, Decode, Clone)]
59struct SerializableHnswConfig {
60    m: usize,
61    ef_construction: usize,
62    ef_search: usize,
63    max_elements: usize,
64}
65
66#[derive(Encode, Decode, Clone, Copy)]
67enum SerializableDistanceMetric {
68    Euclidean,
69    Cosine,
70    DotProduct,
71    Manhattan,
72}
73
74impl From<DistanceMetric> for SerializableDistanceMetric {
75    fn from(metric: DistanceMetric) -> Self {
76        match metric {
77            DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
78            DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
79            DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
80            DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
81        }
82    }
83}
84
85impl From<SerializableDistanceMetric> for DistanceMetric {
86    fn from(metric: SerializableDistanceMetric) -> Self {
87        match metric {
88            SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
89            SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
90            SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
91            SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
92        }
93    }
94}
95
96impl HnswIndex {
97    /// Create a new HNSW index
98    pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
99        let distance_fn = DistanceFn::new(metric);
100
101        // Create HNSW with configured parameters
102        let hnsw = Hnsw::<f32, DistanceFn>::new(
103            config.m,
104            config.max_elements,
105            dimensions,
106            config.ef_construction,
107            distance_fn,
108        );
109
110        Ok(Self {
111            inner: Arc::new(RwLock::new(HnswInner {
112                hnsw,
113                vectors: DashMap::new(),
114                id_to_idx: DashMap::new(),
115                idx_to_id: DashMap::new(),
116                next_idx: 0,
117            })),
118            config,
119            metric,
120            dimensions,
121        })
122    }
123
124    /// Get configuration
125    pub fn config(&self) -> &HnswConfig {
126        &self.config
127    }
128
129    /// Set efSearch parameter for query-time accuracy tuning
130    pub fn set_ef_search(&mut self, _ef_search: usize) {
131        // Note: hnsw_rs controls ef_search via the search method's knbn parameter
132        // We store it in config and use it in search_with_ef
133    }
134
135    /// Serialize the index to bytes using bincode
136    pub fn serialize(&self) -> Result<Vec<u8>> {
137        let inner = self.inner.read();
138
139        let state = HnswState {
140            vectors: inner
141                .vectors
142                .iter()
143                .map(|entry| (entry.key().clone(), entry.value().clone()))
144                .collect(),
145            id_to_idx: inner
146                .id_to_idx
147                .iter()
148                .map(|entry| (entry.key().clone(), *entry.value()))
149                .collect(),
150            idx_to_id: inner
151                .idx_to_id
152                .iter()
153                .map(|entry| (*entry.key(), entry.value().clone()))
154                .collect(),
155            next_idx: inner.next_idx,
156            config: SerializableHnswConfig {
157                m: self.config.m,
158                ef_construction: self.config.ef_construction,
159                ef_search: self.config.ef_search,
160                max_elements: self.config.max_elements,
161            },
162            dimensions: self.dimensions,
163            metric: self.metric.into(),
164        };
165
166        bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
167            RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
168        })
169    }
170
171    /// Deserialize the index from bytes using bincode
172    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
173        let (state, _): (HnswState, usize) =
174            bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
175                RuvectorError::SerializationError(format!(
176                    "Failed to deserialize HNSW index: {}",
177                    e
178                ))
179            })?;
180
181        let config = HnswConfig {
182            m: state.config.m,
183            ef_construction: state.config.ef_construction,
184            ef_search: state.config.ef_search,
185            max_elements: state.config.max_elements,
186        };
187
188        let dimensions = state.dimensions;
189        let metric: DistanceMetric = state.metric.into();
190
191        let distance_fn = DistanceFn::new(metric);
192        let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
193            config.m,
194            config.max_elements,
195            dimensions,
196            config.ef_construction,
197            distance_fn,
198        );
199
200        // Rebuild the index by inserting all vectors
201        let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
202        let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
203
204        // P-1: O(N²) → O(N log N) optimization (ADR-0012)
205        // Build HashMap for O(1) vector lookups instead of O(N) linear scan
206        let vectors_by_id: std::collections::HashMap<VectorId, Vec<f32>> =
207            state.vectors.iter().cloned().collect();
208
209        // Insert vectors into HNSW in order - now O(N) total instead of O(N²)
210        for entry in idx_to_id.iter() {
211            let idx = *entry.key();
212            let id = entry.value();
213            // O(1) HashMap lookup instead of O(N) linear search
214            if let Some(vector) = vectors_by_id.get(id) {
215                // Use insert_data method with slice and idx
216                hnsw.insert_data(vector, idx);
217            }
218        }
219
220        let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
221
222        Ok(Self {
223            inner: Arc::new(RwLock::new(HnswInner {
224                hnsw,
225                vectors: vectors_map,
226                id_to_idx,
227                idx_to_id,
228                next_idx: state.next_idx,
229            })),
230            config,
231            metric,
232            dimensions,
233        })
234    }
235
236    /// Search with custom efSearch parameter
237    pub fn search_with_ef(
238        &self,
239        query: &[f32],
240        k: usize,
241        ef_search: usize,
242    ) -> Result<Vec<SearchResult>> {
243        if query.len() != self.dimensions {
244            return Err(RuvectorError::DimensionMismatch {
245                expected: self.dimensions,
246                actual: query.len(),
247            });
248        }
249
250        let inner = self.inner.read();
251
252        // Use HNSW search with custom ef parameter (knbn)
253        let neighbors = inner.hnsw.search(query, k, ef_search);
254
255        Ok(neighbors
256            .into_iter()
257            .filter_map(|neighbor| {
258                inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
259                    id: id.clone(),
260                    score: neighbor.distance,
261                    vector: None,
262                    metadata: None,
263                })
264            })
265            .collect())
266    }
267}
268
269impl VectorIndex for HnswIndex {
270    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
271        if vector.len() != self.dimensions {
272            return Err(RuvectorError::DimensionMismatch {
273                expected: self.dimensions,
274                actual: vector.len(),
275            });
276        }
277
278        let mut inner = self.inner.write();
279        let idx = inner.next_idx;
280        inner.next_idx += 1;
281
282        // Insert into HNSW graph using insert_data
283        inner.hnsw.insert_data(&vector, idx);
284
285        // Store mappings
286        inner.vectors.insert(id.clone(), vector);
287        inner.id_to_idx.insert(id.clone(), idx);
288        inner.idx_to_id.insert(idx, id);
289
290        Ok(())
291    }
292
293    fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
294        // Validate all dimensions first
295        for (_, vector) in &entries {
296            if vector.len() != self.dimensions {
297                return Err(RuvectorError::DimensionMismatch {
298                    expected: self.dimensions,
299                    actual: vector.len(),
300                });
301            }
302        }
303
304        let mut inner = self.inner.write();
305
306        // Prepare batch data for parallel insertion
307        use rayon::prelude::*;
308
309        // First, assign indices and collect vector data
310        let data_with_ids: Vec<_> = entries
311            .iter()
312            .enumerate()
313            .map(|(i, (id, vector))| {
314                let idx = inner.next_idx + i;
315                (id.clone(), idx, vector.clone())
316            })
317            .collect();
318
319        // Update next_idx
320        inner.next_idx += entries.len();
321
322        // Insert into HNSW sequentially
323        // Note: Using sequential insertion to avoid Send requirements with RwLock guard
324        // For large batches, consider restructuring to use hnsw_rs parallel_insert
325        for (_id, idx, vector) in &data_with_ids {
326            inner.hnsw.insert_data(vector, *idx);
327        }
328
329        // Store mappings
330        for (id, idx, vector) in data_with_ids {
331            inner.vectors.insert(id.clone(), vector);
332            inner.id_to_idx.insert(id.clone(), idx);
333            inner.idx_to_id.insert(idx, id);
334        }
335
336        Ok(())
337    }
338
339    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
340        // Use configured ef_search
341        self.search_with_ef(query, k, self.config.ef_search)
342    }
343
344    fn remove(&mut self, id: &VectorId) -> Result<bool> {
345        let mut inner = self.inner.write();
346
347        // Note: hnsw_rs doesn't support direct deletion
348        // We remove from our mappings but the graph structure remains
349        // This is a known limitation of HNSW
350        let removed = inner.vectors.remove(id).is_some();
351
352        if removed {
353            if let Some((_, idx)) = inner.id_to_idx.remove(id) {
354                inner.idx_to_id.remove(&idx);
355            }
356        }
357
358        Ok(removed)
359    }
360
361    fn len(&self) -> usize {
362        self.inner.read().vectors.len()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
371        use rand::Rng;
372        let mut rng = rand::thread_rng();
373
374        (0..count)
375            .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
376            .collect()
377    }
378
379    fn normalize_vector(v: &[f32]) -> Vec<f32> {
380        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
381        if norm > 0.0 {
382            v.iter().map(|x| x / norm).collect()
383        } else {
384            v.to_vec()
385        }
386    }
387
388    #[test]
389    fn test_hnsw_index_creation() -> Result<()> {
390        let config = HnswConfig::default();
391        let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
392        assert_eq!(index.len(), 0);
393        Ok(())
394    }
395
396    #[test]
397    fn test_hnsw_insert_and_search() -> Result<()> {
398        let config = HnswConfig {
399            m: 16,
400            ef_construction: 100,
401            ef_search: 50,
402            max_elements: 1000,
403        };
404
405        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
406
407        // Insert a few vectors
408        let vectors = generate_random_vectors(100, 128);
409        for (i, vector) in vectors.iter().enumerate() {
410            let normalized = normalize_vector(vector);
411            index.add(format!("vec_{}", i), normalized)?;
412        }
413
414        assert_eq!(index.len(), 100);
415
416        // Search for the first vector
417        let query = normalize_vector(&vectors[0]);
418        let results = index.search(&query, 10)?;
419
420        assert!(!results.is_empty());
421        assert_eq!(results[0].id, "vec_0");
422
423        Ok(())
424    }
425
426    #[test]
427    fn test_hnsw_batch_insert() -> Result<()> {
428        let config = HnswConfig::default();
429        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
430
431        let vectors = generate_random_vectors(100, 128);
432        let entries: Vec<_> = vectors
433            .iter()
434            .enumerate()
435            .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
436            .collect();
437
438        index.add_batch(entries)?;
439        assert_eq!(index.len(), 100);
440
441        Ok(())
442    }
443
444    #[test]
445    fn test_hnsw_serialization() -> Result<()> {
446        let config = HnswConfig {
447            m: 16,
448            ef_construction: 100,
449            ef_search: 50,
450            max_elements: 1000,
451        };
452
453        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
454
455        // Insert vectors
456        let vectors = generate_random_vectors(50, 128);
457        for (i, vector) in vectors.iter().enumerate() {
458            let normalized = normalize_vector(vector);
459            index.add(format!("vec_{}", i), normalized)?;
460        }
461
462        // Serialize
463        let bytes = index.serialize()?;
464
465        // Deserialize
466        let restored_index = HnswIndex::deserialize(&bytes)?;
467
468        assert_eq!(restored_index.len(), 50);
469
470        // Test search on restored index
471        let query = normalize_vector(&vectors[0]);
472        let results = restored_index.search(&query, 5)?;
473
474        assert!(!results.is_empty());
475
476        Ok(())
477    }
478
479    #[test]
480    fn test_dimension_mismatch() -> Result<()> {
481        let config = HnswConfig::default();
482        let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
483
484        let result = index.add("test".to_string(), vec![1.0; 64]);
485        assert!(result.is_err());
486
487        Ok(())
488    }
489}