Skip to main content

synaptic_store/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8// Re-export Store trait and Item from core
9pub use synaptic_core::{now_iso, Embeddings, Item, Store, SynapticError};
10
11fn namespace_key(namespace: &[&str]) -> String {
12    namespace.join("::")
13}
14
15/// BM25 scoring for a single query against a document.
16fn bm25_score(query: &str, document: &str, avg_doc_len: f64, total_docs: usize) -> f64 {
17    let k1 = 1.2;
18    let b = 0.75;
19    let query_terms: Vec<&str> = query.split_whitespace().collect();
20    let doc_terms: Vec<&str> = document.split_whitespace().collect();
21    let doc_len = doc_terms.len() as f64;
22
23    let mut score = 0.0;
24    for qt in &query_terms {
25        let qt_lower = qt.to_lowercase();
26        let tf = doc_terms
27            .iter()
28            .filter(|dt| dt.to_lowercase() == qt_lower)
29            .count() as f64;
30        if tf == 0.0 {
31            continue;
32        }
33        let idf = ((total_docs as f64 - 1.0 + 0.5) / (1.0 + 0.5) + 1.0).ln();
34        let numerator = tf * (k1 + 1.0);
35        let denominator = tf + k1 * (1.0 - b + b * doc_len / avg_doc_len.max(1.0));
36        score += idf * numerator / denominator;
37    }
38    score
39}
40
41/// Cosine similarity between two vectors.
42fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
43    if a.len() != b.len() || a.is_empty() {
44        return 0.0;
45    }
46    let mut dot = 0.0f64;
47    let mut norm_a = 0.0f64;
48    let mut norm_b = 0.0f64;
49    for (x, y) in a.iter().zip(b.iter()) {
50        let x = *x as f64;
51        let y = *y as f64;
52        dot += x * y;
53        norm_a += x * x;
54        norm_b += y * y;
55    }
56    let denom = norm_a.sqrt() * norm_b.sqrt();
57    if denom == 0.0 {
58        0.0
59    } else {
60        dot / denom
61    }
62}
63
64/// Thread-safe in-memory implementation of `Store`.
65///
66/// Supports optional embedding-based semantic search via [`with_embeddings`](InMemoryStore::with_embeddings).
67pub struct InMemoryStore {
68    data: Arc<RwLock<HashMap<String, HashMap<String, Item>>>>,
69    /// Optional embeddings model for semantic search.
70    embeddings: Option<Arc<dyn Embeddings>>,
71    /// Pre-computed embedding vectors, keyed by `namespace_key::item_key`.
72    vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
73    /// When true, search uses hybrid BM25 + embedding via Reciprocal Rank Fusion.
74    hybrid: bool,
75}
76
77impl Default for InMemoryStore {
78    fn default() -> Self {
79        Self {
80            data: Arc::new(RwLock::new(HashMap::new())),
81            embeddings: None,
82            vectors: Arc::new(RwLock::new(HashMap::new())),
83            hybrid: false,
84        }
85    }
86}
87
88impl InMemoryStore {
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Enable embedding-based semantic search.
94    ///
95    /// When configured, [`Store::search()`] with a query will use embedding
96    /// similarity instead of substring matching. Items are ranked by cosine
97    /// similarity and `Item::score` is populated.
98    pub fn with_embeddings(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
99        self.embeddings = Some(embeddings);
100        self
101    }
102
103    /// Enable hybrid search combining BM25 text scoring and embedding similarity
104    /// via Reciprocal Rank Fusion (RRF).
105    ///
106    /// When configured, [`Store::search()`] with a query will:
107    /// 1. Compute cosine similarity scores using embeddings
108    /// 2. Compute BM25 text relevance scores
109    /// 3. Fuse rankings via RRF: `score = 1/(60+embed_rank) + 1/(60+bm25_rank)`
110    pub fn with_hybrid_search(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
111        self.embeddings = Some(embeddings);
112        self.hybrid = true;
113        self
114    }
115}
116
117#[async_trait]
118impl Store for InMemoryStore {
119    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
120        let data = self.data.read().await;
121        let ns_key = namespace_key(namespace);
122        Ok(data.get(&ns_key).and_then(|ns| ns.get(key).cloned()))
123    }
124
125    async fn search(
126        &self,
127        namespace: &[&str],
128        query: Option<&str>,
129        limit: usize,
130    ) -> Result<Vec<Item>, SynapticError> {
131        let data = self.data.read().await;
132        let ns_key = namespace_key(namespace);
133
134        let Some(ns) = data.get(&ns_key) else {
135            return Ok(vec![]);
136        };
137
138        // Hybrid search: BM25 + embedding similarity fused via RRF
139        if self.hybrid {
140            if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
141                let query_vec = embeddings.embed_query(q).await?;
142                let vectors = self.vectors.read().await;
143
144                // Compute embedding scores
145                let mut embed_scored: Vec<(&String, f64)> = ns
146                    .keys()
147                    .map(|key| {
148                        let vec_key = format!("{}::{}", ns_key, key);
149                        let score = vectors
150                            .get(&vec_key)
151                            .map(|v| cosine_similarity(v, &query_vec))
152                            .unwrap_or(0.0);
153                        (key, score)
154                    })
155                    .collect();
156                embed_scored
157                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
158
159                // Build embed rank map (rank 0 = best)
160                let embed_rank: HashMap<&String, usize> = embed_scored
161                    .iter()
162                    .enumerate()
163                    .map(|(rank, (key, _))| (*key, rank))
164                    .collect();
165
166                // Compute BM25 scores
167                let total_docs = ns.len();
168                let avg_doc_len = if total_docs > 0 {
169                    ns.values()
170                        .map(|item| {
171                            let text = item
172                                .value
173                                .as_str()
174                                .unwrap_or(&item.value.to_string())
175                                .to_string();
176                            text.split_whitespace().count() as f64
177                        })
178                        .sum::<f64>()
179                        / total_docs as f64
180                } else {
181                    1.0
182                };
183
184                let mut bm25_scored: Vec<(&String, f64)> = ns
185                    .iter()
186                    .map(|(key, item)| {
187                        let text = item
188                            .value
189                            .as_str()
190                            .unwrap_or(&item.value.to_string())
191                            .to_string();
192                        let score = bm25_score(q, &text, avg_doc_len, total_docs);
193                        (key, score)
194                    })
195                    .collect();
196                bm25_scored
197                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199                // Build BM25 rank map
200                let bm25_rank: HashMap<&String, usize> = bm25_scored
201                    .iter()
202                    .enumerate()
203                    .map(|(rank, (key, _))| (*key, rank))
204                    .collect();
205
206                // Reciprocal Rank Fusion
207                let k = 60.0;
208                let mut fused: Vec<(Item, f64)> = ns
209                    .iter()
210                    .map(|(key, item)| {
211                        let e_rank = embed_rank.get(key).copied().unwrap_or(ns.len()) as f64;
212                        let b_rank = bm25_rank.get(key).copied().unwrap_or(ns.len()) as f64;
213                        let fused_score = 1.0 / (k + e_rank) + 1.0 / (k + b_rank);
214                        let mut item = item.clone();
215                        item.score = Some(fused_score);
216                        (item, fused_score)
217                    })
218                    .collect();
219
220                fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221                fused.truncate(limit);
222
223                return Ok(fused.into_iter().map(|(item, _)| item).collect());
224            }
225        }
226
227        // If embeddings are configured (non-hybrid) and a query is provided, use semantic search
228        if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
229            let query_vec = embeddings.embed_query(q).await?;
230            let vectors = self.vectors.read().await;
231
232            let mut scored: Vec<(Item, f64)> = ns
233                .iter()
234                .map(|(key, item)| {
235                    let vec_key = format!("{}::{}", ns_key, key);
236                    let score = vectors
237                        .get(&vec_key)
238                        .map(|v| cosine_similarity(v, &query_vec))
239                        .unwrap_or(0.0);
240                    let mut item = item.clone();
241                    item.score = Some(score);
242                    (item, score)
243                })
244                .collect();
245
246            // Sort by score descending
247            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
248            scored.truncate(limit);
249
250            return Ok(scored.into_iter().map(|(item, _)| item).collect());
251        }
252
253        // Fallback: substring search
254        let items: Vec<Item> = ns
255            .values()
256            .filter(|item| {
257                if let Some(q) = query {
258                    // Simple substring search in key and value
259                    item.key.contains(q) || item.value.to_string().contains(q)
260                } else {
261                    true
262                }
263            })
264            .take(limit)
265            .cloned()
266            .collect();
267
268        Ok(items)
269    }
270
271    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
272        let mut data = self.data.write().await;
273        let ns_key = namespace_key(namespace);
274        let ns = data.entry(ns_key.clone()).or_default();
275        let now = now_iso();
276
277        let item = if let Some(existing) = ns.get(key) {
278            Item {
279                namespace: namespace.iter().map(|s| s.to_string()).collect(),
280                key: key.to_string(),
281                value: value.clone(),
282                created_at: existing.created_at.clone(),
283                updated_at: now,
284                score: None,
285            }
286        } else {
287            Item {
288                namespace: namespace.iter().map(|s| s.to_string()).collect(),
289                key: key.to_string(),
290                value: value.clone(),
291                created_at: now.clone(),
292                updated_at: now,
293                score: None,
294            }
295        };
296
297        ns.insert(key.to_string(), item);
298
299        // If embeddings are configured, compute and store the embedding
300        if let Some(ref embeddings) = self.embeddings {
301            let text = value.as_str().unwrap_or(&value.to_string()).to_string();
302            let vecs = embeddings.embed_documents(&[&text]).await?;
303            if let Some(vec) = vecs.into_iter().next() {
304                let vec_key = format!("{}::{}", ns_key, key);
305                self.vectors.write().await.insert(vec_key, vec);
306            }
307        }
308
309        Ok(())
310    }
311
312    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
313        let mut data = self.data.write().await;
314        let ns_key = namespace_key(namespace);
315        if let Some(ns) = data.get_mut(&ns_key) {
316            ns.remove(key);
317        }
318        // Clean up embedding vector
319        let vec_key = format!("{}::{}", ns_key, key);
320        self.vectors.write().await.remove(&vec_key);
321        Ok(())
322    }
323
324    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
325        let data = self.data.read().await;
326        let prefix_str = if prefix.is_empty() {
327            String::new()
328        } else {
329            namespace_key(prefix)
330        };
331
332        let namespaces: Vec<Vec<String>> = data
333            .keys()
334            .filter(|k| prefix.is_empty() || k.starts_with(&prefix_str))
335            .map(|k| k.split("::").map(String::from).collect())
336            .collect();
337
338        Ok(namespaces)
339    }
340}
341
342#[cfg(feature = "filesystem")]
343mod file_store;
344#[cfg(feature = "filesystem")]
345pub use file_store::FileStore;
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use serde_json::json;
351
352    #[tokio::test]
353    async fn put_and_get() {
354        let store = InMemoryStore::new();
355        store
356            .put(&["users", "prefs"], "theme", json!("dark"))
357            .await
358            .unwrap();
359
360        let item = store
361            .get(&["users", "prefs"], "theme")
362            .await
363            .unwrap()
364            .unwrap();
365        assert_eq!(item.key, "theme");
366        assert_eq!(item.value, json!("dark"));
367        assert_eq!(item.namespace, vec!["users", "prefs"]);
368    }
369
370    #[tokio::test]
371    async fn get_nonexistent() {
372        let store = InMemoryStore::new();
373        let item = store.get(&["a"], "missing").await.unwrap();
374        assert!(item.is_none());
375    }
376
377    #[tokio::test]
378    async fn delete_item() {
379        let store = InMemoryStore::new();
380        store.put(&["ns"], "k", json!(1)).await.unwrap();
381        store.delete(&["ns"], "k").await.unwrap();
382        assert!(store.get(&["ns"], "k").await.unwrap().is_none());
383    }
384
385    #[tokio::test]
386    async fn search_items() {
387        let store = InMemoryStore::new();
388        store.put(&["ns"], "a", json!("apple")).await.unwrap();
389        store.put(&["ns"], "b", json!("banana")).await.unwrap();
390        store.put(&["ns"], "c", json!("cherry")).await.unwrap();
391
392        let all = store.search(&["ns"], None, 10).await.unwrap();
393        assert_eq!(all.len(), 3);
394
395        let filtered = store.search(&["ns"], Some("apple"), 10).await.unwrap();
396        assert_eq!(filtered.len(), 1);
397    }
398
399    #[tokio::test]
400    async fn list_namespaces_with_prefix() {
401        let store = InMemoryStore::new();
402        store.put(&["a", "b"], "k1", json!(1)).await.unwrap();
403        store.put(&["a", "c"], "k2", json!(2)).await.unwrap();
404        store.put(&["x", "y"], "k3", json!(3)).await.unwrap();
405
406        let all = store.list_namespaces(&[]).await.unwrap();
407        assert_eq!(all.len(), 3);
408
409        let filtered = store.list_namespaces(&["a"]).await.unwrap();
410        assert_eq!(filtered.len(), 2);
411    }
412
413    #[tokio::test]
414    async fn upsert_preserves_created_at() {
415        let store = InMemoryStore::new();
416        store.put(&["ns"], "k", json!(1)).await.unwrap();
417        let first = store.get(&["ns"], "k").await.unwrap().unwrap();
418
419        store.put(&["ns"], "k", json!(2)).await.unwrap();
420        let second = store.get(&["ns"], "k").await.unwrap().unwrap();
421
422        assert_eq!(first.created_at, second.created_at);
423        assert_eq!(second.value, json!(2));
424    }
425
426    /// Simple deterministic embeddings for testing semantic search.
427    struct TestEmbeddings;
428
429    #[async_trait]
430    impl Embeddings for TestEmbeddings {
431        async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
432            Ok(texts.iter().map(|t| text_to_vec(t)).collect())
433        }
434
435        async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
436            Ok(text_to_vec(text))
437        }
438    }
439
440    /// Simple deterministic vector: sum of byte values in 4 dimensions.
441    fn text_to_vec(text: &str) -> Vec<f32> {
442        let bytes = text.as_bytes();
443        let mut v = vec![0.0f32; 4];
444        for (i, b) in bytes.iter().enumerate() {
445            v[i % 4] += *b as f32;
446        }
447        v
448    }
449
450    #[tokio::test]
451    async fn semantic_search_ranked_by_similarity() {
452        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
453
454        store
455            .put(&["docs"], "a", json!("rust programming"))
456            .await
457            .unwrap();
458        store
459            .put(&["docs"], "b", json!("python programming"))
460            .await
461            .unwrap();
462        store
463            .put(&["docs"], "c", json!("cooking recipes"))
464            .await
465            .unwrap();
466
467        // Search for "rust" — should rank "rust programming" highest
468        let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
469        assert_eq!(results.len(), 3);
470
471        // All should have scores populated
472        for item in &results {
473            assert!(item.score.is_some());
474        }
475
476        // Scores should be descending
477        let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
478        for w in scores.windows(2) {
479            assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
480        }
481    }
482
483    #[tokio::test]
484    async fn semantic_search_respects_limit() {
485        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
486
487        for i in 0..5 {
488            store
489                .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
490                .await
491                .unwrap();
492        }
493
494        let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
495        assert_eq!(results.len(), 2);
496    }
497
498    #[tokio::test]
499    async fn delete_cleans_up_embeddings() {
500        let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
501
502        store.put(&["ns"], "k", json!("hello")).await.unwrap();
503        assert!(!store.vectors.read().await.is_empty());
504
505        store.delete(&["ns"], "k").await.unwrap();
506        assert!(store.vectors.read().await.is_empty());
507    }
508
509    #[tokio::test]
510    async fn hybrid_search_combines_bm25_and_embeddings() {
511        let store = InMemoryStore::new().with_hybrid_search(Arc::new(TestEmbeddings));
512        store
513            .put(&["docs"], "a", json!("rust programming language guide"))
514            .await
515            .unwrap();
516        store
517            .put(&["docs"], "b", json!("python web development"))
518            .await
519            .unwrap();
520        store
521            .put(&["docs"], "c", json!("rust cargo build system"))
522            .await
523            .unwrap();
524
525        let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
526        assert_eq!(results.len(), 3);
527        for item in &results {
528            assert!(item.score.is_some());
529        }
530        // Scores should be descending
531        let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
532        for w in scores.windows(2) {
533            assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
534        }
535    }
536
537    #[tokio::test]
538    async fn hybrid_search_respects_limit() {
539        let store = InMemoryStore::new().with_hybrid_search(Arc::new(TestEmbeddings));
540        for i in 0..5 {
541            store
542                .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
543                .await
544                .unwrap();
545        }
546        let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
547        assert_eq!(results.len(), 2);
548    }
549
550    #[test]
551    fn bm25_score_basic() {
552        let score = bm25_score("rust", "rust programming language", 3.0, 10);
553        assert!(score > 0.0);
554    }
555
556    #[test]
557    fn bm25_score_zero_for_no_match() {
558        let score = bm25_score("python", "rust programming language", 3.0, 10);
559        assert_eq!(score, 0.0);
560    }
561}