Skip to main content

ruvector_router_core/
index.rs

1//! HNSW index implementation
2
3use crate::distance::calculate_distance;
4use crate::error::{Result, VectorDbError};
5use crate::types::{DistanceMetric, SearchQuery, SearchResult};
6use parking_lot::RwLock;
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9use std::sync::Arc;
10
11/// HNSW Index configuration
12#[derive(Debug, Clone)]
13pub struct HnswConfig {
14    /// M parameter - number of connections per node
15    pub m: usize,
16    /// ef_construction - size of dynamic candidate list during construction
17    pub ef_construction: usize,
18    /// ef_search - size of dynamic candidate list during search
19    pub ef_search: usize,
20    /// Distance metric
21    pub metric: DistanceMetric,
22    /// Number of dimensions
23    pub dimensions: usize,
24}
25
26impl Default for HnswConfig {
27    fn default() -> Self {
28        Self {
29            m: 32,
30            ef_construction: 200,
31            ef_search: 100,
32            metric: DistanceMetric::Cosine,
33            dimensions: 384,
34        }
35    }
36}
37
38#[derive(Clone)]
39struct Neighbor {
40    id: String,
41    distance: f32,
42}
43
44impl PartialEq for Neighbor {
45    fn eq(&self, other: &Self) -> bool {
46        self.distance == other.distance
47    }
48}
49
50impl Eq for Neighbor {}
51
52impl PartialOrd for Neighbor {
53    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl Ord for Neighbor {
59    fn cmp(&self, other: &Self) -> Ordering {
60        // Reverse ordering for min-heap behavior
61        other
62            .distance
63            .partial_cmp(&self.distance)
64            .unwrap_or(Ordering::Equal)
65    }
66}
67
68/// Simplified HNSW index
69pub struct HnswIndex {
70    config: HnswConfig,
71    vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
72    graph: Arc<RwLock<HashMap<String, Vec<String>>>>,
73    entry_point: Arc<RwLock<Option<String>>>,
74}
75
76impl HnswIndex {
77    /// Create a new HNSW index
78    pub fn new(config: HnswConfig) -> Self {
79        Self {
80            config,
81            vectors: Arc::new(RwLock::new(HashMap::new())),
82            graph: Arc::new(RwLock::new(HashMap::new())),
83            entry_point: Arc::new(RwLock::new(None)),
84        }
85    }
86
87    /// Insert a vector into the index
88    pub fn insert(&self, id: String, vector: Vec<f32>) -> Result<()> {
89        if vector.len() != self.config.dimensions {
90            return Err(VectorDbError::InvalidDimensions {
91                expected: self.config.dimensions,
92                actual: vector.len(),
93            });
94        }
95
96        // Store vector
97        self.vectors.write().insert(id.clone(), vector.clone());
98
99        // Initialize graph connections and check if this is the first vector
100        // IMPORTANT: Release all locks before calling search_knn_internal to avoid deadlock
101        // (parking_lot::RwLock is NOT reentrant)
102        let is_first = {
103            let mut graph = self.graph.write();
104            graph.insert(id.clone(), Vec::new());
105
106            let mut entry_point = self.entry_point.write();
107            if entry_point.is_none() {
108                *entry_point = Some(id.clone());
109                return Ok(());
110            }
111            false
112        }; // All locks released here
113
114        if is_first {
115            return Ok(());
116        }
117
118        // Find nearest neighbors (safe now - no locks held)
119        let neighbors =
120            self.search_knn_internal(&vector, self.config.ef_construction.min(self.config.m * 2));
121
122        // Re-acquire graph lock for modifications
123        let mut graph = self.graph.write();
124
125        // Connect to nearest neighbors (bidirectional)
126        for neighbor in neighbors.iter().take(self.config.m) {
127            if let Some(connections) = graph.get_mut(&id) {
128                connections.push(neighbor.id.clone());
129            }
130
131            if let Some(neighbor_connections) = graph.get_mut(&neighbor.id) {
132                neighbor_connections.push(id.clone());
133
134                // Prune connections if needed
135                if neighbor_connections.len() > self.config.m * 2 {
136                    neighbor_connections.truncate(self.config.m);
137                }
138            }
139        }
140
141        Ok(())
142    }
143
144    /// Insert multiple vectors in batch
145    pub fn insert_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
146        for (id, vector) in vectors {
147            self.insert(id, vector)?;
148        }
149        Ok(())
150    }
151
152    /// Search for k nearest neighbors
153    pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>> {
154        let ef_search = query.ef_search.unwrap_or(self.config.ef_search);
155        let candidates = self.search_knn_internal(&query.vector, ef_search);
156
157        let mut results = Vec::new();
158        for candidate in candidates.into_iter().take(query.k) {
159            // Apply distance threshold if specified
160            if let Some(threshold) = query.threshold {
161                if candidate.distance > threshold {
162                    continue;
163                }
164            }
165
166            results.push(SearchResult {
167                id: candidate.id,
168                score: candidate.distance,
169                metadata: HashMap::new(),
170                vector: None,
171            });
172        }
173
174        Ok(results)
175    }
176
177    /// Internal k-NN search implementation
178    fn search_knn_internal(&self, query: &[f32], ef: usize) -> Vec<Neighbor> {
179        let vectors = self.vectors.read();
180        let graph = self.graph.read();
181        let entry_point = self.entry_point.read();
182
183        if entry_point.is_none() {
184            return Vec::new();
185        }
186
187        let entry_id = entry_point.as_ref().unwrap();
188        let mut visited = HashSet::new();
189        let mut candidates = BinaryHeap::new();
190        let mut result = BinaryHeap::new();
191
192        // Calculate distance to entry point
193        if let Some(entry_vec) = vectors.get(entry_id) {
194            let dist = calculate_distance(query, entry_vec, self.config.metric).unwrap_or(f32::MAX);
195
196            let neighbor = Neighbor {
197                id: entry_id.clone(),
198                distance: dist,
199            };
200
201            candidates.push(neighbor.clone());
202            result.push(neighbor);
203            visited.insert(entry_id.clone());
204        }
205
206        // Search phase
207        while let Some(current) = candidates.pop() {
208            // Check if we should continue
209            if let Some(furthest) = result.peek() {
210                if current.distance > furthest.distance && result.len() >= ef {
211                    break;
212                }
213            }
214
215            // Explore neighbors
216            if let Some(neighbors) = graph.get(&current.id) {
217                for neighbor_id in neighbors {
218                    if visited.contains(neighbor_id) {
219                        continue;
220                    }
221
222                    visited.insert(neighbor_id.clone());
223
224                    if let Some(neighbor_vec) = vectors.get(neighbor_id) {
225                        let dist = calculate_distance(query, neighbor_vec, self.config.metric)
226                            .unwrap_or(f32::MAX);
227
228                        let neighbor = Neighbor {
229                            id: neighbor_id.clone(),
230                            distance: dist,
231                        };
232
233                        // Add to candidates
234                        candidates.push(neighbor.clone());
235
236                        // Add to results if better than current worst
237                        if result.len() < ef {
238                            result.push(neighbor);
239                        } else if let Some(worst) = result.peek() {
240                            if dist < worst.distance {
241                                result.pop();
242                                result.push(neighbor);
243                            }
244                        }
245                    }
246                }
247            }
248        }
249
250        // Convert to sorted vector
251        let mut sorted_results: Vec<Neighbor> = result.into_iter().collect();
252        sorted_results.sort_by(|a, b| {
253            a.distance
254                .partial_cmp(&b.distance)
255                .unwrap_or(Ordering::Equal)
256        });
257
258        sorted_results
259    }
260
261    /// Remove a vector from the index
262    pub fn remove(&self, id: &str) -> Result<bool> {
263        let mut vectors = self.vectors.write();
264        let mut graph = self.graph.write();
265
266        if vectors.remove(id).is_none() {
267            return Ok(false);
268        }
269
270        // Remove from graph
271        graph.remove(id);
272
273        // Remove references from other nodes
274        for connections in graph.values_mut() {
275            connections.retain(|conn_id| conn_id != id);
276        }
277
278        // Update entry point if needed
279        let mut entry_point = self.entry_point.write();
280        if entry_point.as_ref() == Some(&id.to_string()) {
281            *entry_point = vectors.keys().next().cloned();
282        }
283
284        Ok(true)
285    }
286
287    /// Get total number of vectors in index
288    pub fn len(&self) -> usize {
289        self.vectors.read().len()
290    }
291
292    /// Check if index is empty
293    pub fn is_empty(&self) -> bool {
294        self.vectors.read().is_empty()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_hnsw_insert_and_search() {
304        let config = HnswConfig {
305            m: 16,
306            ef_construction: 100,
307            ef_search: 50,
308            metric: DistanceMetric::Euclidean,
309            dimensions: 3,
310        };
311
312        let index = HnswIndex::new(config);
313
314        // Insert vectors
315        index.insert("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
316        index.insert("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
317        index.insert("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
318
319        // Search
320        let query = SearchQuery {
321            vector: vec![0.9, 0.1, 0.0],
322            k: 2,
323            filters: None,
324            threshold: None,
325            ef_search: None,
326        };
327
328        let results = index.search(&query).unwrap();
329        assert_eq!(results.len(), 2);
330        assert_eq!(results[0].id, "v1"); // Should be closest
331    }
332
333    #[test]
334    fn test_hnsw_multiple_inserts_no_deadlock() {
335        // Regression test for issue #133: VectorDb.insert() deadlocks on second call
336        // The bug was caused by holding write locks while calling search_knn_internal,
337        // which tries to acquire read locks on the same RwLocks (parking_lot is not reentrant)
338        let config = HnswConfig {
339            m: 16,
340            ef_construction: 100,
341            ef_search: 50,
342            metric: DistanceMetric::Cosine,
343            dimensions: 128,
344        };
345
346        let index = HnswIndex::new(config);
347
348        // Insert many vectors to ensure we exercise the KNN search path
349        for i in 0..20 {
350            let mut vector = vec![0.0f32; 128];
351            vector[i % 128] = 1.0;
352            index.insert(format!("v{}", i), vector).unwrap();
353        }
354
355        assert_eq!(index.len(), 20);
356
357        // Verify search still works
358        let query = SearchQuery {
359            vector: vec![1.0; 128],
360            k: 5,
361            filters: None,
362            threshold: None,
363            ef_search: None,
364        };
365
366        let results = index.search(&query).unwrap();
367        assert_eq!(results.len(), 5);
368    }
369
370    #[test]
371    fn test_hnsw_concurrent_inserts() {
372        use std::sync::Arc;
373        use std::thread;
374
375        let config = HnswConfig {
376            m: 16,
377            ef_construction: 100,
378            ef_search: 50,
379            metric: DistanceMetric::Euclidean,
380            dimensions: 3,
381        };
382
383        let index = Arc::new(HnswIndex::new(config));
384
385        // Spawn multiple threads to insert concurrently
386        let mut handles = vec![];
387        for t in 0..4 {
388            let index_clone = Arc::clone(&index);
389            let handle = thread::spawn(move || {
390                for i in 0..10 {
391                    let id = format!("t{}_v{}", t, i);
392                    let vector = vec![t as f32, i as f32, 0.0];
393                    index_clone.insert(id, vector).unwrap();
394                }
395            });
396            handles.push(handle);
397        }
398
399        // Wait for all threads
400        for handle in handles {
401            handle.join().unwrap();
402        }
403
404        assert_eq!(index.len(), 40);
405    }
406}