smelt_memory/storage/
vectors.rs1use crate::error::{MemoryError, MemoryResult};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use uuid::Uuid;
9
10pub struct VectorStore {
12 path: Option<PathBuf>,
14 vectors: HashMap<Uuid, Vec<f32>>,
16 dimension: usize,
18}
19
20#[derive(Serialize, Deserialize)]
22struct VectorStoreData {
23 dimension: usize,
24 vectors: Vec<(String, Vec<f32>)>,
25}
26
27impl VectorStore {
28 pub fn new(dimension: usize) -> Self {
30 Self {
31 path: None,
32 vectors: HashMap::new(),
33 dimension,
34 }
35 }
36
37 pub fn open(path: &Path, dimension: usize) -> MemoryResult<Self> {
39 let mut store = Self {
40 path: Some(path.to_path_buf()),
41 vectors: HashMap::new(),
42 dimension,
43 };
44
45 if path.exists() {
47 let data = fs::read_to_string(path)?;
48 let stored: VectorStoreData = serde_json::from_str(&data)?;
49
50 if stored.dimension != dimension {
51 return Err(MemoryError::InvalidConfig(format!(
52 "Dimension mismatch: stored={}, expected={}",
53 stored.dimension, dimension
54 )));
55 }
56
57 for (id_str, vec) in stored.vectors {
58 if let Ok(id) = Uuid::parse_str(&id_str) {
59 store.vectors.insert(id, vec);
60 }
61 }
62 }
63
64 Ok(store)
65 }
66
67 pub fn store(&mut self, episode_id: Uuid, vector: Vec<f32>) -> MemoryResult<()> {
69 if vector.len() != self.dimension {
70 return Err(MemoryError::InvalidConfig(format!(
71 "Vector dimension mismatch: got={}, expected={}",
72 vector.len(),
73 self.dimension
74 )));
75 }
76
77 self.vectors.insert(episode_id, vector);
78 self.persist()?;
79 Ok(())
80 }
81
82 pub fn get(&self, episode_id: Uuid) -> Option<&Vec<f32>> {
84 self.vectors.get(&episode_id)
85 }
86
87 pub fn remove(&mut self, episode_id: Uuid) -> MemoryResult<()> {
89 self.vectors.remove(&episode_id);
90 self.persist()?;
91 Ok(())
92 }
93
94 pub fn search(&self, query: &[f32], limit: usize) -> Vec<(Uuid, f64)> {
96 if query.len() != self.dimension {
97 return Vec::new();
98 }
99
100 let mut results: Vec<(Uuid, f64)> = self
101 .vectors
102 .iter()
103 .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
104 .collect();
105
106 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
108
109 results.truncate(limit);
110 results
111 }
112
113 pub fn len(&self) -> usize {
115 self.vectors.len()
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.vectors.is_empty()
121 }
122
123 pub fn episode_ids(&self) -> Vec<Uuid> {
125 self.vectors.keys().copied().collect()
126 }
127
128 fn persist(&self) -> MemoryResult<()> {
130 if let Some(ref path) = self.path {
131 let data = VectorStoreData {
132 dimension: self.dimension,
133 vectors: self
134 .vectors
135 .iter()
136 .map(|(id, vec)| (id.to_string(), vec.clone()))
137 .collect(),
138 };
139
140 let json = serde_json::to_string_pretty(&data)?;
141
142 if let Some(parent) = path.parent() {
144 fs::create_dir_all(parent)?;
145 }
146
147 fs::write(path, json)?;
148 }
149 Ok(())
150 }
151}
152
153fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
155 if a.len() != b.len() || a.is_empty() {
156 return 0.0;
157 }
158
159 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
160 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
161 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
162
163 if norm_a == 0.0 || norm_b == 0.0 {
164 return 0.0;
165 }
166
167 (dot_product / (norm_a * norm_b)) as f64
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use tempfile::tempdir;
174
175 #[test]
176 fn test_store_and_get() {
177 let mut store = VectorStore::new(3);
178 let id = Uuid::new_v4();
179 let vec = vec![1.0, 2.0, 3.0];
180
181 store.store(id, vec.clone()).unwrap();
182
183 let retrieved = store.get(id).unwrap();
184 assert_eq!(retrieved, &vec);
185 }
186
187 #[test]
188 fn test_search() {
189 let mut store = VectorStore::new(3);
190
191 let id1 = Uuid::new_v4();
192 let id2 = Uuid::new_v4();
193 let id3 = Uuid::new_v4();
194
195 store.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
196 store.store(id2, vec![0.9, 0.1, 0.0]).unwrap();
197 store.store(id3, vec![0.0, 1.0, 0.0]).unwrap();
198
199 let query = vec![1.0, 0.0, 0.0];
200 let results = store.search(&query, 2);
201
202 assert_eq!(results.len(), 2);
203 assert_eq!(results[0].0, id1);
205 assert!((results[0].1 - 1.0).abs() < 0.001);
206 }
207
208 #[test]
209 fn test_persistence() {
210 let dir = tempdir().unwrap();
211 let path = dir.path().join("vectors.json");
212
213 let id = Uuid::new_v4();
214 let vec = vec![1.0, 2.0, 3.0];
215
216 {
218 let mut store = VectorStore::open(&path, 3).unwrap();
219 store.store(id, vec.clone()).unwrap();
220 }
221
222 {
224 let store = VectorStore::open(&path, 3).unwrap();
225 let retrieved = store.get(id).unwrap();
226 assert_eq!(retrieved, &vec);
227 }
228 }
229
230 #[test]
231 fn test_cosine_similarity() {
232 let a = vec![1.0, 0.0, 0.0];
233 let b = vec![1.0, 0.0, 0.0];
234 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
235
236 let c = vec![0.0, 1.0, 0.0];
237 assert!(cosine_similarity(&a, &c).abs() < 0.001);
238
239 let d = vec![-1.0, 0.0, 0.0];
240 assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
241 }
242
243 #[test]
244 fn test_dimension_mismatch() {
245 let mut store = VectorStore::new(3);
246 let id = Uuid::new_v4();
247
248 let result = store.store(id, vec![1.0, 2.0]); assert!(result.is_err());
250 }
251}