1use super::{DistanceMetric, VectorCollection};
2use crate::{Result, SqlRsError, Storage};
3use std::collections::HashMap;
4use std::time::{Duration, SystemTime};
5
6#[derive(Debug, Clone)]
7struct CacheEntry {
8 collection: VectorCollection,
9 last_accessed: SystemTime,
10 dirty: bool,
11}
12
13pub struct VectorStore<S: Storage> {
15 pub storage: S,
16 collections_cache: HashMap<String, CacheEntry>,
17 max_cache_entries: usize,
18 cache_ttl: Duration,
19}
20
21impl<S: Storage> VectorStore<S> {
22 pub fn new(storage: S) -> Self {
23 Self {
24 storage,
25 collections_cache: HashMap::new(),
26 max_cache_entries: 10,
27 cache_ttl: Duration::from_secs(300), }
29 }
30
31 pub fn with_cache_config(storage: S, max_entries: usize, ttl_secs: u64) -> Self {
33 Self {
34 storage,
35 collections_cache: HashMap::new(),
36 max_cache_entries: max_entries,
37 cache_ttl: Duration::from_secs(ttl_secs),
38 }
39 }
40
41 pub fn create_collection(
43 &mut self,
44 name: &str,
45 dimension: usize,
46 metric: DistanceMetric,
47 ) -> Result<()> {
48 let key = format!("__vector_collection__{}", name);
49
50 if self.storage.get(key.as_bytes())?.is_some() {
52 return Err(SqlRsError::Vector(format!(
53 "Collection '{}' already exists",
54 name
55 )));
56 }
57
58 let collection = VectorCollection::new(name, dimension, metric);
59 let serialized = bincode::serialize(&collection)
60 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
61
62 self.storage.put(key.as_bytes(), &serialized)?;
63
64 self.collections_cache.insert(
66 name.to_string(),
67 CacheEntry {
68 collection,
69 last_accessed: SystemTime::now(),
70 dirty: false,
71 },
72 );
73
74 Ok(())
75 }
76
77 pub fn get_collection(&mut self, name: &str) -> Result<VectorCollection> {
79 if let Some(entry) = self.collections_cache.get_mut(name) {
81 entry.last_accessed = SystemTime::now();
82 return Ok(entry.collection.clone());
83 }
84
85 let key = format!("__vector_collection__{}", name);
87 let bytes = self
88 .storage
89 .get(key.as_bytes())?
90 .ok_or_else(|| SqlRsError::Vector(format!("Collection '{}' not found", name)))?;
91
92 let collection: VectorCollection =
93 bincode::deserialize(&bytes).map_err(|e| SqlRsError::Serialization(e.to_string()))?;
94
95 self.evict_if_needed()?;
97
98 self.collections_cache.insert(
100 name.to_string(),
101 CacheEntry {
102 collection: collection.clone(),
103 last_accessed: SystemTime::now(),
104 dirty: false,
105 },
106 );
107
108 Ok(collection)
109 }
110
111 pub fn save_collection(&mut self, collection: VectorCollection) -> Result<()> {
113 let name = collection.name();
114 let key = format!("__vector_collection__{}", name);
115
116 let serialized = bincode::serialize(&collection)
117 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
118
119 self.storage.put(key.as_bytes(), &serialized)?;
120
121 if let Some(entry) = self.collections_cache.get_mut(name) {
123 entry.collection = collection;
124 entry.last_accessed = SystemTime::now();
125 entry.dirty = false;
126 } else {
127 self.evict_if_needed()?;
129 self.collections_cache.insert(
130 name.to_string(),
131 CacheEntry {
132 collection,
133 last_accessed: SystemTime::now(),
134 dirty: false,
135 },
136 );
137 }
138
139 Ok(())
140 }
141
142 fn evict_if_needed(&mut self) -> Result<()> {
144 let dirty_entries: Vec<_> = self
146 .collections_cache
147 .iter()
148 .filter(|(_, entry)| entry.dirty)
149 .map(|(name, _)| name.clone())
150 .collect();
151
152 for name in dirty_entries {
153 if let Some(entry) = self.collections_cache.remove(&name) {
154 let key = format!("__vector_collection__{}", name);
155 let serialized = bincode::serialize(&entry.collection)
156 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
157 self.storage.put(key.as_bytes(), &serialized)?;
158
159 self.collections_cache.insert(
161 name,
162 CacheEntry {
163 collection: entry.collection,
164 last_accessed: SystemTime::now(),
165 dirty: false,
166 },
167 );
168 }
169 }
170
171 if self.collections_cache.len() >= self.max_cache_entries {
173 let mut entries_by_age: Vec<_> = self
174 .collections_cache
175 .iter()
176 .map(|(name, entry)| (name.clone(), entry.last_accessed))
177 .collect();
178
179 entries_by_age.sort_by_key(|(_, time)| *time);
180
181 let to_remove = entries_by_age.len() - self.max_cache_entries + 1;
182 for (name, _) in entries_by_age.into_iter().take(to_remove) {
183 self.collections_cache.remove(&name);
184 }
185 }
186
187 Ok(())
188 }
189
190 pub fn flush_all(&mut self) -> Result<()> {
192 let collections: Vec<_> = self.collections_cache.keys().cloned().collect();
193
194 for name in collections {
195 if let Some(entry) = self.collections_cache.remove(&name) {
196 let key = format!("__vector_collection__{}", name);
197 let serialized = bincode::serialize(&entry.collection)
198 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
199 self.storage.put(key.as_bytes(), &serialized)?;
200 }
201 }
202
203 Ok(())
204 }
205
206 pub fn add_vector(&mut self, collection_name: &str, vector: super::Vector) -> Result<()> {
208 let mut collection = self.get_collection(collection_name)?;
209 collection.add(vector)?;
210 self.save_collection(collection)?;
211 Ok(())
212 }
213
214 pub fn add_vectors_batch(
216 &mut self,
217 collection_name: &str,
218 vectors: Vec<super::Vector>,
219 ) -> Result<usize> {
220 let mut collection = self.get_collection(collection_name)?;
221 let added_count = collection.add_batch(vectors)?;
222 self.save_collection(collection)?;
223 Ok(added_count)
224 }
225
226 pub fn search(
228 &mut self,
229 collection_name: &str,
230 query: &[f32],
231 k: usize,
232 ) -> Result<Vec<super::SearchResult>> {
233 let collection = self.get_collection(collection_name)?;
234 collection.search(query, k)
235 }
236
237 pub fn list_collections(&self) -> Result<Vec<String>> {
239 let prefix = b"__vector_collection__";
240 let scan_results = self.storage.scan(prefix, &[0xFF; 256])?;
241
242 let mut names = Vec::new();
243 for (key, _) in scan_results {
244 if let Ok(key_str) = std::str::from_utf8(&key) {
245 if let Some(name) = key_str.strip_prefix("__vector_collection__") {
246 names.push(name.to_string());
247 }
248 }
249 }
250
251 Ok(names)
252 }
253
254 pub fn delete_collection(&mut self, name: &str) -> Result<()> {
256 let key = format!("__vector_collection__{}", name);
257 self.storage.delete(key.as_bytes())?;
258 self.collections_cache.remove(name);
259 Ok(())
260 }
261
262 pub fn get_collection_stats(&mut self) -> Result<Vec<(String, usize, usize, DistanceMetric)>> {
264 let collection_names = self.list_collections()?;
265 let mut stats = Vec::new();
266
267 for name in collection_names {
268 let collection = self.get_collection(&name)?;
269 stats.push((
270 name,
271 collection.len(),
272 collection.dimension(),
273 collection.metric().clone(),
274 ));
275 }
276
277 Ok(stats)
278 }
279}