1use crate::distance::distance;
4use crate::error::{Result, RuvectorError};
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
7use bincode::{Decode, Encode};
8use dashmap::DashMap;
9use hnsw_rs::prelude::*;
10use parking_lot::RwLock;
11use std::sync::Arc;
12
13struct DistanceFn {
15 metric: DistanceMetric,
16}
17
18impl DistanceFn {
19 fn new(metric: DistanceMetric) -> Self {
20 Self { metric }
21 }
22}
23
24impl Distance<f32> for DistanceFn {
25 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
26 distance(a, b, self.metric).unwrap_or(f32::MAX)
27 }
28}
29
30pub struct HnswIndex {
32 inner: Arc<RwLock<HnswInner>>,
33 config: HnswConfig,
34 metric: DistanceMetric,
35 dimensions: usize,
36}
37
38struct HnswInner {
39 hnsw: Hnsw<'static, f32, DistanceFn>,
40 vectors: DashMap<VectorId, Vec<f32>>,
41 id_to_idx: DashMap<VectorId, usize>,
42 idx_to_id: DashMap<usize, VectorId>,
43 next_idx: usize,
44}
45
46#[derive(Encode, Decode, Clone)]
48pub struct HnswState {
49 vectors: Vec<(String, Vec<f32>)>,
50 id_to_idx: Vec<(String, usize)>,
51 idx_to_id: Vec<(usize, String)>,
52 next_idx: usize,
53 config: SerializableHnswConfig,
54 dimensions: usize,
55 metric: SerializableDistanceMetric,
56}
57
58#[derive(Encode, Decode, Clone)]
59struct SerializableHnswConfig {
60 m: usize,
61 ef_construction: usize,
62 ef_search: usize,
63 max_elements: usize,
64}
65
66#[derive(Encode, Decode, Clone, Copy)]
67enum SerializableDistanceMetric {
68 Euclidean,
69 Cosine,
70 DotProduct,
71 Manhattan,
72}
73
74impl From<DistanceMetric> for SerializableDistanceMetric {
75 fn from(metric: DistanceMetric) -> Self {
76 match metric {
77 DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
78 DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
79 DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
80 DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
81 }
82 }
83}
84
85impl From<SerializableDistanceMetric> for DistanceMetric {
86 fn from(metric: SerializableDistanceMetric) -> Self {
87 match metric {
88 SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
89 SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
90 SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
91 SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
92 }
93 }
94}
95
96impl HnswIndex {
97 pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
99 let distance_fn = DistanceFn::new(metric);
100
101 let hnsw = Hnsw::<f32, DistanceFn>::new(
103 config.m,
104 config.max_elements,
105 dimensions,
106 config.ef_construction,
107 distance_fn,
108 );
109
110 Ok(Self {
111 inner: Arc::new(RwLock::new(HnswInner {
112 hnsw,
113 vectors: DashMap::new(),
114 id_to_idx: DashMap::new(),
115 idx_to_id: DashMap::new(),
116 next_idx: 0,
117 })),
118 config,
119 metric,
120 dimensions,
121 })
122 }
123
124 pub fn config(&self) -> &HnswConfig {
126 &self.config
127 }
128
129 pub fn set_ef_search(&mut self, _ef_search: usize) {
131 }
134
135 pub fn serialize(&self) -> Result<Vec<u8>> {
137 let inner = self.inner.read();
138
139 let state = HnswState {
140 vectors: inner
141 .vectors
142 .iter()
143 .map(|entry| (entry.key().clone(), entry.value().clone()))
144 .collect(),
145 id_to_idx: inner
146 .id_to_idx
147 .iter()
148 .map(|entry| (entry.key().clone(), *entry.value()))
149 .collect(),
150 idx_to_id: inner
151 .idx_to_id
152 .iter()
153 .map(|entry| (*entry.key(), entry.value().clone()))
154 .collect(),
155 next_idx: inner.next_idx,
156 config: SerializableHnswConfig {
157 m: self.config.m,
158 ef_construction: self.config.ef_construction,
159 ef_search: self.config.ef_search,
160 max_elements: self.config.max_elements,
161 },
162 dimensions: self.dimensions,
163 metric: self.metric.into(),
164 };
165
166 bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
167 RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
168 })
169 }
170
171 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
173 let (state, _): (HnswState, usize) =
174 bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
175 RuvectorError::SerializationError(format!(
176 "Failed to deserialize HNSW index: {}",
177 e
178 ))
179 })?;
180
181 let config = HnswConfig {
182 m: state.config.m,
183 ef_construction: state.config.ef_construction,
184 ef_search: state.config.ef_search,
185 max_elements: state.config.max_elements,
186 };
187
188 let dimensions = state.dimensions;
189 let metric: DistanceMetric = state.metric.into();
190
191 let distance_fn = DistanceFn::new(metric);
192 let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
193 config.m,
194 config.max_elements,
195 dimensions,
196 config.ef_construction,
197 distance_fn,
198 );
199
200 let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
202 let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
203
204 for entry in idx_to_id.iter() {
206 let idx = *entry.key();
207 let id = entry.value();
208 if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
209 hnsw.insert_data(&vector.1, idx);
211 }
212 }
213
214 let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
215
216 Ok(Self {
217 inner: Arc::new(RwLock::new(HnswInner {
218 hnsw,
219 vectors: vectors_map,
220 id_to_idx,
221 idx_to_id,
222 next_idx: state.next_idx,
223 })),
224 config,
225 metric,
226 dimensions,
227 })
228 }
229
230 pub fn search_with_ef(
232 &self,
233 query: &[f32],
234 k: usize,
235 ef_search: usize,
236 ) -> Result<Vec<SearchResult>> {
237 if query.len() != self.dimensions {
238 return Err(RuvectorError::DimensionMismatch {
239 expected: self.dimensions,
240 actual: query.len(),
241 });
242 }
243
244 let inner = self.inner.read();
245
246 let neighbors = inner.hnsw.search(query, k, ef_search);
248
249 Ok(neighbors
250 .into_iter()
251 .filter_map(|neighbor| {
252 inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
253 id: id.clone(),
254 score: neighbor.distance,
255 vector: None,
256 metadata: None,
257 })
258 })
259 .collect())
260 }
261}
262
263impl VectorIndex for HnswIndex {
264 fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
265 if vector.len() != self.dimensions {
266 return Err(RuvectorError::DimensionMismatch {
267 expected: self.dimensions,
268 actual: vector.len(),
269 });
270 }
271
272 let mut inner = self.inner.write();
273 let idx = inner.next_idx;
274 inner.next_idx += 1;
275
276 inner.hnsw.insert_data(&vector, idx);
278
279 inner.vectors.insert(id.clone(), vector);
281 inner.id_to_idx.insert(id.clone(), idx);
282 inner.idx_to_id.insert(idx, id);
283
284 Ok(())
285 }
286
287 fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
288 for (_, vector) in &entries {
290 if vector.len() != self.dimensions {
291 return Err(RuvectorError::DimensionMismatch {
292 expected: self.dimensions,
293 actual: vector.len(),
294 });
295 }
296 }
297
298 let mut inner = self.inner.write();
299
300 use rayon::prelude::*;
302
303 let data_with_ids: Vec<_> = entries
305 .iter()
306 .enumerate()
307 .map(|(i, (id, vector))| {
308 let idx = inner.next_idx + i;
309 (id.clone(), idx, vector.clone())
310 })
311 .collect();
312
313 inner.next_idx += entries.len();
315
316 for (_id, idx, vector) in &data_with_ids {
320 inner.hnsw.insert_data(vector, *idx);
321 }
322
323 for (id, idx, vector) in data_with_ids {
325 inner.vectors.insert(id.clone(), vector);
326 inner.id_to_idx.insert(id.clone(), idx);
327 inner.idx_to_id.insert(idx, id);
328 }
329
330 Ok(())
331 }
332
333 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
334 self.search_with_ef(query, k, self.config.ef_search)
336 }
337
338 fn remove(&mut self, id: &VectorId) -> Result<bool> {
339 let mut inner = self.inner.write();
340
341 let removed = inner.vectors.remove(id).is_some();
345
346 if removed {
347 if let Some((_, idx)) = inner.id_to_idx.remove(id) {
348 inner.idx_to_id.remove(&idx);
349 }
350 }
351
352 Ok(removed)
353 }
354
355 fn len(&self) -> usize {
356 self.inner.read().vectors.len()
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
365 use rand::Rng;
366 let mut rng = rand::thread_rng();
367
368 (0..count)
369 .map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
370 .collect()
371 }
372
373 fn normalize_vector(v: &[f32]) -> Vec<f32> {
374 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
375 if norm > 0.0 {
376 v.iter().map(|x| x / norm).collect()
377 } else {
378 v.to_vec()
379 }
380 }
381
382 #[test]
383 fn test_hnsw_index_creation() -> Result<()> {
384 let config = HnswConfig::default();
385 let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
386 assert_eq!(index.len(), 0);
387 Ok(())
388 }
389
390 #[test]
391 fn test_hnsw_insert_and_search() -> Result<()> {
392 let config = HnswConfig {
393 m: 16,
394 ef_construction: 100,
395 ef_search: 50,
396 max_elements: 1000,
397 };
398
399 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
400
401 let vectors = generate_random_vectors(100, 128);
403 for (i, vector) in vectors.iter().enumerate() {
404 let normalized = normalize_vector(vector);
405 index.add(format!("vec_{}", i), normalized)?;
406 }
407
408 assert_eq!(index.len(), 100);
409
410 let query = normalize_vector(&vectors[0]);
412 let results = index.search(&query, 10)?;
413
414 assert!(!results.is_empty());
415 assert_eq!(results[0].id, "vec_0");
416
417 Ok(())
418 }
419
420 #[test]
421 fn test_hnsw_batch_insert() -> Result<()> {
422 let config = HnswConfig::default();
423 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
424
425 let vectors = generate_random_vectors(100, 128);
426 let entries: Vec<_> = vectors
427 .iter()
428 .enumerate()
429 .map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
430 .collect();
431
432 index.add_batch(entries)?;
433 assert_eq!(index.len(), 100);
434
435 Ok(())
436 }
437
438 #[test]
439 fn test_hnsw_serialization() -> Result<()> {
440 let config = HnswConfig {
441 m: 16,
442 ef_construction: 100,
443 ef_search: 50,
444 max_elements: 1000,
445 };
446
447 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
448
449 let vectors = generate_random_vectors(50, 128);
451 for (i, vector) in vectors.iter().enumerate() {
452 let normalized = normalize_vector(vector);
453 index.add(format!("vec_{}", i), normalized)?;
454 }
455
456 let bytes = index.serialize()?;
458
459 let restored_index = HnswIndex::deserialize(&bytes)?;
461
462 assert_eq!(restored_index.len(), 50);
463
464 let query = normalize_vector(&vectors[0]);
466 let results = restored_index.search(&query, 5)?;
467
468 assert!(!results.is_empty());
469
470 Ok(())
471 }
472
473 #[test]
474 fn test_dimension_mismatch() -> Result<()> {
475 let config = HnswConfig::default();
476 let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
477
478 let result = index.add("test".to_string(), vec![1.0; 64]);
479 assert!(result.is_err());
480
481 Ok(())
482 }
483}