Skip to main content

synaptic_store/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8// Re-export Store trait and Item from core
9pub use synaptic_core::{Embeddings, Item, Store, SynapticError};
10
11fn namespace_key(namespace: &[&str]) -> String {
12    namespace.join("::")
13}
14
15fn now_iso() -> String {
16    // Simple timestamp without external chrono dependency
17    format!("{:?}", std::time::SystemTime::now())
18}
19
20/// Cosine similarity between two vectors.
21fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
22    if a.len() != b.len() || a.is_empty() {
23        return 0.0;
24    }
25    let mut dot = 0.0f64;
26    let mut norm_a = 0.0f64;
27    let mut norm_b = 0.0f64;
28    for (x, y) in a.iter().zip(b.iter()) {
29        let x = *x as f64;
30        let y = *y as f64;
31        dot += x * y;
32        norm_a += x * x;
33        norm_b += y * y;
34    }
35    let denom = norm_a.sqrt() * norm_b.sqrt();
36    if denom == 0.0 {
37        0.0
38    } else {
39        dot / denom
40    }
41}
42
43/// Thread-safe in-memory implementation of `Store`.
44///
45/// Supports optional embedding-based semantic search via [`with_embeddings`](InMemoryStore::with_embeddings).
46pub struct InMemoryStore {
47    data: Arc<RwLock<HashMap<String, HashMap<String, Item>>>>,
48    /// Optional embeddings model for semantic search.
49    embeddings: Option<Arc<dyn Embeddings>>,
50    /// Pre-computed embedding vectors, keyed by `namespace_key::item_key`.
51    vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
52}
53
54impl Default for InMemoryStore {
55    fn default() -> Self {
56        Self {
57            data: Arc::new(RwLock::new(HashMap::new())),
58            embeddings: None,
59            vectors: Arc::new(RwLock::new(HashMap::new())),
60        }
61    }
62}
63
64impl InMemoryStore {
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    /// Enable embedding-based semantic search.
70    ///
71    /// When configured, [`Store::search()`] with a query will use embedding
72    /// similarity instead of substring matching. Items are ranked by cosine
73    /// similarity and `Item::score` is populated.
74    pub fn with_embeddings(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
75        self.embeddings = Some(embeddings);
76        self
77    }
78}
79
80#[async_trait]
81impl Store for InMemoryStore {
82    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
83        let data = self.data.read().await;
84        let ns_key = namespace_key(namespace);
85        Ok(data.get(&ns_key).and_then(|ns| ns.get(key).cloned()))
86    }
87
88    async fn search(
89        &self,
90        namespace: &[&str],
91        query: Option<&str>,
92        limit: usize,
93    ) -> Result<Vec<Item>, SynapticError> {
94        let data = self.data.read().await;
95        let ns_key = namespace_key(namespace);
96
97        let Some(ns) = data.get(&ns_key) else {
98            return Ok(vec![]);
99        };
100
101        // If embeddings are configured and a query is provided, use semantic search
102        if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
103            let query_vec = embeddings.embed_query(q).await?;
104            let vectors = self.vectors.read().await;
105
106            let mut scored: Vec<(Item, f64)> = ns
107                .iter()
108                .map(|(key, item)| {
109                    let vec_key = format!("{}::{}", ns_key, key);
110                    let score = vectors
111                        .get(&vec_key)
112                        .map(|v| cosine_similarity(v, &query_vec))
113                        .unwrap_or(0.0);
114                    let mut item = item.clone();
115                    item.score = Some(score);
116                    (item, score)
117                })
118                .collect();
119
120            // Sort by score descending
121            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
122            scored.truncate(limit);
123
124            return Ok(scored.into_iter().map(|(item, _)| item).collect());
125        }
126
127        // Fallback: substring search
128        let items: Vec<Item> = ns
129            .values()
130            .filter(|item| {
131                if let Some(q) = query {
132                    // Simple substring search in key and value
133                    item.key.contains(q) || item.value.to_string().contains(q)
134                } else {
135                    true
136                }
137            })
138            .take(limit)
139            .cloned()
140            .collect();
141
142        Ok(items)
143    }
144
145    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
146        let mut data = self.data.write().await;
147        let ns_key = namespace_key(namespace);
148        let ns = data.entry(ns_key.clone()).or_default();
149        let now = now_iso();
150
151        let item = if let Some(existing) = ns.get(key) {
152            Item {
153                namespace: namespace.iter().map(|s| s.to_string()).collect(),
154                key: key.to_string(),
155                value: value.clone(),
156                created_at: existing.created_at.clone(),
157                updated_at: now,
158                score: None,
159            }
160        } else {
161            Item {
162                namespace: namespace.iter().map(|s| s.to_string()).collect(),
163                key: key.to_string(),
164                value: value.clone(),
165                created_at: now.clone(),
166                updated_at: now,
167                score: None,
168            }
169        };
170
171        ns.insert(key.to_string(), item);
172
173        // If embeddings are configured, compute and store the embedding
174        if let Some(ref embeddings) = self.embeddings {
175            let text = value.as_str().unwrap_or(&value.to_string()).to_string();
176            let vecs = embeddings.embed_documents(&[&text]).await?;
177            if let Some(vec) = vecs.into_iter().next() {
178                let vec_key = format!("{}::{}", ns_key, key);
179                self.vectors.write().await.insert(vec_key, vec);
180            }
181        }
182
183        Ok(())
184    }
185
186    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
187        let mut data = self.data.write().await;
188        let ns_key = namespace_key(namespace);
189        if let Some(ns) = data.get_mut(&ns_key) {
190            ns.remove(key);
191        }
192        // Clean up embedding vector
193        let vec_key = format!("{}::{}", ns_key, key);
194        self.vectors.write().await.remove(&vec_key);
195        Ok(())
196    }
197
198    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
199        let data = self.data.read().await;
200        let prefix_str = if prefix.is_empty() {
201            String::new()
202        } else {
203            namespace_key(prefix)
204        };
205
206        let namespaces: Vec<Vec<String>> = data
207            .keys()
208            .filter(|k| prefix.is_empty() || k.starts_with(&prefix_str))
209            .map(|k| k.split("::").map(String::from).collect())
210            .collect();
211
212        Ok(namespaces)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use serde_json::json;
220
221    #[tokio::test]
222    async fn put_and_get() {
223        let store = InMemoryStore::new();
224        store
225            .put(&["users", "prefs"], "theme", json!("dark"))
226            .await
227            .unwrap();
228
229        let item = store
230            .get(&["users", "prefs"], "theme")
231            .await
232            .unwrap()
233            .unwrap();
234        assert_eq!(item.key, "theme");
235        assert_eq!(item.value, json!("dark"));
236        assert_eq!(item.namespace, vec!["users", "prefs"]);
237    }
238
239    #[tokio::test]
240    async fn get_nonexistent() {
241        let store = InMemoryStore::new();
242        let item = store.get(&["a"], "missing").await.unwrap();
243        assert!(item.is_none());
244    }
245
246    #[tokio::test]
247    async fn delete_item() {
248        let store = InMemoryStore::new();
249        store.put(&["ns"], "k", json!(1)).await.unwrap();
250        store.delete(&["ns"], "k").await.unwrap();
251        assert!(store.get(&["ns"], "k").await.unwrap().is_none());
252    }
253
254    #[tokio::test]
255    async fn search_items() {
256        let store = InMemoryStore::new();
257        store.put(&["ns"], "a", json!("apple")).await.unwrap();
258        store.put(&["ns"], "b", json!("banana")).await.unwrap();
259        store.put(&["ns"], "c", json!("cherry")).await.unwrap();
260
261        let all = store.search(&["ns"], None, 10).await.unwrap();
262        assert_eq!(all.len(), 3);
263
264        let filtered = store.search(&["ns"], Some("apple"), 10).await.unwrap();
265        assert_eq!(filtered.len(), 1);
266    }
267
268    #[tokio::test]
269    async fn list_namespaces_with_prefix() {
270        let store = InMemoryStore::new();
271        store.put(&["a", "b"], "k1", json!(1)).await.unwrap();
272        store.put(&["a", "c"], "k2", json!(2)).await.unwrap();
273        store.put(&["x", "y"], "k3", json!(3)).await.unwrap();
274
275        let all = store.list_namespaces(&[]).await.unwrap();
276        assert_eq!(all.len(), 3);
277
278        let filtered = store.list_namespaces(&["a"]).await.unwrap();
279        assert_eq!(filtered.len(), 2);
280    }
281
282    #[tokio::test]
283    async fn upsert_preserves_created_at() {
284        let store = InMemoryStore::new();
285        store.put(&["ns"], "k", json!(1)).await.unwrap();
286        let first = store.get(&["ns"], "k").await.unwrap().unwrap();
287
288        store.put(&["ns"], "k", json!(2)).await.unwrap();
289        let second = store.get(&["ns"], "k").await.unwrap().unwrap();
290
291        assert_eq!(first.created_at, second.created_at);
292        assert_eq!(second.value, json!(2));
293    }
294
295    /// Simple deterministic embeddings for testing semantic search.
296    struct TestEmbeddings;
297
298    #[async_trait]
299    impl Embeddings for TestEmbeddings {
300        async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
301            Ok(texts.iter().map(|t| text_to_vec(t)).collect())
302        }
303
304        async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
305            Ok(text_to_vec(text))
306        }
307    }
308
309    /// Simple deterministic vector: sum of byte values in 4 dimensions.
310    fn text_to_vec(text: &str) -> Vec<f32> {
311        let bytes = text.as_bytes();
312        let mut v = vec![0.0f32; 4];
313        for (i, b) in bytes.iter().enumerate() {
314            v[i % 4] += *b as f32;
315        }
316        v
317    }
318
319    #[tokio::test]
320    async fn semantic_search_ranked_by_similarity() {
321        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
322
323        store
324            .put(&["docs"], "a", json!("rust programming"))
325            .await
326            .unwrap();
327        store
328            .put(&["docs"], "b", json!("python programming"))
329            .await
330            .unwrap();
331        store
332            .put(&["docs"], "c", json!("cooking recipes"))
333            .await
334            .unwrap();
335
336        // Search for "rust" — should rank "rust programming" highest
337        let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
338        assert_eq!(results.len(), 3);
339
340        // All should have scores populated
341        for item in &results {
342            assert!(item.score.is_some());
343        }
344
345        // Scores should be descending
346        let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
347        for w in scores.windows(2) {
348            assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
349        }
350    }
351
352    #[tokio::test]
353    async fn semantic_search_respects_limit() {
354        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
355
356        for i in 0..5 {
357            store
358                .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
359                .await
360                .unwrap();
361        }
362
363        let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
364        assert_eq!(results.len(), 2);
365    }
366
367    #[tokio::test]
368    async fn delete_cleans_up_embeddings() {
369        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
370
371        store.put(&["ns"], "k", json!("hello")).await.unwrap();
372        assert!(!store.vectors.read().await.is_empty());
373
374        store.delete(&["ns"], "k").await.unwrap();
375        assert!(store.vectors.read().await.is_empty());
376    }
377}