Skip to main content

sql_rs/vector/
store.rs

1use super::{VectorCollection, DistanceMetric};
2use crate::{Result, SqlRsError, Storage};
3use std::collections::HashMap;
4
5/// Persistent vector store that saves collections to disk
6pub struct VectorStore<S: Storage> {
7    storage: S,
8    collections_cache: HashMap<String, VectorCollection>,
9}
10
11impl<S: Storage> VectorStore<S> {
12    pub fn new(storage: S) -> Self {
13        Self {
14            storage,
15            collections_cache: HashMap::new(),
16        }
17    }
18
19    /// Create a new vector collection
20    pub fn create_collection(&mut self, name: &str, dimension: usize, metric: DistanceMetric) -> Result<()> {
21        let key = format!("__vector_collection__{}", name);
22
23        // Check if collection already exists
24        if self.storage.get(key.as_bytes())?.is_some() {
25            return Err(SqlRsError::Vector(format!("Collection '{}' already exists", name)));
26        }
27
28        let collection = VectorCollection::new(name, dimension, metric);
29        let serialized = bincode::serialize(&collection)
30            .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
31
32        self.storage.put(key.as_bytes(), &serialized)?;
33        self.collections_cache.insert(name.to_string(), collection);
34
35        Ok(())
36    }
37
38    /// Get a vector collection, loading from disk if not cached
39    pub fn get_collection(&mut self, name: &str) -> Result<VectorCollection> {
40        // Check cache first
41        if let Some(collection) = self.collections_cache.get(name) {
42            return Ok(collection.clone());
43        }
44
45        // Load from storage
46        let key = format!("__vector_collection__{}", name);
47        let bytes = self.storage.get(key.as_bytes())?
48            .ok_or_else(|| SqlRsError::Vector(format!("Collection '{}' not found", name)))?;
49
50        let collection: VectorCollection = bincode::deserialize(&bytes)
51            .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
52
53        self.collections_cache.insert(name.to_string(), collection.clone());
54        Ok(collection)
55    }
56
57    /// Save a collection to disk and cache
58    pub fn save_collection(&mut self, collection: VectorCollection) -> Result<()> {
59        let name = collection.name();
60        let key = format!("__vector_collection__{}", name);
61
62        let serialized = bincode::serialize(&collection)
63            .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
64
65        self.storage.put(key.as_bytes(), &serialized)?;
66        self.collections_cache.insert(name.to_string(), collection);
67
68        Ok(())
69    }
70
71    /// Add a vector to a collection
72    pub fn add_vector(&mut self, collection_name: &str, vector: super::Vector) -> Result<()> {
73        let mut collection = self.get_collection(collection_name)?;
74        collection.add(vector)?;
75        self.save_collection(collection)?;
76        Ok(())
77    }
78
79    /// Search for similar vectors in a collection
80    pub fn search(&mut self, collection_name: &str, query: &[f32], k: usize) -> Result<Vec<super::SearchResult>> {
81        let collection = self.get_collection(collection_name)?;
82        collection.search(query, k)
83    }
84
85    /// List all collection names
86    pub fn list_collections(&self) -> Result<Vec<String>> {
87        let prefix = b"__vector_collection__";
88        let scan_results = self.storage.scan(prefix, &[0xFF; 256])?;
89
90        let mut names = Vec::new();
91        for (key, _) in scan_results {
92            if let Ok(key_str) = std::str::from_utf8(&key) {
93                if let Some(name) = key_str.strip_prefix("__vector_collection__") {
94                    names.push(name.to_string());
95                }
96            }
97        }
98
99        Ok(names)
100    }
101
102    /// Delete a collection
103    pub fn delete_collection(&mut self, name: &str) -> Result<()> {
104        let key = format!("__vector_collection__{}", name);
105        self.storage.delete(key.as_bytes())?;
106        self.collections_cache.remove(name);
107        Ok(())
108    }
109}