1use crate::distance::calculate_distance;
4use crate::error::{Result, VectorDbError};
5use crate::types::{DistanceMetric, SearchQuery, SearchResult};
6use parking_lot::RwLock;
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct HnswConfig {
14 pub m: usize,
16 pub ef_construction: usize,
18 pub ef_search: usize,
20 pub metric: DistanceMetric,
22 pub dimensions: usize,
24}
25
26impl Default for HnswConfig {
27 fn default() -> Self {
28 Self {
29 m: 32,
30 ef_construction: 200,
31 ef_search: 100,
32 metric: DistanceMetric::Cosine,
33 dimensions: 384,
34 }
35 }
36}
37
38#[derive(Clone)]
39struct Neighbor {
40 id: String,
41 distance: f32,
42}
43
44impl PartialEq for Neighbor {
45 fn eq(&self, other: &Self) -> bool {
46 self.distance == other.distance
47 }
48}
49
50impl Eq for Neighbor {}
51
52impl PartialOrd for Neighbor {
53 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl Ord for Neighbor {
59 fn cmp(&self, other: &Self) -> Ordering {
60 other
62 .distance
63 .partial_cmp(&self.distance)
64 .unwrap_or(Ordering::Equal)
65 }
66}
67
68pub struct HnswIndex {
70 config: HnswConfig,
71 vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
72 graph: Arc<RwLock<HashMap<String, Vec<String>>>>,
73 entry_point: Arc<RwLock<Option<String>>>,
74}
75
76impl HnswIndex {
77 pub fn new(config: HnswConfig) -> Self {
79 Self {
80 config,
81 vectors: Arc::new(RwLock::new(HashMap::new())),
82 graph: Arc::new(RwLock::new(HashMap::new())),
83 entry_point: Arc::new(RwLock::new(None)),
84 }
85 }
86
87 pub fn insert(&self, id: String, vector: Vec<f32>) -> Result<()> {
89 if vector.len() != self.config.dimensions {
90 return Err(VectorDbError::InvalidDimensions {
91 expected: self.config.dimensions,
92 actual: vector.len(),
93 });
94 }
95
96 self.vectors.write().insert(id.clone(), vector.clone());
98
99 let is_first = {
103 let mut graph = self.graph.write();
104 graph.insert(id.clone(), Vec::new());
105
106 let mut entry_point = self.entry_point.write();
107 if entry_point.is_none() {
108 *entry_point = Some(id.clone());
109 return Ok(());
110 }
111 false
112 }; if is_first {
115 return Ok(());
116 }
117
118 let neighbors =
120 self.search_knn_internal(&vector, self.config.ef_construction.min(self.config.m * 2));
121
122 let mut graph = self.graph.write();
124
125 for neighbor in neighbors.iter().take(self.config.m) {
127 if let Some(connections) = graph.get_mut(&id) {
128 connections.push(neighbor.id.clone());
129 }
130
131 if let Some(neighbor_connections) = graph.get_mut(&neighbor.id) {
132 neighbor_connections.push(id.clone());
133
134 if neighbor_connections.len() > self.config.m * 2 {
136 neighbor_connections.truncate(self.config.m);
137 }
138 }
139 }
140
141 Ok(())
142 }
143
144 pub fn insert_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
146 for (id, vector) in vectors {
147 self.insert(id, vector)?;
148 }
149 Ok(())
150 }
151
152 pub fn search(&self, query: &SearchQuery) -> Result<Vec<SearchResult>> {
154 let ef_search = query.ef_search.unwrap_or(self.config.ef_search);
155 let candidates = self.search_knn_internal(&query.vector, ef_search);
156
157 let mut results = Vec::new();
158 for candidate in candidates.into_iter().take(query.k) {
159 if let Some(threshold) = query.threshold {
161 if candidate.distance > threshold {
162 continue;
163 }
164 }
165
166 results.push(SearchResult {
167 id: candidate.id,
168 score: candidate.distance,
169 metadata: HashMap::new(),
170 vector: None,
171 });
172 }
173
174 Ok(results)
175 }
176
177 fn search_knn_internal(&self, query: &[f32], ef: usize) -> Vec<Neighbor> {
179 let vectors = self.vectors.read();
180 let graph = self.graph.read();
181 let entry_point = self.entry_point.read();
182
183 if entry_point.is_none() {
184 return Vec::new();
185 }
186
187 let entry_id = entry_point.as_ref().unwrap();
188 let mut visited = HashSet::new();
189 let mut candidates = BinaryHeap::new();
190 let mut result = BinaryHeap::new();
191
192 if let Some(entry_vec) = vectors.get(entry_id) {
194 let dist = calculate_distance(query, entry_vec, self.config.metric).unwrap_or(f32::MAX);
195
196 let neighbor = Neighbor {
197 id: entry_id.clone(),
198 distance: dist,
199 };
200
201 candidates.push(neighbor.clone());
202 result.push(neighbor);
203 visited.insert(entry_id.clone());
204 }
205
206 while let Some(current) = candidates.pop() {
208 if let Some(furthest) = result.peek() {
210 if current.distance > furthest.distance && result.len() >= ef {
211 break;
212 }
213 }
214
215 if let Some(neighbors) = graph.get(¤t.id) {
217 for neighbor_id in neighbors {
218 if visited.contains(neighbor_id) {
219 continue;
220 }
221
222 visited.insert(neighbor_id.clone());
223
224 if let Some(neighbor_vec) = vectors.get(neighbor_id) {
225 let dist = calculate_distance(query, neighbor_vec, self.config.metric)
226 .unwrap_or(f32::MAX);
227
228 let neighbor = Neighbor {
229 id: neighbor_id.clone(),
230 distance: dist,
231 };
232
233 candidates.push(neighbor.clone());
235
236 if result.len() < ef {
238 result.push(neighbor);
239 } else if let Some(worst) = result.peek() {
240 if dist < worst.distance {
241 result.pop();
242 result.push(neighbor);
243 }
244 }
245 }
246 }
247 }
248 }
249
250 let mut sorted_results: Vec<Neighbor> = result.into_iter().collect();
252 sorted_results.sort_by(|a, b| {
253 a.distance
254 .partial_cmp(&b.distance)
255 .unwrap_or(Ordering::Equal)
256 });
257
258 sorted_results
259 }
260
261 pub fn remove(&self, id: &str) -> Result<bool> {
263 let mut vectors = self.vectors.write();
264 let mut graph = self.graph.write();
265
266 if vectors.remove(id).is_none() {
267 return Ok(false);
268 }
269
270 graph.remove(id);
272
273 for connections in graph.values_mut() {
275 connections.retain(|conn_id| conn_id != id);
276 }
277
278 let mut entry_point = self.entry_point.write();
280 if entry_point.as_ref() == Some(&id.to_string()) {
281 *entry_point = vectors.keys().next().cloned();
282 }
283
284 Ok(true)
285 }
286
287 pub fn len(&self) -> usize {
289 self.vectors.read().len()
290 }
291
292 pub fn is_empty(&self) -> bool {
294 self.vectors.read().is_empty()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_hnsw_insert_and_search() {
304 let config = HnswConfig {
305 m: 16,
306 ef_construction: 100,
307 ef_search: 50,
308 metric: DistanceMetric::Euclidean,
309 dimensions: 3,
310 };
311
312 let index = HnswIndex::new(config);
313
314 index.insert("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
316 index.insert("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
317 index.insert("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
318
319 let query = SearchQuery {
321 vector: vec![0.9, 0.1, 0.0],
322 k: 2,
323 filters: None,
324 threshold: None,
325 ef_search: None,
326 };
327
328 let results = index.search(&query).unwrap();
329 assert_eq!(results.len(), 2);
330 assert_eq!(results[0].id, "v1"); }
332
333 #[test]
334 fn test_hnsw_multiple_inserts_no_deadlock() {
335 let config = HnswConfig {
339 m: 16,
340 ef_construction: 100,
341 ef_search: 50,
342 metric: DistanceMetric::Cosine,
343 dimensions: 128,
344 };
345
346 let index = HnswIndex::new(config);
347
348 for i in 0..20 {
350 let mut vector = vec![0.0f32; 128];
351 vector[i % 128] = 1.0;
352 index.insert(format!("v{}", i), vector).unwrap();
353 }
354
355 assert_eq!(index.len(), 20);
356
357 let query = SearchQuery {
359 vector: vec![1.0; 128],
360 k: 5,
361 filters: None,
362 threshold: None,
363 ef_search: None,
364 };
365
366 let results = index.search(&query).unwrap();
367 assert_eq!(results.len(), 5);
368 }
369
370 #[test]
371 fn test_hnsw_concurrent_inserts() {
372 use std::sync::Arc;
373 use std::thread;
374
375 let config = HnswConfig {
376 m: 16,
377 ef_construction: 100,
378 ef_search: 50,
379 metric: DistanceMetric::Euclidean,
380 dimensions: 3,
381 };
382
383 let index = Arc::new(HnswIndex::new(config));
384
385 let mut handles = vec![];
387 for t in 0..4 {
388 let index_clone = Arc::clone(&index);
389 let handle = thread::spawn(move || {
390 for i in 0..10 {
391 let id = format!("t{}_v{}", t, i);
392 let vector = vec![t as f32, i as f32, 0.0];
393 index_clone.insert(id, vector).unwrap();
394 }
395 });
396 handles.push(handle);
397 }
398
399 for handle in handles {
401 handle.join().unwrap();
402 }
403
404 assert_eq!(index.len(), 40);
405 }
406}