1use std::sync::Arc;
7
8use parking_lot::RwLock;
9use tracing::{debug, info, instrument, warn};
10
11use crate::distance::{cosine_similarity, normalize_vector};
12use crate::domain::{
13 EmbeddingId, HnswConfig, SimilarityEdge, EdgeType,
14 VectorError,
15};
16use crate::infrastructure::hnsw_index::HnswIndex;
17
18#[derive(Debug, Clone)]
20pub struct Neighbor {
21 pub id: EmbeddingId,
23
24 pub distance: f32,
26
27 pub similarity: f32,
29
30 pub rank: usize,
32}
33
34impl Neighbor {
35 pub fn new(id: EmbeddingId, distance: f32, rank: usize) -> Self {
37 Self {
38 id,
39 distance,
40 similarity: 1.0 - distance.clamp(0.0, 1.0),
41 rank,
42 }
43 }
44
45 #[inline]
47 pub fn is_above_threshold(&self, threshold: f32) -> bool {
48 self.similarity >= threshold
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct SearchOptions {
55 pub k: usize,
57
58 pub min_similarity: Option<f32>,
60
61 pub max_distance: Option<f32>,
63
64 pub ef_search: Option<usize>,
66
67 pub include_query: bool,
69}
70
71impl Default for SearchOptions {
72 fn default() -> Self {
73 Self {
74 k: 10,
75 min_similarity: None,
76 max_distance: None,
77 ef_search: None,
78 include_query: false,
79 }
80 }
81}
82
83impl SearchOptions {
84 pub fn new(k: usize) -> Self {
86 Self {
87 k,
88 ..Default::default()
89 }
90 }
91
92 pub fn with_min_similarity(mut self, threshold: f32) -> Self {
94 self.min_similarity = Some(threshold);
95 self
96 }
97
98 pub fn with_max_distance(mut self, distance: f32) -> Self {
100 self.max_distance = Some(distance);
101 self
102 }
103
104 pub fn with_ef_search(mut self, ef: usize) -> Self {
106 self.ef_search = Some(ef);
107 self
108 }
109
110 pub fn include_query(mut self) -> Self {
112 self.include_query = true;
113 self
114 }
115}
116
117pub struct VectorSpaceService {
124 index: Arc<RwLock<HnswIndex>>,
126
127 config: HnswConfig,
129}
130
131impl VectorSpaceService {
132 pub fn new(config: HnswConfig) -> Self {
134 let index = HnswIndex::new(&config);
135 Self {
136 index: Arc::new(RwLock::new(index)),
137 config,
138 }
139 }
140
141 pub fn from_index(index: HnswIndex, config: HnswConfig) -> Self {
143 Self {
144 index: Arc::new(RwLock::new(index)),
145 config,
146 }
147 }
148
149 #[inline]
151 pub fn dimensions(&self) -> usize {
152 self.config.dimensions
153 }
154
155 pub fn len(&self) -> usize {
157 self.index.read().len()
158 }
159
160 pub fn is_empty(&self) -> bool {
162 self.len() == 0
163 }
164
165 pub fn config(&self) -> &HnswConfig {
167 &self.config
168 }
169
170 #[instrument(skip(self, vector), fields(vector_dim = vector.len()))]
174 pub async fn add_embedding(
175 &self,
176 id: EmbeddingId,
177 vector: Vec<f32>,
178 ) -> Result<(), VectorError> {
179 self.validate_vector(&vector)?;
180
181 let vector = if self.config.normalize {
182 normalize_vector(&vector)
183 } else {
184 vector
185 };
186
187 let mut index = self.index.write();
188 index.insert(id, &vector)?;
189
190 debug!(id = %id, "Added embedding to index");
191 Ok(())
192 }
193
194 #[instrument(skip(self, items), fields(batch_size = items.len()))]
199 pub async fn add_embeddings_batch(
200 &self,
201 items: Vec<(EmbeddingId, Vec<f32>)>,
202 ) -> Result<usize, VectorError> {
203 if items.is_empty() {
204 return Ok(0);
205 }
206
207 for (_, vector) in &items {
209 self.validate_vector(vector)?;
210 }
211
212 let items: Vec<_> = if self.config.normalize {
214 items
215 .into_iter()
216 .map(|(id, v)| (id, normalize_vector(&v)))
217 .collect()
218 } else {
219 items
220 };
221
222 let mut index = self.index.write();
223 let mut added = 0;
224
225 for (id, vector) in &items {
226 if let Err(e) = index.insert(*id, vector) {
227 warn!(id = %id, error = %e, "Failed to add embedding in batch");
228 } else {
229 added += 1;
230 }
231 }
232
233 info!(added, total = items.len(), "Batch insert completed");
234 Ok(added)
235 }
236
237 #[instrument(skip(self, query), fields(query_dim = query.len(), k))]
239 pub async fn find_neighbors(
240 &self,
241 query: &[f32],
242 k: usize,
243 ) -> Result<Vec<Neighbor>, VectorError> {
244 self.find_neighbors_with_options(query, SearchOptions::new(k))
245 .await
246 }
247
248 #[instrument(skip(self, query, options), fields(query_dim = query.len()))]
250 pub async fn find_neighbors_with_options(
251 &self,
252 query: &[f32],
253 options: SearchOptions,
254 ) -> Result<Vec<Neighbor>, VectorError> {
255 self.validate_vector(query)?;
256
257 let query = if self.config.normalize {
258 normalize_vector(query)
259 } else {
260 query.to_vec()
261 };
262
263 let index = self.index.read();
264
265 if index.is_empty() {
266 return Ok(Vec::new());
267 }
268
269 let k_fetch = if options.min_similarity.is_some() || options.max_distance.is_some() {
271 options.k * 2
272 } else {
273 options.k
274 };
275
276 let results = index.search(&query, k_fetch);
277
278 let mut neighbors: Vec<_> = results
279 .into_iter()
280 .enumerate()
281 .map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
282 .collect();
283
284 if let Some(min_sim) = options.min_similarity {
286 neighbors.retain(|n| n.similarity >= min_sim);
287 }
288 if let Some(max_dist) = options.max_distance {
289 neighbors.retain(|n| n.distance <= max_dist);
290 }
291
292 neighbors.truncate(options.k);
294
295 for (rank, neighbor) in neighbors.iter_mut().enumerate() {
297 neighbor.rank = rank;
298 }
299
300 debug!(found = neighbors.len(), "Neighbor search completed");
301 Ok(neighbors)
302 }
303
304 #[instrument(skip(self, query, filter), fields(query_dim = query.len(), k))]
309 pub async fn find_neighbors_with_filter<F>(
310 &self,
311 query: &[f32],
312 k: usize,
313 filter: F,
314 ) -> Result<Vec<Neighbor>, VectorError>
315 where
316 F: Fn(&EmbeddingId) -> bool + Send + Sync,
317 {
318 self.validate_vector(query)?;
319
320 let query = if self.config.normalize {
321 normalize_vector(query)
322 } else {
323 query.to_vec()
324 };
325
326 let index = self.index.read();
327
328 if index.is_empty() {
329 return Ok(Vec::new());
330 }
331
332 let k_fetch = k * 4;
334 let results = index.search(&query, k_fetch);
335
336 let mut neighbors: Vec<_> = results
337 .into_iter()
338 .filter(|(id, _)| filter(id))
339 .take(k)
340 .enumerate()
341 .map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
342 .collect();
343
344 for (rank, neighbor) in neighbors.iter_mut().enumerate() {
346 neighbor.rank = rank;
347 }
348
349 Ok(neighbors)
350 }
351
352 #[instrument(skip(self))]
354 pub async fn remove_embedding(&self, id: &EmbeddingId) -> Result<(), VectorError> {
355 let mut index = self.index.write();
356 index.remove(id)?;
357 debug!(id = %id, "Removed embedding from index");
358 Ok(())
359 }
360
361 pub fn contains(&self, id: &EmbeddingId) -> bool {
363 self.index.read().contains(id)
364 }
365
366 pub fn get_vector(&self, id: &EmbeddingId) -> Option<Vec<f32>> {
368 self.index.read().get_vector(id)
369 }
370
371 #[instrument(skip(self, vector))]
375 pub async fn build_similarity_edges(
376 &self,
377 id: EmbeddingId,
378 vector: &[f32],
379 k: usize,
380 min_similarity: f32,
381 ) -> Result<Vec<SimilarityEdge>, VectorError> {
382 let neighbors = self
383 .find_neighbors_with_options(
384 vector,
385 SearchOptions::new(k).with_min_similarity(min_similarity),
386 )
387 .await?;
388
389 let edges: Vec<_> = neighbors
390 .into_iter()
391 .filter(|n| n.id != id) .map(|n| {
393 SimilarityEdge::new(id, n.id, n.distance)
394 .with_type(EdgeType::Similar)
395 })
396 .collect();
397
398 Ok(edges)
399 }
400
401 #[instrument(skip(self, vectors))]
403 pub async fn compute_pairwise_similarities(
404 &self,
405 vectors: &[(EmbeddingId, Vec<f32>)],
406 ) -> Result<Vec<(EmbeddingId, EmbeddingId, f32)>, VectorError> {
407 if vectors.len() < 2 {
408 return Ok(Vec::new());
409 }
410
411 for (_, vector) in vectors {
413 self.validate_vector(vector)?;
414 }
415
416 let vectors: Vec<_> = if self.config.normalize {
418 vectors
419 .iter()
420 .map(|(id, v)| (*id, normalize_vector(v)))
421 .collect()
422 } else {
423 vectors.to_vec()
424 };
425
426 let mut similarities = Vec::with_capacity(vectors.len() * (vectors.len() - 1) / 2);
427
428 for i in 0..vectors.len() {
429 for j in (i + 1)..vectors.len() {
430 let sim = cosine_similarity(&vectors[i].1, &vectors[j].1);
431 similarities.push((vectors[i].0, vectors[j].0, sim));
432 }
433 }
434
435 Ok(similarities)
436 }
437
438 pub async fn clear(&self) -> Result<(), VectorError> {
440 let mut index = self.index.write();
441 index.clear();
442 info!("Cleared all embeddings from index");
443 Ok(())
444 }
445
446 pub async fn save(&self, path: &std::path::Path) -> Result<(), VectorError> {
448 let index = self.index.read();
449 index.save(path)?;
450 info!(path = %path.display(), "Saved index to file");
451 Ok(())
452 }
453
454 pub async fn load(path: &std::path::Path, config: HnswConfig) -> Result<Self, VectorError> {
456 let index = HnswIndex::load(path)?;
457 info!(path = %path.display(), "Loaded index from file");
458 Ok(Self::from_index(index, config))
459 }
460
461 pub fn stats(&self) -> IndexStatistics {
463 let index = self.index.read();
464 IndexStatistics {
465 vector_count: index.len(),
466 dimensions: self.config.dimensions,
467 max_capacity: self.config.max_elements,
468 utilization: index.len() as f64 / self.config.max_elements as f64,
469 }
470 }
471
472 fn validate_vector(&self, vector: &[f32]) -> Result<(), VectorError> {
474 if vector.len() != self.config.dimensions {
475 return Err(VectorError::dimension_mismatch(
476 self.config.dimensions,
477 vector.len(),
478 ));
479 }
480
481 for (i, &v) in vector.iter().enumerate() {
483 if v.is_nan() {
484 return Err(VectorError::invalid_vector(format!(
485 "NaN value at index {i}"
486 )));
487 }
488 if v.is_infinite() {
489 return Err(VectorError::invalid_vector(format!(
490 "Infinite value at index {i}"
491 )));
492 }
493 }
494
495 Ok(())
496 }
497}
498
499impl Clone for VectorSpaceService {
500 fn clone(&self) -> Self {
501 Self {
502 index: Arc::clone(&self.index),
503 config: self.config.clone(),
504 }
505 }
506}
507
508#[derive(Debug, Clone)]
510pub struct IndexStatistics {
511 pub vector_count: usize,
513
514 pub dimensions: usize,
516
517 pub max_capacity: usize,
519
520 pub utilization: f64,
522}
523
524pub struct VectorSpaceServiceBuilder {
526 config: HnswConfig,
527}
528
529impl VectorSpaceServiceBuilder {
530 pub fn new() -> Self {
532 Self {
533 config: HnswConfig::default(),
534 }
535 }
536
537 pub fn dimensions(mut self, dim: usize) -> Self {
539 self.config.dimensions = dim;
540 self
541 }
542
543 pub fn m(mut self, m: usize) -> Self {
545 self.config.m = m;
546 self
547 }
548
549 pub fn ef_construction(mut self, ef: usize) -> Self {
551 self.config.ef_construction = ef;
552 self
553 }
554
555 pub fn ef_search(mut self, ef: usize) -> Self {
557 self.config.ef_search = ef;
558 self
559 }
560
561 pub fn max_elements(mut self, max: usize) -> Self {
563 self.config.max_elements = max;
564 self
565 }
566
567 pub fn normalize(mut self, normalize: bool) -> Self {
569 self.config.normalize = normalize;
570 self
571 }
572
573 pub fn build(self) -> Result<VectorSpaceService, VectorError> {
575 self.config.validate()?;
576 Ok(VectorSpaceService::new(self.config))
577 }
578}
579
580impl Default for VectorSpaceServiceBuilder {
581 fn default() -> Self {
582 Self::new()
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 fn create_test_service() -> VectorSpaceService {
591 let config = HnswConfig::for_dimension(128)
592 .with_max_elements(1000)
593 .with_normalize(false);
594 VectorSpaceService::new(config)
595 }
596
597 #[tokio::test]
598 async fn test_add_and_search() {
599 let service = create_test_service();
600
601 let id1 = EmbeddingId::new();
602 let id2 = EmbeddingId::new();
603
604 let v1: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
605 let v2: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) / 128.0).collect();
606
607 service.add_embedding(id1, v1.clone()).await.unwrap();
608 service.add_embedding(id2, v2).await.unwrap();
609
610 assert_eq!(service.len(), 2);
611
612 let neighbors = service.find_neighbors(&v1, 2).await.unwrap();
613 assert_eq!(neighbors.len(), 2);
614 assert_eq!(neighbors[0].id, id1);
615 }
616
617 #[tokio::test]
618 async fn test_dimension_mismatch() {
619 let service = create_test_service();
620 let id = EmbeddingId::new();
621 let wrong_dim: Vec<f32> = vec![0.1; 64];
622
623 let result = service.add_embedding(id, wrong_dim).await;
624 assert!(matches!(
625 result,
626 Err(VectorError::DimensionMismatch { .. })
627 ));
628 }
629
630 #[tokio::test]
631 async fn test_batch_insert() {
632 let service = create_test_service();
633
634 let items: Vec<_> = (0..10)
635 .map(|i| {
636 let id = EmbeddingId::new();
637 let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 1280.0).collect();
638 (id, vector)
639 })
640 .collect();
641
642 let added = service.add_embeddings_batch(items).await.unwrap();
643 assert_eq!(added, 10);
644 assert_eq!(service.len(), 10);
645 }
646
647 #[tokio::test]
648 async fn test_search_with_filter() {
649 let service = create_test_service();
650
651 let ids: Vec<_> = (0..5).map(|_| EmbeddingId::new()).collect();
652
653 for (i, id) in ids.iter().enumerate() {
654 let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 640.0).collect();
655 service.add_embedding(*id, vector).await.unwrap();
656 }
657
658 let query: Vec<f32> = (0..128).map(|j| j as f32 / 640.0).collect();
659
660 let odd_ids: std::collections::HashSet<_> =
662 ids.iter().enumerate().filter(|(i, _)| i % 2 == 1).map(|(_, id)| *id).collect();
663
664 let neighbors = service
665 .find_neighbors_with_filter(&query, 10, |id| odd_ids.contains(id))
666 .await
667 .unwrap();
668
669 for n in &neighbors {
670 assert!(odd_ids.contains(&n.id));
671 }
672 }
673
674 #[test]
675 fn test_neighbor() {
676 let neighbor = Neighbor::new(EmbeddingId::new(), 0.2, 0);
677 assert!((neighbor.similarity - 0.8).abs() < 0.001);
678 assert!(neighbor.is_above_threshold(0.7));
679 assert!(!neighbor.is_above_threshold(0.9));
680 }
681
682 #[test]
683 fn test_search_options() {
684 let opts = SearchOptions::new(10)
685 .with_min_similarity(0.8)
686 .with_max_distance(0.3);
687
688 assert_eq!(opts.k, 10);
689 assert_eq!(opts.min_similarity, Some(0.8));
690 assert_eq!(opts.max_distance, Some(0.3));
691 }
692
693 #[test]
694 fn test_builder() {
695 let service = VectorSpaceServiceBuilder::new()
696 .dimensions(256)
697 .m(16)
698 .ef_construction(100)
699 .max_elements(5000)
700 .build()
701 .unwrap();
702
703 assert_eq!(service.dimensions(), 256);
704 }
705}