ruvector_router_core/
index.rs1use crate::distance::calculate_distance;
4use crate::error::{Result, VectorDbError};
5use crate::types::{DistanceMetric, SearchQuery, SearchResult};
6use parking_lot::RwLock;
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct HnswConfig {
14 pub m: usize,
16 pub ef_construction: usize,
18 pub ef_search: usize,
20 pub metric: DistanceMetric,
22 pub dimensions: usize,
24}
25
26impl Default for HnswConfig {
27 fn default() -> Self {
28 Self {
29 m: 32,
30 ef_construction: 200,
31 ef_search: 100,
32 metric: DistanceMetric::Cosine,
33 dimensions: 384,
34 }
35 }
36}
37
38#[derive(Clone)]
39struct Neighbor {
40 id: String,
41 distance: f32,
42}
43
44impl PartialEq for Neighbor {
45 fn eq(&self, other: &Self) -> bool {
46 self.distance == other.distance
47 }
48}
49
50impl Eq for Neighbor {}
51
52impl PartialOrd for Neighbor {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl Ord for Neighbor {
59 fn cmp(&self, other: &Self) -> Ordering {
60 other
62 .distance
63 .partial_cmp(&self.distance)
64 .unwrap_or(Ordering::Equal)
65 }
66}
67
68pub struct HnswIndex {
70 config: HnswConfig,
71 vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
72 graph: Arc<RwLock<HashMap<String, Vec<String>>>>,
73 entry_point: Arc<RwLock<Option<String>>>,
74}
75
76impl HnswIndex {
77 pub fn new(config: HnswConfig) -> Self {
79 Self {
80 config,
81 vectors: Arc::new(RwLock::new(HashMap::new())),
82 graph: Arc::new(RwLock::new(HashMap::new())),
83 entry_point: Arc::new(RwLock::new(None)),
84 }
85 }
86
87 pub fn insert(&self, id: String, vector: Vec<f32>) -> Result<()> {
89 if vector.len() != self.config.dimensions {
90 return Err(VectorDbError::InvalidDimensions {
91 expected: self.config.dimensions,
92 actual: vector.len(),
93 });
94 }
95
96 self.vectors.write().insert(id.clone(), vector.clone());
98
99 let mut graph = self.graph.write();
101 graph.insert(id.clone(), Vec::new());
102
103 let mut entry_point = self.entry_point.write();
105 if entry_point.is_none() {
106 *entry_point = Some(id.clone());
107 return Ok(());
108 }
109
110 let neighbors =
112 self.search_knn_internal(&vector, self.config.ef_construction.min(self.config.m * 2));
113
114 for neighbor in neighbors.iter().take(self.config.m) {
116 graph.get_mut(&id).unwrap().push(neighbor.id.clone());
117
118 if let Some(neighbor_connections) = graph.get_mut(&neighbor.id) {
119 neighbor_connections.push(id.clone());
120
121 if neighbor_connections.len() > self.config.m * 2 {
123 neighbor_connections.truncate(self.config.m);
124 }
125 }
126 }
127
128 Ok(())
129 }
130
131 pub fn insert_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
133 for (id, vector) in vectors {
134 self.insert(id, vector)?;
135 }
136 Ok(())
137 }
138
139 pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>> {
141 let ef_search = query.ef_search.unwrap_or(self.config.ef_search);
142 let candidates = self.search_knn_internal(&query.vector, ef_search);
143
144 let mut results = Vec::new();
145 for candidate in candidates.into_iter().take(query.k) {
146 if let Some(threshold) = query.threshold {
148 if candidate.distance > threshold {
149 continue;
150 }
151 }
152
153 results.push(SearchResult {
154 id: candidate.id,
155 score: candidate.distance,
156 metadata: HashMap::new(),
157 vector: None,
158 });
159 }
160
161 Ok(results)
162 }
163
164 fn search_knn_internal(&self, query: &[f32], ef: usize) -> Vec<Neighbor> {
166 let vectors = self.vectors.read();
167 let graph = self.graph.read();
168 let entry_point = self.entry_point.read();
169
170 if entry_point.is_none() {
171 return Vec::new();
172 }
173
174 let entry_id = entry_point.as_ref().unwrap();
175 let mut visited = HashSet::new();
176 let mut candidates = BinaryHeap::new();
177 let mut result = BinaryHeap::new();
178
179 if let Some(entry_vec) = vectors.get(entry_id) {
181 let dist = calculate_distance(query, entry_vec, self.config.metric).unwrap_or(f32::MAX);
182
183 let neighbor = Neighbor {
184 id: entry_id.clone(),
185 distance: dist,
186 };
187
188 candidates.push(neighbor.clone());
189 result.push(neighbor);
190 visited.insert(entry_id.clone());
191 }
192
193 while let Some(current) = candidates.pop() {
195 if let Some(furthest) = result.peek() {
197 if current.distance > furthest.distance && result.len() >= ef {
198 break;
199 }
200 }
201
202 if let Some(neighbors) = graph.get(¤t.id) {
204 for neighbor_id in neighbors {
205 if visited.contains(neighbor_id) {
206 continue;
207 }
208
209 visited.insert(neighbor_id.clone());
210
211 if let Some(neighbor_vec) = vectors.get(neighbor_id) {
212 let dist = calculate_distance(query, neighbor_vec, self.config.metric)
213 .unwrap_or(f32::MAX);
214
215 let neighbor = Neighbor {
216 id: neighbor_id.clone(),
217 distance: dist,
218 };
219
220 candidates.push(neighbor.clone());
222
223 if result.len() < ef {
225 result.push(neighbor);
226 } else if let Some(worst) = result.peek() {
227 if dist < worst.distance {
228 result.pop();
229 result.push(neighbor);
230 }
231 }
232 }
233 }
234 }
235 }
236
237 let mut sorted_results: Vec<Neighbor> = result.into_iter().collect();
239 sorted_results.sort_by(|a, b| {
240 a.distance
241 .partial_cmp(&b.distance)
242 .unwrap_or(Ordering::Equal)
243 });
244
245 sorted_results
246 }
247
248 pub fn remove(&self, id: &str) -> Result<bool> {
250 let mut vectors = self.vectors.write();
251 let mut graph = self.graph.write();
252
253 if vectors.remove(id).is_none() {
254 return Ok(false);
255 }
256
257 graph.remove(id);
259
260 for connections in graph.values_mut() {
262 connections.retain(|conn_id| conn_id != id);
263 }
264
265 let mut entry_point = self.entry_point.write();
267 if entry_point.as_ref() == Some(&id.to_string()) {
268 *entry_point = vectors.keys().next().cloned();
269 }
270
271 Ok(true)
272 }
273
274 pub fn len(&self) -> usize {
276 self.vectors.read().len()
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.vectors.read().is_empty()
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_hnsw_insert_and_search() {
291 let config = HnswConfig {
292 m: 16,
293 ef_construction: 100,
294 ef_search: 50,
295 metric: DistanceMetric::Euclidean,
296 dimensions: 3,
297 };
298
299 let index = HnswIndex::new(config);
300
301 index.insert("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
303 index.insert("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
304 index.insert("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
305
306 let query = SearchQuery {
308 vector: vec![0.9, 0.1, 0.0],
309 k: 2,
310 filters: None,
311 threshold: None,
312 ef_search: None,
313 };
314
315 let results = index.search(&query).unwrap();
316 assert_eq!(results.len(), 2);
317 assert_eq!(results[0].id, "v1"); }
319}