1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8pub use synaptic_core::{Embeddings, Item, Store, SynapticError};
10
11fn namespace_key(namespace: &[&str]) -> String {
12 namespace.join("::")
13}
14
15fn now_iso() -> String {
16 format!("{:?}", std::time::SystemTime::now())
18}
19
20fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
22 if a.len() != b.len() || a.is_empty() {
23 return 0.0;
24 }
25 let mut dot = 0.0f64;
26 let mut norm_a = 0.0f64;
27 let mut norm_b = 0.0f64;
28 for (x, y) in a.iter().zip(b.iter()) {
29 let x = *x as f64;
30 let y = *y as f64;
31 dot += x * y;
32 norm_a += x * x;
33 norm_b += y * y;
34 }
35 let denom = norm_a.sqrt() * norm_b.sqrt();
36 if denom == 0.0 {
37 0.0
38 } else {
39 dot / denom
40 }
41}
42
43pub struct InMemoryStore {
47 data: Arc<RwLock<HashMap<String, HashMap<String, Item>>>>,
48 embeddings: Option<Arc<dyn Embeddings>>,
50 vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
52}
53
54impl Default for InMemoryStore {
55 fn default() -> Self {
56 Self {
57 data: Arc::new(RwLock::new(HashMap::new())),
58 embeddings: None,
59 vectors: Arc::new(RwLock::new(HashMap::new())),
60 }
61 }
62}
63
64impl InMemoryStore {
65 pub fn new() -> Self {
66 Self::default()
67 }
68
69 pub fn with_embeddings(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
75 self.embeddings = Some(embeddings);
76 self
77 }
78}
79
80#[async_trait]
81impl Store for InMemoryStore {
82 async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
83 let data = self.data.read().await;
84 let ns_key = namespace_key(namespace);
85 Ok(data.get(&ns_key).and_then(|ns| ns.get(key).cloned()))
86 }
87
88 async fn search(
89 &self,
90 namespace: &[&str],
91 query: Option<&str>,
92 limit: usize,
93 ) -> Result<Vec<Item>, SynapticError> {
94 let data = self.data.read().await;
95 let ns_key = namespace_key(namespace);
96
97 let Some(ns) = data.get(&ns_key) else {
98 return Ok(vec![]);
99 };
100
101 if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
103 let query_vec = embeddings.embed_query(q).await?;
104 let vectors = self.vectors.read().await;
105
106 let mut scored: Vec<(Item, f64)> = ns
107 .iter()
108 .map(|(key, item)| {
109 let vec_key = format!("{}::{}", ns_key, key);
110 let score = vectors
111 .get(&vec_key)
112 .map(|v| cosine_similarity(v, &query_vec))
113 .unwrap_or(0.0);
114 let mut item = item.clone();
115 item.score = Some(score);
116 (item, score)
117 })
118 .collect();
119
120 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
122 scored.truncate(limit);
123
124 return Ok(scored.into_iter().map(|(item, _)| item).collect());
125 }
126
127 let items: Vec<Item> = ns
129 .values()
130 .filter(|item| {
131 if let Some(q) = query {
132 item.key.contains(q) || item.value.to_string().contains(q)
134 } else {
135 true
136 }
137 })
138 .take(limit)
139 .cloned()
140 .collect();
141
142 Ok(items)
143 }
144
145 async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
146 let mut data = self.data.write().await;
147 let ns_key = namespace_key(namespace);
148 let ns = data.entry(ns_key.clone()).or_default();
149 let now = now_iso();
150
151 let item = if let Some(existing) = ns.get(key) {
152 Item {
153 namespace: namespace.iter().map(|s| s.to_string()).collect(),
154 key: key.to_string(),
155 value: value.clone(),
156 created_at: existing.created_at.clone(),
157 updated_at: now,
158 score: None,
159 }
160 } else {
161 Item {
162 namespace: namespace.iter().map(|s| s.to_string()).collect(),
163 key: key.to_string(),
164 value: value.clone(),
165 created_at: now.clone(),
166 updated_at: now,
167 score: None,
168 }
169 };
170
171 ns.insert(key.to_string(), item);
172
173 if let Some(ref embeddings) = self.embeddings {
175 let text = value.as_str().unwrap_or(&value.to_string()).to_string();
176 let vecs = embeddings.embed_documents(&[&text]).await?;
177 if let Some(vec) = vecs.into_iter().next() {
178 let vec_key = format!("{}::{}", ns_key, key);
179 self.vectors.write().await.insert(vec_key, vec);
180 }
181 }
182
183 Ok(())
184 }
185
186 async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
187 let mut data = self.data.write().await;
188 let ns_key = namespace_key(namespace);
189 if let Some(ns) = data.get_mut(&ns_key) {
190 ns.remove(key);
191 }
192 let vec_key = format!("{}::{}", ns_key, key);
194 self.vectors.write().await.remove(&vec_key);
195 Ok(())
196 }
197
198 async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
199 let data = self.data.read().await;
200 let prefix_str = if prefix.is_empty() {
201 String::new()
202 } else {
203 namespace_key(prefix)
204 };
205
206 let namespaces: Vec<Vec<String>> = data
207 .keys()
208 .filter(|k| prefix.is_empty() || k.starts_with(&prefix_str))
209 .map(|k| k.split("::").map(String::from).collect())
210 .collect();
211
212 Ok(namespaces)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use serde_json::json;
220
221 #[tokio::test]
222 async fn put_and_get() {
223 let store = InMemoryStore::new();
224 store
225 .put(&["users", "prefs"], "theme", json!("dark"))
226 .await
227 .unwrap();
228
229 let item = store
230 .get(&["users", "prefs"], "theme")
231 .await
232 .unwrap()
233 .unwrap();
234 assert_eq!(item.key, "theme");
235 assert_eq!(item.value, json!("dark"));
236 assert_eq!(item.namespace, vec!["users", "prefs"]);
237 }
238
239 #[tokio::test]
240 async fn get_nonexistent() {
241 let store = InMemoryStore::new();
242 let item = store.get(&["a"], "missing").await.unwrap();
243 assert!(item.is_none());
244 }
245
246 #[tokio::test]
247 async fn delete_item() {
248 let store = InMemoryStore::new();
249 store.put(&["ns"], "k", json!(1)).await.unwrap();
250 store.delete(&["ns"], "k").await.unwrap();
251 assert!(store.get(&["ns"], "k").await.unwrap().is_none());
252 }
253
254 #[tokio::test]
255 async fn search_items() {
256 let store = InMemoryStore::new();
257 store.put(&["ns"], "a", json!("apple")).await.unwrap();
258 store.put(&["ns"], "b", json!("banana")).await.unwrap();
259 store.put(&["ns"], "c", json!("cherry")).await.unwrap();
260
261 let all = store.search(&["ns"], None, 10).await.unwrap();
262 assert_eq!(all.len(), 3);
263
264 let filtered = store.search(&["ns"], Some("apple"), 10).await.unwrap();
265 assert_eq!(filtered.len(), 1);
266 }
267
268 #[tokio::test]
269 async fn list_namespaces_with_prefix() {
270 let store = InMemoryStore::new();
271 store.put(&["a", "b"], "k1", json!(1)).await.unwrap();
272 store.put(&["a", "c"], "k2", json!(2)).await.unwrap();
273 store.put(&["x", "y"], "k3", json!(3)).await.unwrap();
274
275 let all = store.list_namespaces(&[]).await.unwrap();
276 assert_eq!(all.len(), 3);
277
278 let filtered = store.list_namespaces(&["a"]).await.unwrap();
279 assert_eq!(filtered.len(), 2);
280 }
281
282 #[tokio::test]
283 async fn upsert_preserves_created_at() {
284 let store = InMemoryStore::new();
285 store.put(&["ns"], "k", json!(1)).await.unwrap();
286 let first = store.get(&["ns"], "k").await.unwrap().unwrap();
287
288 store.put(&["ns"], "k", json!(2)).await.unwrap();
289 let second = store.get(&["ns"], "k").await.unwrap().unwrap();
290
291 assert_eq!(first.created_at, second.created_at);
292 assert_eq!(second.value, json!(2));
293 }
294
295 struct TestEmbeddings;
297
298 #[async_trait]
299 impl Embeddings for TestEmbeddings {
300 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
301 Ok(texts.iter().map(|t| text_to_vec(t)).collect())
302 }
303
304 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
305 Ok(text_to_vec(text))
306 }
307 }
308
309 fn text_to_vec(text: &str) -> Vec<f32> {
311 let bytes = text.as_bytes();
312 let mut v = vec![0.0f32; 4];
313 for (i, b) in bytes.iter().enumerate() {
314 v[i % 4] += *b as f32;
315 }
316 v
317 }
318
319 #[tokio::test]
320 async fn semantic_search_ranked_by_similarity() {
321 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
322
323 store
324 .put(&["docs"], "a", json!("rust programming"))
325 .await
326 .unwrap();
327 store
328 .put(&["docs"], "b", json!("python programming"))
329 .await
330 .unwrap();
331 store
332 .put(&["docs"], "c", json!("cooking recipes"))
333 .await
334 .unwrap();
335
336 let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
338 assert_eq!(results.len(), 3);
339
340 for item in &results {
342 assert!(item.score.is_some());
343 }
344
345 let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
347 for w in scores.windows(2) {
348 assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
349 }
350 }
351
352 #[tokio::test]
353 async fn semantic_search_respects_limit() {
354 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
355
356 for i in 0..5 {
357 store
358 .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
359 .await
360 .unwrap();
361 }
362
363 let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
364 assert_eq!(results.len(), 2);
365 }
366
367 #[tokio::test]
368 async fn delete_cleans_up_embeddings() {
369 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
370
371 store.put(&["ns"], "k", json!("hello")).await.unwrap();
372 assert!(!store.vectors.read().await.is_empty());
373
374 store.delete(&["ns"], "k").await.unwrap();
375 assert!(store.vectors.read().await.is_empty());
376 }
377}