1use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13struct DistanceFn {
15 metric: DistanceMetric,
16}
17
18impl DistanceFn {
19 fn new(metric: DistanceMetric) -> Self {
20 Self { metric }
21 }
22}
23
24impl Distance<f32> for DistanceFn {
25 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26 distance(a, b, self.metric).unwrap_or(f32::MAX)
27 }
28}
29
30pub struct HnswIndex {
32 inner: Arc<RwLock<HnswInner>>,
33 config: HnswConfig,
34 metric: DistanceMetric,
35 dimensions: usize,
36}
37
38struct HnswInner {
39 hnsw: Hnsw<'static, f32, DistanceFn>,
40 vectors: DashMap<VectorId, Vec<f32>>,
41 id_to_idx: DashMap<VectorId, usize>,
42 idx_to_id: DashMap<usize, VectorId>,
43 next_idx: usize,
44}
45
46#[derive(Encode, Decode, Clone)]
48pub struct HnswState {
49 vectors: Vec<(String, Vec<f32>)>,
50 id_to_idx: Vec<(String, usize)>,
51 idx_to_id: Vec<(usize, String)>,
52 next_idx: usize,
53 config: SerializableHnswConfig,
54 dimensions: usize,
55 metric: SerializableDistanceMetric,
56}
57
58#[derive(Encode, Decode, Clone)]
59struct SerializableHnswConfig {
60 m: usize,
61 ef_construction: usize,
62 ef_search: usize,
63 max_elements: usize,
64}
65
66#[derive(Encode, Decode, Clone, Copy)]
67enum SerializableDistanceMetric {
68 Euclidean,
69 Cosine,
70 DotProduct,
71 Manhattan,
72}
73
74impl From<DistanceMetric> for SerializableDistanceMetric {
75 fn from(metric: DistanceMetric) -> Self {
76 match metric {
77 DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
78 DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
79 DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
80 DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
81 }
82 }
83}
84
85impl From<SerializableDistanceMetric> for DistanceMetric {
86 fn from(metric: SerializableDistanceMetric) -> Self {
87 match metric {
88 SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
89 SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
90 SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
91 SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
92 }
93 }
94}
95
96impl HnswIndex {
97 pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
99 let distance_fn = DistanceFn::new(metric);
100
101 let hnsw = Hnsw::<f32, DistanceFn>::new(
103 config.m,
104 config.max_elements,
105 dimensions,
106 config.ef_construction,
107 distance_fn,
108 );
109
110 Ok(Self {
111 inner: Arc::new(RwLock::new(HnswInner {
112 hnsw,
113 vectors: DashMap::new(),
114 id_to_idx: DashMap::new(),
115 idx_to_id: DashMap::new(),
116 next_idx: 0,
117 })),
118 config,
119 metric,
120 dimensions,
121 })
122 }
123
124 pub fn config(&self) -> &HnswConfig {
126 &self.config
127 }
128
129 pub fn set_ef_search(&mut self, _ef_search: usize) {
131 }
134
135 pub fn serialize(&self) -> Result<Vec<u8>> {
137 let inner = self.inner.read();
138
139 let state = HnswState {
140 vectors: inner
141 .vectors
142 .iter()
143 .map(|entry| (entry.key().clone(), entry.value().clone()))
144 .collect(),
145 id_to_idx: inner
146 .id_to_idx
147 .iter()
148 .map(|entry| (entry.key().clone(), *entry.value()))
149 .collect(),
150 idx_to_id: inner
151 .idx_to_id
152 .iter()
153 .map(|entry| (*entry.key(), entry.value().clone()))
154 .collect(),
155 next_idx: inner.next_idx,
156 config: SerializableHnswConfig {
157 m: self.config.m,
158 ef_construction: self.config.ef_construction,
159 ef_search: self.config.ef_search,
160 max_elements: self.config.max_elements,
161 },
162 dimensions: self.dimensions,
163 metric: self.metric.into(),
164 };
165
166 bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
167 RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
168 })
169 }
170
171 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
173 let (state, _): (HnswState, usize) =
174 bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
175 RuvectorError::SerializationError(format!(
176 "Failed to deserialize HNSW index: {}",
177 e
178 ))
179 })?;
180
181 let config = HnswConfig {
182 m: state.config.m,
183 ef_construction: state.config.ef_construction,
184 ef_search: state.config.ef_search,
185 max_elements: state.config.max_elements,
186 };
187
188 let dimensions = state.dimensions;
189 let metric: DistanceMetric = state.metric.into();
190
191 let distance_fn = DistanceFn::new(metric);
192 let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
193 config.m,
194 config.max_elements,
195 dimensions,
196 config.ef_construction,
197 distance_fn,
198 );
199
200 let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
202 let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
203
204 let vectors_by_id: std::collections::HashMap<VectorId, Vec<f32>> =
207 state.vectors.iter().cloned().collect();
208
209 for entry in idx_to_id.iter() {
211 let idx = *entry.key();
212 let id = entry.value();
213 if let Some(vector) = vectors_by_id.get(id) {
215 hnsw.insert_data(vector, idx);
217 }
218 }
219
220 let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
221
222 Ok(Self {
223 inner: Arc::new(RwLock::new(HnswInner {
224 hnsw,
225 vectors: vectors_map,
226 id_to_idx,
227 idx_to_id,
228 next_idx: state.next_idx,
229 })),
230 config,
231 metric,
232 dimensions,
233 })
234 }
235
236 pub fn search_with_ef(
238 &self,
239 query: &[f32],
240 k: usize,
241 ef_search: usize,
242 ) -> Result<Vec<SearchResult>> {
243 if query.len() != self.dimensions {
244 return Err(RuvectorError::DimensionMismatch {
245 expected: self.dimensions,
246 actual: query.len(),
247 });
248 }
249
250 let inner = self.inner.read();
251
252 let neighbors = inner.hnsw.search(query, k, ef_search);
254
255 Ok(neighbors
256 .into_iter()
257 .filter_map(|neighbor| {
258 inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
259 id: id.clone(),
260 score: neighbor.distance,
261 vector: None,
262 metadata: None,
263 })
264 })
265 .collect())
266 }
267}
268
269impl VectorIndex for HnswIndex {
270 fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
271 if vector.len() != self.dimensions {
272 return Err(RuvectorError::DimensionMismatch {
273 expected: self.dimensions,
274 actual: vector.len(),
275 });
276 }
277
278 let mut inner = self.inner.write();
279 let idx = inner.next_idx;
280 inner.next_idx += 1;
281
282 inner.hnsw.insert_data(&vector, idx);
284
285 inner.vectors.insert(id.clone(), vector);
287 inner.id_to_idx.insert(id.clone(), idx);
288 inner.idx_to_id.insert(idx, id);
289
290 Ok(())
291 }
292
293 fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
294 for (_, vector) in &entries {
296 if vector.len() != self.dimensions {
297 return Err(RuvectorError::DimensionMismatch {
298 expected: self.dimensions,
299 actual: vector.len(),
300 });
301 }
302 }
303
304 let mut inner = self.inner.write();
305
306 use rayon::prelude::*;
308
309 let data_with_ids: Vec<_> = entries
311 .iter()
312 .enumerate()
313 .map(|(i, (id, vector))| {
314 let idx = inner.next_idx + i;
315 (id.clone(), idx, vector.clone())
316 })
317 .collect();
318
319 inner.next_idx += entries.len();
321
322 for (_id, idx, vector) in &data_with_ids {
326 inner.hnsw.insert_data(vector, *idx);
327 }
328
329 for (id, idx, vector) in data_with_ids {
331 inner.vectors.insert(id.clone(), vector);
332 inner.id_to_idx.insert(id.clone(), idx);
333 inner.idx_to_id.insert(idx, id);
334 }
335
336 Ok(())
337 }
338
339 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
340 self.search_with_ef(query, k, self.config.ef_search)
342 }
343
344 fn remove(&mut self, id: &VectorId) -> Result<bool> {
345 let mut inner = self.inner.write();
346
347 let removed = inner.vectors.remove(id).is_some();
351
352 if removed {
353 if let Some((_, idx)) = inner.id_to_idx.remove(id) {
354 inner.idx_to_id.remove(&idx);
355 }
356 }
357
358 Ok(removed)
359 }
360
361 fn len(&self) -> usize {
362 self.inner.read().vectors.len()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
371 use rand::Rng;
372 let mut rng = rand::thread_rng();
373
374 (0..count)
375 .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
376 .collect()
377 }
378
379 fn normalize_vector(v: &[f32]) -> Vec<f32> {
380 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
381 if norm > 0.0 {
382 v.iter().map(|x| x / norm).collect()
383 } else {
384 v.to_vec()
385 }
386 }
387
388 #[test]
389 fn test_hnsw_index_creation() -> Result<()> {
390 let config = HnswConfig::default();
391 let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
392 assert_eq!(index.len(), 0);
393 Ok(())
394 }
395
396 #[test]
397 fn test_hnsw_insert_and_search() -> Result<()> {
398 let config = HnswConfig {
399 m: 16,
400 ef_construction: 100,
401 ef_search: 50,
402 max_elements: 1000,
403 };
404
405 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
406
407 let vectors = generate_random_vectors(100, 128);
409 for (i, vector) in vectors.iter().enumerate() {
410 let normalized = normalize_vector(vector);
411 index.add(format!("vec_{}", i), normalized)?;
412 }
413
414 assert_eq!(index.len(), 100);
415
416 let query = normalize_vector(&vectors[0]);
418 let results = index.search(&query, 10)?;
419
420 assert!(!results.is_empty());
421 assert_eq!(results[0].id, "vec_0");
422
423 Ok(())
424 }
425
426 #[test]
427 fn test_hnsw_batch_insert() -> Result<()> {
428 let config = HnswConfig::default();
429 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
430
431 let vectors = generate_random_vectors(100, 128);
432 let entries: Vec<_> = vectors
433 .iter()
434 .enumerate()
435 .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
436 .collect();
437
438 index.add_batch(entries)?;
439 assert_eq!(index.len(), 100);
440
441 Ok(())
442 }
443
444 #[test]
445 fn test_hnsw_serialization() -> Result<()> {
446 let config = HnswConfig {
447 m: 16,
448 ef_construction: 100,
449 ef_search: 50,
450 max_elements: 1000,
451 };
452
453 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
454
455 let vectors = generate_random_vectors(50, 128);
457 for (i, vector) in vectors.iter().enumerate() {
458 let normalized = normalize_vector(vector);
459 index.add(format!("vec_{}", i), normalized)?;
460 }
461
462 let bytes = index.serialize()?;
464
465 let restored_index = HnswIndex::deserialize(&bytes)?;
467
468 assert_eq!(restored_index.len(), 50);
469
470 let query = normalize_vector(&vectors[0]);
472 let results = restored_index.search(&query, 5)?;
473
474 assert!(!results.is_empty());
475
476 Ok(())
477 }
478
479 #[test]
480 fn test_dimension_mismatch() -> Result<()> {
481 let config = HnswConfig::default();
482 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
483
484 let result = index.add("test".to_string(), vec![1.0; 64]);
485 assert!(result.is_err());
486
487 Ok(())
488 }
489}