Skip to main content

sql_rs/vector/
store.rs

1use super::{DistanceMetric, VectorCollection};
2use crate::{Result, SqlRsError, Storage};
3use std::collections::HashMap;
4use std::time::{Duration, SystemTime};
5
6#[derive(Debug, Clone)]
7struct CacheEntry {
8    collection: VectorCollection,
9    last_accessed: SystemTime,
10    dirty: bool,
11}
12
13/// Persistent vector store that saves collections to disk
14pub struct VectorStore<S: Storage> {
15    pub storage: S,
16    collections_cache: HashMap<String, CacheEntry>,
17    max_cache_entries: usize,
18    cache_ttl: Duration,
19}
20
21impl<S: Storage> VectorStore<S> {
22    pub fn new(storage: S) -> Self {
23        Self {
24            storage,
25            collections_cache: HashMap::new(),
26            max_cache_entries: 10,
27            cache_ttl: Duration::from_secs(300), // 5 minutes
28        }
29    }
30
31    /// Create a new VectorStore with custom cache settings
32    pub fn with_cache_config(storage: S, max_entries: usize, ttl_secs: u64) -> Self {
33        Self {
34            storage,
35            collections_cache: HashMap::new(),
36            max_cache_entries: max_entries,
37            cache_ttl: Duration::from_secs(ttl_secs),
38        }
39    }
40
41    /// Create a new vector collection
42    pub fn create_collection(
43        &mut self,
44        name: &str,
45        dimension: usize,
46        metric: DistanceMetric,
47    ) -> Result<()> {
48        let key = format!("__vector_collection__{}", name);
49
50        // Check if collection already exists
51        if self.storage.get(key.as_bytes())?.is_some() {
52            return Err(SqlRsError::Vector(format!(
53                "Collection '{}' already exists",
54                name
55            )));
56        }
57
58        let collection = VectorCollection::new(name, dimension, metric);
59        let serialized = bincode::serialize(&collection)
60            .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
61
62        self.storage.put(key.as_bytes(), &serialized)?;
63
64        // Add to cache
65        self.collections_cache.insert(
66            name.to_string(),
67            CacheEntry {
68                collection,
69                last_accessed: SystemTime::now(),
70                dirty: false,
71            },
72        );
73
74        Ok(())
75    }
76
77    /// Get a vector collection, loading from disk if not cached
78    pub fn get_collection(&mut self, name: &str) -> Result<VectorCollection> {
79        // Check cache first
80        if let Some(entry) = self.collections_cache.get_mut(name) {
81            entry.last_accessed = SystemTime::now();
82            return Ok(entry.collection.clone());
83        }
84
85        // Load from storage
86        let key = format!("__vector_collection__{}", name);
87        let bytes = self
88            .storage
89            .get(key.as_bytes())?
90            .ok_or_else(|| SqlRsError::Vector(format!("Collection '{}' not found", name)))?;
91
92        let collection: VectorCollection =
93            bincode::deserialize(&bytes).map_err(|e| SqlRsError::Serialization(e.to_string()))?;
94
95        // Check cache size and evict if necessary
96        self.evict_if_needed()?;
97
98        // Add to cache
99        self.collections_cache.insert(
100            name.to_string(),
101            CacheEntry {
102                collection: collection.clone(),
103                last_accessed: SystemTime::now(),
104                dirty: false,
105            },
106        );
107
108        Ok(collection)
109    }
110
111    /// Save a collection to disk and cache
112    pub fn save_collection(&mut self, collection: VectorCollection) -> Result<()> {
113        let name = collection.name();
114        let key = format!("__vector_collection__{}", name);
115
116        let serialized = bincode::serialize(&collection)
117            .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
118
119        self.storage.put(key.as_bytes(), &serialized)?;
120
121        // Update cache
122        if let Some(entry) = self.collections_cache.get_mut(name) {
123            entry.collection = collection;
124            entry.last_accessed = SystemTime::now();
125            entry.dirty = false;
126        } else {
127            // Add to cache if not present
128            self.evict_if_needed()?;
129            self.collections_cache.insert(
130                name.to_string(),
131                CacheEntry {
132                    collection,
133                    last_accessed: SystemTime::now(),
134                    dirty: false,
135                },
136            );
137        }
138
139        Ok(())
140    }
141
142    /// Evict old or dirty entries from cache
143    fn evict_if_needed(&mut self) -> Result<()> {
144        // First, save any dirty entries
145        let dirty_entries: Vec<_> = self
146            .collections_cache
147            .iter()
148            .filter(|(_, entry)| entry.dirty)
149            .map(|(name, _)| name.clone())
150            .collect();
151
152        for name in dirty_entries {
153            if let Some(entry) = self.collections_cache.remove(&name) {
154                let key = format!("__vector_collection__{}", name);
155                let serialized = bincode::serialize(&entry.collection)
156                    .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
157                self.storage.put(key.as_bytes(), &serialized)?;
158
159                // Re-add as clean
160                self.collections_cache.insert(
161                    name,
162                    CacheEntry {
163                        collection: entry.collection,
164                        last_accessed: SystemTime::now(),
165                        dirty: false,
166                    },
167                );
168            }
169        }
170
171        // If cache is still full, remove oldest entries
172        if self.collections_cache.len() >= self.max_cache_entries {
173            let mut entries_by_age: Vec<_> = self
174                .collections_cache
175                .iter()
176                .map(|(name, entry)| (name.clone(), entry.last_accessed))
177                .collect();
178
179            entries_by_age.sort_by_key(|(_, time)| *time);
180
181            let to_remove = entries_by_age.len() - self.max_cache_entries + 1;
182            for (name, _) in entries_by_age.into_iter().take(to_remove) {
183                self.collections_cache.remove(&name);
184            }
185        }
186
187        Ok(())
188    }
189
190    /// Flush all cached collections to disk
191    pub fn flush_all(&mut self) -> Result<()> {
192        let collections: Vec<_> = self.collections_cache.keys().cloned().collect();
193
194        for name in collections {
195            if let Some(entry) = self.collections_cache.remove(&name) {
196                let key = format!("__vector_collection__{}", name);
197                let serialized = bincode::serialize(&entry.collection)
198                    .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
199                self.storage.put(key.as_bytes(), &serialized)?;
200            }
201        }
202
203        Ok(())
204    }
205
206    /// Add a vector to a collection
207    pub fn add_vector(&mut self, collection_name: &str, vector: super::Vector) -> Result<()> {
208        let mut collection = self.get_collection(collection_name)?;
209        collection.add(vector)?;
210        self.save_collection(collection)?;
211        Ok(())
212    }
213
214    /// Add multiple vectors to a collection in batch
215    pub fn add_vectors_batch(
216        &mut self,
217        collection_name: &str,
218        vectors: Vec<super::Vector>,
219    ) -> Result<usize> {
220        let mut collection = self.get_collection(collection_name)?;
221        let added_count = collection.add_batch(vectors)?;
222        self.save_collection(collection)?;
223        Ok(added_count)
224    }
225
226    /// Search for similar vectors in a collection
227    pub fn search(
228        &mut self,
229        collection_name: &str,
230        query: &[f32],
231        k: usize,
232    ) -> Result<Vec<super::SearchResult>> {
233        let collection = self.get_collection(collection_name)?;
234        collection.search(query, k)
235    }
236
237    /// List all collection names
238    pub fn list_collections(&self) -> Result<Vec<String>> {
239        let prefix = b"__vector_collection__";
240        let scan_results = self.storage.scan(prefix, &[0xFF; 256])?;
241
242        let mut names = Vec::new();
243        for (key, _) in scan_results {
244            if let Ok(key_str) = std::str::from_utf8(&key) {
245                if let Some(name) = key_str.strip_prefix("__vector_collection__") {
246                    names.push(name.to_string());
247                }
248            }
249        }
250
251        Ok(names)
252    }
253
254    /// Delete a collection
255    pub fn delete_collection(&mut self, name: &str) -> Result<()> {
256        let key = format!("__vector_collection__{}", name);
257        self.storage.delete(key.as_bytes())?;
258        self.collections_cache.remove(name);
259        Ok(())
260    }
261
262    /// Get statistics about all vector collections
263    pub fn get_collection_stats(&mut self) -> Result<Vec<(String, usize, usize, DistanceMetric)>> {
264        let collection_names = self.list_collections()?;
265        let mut stats = Vec::new();
266
267        for name in collection_names {
268            let collection = self.get_collection(&name)?;
269            stats.push((
270                name,
271                collection.len(),
272                collection.dimension(),
273                collection.metric().clone(),
274            ));
275        }
276
277        Ok(stats)
278    }
279}