ruvector_router_core/
index.rs

1//! HNSW index implementation
2
3use crate::distance::calculate_distance;
4use crate::error::{Result, VectorDbError};
5use crate::types::{DistanceMetric, SearchQuery, SearchResult};
6use parking_lot::RwLock;
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9use std::sync::Arc;
10
11/// HNSW Index configuration
12#[derive(Debug, Clone)]
13pub struct HnswConfig {
14    /// M parameter - number of connections per node
15    pub m: usize,
16    /// ef_construction - size of dynamic candidate list during construction
17    pub ef_construction: usize,
18    /// ef_search - size of dynamic candidate list during search
19    pub ef_search: usize,
20    /// Distance metric
21    pub metric: DistanceMetric,
22    /// Number of dimensions
23    pub dimensions: usize,
24}
25
26impl Default for HnswConfig {
27    fn default() -> Self {
28        Self {
29            m: 32,
30            ef_construction: 200,
31            ef_search: 100,
32            metric: DistanceMetric::Cosine,
33            dimensions: 384,
34        }
35    }
36}
37
38#[derive(Clone)]
39struct Neighbor {
40    id: String,
41    distance: f32,
42}
43
44impl PartialEq for Neighbor {
45    fn eq(&self, other: &Self) -> bool {
46        self.distance == other.distance
47    }
48}
49
50impl Eq for Neighbor {}
51
52impl PartialOrd for Neighbor {
53    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl Ord for Neighbor {
59    fn cmp(&self, other: &Self) -> Ordering {
60        // Reverse ordering for min-heap behavior
61        other
62            .distance
63            .partial_cmp(&self.distance)
64            .unwrap_or(Ordering::Equal)
65    }
66}
67
68/// Simplified HNSW index
69pub struct HnswIndex {
70    config: HnswConfig,
71    vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
72    graph: Arc<RwLock<HashMap<String, Vec<String>>>>,
73    entry_point: Arc<RwLock<Option<String>>>,
74}
75
76impl HnswIndex {
77    /// Create a new HNSW index
78    pub fn new(config: HnswConfig) -> Self {
79        Self {
80            config,
81            vectors: Arc::new(RwLock::new(HashMap::new())),
82            graph: Arc::new(RwLock::new(HashMap::new())),
83            entry_point: Arc::new(RwLock::new(None)),
84        }
85    }
86
87    /// Insert a vector into the index
88    pub fn insert(&self, id: String, vector: Vec<f32>) -> Result<()> {
89        if vector.len() != self.config.dimensions {
90            return Err(VectorDbError::InvalidDimensions {
91                expected: self.config.dimensions,
92                actual: vector.len(),
93            });
94        }
95
96        // Store vector
97        self.vectors.write().insert(id.clone(), vector.clone());
98
99        // Initialize graph connections
100        let mut graph = self.graph.write();
101        graph.insert(id.clone(), Vec::new());
102
103        // Set entry point if this is the first vector
104        let mut entry_point = self.entry_point.write();
105        if entry_point.is_none() {
106            *entry_point = Some(id.clone());
107            return Ok(());
108        }
109
110        // Find nearest neighbors
111        let neighbors =
112            self.search_knn_internal(&vector, self.config.ef_construction.min(self.config.m * 2));
113
114        // Connect to nearest neighbors (bidirectional)
115        for neighbor in neighbors.iter().take(self.config.m) {
116            graph.get_mut(&id).unwrap().push(neighbor.id.clone());
117
118            if let Some(neighbor_connections) = graph.get_mut(&neighbor.id) {
119                neighbor_connections.push(id.clone());
120
121                // Prune connections if needed
122                if neighbor_connections.len() > self.config.m * 2 {
123                    neighbor_connections.truncate(self.config.m);
124                }
125            }
126        }
127
128        Ok(())
129    }
130
131    /// Insert multiple vectors in batch
132    pub fn insert_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
133        for (id, vector) in vectors {
134            self.insert(id, vector)?;
135        }
136        Ok(())
137    }
138
139    /// Search for k nearest neighbors
140    pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>> {
141        let ef_search = query.ef_search.unwrap_or(self.config.ef_search);
142        let candidates = self.search_knn_internal(&query.vector, ef_search);
143
144        let mut results = Vec::new();
145        for candidate in candidates.into_iter().take(query.k) {
146            // Apply distance threshold if specified
147            if let Some(threshold) = query.threshold {
148                if candidate.distance > threshold {
149                    continue;
150                }
151            }
152
153            results.push(SearchResult {
154                id: candidate.id,
155                score: candidate.distance,
156                metadata: HashMap::new(),
157                vector: None,
158            });
159        }
160
161        Ok(results)
162    }
163
164    /// Internal k-NN search implementation
165    fn search_knn_internal(&self, query: &[f32], ef: usize) -> Vec<Neighbor> {
166        let vectors = self.vectors.read();
167        let graph = self.graph.read();
168        let entry_point = self.entry_point.read();
169
170        if entry_point.is_none() {
171            return Vec::new();
172        }
173
174        let entry_id = entry_point.as_ref().unwrap();
175        let mut visited = HashSet::new();
176        let mut candidates = BinaryHeap::new();
177        let mut result = BinaryHeap::new();
178
179        // Calculate distance to entry point
180        if let Some(entry_vec) = vectors.get(entry_id) {
181            let dist = calculate_distance(query, entry_vec, self.config.metric).unwrap_or(f32::MAX);
182
183            let neighbor = Neighbor {
184                id: entry_id.clone(),
185                distance: dist,
186            };
187
188            candidates.push(neighbor.clone());
189            result.push(neighbor);
190            visited.insert(entry_id.clone());
191        }
192
193        // Search phase
194        while let Some(current) = candidates.pop() {
195            // Check if we should continue
196            if let Some(furthest) = result.peek() {
197                if current.distance > furthest.distance && result.len() >= ef {
198                    break;
199                }
200            }
201
202            // Explore neighbors
203            if let Some(neighbors) = graph.get(&current.id) {
204                for neighbor_id in neighbors {
205                    if visited.contains(neighbor_id) {
206                        continue;
207                    }
208
209                    visited.insert(neighbor_id.clone());
210
211                    if let Some(neighbor_vec) = vectors.get(neighbor_id) {
212                        let dist = calculate_distance(query, neighbor_vec, self.config.metric)
213                            .unwrap_or(f32::MAX);
214
215                        let neighbor = Neighbor {
216                            id: neighbor_id.clone(),
217                            distance: dist,
218                        };
219
220                        // Add to candidates
221                        candidates.push(neighbor.clone());
222
223                        // Add to results if better than current worst
224                        if result.len() < ef {
225                            result.push(neighbor);
226                        } else if let Some(worst) = result.peek() {
227                            if dist < worst.distance {
228                                result.pop();
229                                result.push(neighbor);
230                            }
231                        }
232                    }
233                }
234            }
235        }
236
237        // Convert to sorted vector
238        let mut sorted_results: Vec<Neighbor> = result.into_iter().collect();
239        sorted_results.sort_by(|a, b| {
240            a.distance
241                .partial_cmp(&b.distance)
242                .unwrap_or(Ordering::Equal)
243        });
244
245        sorted_results
246    }
247
248    /// Remove a vector from the index
249    pub fn remove(&self, id: &str) -> Result<bool> {
250        let mut vectors = self.vectors.write();
251        let mut graph = self.graph.write();
252
253        if vectors.remove(id).is_none() {
254            return Ok(false);
255        }
256
257        // Remove from graph
258        graph.remove(id);
259
260        // Remove references from other nodes
261        for connections in graph.values_mut() {
262            connections.retain(|conn_id| conn_id != id);
263        }
264
265        // Update entry point if needed
266        let mut entry_point = self.entry_point.write();
267        if entry_point.as_ref() == Some(&id.to_string()) {
268            *entry_point = vectors.keys().next().cloned();
269        }
270
271        Ok(true)
272    }
273
274    /// Get total number of vectors in index
275    pub fn len(&self) -> usize {
276        self.vectors.read().len()
277    }
278
279    /// Check if index is empty
280    pub fn is_empty(&self) -> bool {
281        self.vectors.read().is_empty()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_hnsw_insert_and_search() {
291        let config = HnswConfig {
292            m: 16,
293            ef_construction: 100,
294            ef_search: 50,
295            metric: DistanceMetric::Euclidean,
296            dimensions: 3,
297        };
298
299        let index = HnswIndex::new(config);
300
301        // Insert vectors
302        index.insert("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
303        index.insert("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
304        index.insert("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
305
306        // Search
307        let query = SearchQuery {
308            vector: vec![0.9, 0.1, 0.0],
309            k: 2,
310            filters: None,
311            threshold: None,
312            ef_search: None,
313        };
314
315        let results = index.search(&query).unwrap();
316        assert_eq!(results.len(), 2);
317        assert_eq!(results[0].id, "v1"); // Should be closest
318    }
319}