Skip to main content

sql_rs/vector/
collection.rs

1use super::{DistanceMetric, HnswIndex, SearchResult, Vector};
2use crate::{Result, SqlRsError};
3use std::collections::HashMap;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct VectorCollection {
8    name: String,
9    dimension: usize,
10    metric: DistanceMetric,
11    index: HnswIndex,
12    vectors: HashMap<String, Vector>,
13    id_mapping: HashMap<usize, String>,
14}
15
16impl VectorCollection {
17    pub fn new(name: impl Into<String>, dimension: usize, metric: DistanceMetric) -> Self {
18        Self {
19            name: name.into(),
20            dimension,
21            metric,
22            index: HnswIndex::new(metric),
23            vectors: HashMap::new(),
24            id_mapping: HashMap::new(),
25        }
26    }
27    
28    pub fn add(&mut self, vector: Vector) -> Result<()> {
29        if vector.dimension() != self.dimension {
30            return Err(SqlRsError::Vector(format!(
31                "Expected dimension {}, got {}",
32                self.dimension,
33                vector.dimension()
34            )));
35        }
36        
37        let internal_id = self.index.add(vector.embedding.clone());
38        self.id_mapping.insert(internal_id, vector.id.clone());
39        self.vectors.insert(vector.id.clone(), vector);
40        
41        Ok(())
42    }
43    
44    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
45        if query.len() != self.dimension {
46            return Err(SqlRsError::Vector(format!(
47                "Expected dimension {}, got {}",
48                self.dimension,
49                query.len()
50            )));
51        }
52        
53        let results = self.index.search(query, k);
54        
55        let mut search_results = Vec::new();
56        for (internal_id, distance) in results {
57            if let Some(vector_id) = self.id_mapping.get(&internal_id) {
58                if let Some(vector) = self.vectors.get(vector_id) {
59                    search_results.push(SearchResult {
60                        id: vector.id.clone(),
61                        distance,
62                        metadata: vector.metadata.clone(),
63                    });
64                }
65            }
66        }
67        
68        Ok(search_results)
69    }
70    
71    pub fn get(&self, id: &str) -> Option<&Vector> {
72        self.vectors.get(id)
73    }
74    
75    pub fn len(&self) -> usize {
76        self.vectors.len()
77    }
78    
79    pub fn is_empty(&self) -> bool {
80        self.vectors.is_empty()
81    }
82    
83    pub fn name(&self) -> &str {
84        &self.name
85    }
86    
87    pub fn dimension(&self) -> usize {
88        self.dimension
89    }
90    
91    pub fn metric(&self) -> DistanceMetric {
92        self.metric
93    }
94}