Skip to main content

sql_rs/vector/
collection.rs

1use super::{DistanceMetric, HnswIndex, SearchResult, Vector};
2use crate::{Result, SqlRsError};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct VectorCollection {
8    name: String,
9    dimension: usize,
10    metric: DistanceMetric,
11    index: HnswIndex,
12    vectors: HashMap<String, Vector>,
13    id_mapping: HashMap<usize, String>,
14}
15
16impl VectorCollection {
17    pub fn new(name: impl Into<String>, dimension: usize, metric: DistanceMetric) -> Self {
18        Self {
19            name: name.into(),
20            dimension,
21            metric,
22            index: HnswIndex::new(metric),
23            vectors: HashMap::new(),
24            id_mapping: HashMap::new(),
25        }
26    }
27
28    pub fn add(&mut self, vector: Vector) -> Result<()> {
29        if vector.dimension() != self.dimension {
30            return Err(SqlRsError::Vector(format!(
31                "Expected dimension {}, got {}",
32                self.dimension,
33                vector.dimension()
34            )));
35        }
36
37        let internal_id = self.index.add(vector.embedding.clone());
38        self.id_mapping.insert(internal_id, vector.id.clone());
39        self.vectors.insert(vector.id.clone(), vector);
40
41        Ok(())
42    }
43
44    /// Add multiple vectors in batch for better performance
45    pub fn add_batch(&mut self, vectors: Vec<Vector>) -> Result<usize> {
46        let mut added_count = 0;
47
48        for vector in vectors {
49            if vector.dimension() != self.dimension {
50                return Err(SqlRsError::Vector(format!(
51                    "Expected dimension {}, got {}",
52                    self.dimension,
53                    vector.dimension()
54                )));
55            }
56
57            let internal_id = self.index.add(vector.embedding.clone());
58            self.id_mapping.insert(internal_id, vector.id.clone());
59            self.vectors.insert(vector.id.clone(), vector);
60            added_count += 1;
61        }
62
63        Ok(added_count)
64    }
65
66    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
67        if query.len() != self.dimension {
68            return Err(SqlRsError::Vector(format!(
69                "Expected dimension {}, got {}",
70                self.dimension,
71                query.len()
72            )));
73        }
74
75        let results = self.index.search(query, k);
76
77        let mut search_results = Vec::new();
78        for (internal_id, distance) in results {
79            if let Some(vector_id) = self.id_mapping.get(&internal_id) {
80                if let Some(vector) = self.vectors.get(vector_id) {
81                    search_results.push(SearchResult {
82                        id: vector.id.clone(),
83                        distance,
84                        metadata: vector.metadata.clone(),
85                    });
86                }
87            }
88        }
89
90        Ok(search_results)
91    }
92
93    /// Search with metadata filtering for hybrid search
94    pub fn search_filtered(
95        &self,
96        query: &[f32],
97        k: usize,
98        filter: impl Fn(&HashMap<String, String>) -> bool,
99    ) -> Result<Vec<SearchResult>> {
100        if query.len() != self.dimension {
101            return Err(SqlRsError::Vector(format!(
102                "Expected dimension {}, got {}",
103                self.dimension,
104                query.len()
105            )));
106        }
107
108        // Get more candidates initially since we'll filter some out
109        let candidate_k = k * 4; // Get 4x more candidates
110        let results = self.index.search(query, candidate_k);
111
112        let mut search_results = Vec::new();
113        for (internal_id, distance) in results {
114            if let Some(vector_id) = self.id_mapping.get(&internal_id) {
115                if let Some(vector) = self.vectors.get(vector_id) {
116                    // Apply filter
117                    if filter(&vector.metadata) {
118                        search_results.push(SearchResult {
119                            id: vector.id.clone(),
120                            distance,
121                            metadata: vector.metadata.clone(),
122                        });
123
124                        // Stop if we have enough results
125                        if search_results.len() >= k {
126                            break;
127                        }
128                    }
129                }
130            }
131        }
132
133        Ok(search_results)
134    }
135
136    /// Search with exact metadata match
137    pub fn search_by_metadata(
138        &self,
139        query: &[f32],
140        k: usize,
141        key: &str,
142        value: &str,
143    ) -> Result<Vec<SearchResult>> {
144        self.search_filtered(query, k, |metadata| {
145            metadata.get(key).map_or(false, |v| v == value)
146        })
147    }
148
149    /// Search with multiple metadata filters (all must match)
150    pub fn search_by_metadata_multiple(
151        &self,
152        query: &[f32],
153        k: usize,
154        filters: &[(&str, &str)],
155    ) -> Result<Vec<SearchResult>> {
156        self.search_filtered(query, k, |metadata| {
157            filters
158                .iter()
159                .all(|(key, value)| metadata.get(*key).map_or(false, |v| v == *value))
160        })
161    }
162
163    pub fn get(&self, id: &str) -> Option<&Vector> {
164        self.vectors.get(id)
165    }
166
167    pub fn len(&self) -> usize {
168        self.vectors.len()
169    }
170
171    pub fn is_empty(&self) -> bool {
172        self.vectors.is_empty()
173    }
174
175    pub fn name(&self) -> &str {
176        &self.name
177    }
178
179    pub fn dimension(&self) -> usize {
180        self.dimension
181    }
182
183    pub fn metric(&self) -> DistanceMetric {
184        self.metric
185    }
186}