sevensense_vector/application/
services.rs

1//! Application services for the Vector Space bounded context.
2//!
3//! These services implement the use cases for vector indexing and search,
4//! providing a high-level API that coordinates domain objects and repositories.
5
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9use tracing::{debug, info, instrument, warn};
10
11use crate::distance::{cosine_similarity, normalize_vector};
12use crate::domain::{
13    EmbeddingId, HnswConfig, SimilarityEdge, EdgeType,
14    VectorError,
15};
16use crate::infrastructure::hnsw_index::HnswIndex;
17
18/// A search result neighbor with similarity information.
19#[derive(Debug, Clone)]
20pub struct Neighbor {
21    /// The embedding ID of this neighbor.
22    pub id: EmbeddingId,
23
24    /// Distance from the query vector.
25    pub distance: f32,
26
27    /// Similarity score (1 - distance for cosine).
28    pub similarity: f32,
29
30    /// Rank in the result set (0 = closest).
31    pub rank: usize,
32}
33
34impl Neighbor {
35    /// Create a new neighbor from search results.
36    pub fn new(id: EmbeddingId, distance: f32, rank: usize) -> Self {
37        Self {
38            id,
39            distance,
40            similarity: 1.0 - distance.clamp(0.0, 1.0),
41            rank,
42        }
43    }
44
45    /// Check if this neighbor exceeds a similarity threshold.
46    #[inline]
47    pub fn is_above_threshold(&self, threshold: f32) -> bool {
48        self.similarity >= threshold
49    }
50}
51
52/// Options for search queries.
53#[derive(Debug, Clone)]
54pub struct SearchOptions {
55    /// Maximum number of results to return.
56    pub k: usize,
57
58    /// Minimum similarity threshold (results below this are filtered).
59    pub min_similarity: Option<f32>,
60
61    /// Maximum distance threshold.
62    pub max_distance: Option<f32>,
63
64    /// ef_search parameter override (higher = more accurate but slower).
65    pub ef_search: Option<usize>,
66
67    /// Whether to include the query vector in results if it exists.
68    pub include_query: bool,
69}
70
71impl Default for SearchOptions {
72    fn default() -> Self {
73        Self {
74            k: 10,
75            min_similarity: None,
76            max_distance: None,
77            ef_search: None,
78            include_query: false,
79        }
80    }
81}
82
83impl SearchOptions {
84    /// Create new search options with specified k.
85    pub fn new(k: usize) -> Self {
86        Self {
87            k,
88            ..Default::default()
89        }
90    }
91
92    /// Set minimum similarity threshold.
93    pub fn with_min_similarity(mut self, threshold: f32) -> Self {
94        self.min_similarity = Some(threshold);
95        self
96    }
97
98    /// Set maximum distance threshold.
99    pub fn with_max_distance(mut self, distance: f32) -> Self {
100        self.max_distance = Some(distance);
101        self
102    }
103
104    /// Set ef_search parameter.
105    pub fn with_ef_search(mut self, ef: usize) -> Self {
106        self.ef_search = Some(ef);
107        self
108    }
109
110    /// Include query vector in results.
111    pub fn include_query(mut self) -> Self {
112        self.include_query = true;
113        self
114    }
115}
116
117/// The main service for vector space operations.
118///
119/// This service provides a thread-safe interface for:
120/// - Adding and removing embeddings
121/// - Nearest neighbor search
122/// - Building similarity graphs
123pub struct VectorSpaceService {
124    /// The underlying HNSW index.
125    index: Arc<RwLock<HnswIndex>>,
126
127    /// Configuration for this service.
128    config: HnswConfig,
129}
130
131impl VectorSpaceService {
132    /// Create a new vector space service with the given configuration.
133    pub fn new(config: HnswConfig) -> Self {
134        let index = HnswIndex::new(&config);
135        Self {
136            index: Arc::new(RwLock::new(index)),
137            config,
138        }
139    }
140
141    /// Create a service from an existing index.
142    pub fn from_index(index: HnswIndex, config: HnswConfig) -> Self {
143        Self {
144            index: Arc::new(RwLock::new(index)),
145            config,
146        }
147    }
148
149    /// Get the index dimensions.
150    #[inline]
151    pub fn dimensions(&self) -> usize {
152        self.config.dimensions
153    }
154
155    /// Get the current number of vectors.
156    pub fn len(&self) -> usize {
157        self.index.read().len()
158    }
159
160    /// Check if the index is empty.
161    pub fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164
165    /// Get a reference to the configuration.
166    pub fn config(&self) -> &HnswConfig {
167        &self.config
168    }
169
170    /// Add a single embedding to the index.
171    ///
172    /// The vector will be normalized if the configuration specifies normalization.
173    #[instrument(skip(self, vector), fields(vector_dim = vector.len()))]
174    pub async fn add_embedding(
175        &self,
176        id: EmbeddingId,
177        vector: Vec<f32>,
178    ) -> Result<(), VectorError> {
179        self.validate_vector(&vector)?;
180
181        let vector = if self.config.normalize {
182            normalize_vector(&vector)
183        } else {
184            vector
185        };
186
187        let mut index = self.index.write();
188        index.insert(id, &vector)?;
189
190        debug!(id = %id, "Added embedding to index");
191        Ok(())
192    }
193
194    /// Add multiple embeddings in a batch.
195    ///
196    /// This is more efficient than multiple single adds due to
197    /// amortized locking overhead.
198    #[instrument(skip(self, items), fields(batch_size = items.len()))]
199    pub async fn add_embeddings_batch(
200        &self,
201        items: Vec<(EmbeddingId, Vec<f32>)>,
202    ) -> Result<usize, VectorError> {
203        if items.is_empty() {
204            return Ok(0);
205        }
206
207        // Validate all vectors first
208        for (_, vector) in &items {
209            self.validate_vector(vector)?;
210        }
211
212        // Normalize if needed
213        let items: Vec<_> = if self.config.normalize {
214            items
215                .into_iter()
216                .map(|(id, v)| (id, normalize_vector(&v)))
217                .collect()
218        } else {
219            items
220        };
221
222        let mut index = self.index.write();
223        let mut added = 0;
224
225        for (id, vector) in &items {
226            if let Err(e) = index.insert(*id, vector) {
227                warn!(id = %id, error = %e, "Failed to add embedding in batch");
228            } else {
229                added += 1;
230            }
231        }
232
233        info!(added, total = items.len(), "Batch insert completed");
234        Ok(added)
235    }
236
237    /// Find the k nearest neighbors to a query vector.
238    #[instrument(skip(self, query), fields(query_dim = query.len(), k))]
239    pub async fn find_neighbors(
240        &self,
241        query: &[f32],
242        k: usize,
243    ) -> Result<Vec<Neighbor>, VectorError> {
244        self.find_neighbors_with_options(query, SearchOptions::new(k))
245            .await
246    }
247
248    /// Find neighbors with custom search options.
249    #[instrument(skip(self, query, options), fields(query_dim = query.len()))]
250    pub async fn find_neighbors_with_options(
251        &self,
252        query: &[f32],
253        options: SearchOptions,
254    ) -> Result<Vec<Neighbor>, VectorError> {
255        self.validate_vector(query)?;
256
257        let query = if self.config.normalize {
258            normalize_vector(query)
259        } else {
260            query.to_vec()
261        };
262
263        let index = self.index.read();
264
265        if index.is_empty() {
266            return Ok(Vec::new());
267        }
268
269        // Request more results if we're filtering
270        let k_fetch = if options.min_similarity.is_some() || options.max_distance.is_some() {
271            options.k * 2
272        } else {
273            options.k
274        };
275
276        let results = index.search(&query, k_fetch);
277
278        let mut neighbors: Vec<_> = results
279            .into_iter()
280            .enumerate()
281            .map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
282            .collect();
283
284        // Apply filters
285        if let Some(min_sim) = options.min_similarity {
286            neighbors.retain(|n| n.similarity >= min_sim);
287        }
288        if let Some(max_dist) = options.max_distance {
289            neighbors.retain(|n| n.distance <= max_dist);
290        }
291
292        // Truncate to requested k
293        neighbors.truncate(options.k);
294
295        // Re-rank after filtering
296        for (rank, neighbor) in neighbors.iter_mut().enumerate() {
297            neighbor.rank = rank;
298        }
299
300        debug!(found = neighbors.len(), "Neighbor search completed");
301        Ok(neighbors)
302    }
303
304    /// Find neighbors using a filter predicate.
305    ///
306    /// The filter function receives an EmbeddingId and returns true if the
307    /// embedding should be included in results.
308    #[instrument(skip(self, query, filter), fields(query_dim = query.len(), k))]
309    pub async fn find_neighbors_with_filter<F>(
310        &self,
311        query: &[f32],
312        k: usize,
313        filter: F,
314    ) -> Result<Vec<Neighbor>, VectorError>
315    where
316        F: Fn(&EmbeddingId) -> bool + Send + Sync,
317    {
318        self.validate_vector(query)?;
319
320        let query = if self.config.normalize {
321            normalize_vector(query)
322        } else {
323            query.to_vec()
324        };
325
326        let index = self.index.read();
327
328        if index.is_empty() {
329            return Ok(Vec::new());
330        }
331
332        // Fetch more results to account for filtering
333        let k_fetch = k * 4;
334        let results = index.search(&query, k_fetch);
335
336        let mut neighbors: Vec<_> = results
337            .into_iter()
338            .filter(|(id, _)| filter(id))
339            .take(k)
340            .enumerate()
341            .map(|(rank, (id, distance))| Neighbor::new(id, distance, rank))
342            .collect();
343
344        // Re-rank
345        for (rank, neighbor) in neighbors.iter_mut().enumerate() {
346            neighbor.rank = rank;
347        }
348
349        Ok(neighbors)
350    }
351
352    /// Remove an embedding from the index.
353    #[instrument(skip(self))]
354    pub async fn remove_embedding(&self, id: &EmbeddingId) -> Result<(), VectorError> {
355        let mut index = self.index.write();
356        index.remove(id)?;
357        debug!(id = %id, "Removed embedding from index");
358        Ok(())
359    }
360
361    /// Check if an embedding exists in the index.
362    pub fn contains(&self, id: &EmbeddingId) -> bool {
363        self.index.read().contains(id)
364    }
365
366    /// Get a vector by its ID.
367    pub fn get_vector(&self, id: &EmbeddingId) -> Option<Vec<f32>> {
368        self.index.read().get_vector(id)
369    }
370
371    /// Build similarity edges for an embedding.
372    ///
373    /// This finds the k nearest neighbors and creates edges to them.
374    #[instrument(skip(self, vector))]
375    pub async fn build_similarity_edges(
376        &self,
377        id: EmbeddingId,
378        vector: &[f32],
379        k: usize,
380        min_similarity: f32,
381    ) -> Result<Vec<SimilarityEdge>, VectorError> {
382        let neighbors = self
383            .find_neighbors_with_options(
384                vector,
385                SearchOptions::new(k).with_min_similarity(min_similarity),
386            )
387            .await?;
388
389        let edges: Vec<_> = neighbors
390            .into_iter()
391            .filter(|n| n.id != id) // Exclude self
392            .map(|n| {
393                SimilarityEdge::new(id, n.id, n.distance)
394                    .with_type(EdgeType::Similar)
395            })
396            .collect();
397
398        Ok(edges)
399    }
400
401    /// Compute pairwise similarities for a set of embeddings.
402    #[instrument(skip(self, vectors))]
403    pub async fn compute_pairwise_similarities(
404        &self,
405        vectors: &[(EmbeddingId, Vec<f32>)],
406    ) -> Result<Vec<(EmbeddingId, EmbeddingId, f32)>, VectorError> {
407        if vectors.len() < 2 {
408            return Ok(Vec::new());
409        }
410
411        // Validate all vectors
412        for (_, vector) in vectors {
413            self.validate_vector(vector)?;
414        }
415
416        // Normalize if needed
417        let vectors: Vec<_> = if self.config.normalize {
418            vectors
419                .iter()
420                .map(|(id, v)| (*id, normalize_vector(v)))
421                .collect()
422        } else {
423            vectors.to_vec()
424        };
425
426        let mut similarities = Vec::with_capacity(vectors.len() * (vectors.len() - 1) / 2);
427
428        for i in 0..vectors.len() {
429            for j in (i + 1)..vectors.len() {
430                let sim = cosine_similarity(&vectors[i].1, &vectors[j].1);
431                similarities.push((vectors[i].0, vectors[j].0, sim));
432            }
433        }
434
435        Ok(similarities)
436    }
437
438    /// Clear all embeddings from the index.
439    pub async fn clear(&self) -> Result<(), VectorError> {
440        let mut index = self.index.write();
441        index.clear();
442        info!("Cleared all embeddings from index");
443        Ok(())
444    }
445
446    /// Save the index to a file.
447    pub async fn save(&self, path: &std::path::Path) -> Result<(), VectorError> {
448        let index = self.index.read();
449        index.save(path)?;
450        info!(path = %path.display(), "Saved index to file");
451        Ok(())
452    }
453
454    /// Load an index from a file.
455    pub async fn load(path: &std::path::Path, config: HnswConfig) -> Result<Self, VectorError> {
456        let index = HnswIndex::load(path)?;
457        info!(path = %path.display(), "Loaded index from file");
458        Ok(Self::from_index(index, config))
459    }
460
461    /// Get index statistics.
462    pub fn stats(&self) -> IndexStatistics {
463        let index = self.index.read();
464        IndexStatistics {
465            vector_count: index.len(),
466            dimensions: self.config.dimensions,
467            max_capacity: self.config.max_elements,
468            utilization: index.len() as f64 / self.config.max_elements as f64,
469        }
470    }
471
472    /// Validate a vector.
473    fn validate_vector(&self, vector: &[f32]) -> Result<(), VectorError> {
474        if vector.len() != self.config.dimensions {
475            return Err(VectorError::dimension_mismatch(
476                self.config.dimensions,
477                vector.len(),
478            ));
479        }
480
481        // Check for NaN or Inf
482        for (i, &v) in vector.iter().enumerate() {
483            if v.is_nan() {
484                return Err(VectorError::invalid_vector(format!(
485                    "NaN value at index {i}"
486                )));
487            }
488            if v.is_infinite() {
489                return Err(VectorError::invalid_vector(format!(
490                    "Infinite value at index {i}"
491                )));
492            }
493        }
494
495        Ok(())
496    }
497}
498
499impl Clone for VectorSpaceService {
500    fn clone(&self) -> Self {
501        Self {
502            index: Arc::clone(&self.index),
503            config: self.config.clone(),
504        }
505    }
506}
507
508/// Statistics about the vector index.
509#[derive(Debug, Clone)]
510pub struct IndexStatistics {
511    /// Number of vectors in the index.
512    pub vector_count: usize,
513
514    /// Dimensionality of vectors.
515    pub dimensions: usize,
516
517    /// Maximum capacity.
518    pub max_capacity: usize,
519
520    /// Utilization ratio (0.0 - 1.0).
521    pub utilization: f64,
522}
523
524/// Builder for `VectorSpaceService`.
525pub struct VectorSpaceServiceBuilder {
526    config: HnswConfig,
527}
528
529impl VectorSpaceServiceBuilder {
530    /// Create a new builder with default configuration.
531    pub fn new() -> Self {
532        Self {
533            config: HnswConfig::default(),
534        }
535    }
536
537    /// Set the dimensions.
538    pub fn dimensions(mut self, dim: usize) -> Self {
539        self.config.dimensions = dim;
540        self
541    }
542
543    /// Set the M parameter.
544    pub fn m(mut self, m: usize) -> Self {
545        self.config.m = m;
546        self
547    }
548
549    /// Set ef_construction.
550    pub fn ef_construction(mut self, ef: usize) -> Self {
551        self.config.ef_construction = ef;
552        self
553    }
554
555    /// Set ef_search.
556    pub fn ef_search(mut self, ef: usize) -> Self {
557        self.config.ef_search = ef;
558        self
559    }
560
561    /// Set max elements.
562    pub fn max_elements(mut self, max: usize) -> Self {
563        self.config.max_elements = max;
564        self
565    }
566
567    /// Enable or disable normalization.
568    pub fn normalize(mut self, normalize: bool) -> Self {
569        self.config.normalize = normalize;
570        self
571    }
572
573    /// Build the service.
574    pub fn build(self) -> Result<VectorSpaceService, VectorError> {
575        self.config.validate()?;
576        Ok(VectorSpaceService::new(self.config))
577    }
578}
579
580impl Default for VectorSpaceServiceBuilder {
581    fn default() -> Self {
582        Self::new()
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    fn create_test_service() -> VectorSpaceService {
591        let config = HnswConfig::for_dimension(128)
592            .with_max_elements(1000)
593            .with_normalize(false);
594        VectorSpaceService::new(config)
595    }
596
597    #[tokio::test]
598    async fn test_add_and_search() {
599        let service = create_test_service();
600
601        let id1 = EmbeddingId::new();
602        let id2 = EmbeddingId::new();
603
604        let v1: Vec<f32> = (0..128).map(|i| i as f32 / 128.0).collect();
605        let v2: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) / 128.0).collect();
606
607        service.add_embedding(id1, v1.clone()).await.unwrap();
608        service.add_embedding(id2, v2).await.unwrap();
609
610        assert_eq!(service.len(), 2);
611
612        let neighbors = service.find_neighbors(&v1, 2).await.unwrap();
613        assert_eq!(neighbors.len(), 2);
614        assert_eq!(neighbors[0].id, id1);
615    }
616
617    #[tokio::test]
618    async fn test_dimension_mismatch() {
619        let service = create_test_service();
620        let id = EmbeddingId::new();
621        let wrong_dim: Vec<f32> = vec![0.1; 64];
622
623        let result = service.add_embedding(id, wrong_dim).await;
624        assert!(matches!(
625            result,
626            Err(VectorError::DimensionMismatch { .. })
627        ));
628    }
629
630    #[tokio::test]
631    async fn test_batch_insert() {
632        let service = create_test_service();
633
634        let items: Vec<_> = (0..10)
635            .map(|i| {
636                let id = EmbeddingId::new();
637                let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 1280.0).collect();
638                (id, vector)
639            })
640            .collect();
641
642        let added = service.add_embeddings_batch(items).await.unwrap();
643        assert_eq!(added, 10);
644        assert_eq!(service.len(), 10);
645    }
646
647    #[tokio::test]
648    async fn test_search_with_filter() {
649        let service = create_test_service();
650
651        let ids: Vec<_> = (0..5).map(|_| EmbeddingId::new()).collect();
652
653        for (i, id) in ids.iter().enumerate() {
654            let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 / 640.0).collect();
655            service.add_embedding(*id, vector).await.unwrap();
656        }
657
658        let query: Vec<f32> = (0..128).map(|j| j as f32 / 640.0).collect();
659
660        // Filter to only include odd indices
661        let odd_ids: std::collections::HashSet<_> =
662            ids.iter().enumerate().filter(|(i, _)| i % 2 == 1).map(|(_, id)| *id).collect();
663
664        let neighbors = service
665            .find_neighbors_with_filter(&query, 10, |id| odd_ids.contains(id))
666            .await
667            .unwrap();
668
669        for n in &neighbors {
670            assert!(odd_ids.contains(&n.id));
671        }
672    }
673
674    #[test]
675    fn test_neighbor() {
676        let neighbor = Neighbor::new(EmbeddingId::new(), 0.2, 0);
677        assert!((neighbor.similarity - 0.8).abs() < 0.001);
678        assert!(neighbor.is_above_threshold(0.7));
679        assert!(!neighbor.is_above_threshold(0.9));
680    }
681
682    #[test]
683    fn test_search_options() {
684        let opts = SearchOptions::new(10)
685            .with_min_similarity(0.8)
686            .with_max_distance(0.3);
687
688        assert_eq!(opts.k, 10);
689        assert_eq!(opts.min_similarity, Some(0.8));
690        assert_eq!(opts.max_distance, Some(0.3));
691    }
692
693    #[test]
694    fn test_builder() {
695        let service = VectorSpaceServiceBuilder::new()
696            .dimensions(256)
697            .m(16)
698            .ef_construction(100)
699            .max_elements(5000)
700            .build()
701            .unwrap();
702
703        assert_eq!(service.dimensions(), 256);
704    }
705}