Skip to main content

synaptic_store/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8// Re-export Store trait and Item from core
9pub use synaptic_core::{now_iso, Embeddings, Item, Store, SynapticError};
10
11fn namespace_key(namespace: &[&str]) -> String {
12    namespace.join("::")
13}
14
15/// Cosine similarity between two vectors.
16fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
17    if a.len() != b.len() || a.is_empty() {
18        return 0.0;
19    }
20    let mut dot = 0.0f64;
21    let mut norm_a = 0.0f64;
22    let mut norm_b = 0.0f64;
23    for (x, y) in a.iter().zip(b.iter()) {
24        let x = *x as f64;
25        let y = *y as f64;
26        dot += x * y;
27        norm_a += x * x;
28        norm_b += y * y;
29    }
30    let denom = norm_a.sqrt() * norm_b.sqrt();
31    if denom == 0.0 {
32        0.0
33    } else {
34        dot / denom
35    }
36}
37
38/// Thread-safe in-memory implementation of `Store`.
39///
40/// Supports optional embedding-based semantic search via [`with_embeddings`](InMemoryStore::with_embeddings).
41pub struct InMemoryStore {
42    data: Arc<RwLock<HashMap<String, HashMap<String, Item>>>>,
43    /// Optional embeddings model for semantic search.
44    embeddings: Option<Arc<dyn Embeddings>>,
45    /// Pre-computed embedding vectors, keyed by `namespace_key::item_key`.
46    vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
47}
48
49impl Default for InMemoryStore {
50    fn default() -> Self {
51        Self {
52            data: Arc::new(RwLock::new(HashMap::new())),
53            embeddings: None,
54            vectors: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57}
58
59impl InMemoryStore {
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// Enable embedding-based semantic search.
65    ///
66    /// When configured, [`Store::search()`] with a query will use embedding
67    /// similarity instead of substring matching. Items are ranked by cosine
68    /// similarity and `Item::score` is populated.
69    pub fn with_embeddings(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
70        self.embeddings = Some(embeddings);
71        self
72    }
73}
74
75#[async_trait]
76impl Store for InMemoryStore {
77    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
78        let data = self.data.read().await;
79        let ns_key = namespace_key(namespace);
80        Ok(data.get(&ns_key).and_then(|ns| ns.get(key).cloned()))
81    }
82
83    async fn search(
84        &self,
85        namespace: &[&str],
86        query: Option<&str>,
87        limit: usize,
88    ) -> Result<Vec<Item>, SynapticError> {
89        let data = self.data.read().await;
90        let ns_key = namespace_key(namespace);
91
92        let Some(ns) = data.get(&ns_key) else {
93            return Ok(vec![]);
94        };
95
96        // If embeddings are configured and a query is provided, use semantic search
97        if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
98            let query_vec = embeddings.embed_query(q).await?;
99            let vectors = self.vectors.read().await;
100
101            let mut scored: Vec<(Item, f64)> = ns
102                .iter()
103                .map(|(key, item)| {
104                    let vec_key = format!("{}::{}", ns_key, key);
105                    let score = vectors
106                        .get(&vec_key)
107                        .map(|v| cosine_similarity(v, &query_vec))
108                        .unwrap_or(0.0);
109                    let mut item = item.clone();
110                    item.score = Some(score);
111                    (item, score)
112                })
113                .collect();
114
115            // Sort by score descending
116            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
117            scored.truncate(limit);
118
119            return Ok(scored.into_iter().map(|(item, _)| item).collect());
120        }
121
122        // Fallback: substring search
123        let items: Vec<Item> = ns
124            .values()
125            .filter(|item| {
126                if let Some(q) = query {
127                    // Simple substring search in key and value
128                    item.key.contains(q) || item.value.to_string().contains(q)
129                } else {
130                    true
131                }
132            })
133            .take(limit)
134            .cloned()
135            .collect();
136
137        Ok(items)
138    }
139
140    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
141        let mut data = self.data.write().await;
142        let ns_key = namespace_key(namespace);
143        let ns = data.entry(ns_key.clone()).or_default();
144        let now = now_iso();
145
146        let item = if let Some(existing) = ns.get(key) {
147            Item {
148                namespace: namespace.iter().map(|s| s.to_string()).collect(),
149                key: key.to_string(),
150                value: value.clone(),
151                created_at: existing.created_at.clone(),
152                updated_at: now,
153                score: None,
154            }
155        } else {
156            Item {
157                namespace: namespace.iter().map(|s| s.to_string()).collect(),
158                key: key.to_string(),
159                value: value.clone(),
160                created_at: now.clone(),
161                updated_at: now,
162                score: None,
163            }
164        };
165
166        ns.insert(key.to_string(), item);
167
168        // If embeddings are configured, compute and store the embedding
169        if let Some(ref embeddings) = self.embeddings {
170            let text = value.as_str().unwrap_or(&value.to_string()).to_string();
171            let vecs = embeddings.embed_documents(&[&text]).await?;
172            if let Some(vec) = vecs.into_iter().next() {
173                let vec_key = format!("{}::{}", ns_key, key);
174                self.vectors.write().await.insert(vec_key, vec);
175            }
176        }
177
178        Ok(())
179    }
180
181    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
182        let mut data = self.data.write().await;
183        let ns_key = namespace_key(namespace);
184        if let Some(ns) = data.get_mut(&ns_key) {
185            ns.remove(key);
186        }
187        // Clean up embedding vector
188        let vec_key = format!("{}::{}", ns_key, key);
189        self.vectors.write().await.remove(&vec_key);
190        Ok(())
191    }
192
193    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
194        let data = self.data.read().await;
195        let prefix_str = if prefix.is_empty() {
196            String::new()
197        } else {
198            namespace_key(prefix)
199        };
200
201        let namespaces: Vec<Vec<String>> = data
202            .keys()
203            .filter(|k| prefix.is_empty() || k.starts_with(&prefix_str))
204            .map(|k| k.split("::").map(String::from).collect())
205            .collect();
206
207        Ok(namespaces)
208    }
209}
210
211#[cfg(feature = "filesystem")]
212mod file_store;
213#[cfg(feature = "filesystem")]
214pub use file_store::FileStore;
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use serde_json::json;
220
221    #[tokio::test]
222    async fn put_and_get() {
223        let store = InMemoryStore::new();
224        store
225            .put(&["users", "prefs"], "theme", json!("dark"))
226            .await
227            .unwrap();
228
229        let item = store
230            .get(&["users", "prefs"], "theme")
231            .await
232            .unwrap()
233            .unwrap();
234        assert_eq!(item.key, "theme");
235        assert_eq!(item.value, json!("dark"));
236        assert_eq!(item.namespace, vec!["users", "prefs"]);
237    }
238
239    #[tokio::test]
240    async fn get_nonexistent() {
241        let store = InMemoryStore::new();
242        let item = store.get(&["a"], "missing").await.unwrap();
243        assert!(item.is_none());
244    }
245
246    #[tokio::test]
247    async fn delete_item() {
248        let store = InMemoryStore::new();
249        store.put(&["ns"], "k", json!(1)).await.unwrap();
250        store.delete(&["ns"], "k").await.unwrap();
251        assert!(store.get(&["ns"], "k").await.unwrap().is_none());
252    }
253
254    #[tokio::test]
255    async fn search_items() {
256        let store = InMemoryStore::new();
257        store.put(&["ns"], "a", json!("apple")).await.unwrap();
258        store.put(&["ns"], "b", json!("banana")).await.unwrap();
259        store.put(&["ns"], "c", json!("cherry")).await.unwrap();
260
261        let all = store.search(&["ns"], None, 10).await.unwrap();
262        assert_eq!(all.len(), 3);
263
264        let filtered = store.search(&["ns"], Some("apple"), 10).await.unwrap();
265        assert_eq!(filtered.len(), 1);
266    }
267
268    #[tokio::test]
269    async fn list_namespaces_with_prefix() {
270        let store = InMemoryStore::new();
271        store.put(&["a", "b"], "k1", json!(1)).await.unwrap();
272        store.put(&["a", "c"], "k2", json!(2)).await.unwrap();
273        store.put(&["x", "y"], "k3", json!(3)).await.unwrap();
274
275        let all = store.list_namespaces(&[]).await.unwrap();
276        assert_eq!(all.len(), 3);
277
278        let filtered = store.list_namespaces(&["a"]).await.unwrap();
279        assert_eq!(filtered.len(), 2);
280    }
281
282    #[tokio::test]
283    async fn upsert_preserves_created_at() {
284        let store = InMemoryStore::new();
285        store.put(&["ns"], "k", json!(1)).await.unwrap();
286        let first = store.get(&["ns"], "k").await.unwrap().unwrap();
287
288        store.put(&["ns"], "k", json!(2)).await.unwrap();
289        let second = store.get(&["ns"], "k").await.unwrap().unwrap();
290
291        assert_eq!(first.created_at, second.created_at);
292        assert_eq!(second.value, json!(2));
293    }
294
295    /// Simple deterministic embeddings for testing semantic search.
296    struct TestEmbeddings;
297
298    #[async_trait]
299    impl Embeddings for TestEmbeddings {
300        async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
301            Ok(texts.iter().map(|t| text_to_vec(t)).collect())
302        }
303
304        async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
305            Ok(text_to_vec(text))
306        }
307    }
308
309    /// Simple deterministic vector: sum of byte values in 4 dimensions.
310    fn text_to_vec(text: &str) -> Vec<f32> {
311        let bytes = text.as_bytes();
312        let mut v = vec![0.0f32; 4];
313        for (i, b) in bytes.iter().enumerate() {
314            v[i % 4] += *b as f32;
315        }
316        v
317    }
318
319    #[tokio::test]
320    async fn semantic_search_ranked_by_similarity() {
321        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
322
323        store
324            .put(&["docs"], "a", json!("rust programming"))
325            .await
326            .unwrap();
327        store
328            .put(&["docs"], "b", json!("python programming"))
329            .await
330            .unwrap();
331        store
332            .put(&["docs"], "c", json!("cooking recipes"))
333            .await
334            .unwrap();
335
336        // Search for "rust" — should rank "rust programming" highest
337        let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
338        assert_eq!(results.len(), 3);
339
340        // All should have scores populated
341        for item in &results {
342            assert!(item.score.is_some());
343        }
344
345        // Scores should be descending
346        let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
347        for w in scores.windows(2) {
348            assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
349        }
350    }
351
352    #[tokio::test]
353    async fn semantic_search_respects_limit() {
354        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
355
356        for i in 0..5 {
357            store
358                .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
359                .await
360                .unwrap();
361        }
362
363        let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
364        assert_eq!(results.len(), 2);
365    }
366
367    #[tokio::test]
368    async fn delete_cleans_up_embeddings() {
369        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
370
371        store.put(&["ns"], "k", json!("hello")).await.unwrap();
372        assert!(!store.vectors.read().await.is_empty());
373
374        store.delete(&["ns"], "k").await.unwrap();
375        assert!(store.vectors.read().await.is_empty());
376    }
377}