Skip to main content

smelt_memory/storage/
vectors.rs

1//! Vector storage for episode embeddings
2
3use crate::error::{MemoryError, MemoryResult};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use uuid::Uuid;
9
10/// In-memory vector store with persistence
11pub struct VectorStore {
12    /// Path for persistence (None for in-memory only)
13    path: Option<PathBuf>,
14    /// Vectors indexed by episode ID
15    vectors: HashMap<Uuid, Vec<f32>>,
16    /// Embedding dimension
17    dimension: usize,
18}
19
20/// Serializable format for persistence
21#[derive(Serialize, Deserialize)]
22struct VectorStoreData {
23    dimension: usize,
24    vectors: Vec<(String, Vec<f32>)>,
25}
26
27impl VectorStore {
28    /// Create a new vector store
29    pub fn new(dimension: usize) -> Self {
30        Self {
31            path: None,
32            vectors: HashMap::new(),
33            dimension,
34        }
35    }
36
37    /// Open or create a persistent vector store
38    pub fn open(path: &Path, dimension: usize) -> MemoryResult<Self> {
39        let mut store = Self {
40            path: Some(path.to_path_buf()),
41            vectors: HashMap::new(),
42            dimension,
43        };
44
45        // Load existing data if file exists
46        if path.exists() {
47            let data = fs::read_to_string(path)?;
48            let stored: VectorStoreData = serde_json::from_str(&data)?;
49
50            if stored.dimension != dimension {
51                return Err(MemoryError::InvalidConfig(format!(
52                    "Dimension mismatch: stored={}, expected={}",
53                    stored.dimension, dimension
54                )));
55            }
56
57            for (id_str, vec) in stored.vectors {
58                if let Ok(id) = Uuid::parse_str(&id_str) {
59                    store.vectors.insert(id, vec);
60                }
61            }
62        }
63
64        Ok(store)
65    }
66
67    /// Store a vector for an episode
68    pub fn store(&mut self, episode_id: Uuid, vector: Vec<f32>) -> MemoryResult<()> {
69        if vector.len() != self.dimension {
70            return Err(MemoryError::InvalidConfig(format!(
71                "Vector dimension mismatch: got={}, expected={}",
72                vector.len(),
73                self.dimension
74            )));
75        }
76
77        self.vectors.insert(episode_id, vector);
78        self.persist()?;
79        Ok(())
80    }
81
82    /// Get a vector by episode ID
83    pub fn get(&self, episode_id: Uuid) -> Option<&Vec<f32>> {
84        self.vectors.get(&episode_id)
85    }
86
87    /// Remove a vector
88    pub fn remove(&mut self, episode_id: Uuid) -> MemoryResult<()> {
89        self.vectors.remove(&episode_id);
90        self.persist()?;
91        Ok(())
92    }
93
94    /// Search for similar vectors using cosine similarity
95    pub fn search(&self, query: &[f32], limit: usize) -> Vec<(Uuid, f64)> {
96        if query.len() != self.dimension {
97            return Vec::new();
98        }
99
100        let mut results: Vec<(Uuid, f64)> = self
101            .vectors
102            .iter()
103            .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
104            .collect();
105
106        // Sort by similarity (descending)
107        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
108
109        results.truncate(limit);
110        results
111    }
112
113    /// Get the number of stored vectors
114    pub fn len(&self) -> usize {
115        self.vectors.len()
116    }
117
118    /// Check if empty
119    pub fn is_empty(&self) -> bool {
120        self.vectors.is_empty()
121    }
122
123    /// Get all episode IDs
124    pub fn episode_ids(&self) -> Vec<Uuid> {
125        self.vectors.keys().copied().collect()
126    }
127
128    /// Persist to disk
129    fn persist(&self) -> MemoryResult<()> {
130        if let Some(ref path) = self.path {
131            let data = VectorStoreData {
132                dimension: self.dimension,
133                vectors: self
134                    .vectors
135                    .iter()
136                    .map(|(id, vec)| (id.to_string(), vec.clone()))
137                    .collect(),
138            };
139
140            let json = serde_json::to_string_pretty(&data)?;
141
142            // Ensure parent directory exists
143            if let Some(parent) = path.parent() {
144                fs::create_dir_all(parent)?;
145            }
146
147            fs::write(path, json)?;
148        }
149        Ok(())
150    }
151}
152
153/// Calculate cosine similarity between two vectors
154fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
155    if a.len() != b.len() || a.is_empty() {
156        return 0.0;
157    }
158
159    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
160    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
161    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
162
163    if norm_a == 0.0 || norm_b == 0.0 {
164        return 0.0;
165    }
166
167    (dot_product / (norm_a * norm_b)) as f64
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use tempfile::tempdir;
174
175    #[test]
176    fn test_store_and_get() {
177        let mut store = VectorStore::new(3);
178        let id = Uuid::new_v4();
179        let vec = vec![1.0, 2.0, 3.0];
180
181        store.store(id, vec.clone()).unwrap();
182
183        let retrieved = store.get(id).unwrap();
184        assert_eq!(retrieved, &vec);
185    }
186
187    #[test]
188    fn test_search() {
189        let mut store = VectorStore::new(3);
190
191        let id1 = Uuid::new_v4();
192        let id2 = Uuid::new_v4();
193        let id3 = Uuid::new_v4();
194
195        store.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
196        store.store(id2, vec![0.9, 0.1, 0.0]).unwrap();
197        store.store(id3, vec![0.0, 1.0, 0.0]).unwrap();
198
199        let query = vec![1.0, 0.0, 0.0];
200        let results = store.search(&query, 2);
201
202        assert_eq!(results.len(), 2);
203        // First result should be exact match
204        assert_eq!(results[0].0, id1);
205        assert!((results[0].1 - 1.0).abs() < 0.001);
206    }
207
208    #[test]
209    fn test_persistence() {
210        let dir = tempdir().unwrap();
211        let path = dir.path().join("vectors.json");
212
213        let id = Uuid::new_v4();
214        let vec = vec![1.0, 2.0, 3.0];
215
216        // Store
217        {
218            let mut store = VectorStore::open(&path, 3).unwrap();
219            store.store(id, vec.clone()).unwrap();
220        }
221
222        // Load
223        {
224            let store = VectorStore::open(&path, 3).unwrap();
225            let retrieved = store.get(id).unwrap();
226            assert_eq!(retrieved, &vec);
227        }
228    }
229
230    #[test]
231    fn test_cosine_similarity() {
232        let a = vec![1.0, 0.0, 0.0];
233        let b = vec![1.0, 0.0, 0.0];
234        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
235
236        let c = vec![0.0, 1.0, 0.0];
237        assert!(cosine_similarity(&a, &c).abs() < 0.001);
238
239        let d = vec![-1.0, 0.0, 0.0];
240        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
241    }
242
243    #[test]
244    fn test_dimension_mismatch() {
245        let mut store = VectorStore::new(3);
246        let id = Uuid::new_v4();
247
248        let result = store.store(id, vec![1.0, 2.0]); // Wrong dimension
249        assert!(result.is_err());
250    }
251}