sql_rs/vector/
collection.rs1use super::{DistanceMetric, HnswIndex, SearchResult, Vector};
2use crate::{Result, SqlRsError};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct VectorCollection {
8 name: String,
9 dimension: usize,
10 metric: DistanceMetric,
11 index: HnswIndex,
12 vectors: HashMap<String, Vector>,
13 id_mapping: HashMap<usize, String>,
14}
15
16impl VectorCollection {
17 pub fn new(name: impl Into<String>, dimension: usize, metric: DistanceMetric) -> Self {
18 Self {
19 name: name.into(),
20 dimension,
21 metric,
22 index: HnswIndex::new(metric),
23 vectors: HashMap::new(),
24 id_mapping: HashMap::new(),
25 }
26 }
27
28 pub fn add(&mut self, vector: Vector) -> Result<()> {
29 if vector.dimension() != self.dimension {
30 return Err(SqlRsError::Vector(format!(
31 "Expected dimension {}, got {}",
32 self.dimension,
33 vector.dimension()
34 )));
35 }
36
37 let internal_id = self.index.add(vector.embedding.clone());
38 self.id_mapping.insert(internal_id, vector.id.clone());
39 self.vectors.insert(vector.id.clone(), vector);
40
41 Ok(())
42 }
43
44 pub fn add_batch(&mut self, vectors: Vec<Vector>) -> Result<usize> {
46 let mut added_count = 0;
47
48 for vector in vectors {
49 if vector.dimension() != self.dimension {
50 return Err(SqlRsError::Vector(format!(
51 "Expected dimension {}, got {}",
52 self.dimension,
53 vector.dimension()
54 )));
55 }
56
57 let internal_id = self.index.add(vector.embedding.clone());
58 self.id_mapping.insert(internal_id, vector.id.clone());
59 self.vectors.insert(vector.id.clone(), vector);
60 added_count += 1;
61 }
62
63 Ok(added_count)
64 }
65
66 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
67 if query.len() != self.dimension {
68 return Err(SqlRsError::Vector(format!(
69 "Expected dimension {}, got {}",
70 self.dimension,
71 query.len()
72 )));
73 }
74
75 let results = self.index.search(query, k);
76
77 let mut search_results = Vec::new();
78 for (internal_id, distance) in results {
79 if let Some(vector_id) = self.id_mapping.get(&internal_id) {
80 if let Some(vector) = self.vectors.get(vector_id) {
81 search_results.push(SearchResult {
82 id: vector.id.clone(),
83 distance,
84 metadata: vector.metadata.clone(),
85 });
86 }
87 }
88 }
89
90 Ok(search_results)
91 }
92
93 pub fn search_filtered(
95 &self,
96 query: &[f32],
97 k: usize,
98 filter: impl Fn(&HashMap<String, String>) -> bool,
99 ) -> Result<Vec<SearchResult>> {
100 if query.len() != self.dimension {
101 return Err(SqlRsError::Vector(format!(
102 "Expected dimension {}, got {}",
103 self.dimension,
104 query.len()
105 )));
106 }
107
108 let candidate_k = k * 4; let results = self.index.search(query, candidate_k);
111
112 let mut search_results = Vec::new();
113 for (internal_id, distance) in results {
114 if let Some(vector_id) = self.id_mapping.get(&internal_id) {
115 if let Some(vector) = self.vectors.get(vector_id) {
116 if filter(&vector.metadata) {
118 search_results.push(SearchResult {
119 id: vector.id.clone(),
120 distance,
121 metadata: vector.metadata.clone(),
122 });
123
124 if search_results.len() >= k {
126 break;
127 }
128 }
129 }
130 }
131 }
132
133 Ok(search_results)
134 }
135
136 pub fn search_by_metadata(
138 &self,
139 query: &[f32],
140 k: usize,
141 key: &str,
142 value: &str,
143 ) -> Result<Vec<SearchResult>> {
144 self.search_filtered(query, k, |metadata| {
145 metadata.get(key).map_or(false, |v| v == value)
146 })
147 }
148
149 pub fn search_by_metadata_multiple(
151 &self,
152 query: &[f32],
153 k: usize,
154 filters: &[(&str, &str)],
155 ) -> Result<Vec<SearchResult>> {
156 self.search_filtered(query, k, |metadata| {
157 filters
158 .iter()
159 .all(|(key, value)| metadata.get(*key).map_or(false, |v| v == *value))
160 })
161 }
162
163 pub fn get(&self, id: &str) -> Option<&Vector> {
164 self.vectors.get(id)
165 }
166
167 pub fn len(&self) -> usize {
168 self.vectors.len()
169 }
170
171 pub fn is_empty(&self) -> bool {
172 self.vectors.is_empty()
173 }
174
175 pub fn name(&self) -> &str {
176 &self.name
177 }
178
179 pub fn dimension(&self) -> usize {
180 self.dimension
181 }
182
183 pub fn metric(&self) -> DistanceMetric {
184 self.metric
185 }
186}