1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8pub use synaptic_core::{now_iso, Embeddings, Item, Store, SynapticError};
10
11fn namespace_key(namespace: &[&str]) -> String {
12 namespace.join("::")
13}
14
15fn bm25_score(query: &str, document: &str, avg_doc_len: f64, total_docs: usize) -> f64 {
17 let k1 = 1.2;
18 let b = 0.75;
19 let query_terms: Vec<&str> = query.split_whitespace().collect();
20 let doc_terms: Vec<&str> = document.split_whitespace().collect();
21 let doc_len = doc_terms.len() as f64;
22
23 let mut score = 0.0;
24 for qt in &query_terms {
25 let qt_lower = qt.to_lowercase();
26 let tf = doc_terms
27 .iter()
28 .filter(|dt| dt.to_lowercase() == qt_lower)
29 .count() as f64;
30 if tf == 0.0 {
31 continue;
32 }
33 let idf = ((total_docs as f64 - 1.0 + 0.5) / (1.0 + 0.5) + 1.0).ln();
34 let numerator = tf * (k1 + 1.0);
35 let denominator = tf + k1 * (1.0 - b + b * doc_len / avg_doc_len.max(1.0));
36 score += idf * numerator / denominator;
37 }
38 score
39}
40
41fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
43 if a.len() != b.len() || a.is_empty() {
44 return 0.0;
45 }
46 let mut dot = 0.0f64;
47 let mut norm_a = 0.0f64;
48 let mut norm_b = 0.0f64;
49 for (x, y) in a.iter().zip(b.iter()) {
50 let x = *x as f64;
51 let y = *y as f64;
52 dot += x * y;
53 norm_a += x * x;
54 norm_b += y * y;
55 }
56 let denom = norm_a.sqrt() * norm_b.sqrt();
57 if denom == 0.0 {
58 0.0
59 } else {
60 dot / denom
61 }
62}
63
64pub struct InMemoryStore {
68 data: Arc<RwLock<HashMap<String, HashMap<String, Item>>>>,
69 embeddings: Option<Arc<dyn Embeddings>>,
71 vectors: Arc<RwLock<HashMap<String, Vec<f32>>>>,
73 hybrid: bool,
75}
76
77impl Default for InMemoryStore {
78 fn default() -> Self {
79 Self {
80 data: Arc::new(RwLock::new(HashMap::new())),
81 embeddings: None,
82 vectors: Arc::new(RwLock::new(HashMap::new())),
83 hybrid: false,
84 }
85 }
86}
87
88impl InMemoryStore {
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn with_embeddings(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
99 self.embeddings = Some(embeddings);
100 self
101 }
102
103 pub fn with_hybrid_search(mut self, embeddings: Arc<dyn Embeddings>) -> Self {
111 self.embeddings = Some(embeddings);
112 self.hybrid = true;
113 self
114 }
115}
116
117#[async_trait]
118impl Store for InMemoryStore {
119 async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
120 let data = self.data.read().await;
121 let ns_key = namespace_key(namespace);
122 Ok(data.get(&ns_key).and_then(|ns| ns.get(key).cloned()))
123 }
124
125 async fn search(
126 &self,
127 namespace: &[&str],
128 query: Option<&str>,
129 limit: usize,
130 ) -> Result<Vec<Item>, SynapticError> {
131 let data = self.data.read().await;
132 let ns_key = namespace_key(namespace);
133
134 let Some(ns) = data.get(&ns_key) else {
135 return Ok(vec![]);
136 };
137
138 if self.hybrid {
140 if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
141 let query_vec = embeddings.embed_query(q).await?;
142 let vectors = self.vectors.read().await;
143
144 let mut embed_scored: Vec<(&String, f64)> = ns
146 .keys()
147 .map(|key| {
148 let vec_key = format!("{}::{}", ns_key, key);
149 let score = vectors
150 .get(&vec_key)
151 .map(|v| cosine_similarity(v, &query_vec))
152 .unwrap_or(0.0);
153 (key, score)
154 })
155 .collect();
156 embed_scored
157 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
158
159 let embed_rank: HashMap<&String, usize> = embed_scored
161 .iter()
162 .enumerate()
163 .map(|(rank, (key, _))| (*key, rank))
164 .collect();
165
166 let total_docs = ns.len();
168 let avg_doc_len = if total_docs > 0 {
169 ns.values()
170 .map(|item| {
171 let text = item
172 .value
173 .as_str()
174 .unwrap_or(&item.value.to_string())
175 .to_string();
176 text.split_whitespace().count() as f64
177 })
178 .sum::<f64>()
179 / total_docs as f64
180 } else {
181 1.0
182 };
183
184 let mut bm25_scored: Vec<(&String, f64)> = ns
185 .iter()
186 .map(|(key, item)| {
187 let text = item
188 .value
189 .as_str()
190 .unwrap_or(&item.value.to_string())
191 .to_string();
192 let score = bm25_score(q, &text, avg_doc_len, total_docs);
193 (key, score)
194 })
195 .collect();
196 bm25_scored
197 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199 let bm25_rank: HashMap<&String, usize> = bm25_scored
201 .iter()
202 .enumerate()
203 .map(|(rank, (key, _))| (*key, rank))
204 .collect();
205
206 let k = 60.0;
208 let mut fused: Vec<(Item, f64)> = ns
209 .iter()
210 .map(|(key, item)| {
211 let e_rank = embed_rank.get(key).copied().unwrap_or(ns.len()) as f64;
212 let b_rank = bm25_rank.get(key).copied().unwrap_or(ns.len()) as f64;
213 let fused_score = 1.0 / (k + e_rank) + 1.0 / (k + b_rank);
214 let mut item = item.clone();
215 item.score = Some(fused_score);
216 (item, fused_score)
217 })
218 .collect();
219
220 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221 fused.truncate(limit);
222
223 return Ok(fused.into_iter().map(|(item, _)| item).collect());
224 }
225 }
226
227 if let (Some(ref embeddings), Some(q)) = (&self.embeddings, query) {
229 let query_vec = embeddings.embed_query(q).await?;
230 let vectors = self.vectors.read().await;
231
232 let mut scored: Vec<(Item, f64)> = ns
233 .iter()
234 .map(|(key, item)| {
235 let vec_key = format!("{}::{}", ns_key, key);
236 let score = vectors
237 .get(&vec_key)
238 .map(|v| cosine_similarity(v, &query_vec))
239 .unwrap_or(0.0);
240 let mut item = item.clone();
241 item.score = Some(score);
242 (item, score)
243 })
244 .collect();
245
246 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
248 scored.truncate(limit);
249
250 return Ok(scored.into_iter().map(|(item, _)| item).collect());
251 }
252
253 let items: Vec<Item> = ns
255 .values()
256 .filter(|item| {
257 if let Some(q) = query {
258 item.key.contains(q) || item.value.to_string().contains(q)
260 } else {
261 true
262 }
263 })
264 .take(limit)
265 .cloned()
266 .collect();
267
268 Ok(items)
269 }
270
271 async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
272 let mut data = self.data.write().await;
273 let ns_key = namespace_key(namespace);
274 let ns = data.entry(ns_key.clone()).or_default();
275 let now = now_iso();
276
277 let item = if let Some(existing) = ns.get(key) {
278 Item {
279 namespace: namespace.iter().map(|s| s.to_string()).collect(),
280 key: key.to_string(),
281 value: value.clone(),
282 created_at: existing.created_at.clone(),
283 updated_at: now,
284 score: None,
285 }
286 } else {
287 Item {
288 namespace: namespace.iter().map(|s| s.to_string()).collect(),
289 key: key.to_string(),
290 value: value.clone(),
291 created_at: now.clone(),
292 updated_at: now,
293 score: None,
294 }
295 };
296
297 ns.insert(key.to_string(), item);
298
299 if let Some(ref embeddings) = self.embeddings {
301 let text = value.as_str().unwrap_or(&value.to_string()).to_string();
302 let vecs = embeddings.embed_documents(&[&text]).await?;
303 if let Some(vec) = vecs.into_iter().next() {
304 let vec_key = format!("{}::{}", ns_key, key);
305 self.vectors.write().await.insert(vec_key, vec);
306 }
307 }
308
309 Ok(())
310 }
311
312 async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
313 let mut data = self.data.write().await;
314 let ns_key = namespace_key(namespace);
315 if let Some(ns) = data.get_mut(&ns_key) {
316 ns.remove(key);
317 }
318 let vec_key = format!("{}::{}", ns_key, key);
320 self.vectors.write().await.remove(&vec_key);
321 Ok(())
322 }
323
324 async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
325 let data = self.data.read().await;
326 let prefix_str = if prefix.is_empty() {
327 String::new()
328 } else {
329 namespace_key(prefix)
330 };
331
332 let namespaces: Vec<Vec<String>> = data
333 .keys()
334 .filter(|k| prefix.is_empty() || k.starts_with(&prefix_str))
335 .map(|k| k.split("::").map(String::from).collect())
336 .collect();
337
338 Ok(namespaces)
339 }
340}
341
342#[cfg(feature = "filesystem")]
343mod file_store;
344#[cfg(feature = "filesystem")]
345pub use file_store::FileStore;
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use serde_json::json;
351
352 #[tokio::test]
353 async fn put_and_get() {
354 let store = InMemoryStore::new();
355 store
356 .put(&["users", "prefs"], "theme", json!("dark"))
357 .await
358 .unwrap();
359
360 let item = store
361 .get(&["users", "prefs"], "theme")
362 .await
363 .unwrap()
364 .unwrap();
365 assert_eq!(item.key, "theme");
366 assert_eq!(item.value, json!("dark"));
367 assert_eq!(item.namespace, vec!["users", "prefs"]);
368 }
369
370 #[tokio::test]
371 async fn get_nonexistent() {
372 let store = InMemoryStore::new();
373 let item = store.get(&["a"], "missing").await.unwrap();
374 assert!(item.is_none());
375 }
376
377 #[tokio::test]
378 async fn delete_item() {
379 let store = InMemoryStore::new();
380 store.put(&["ns"], "k", json!(1)).await.unwrap();
381 store.delete(&["ns"], "k").await.unwrap();
382 assert!(store.get(&["ns"], "k").await.unwrap().is_none());
383 }
384
385 #[tokio::test]
386 async fn search_items() {
387 let store = InMemoryStore::new();
388 store.put(&["ns"], "a", json!("apple")).await.unwrap();
389 store.put(&["ns"], "b", json!("banana")).await.unwrap();
390 store.put(&["ns"], "c", json!("cherry")).await.unwrap();
391
392 let all = store.search(&["ns"], None, 10).await.unwrap();
393 assert_eq!(all.len(), 3);
394
395 let filtered = store.search(&["ns"], Some("apple"), 10).await.unwrap();
396 assert_eq!(filtered.len(), 1);
397 }
398
399 #[tokio::test]
400 async fn list_namespaces_with_prefix() {
401 let store = InMemoryStore::new();
402 store.put(&["a", "b"], "k1", json!(1)).await.unwrap();
403 store.put(&["a", "c"], "k2", json!(2)).await.unwrap();
404 store.put(&["x", "y"], "k3", json!(3)).await.unwrap();
405
406 let all = store.list_namespaces(&[]).await.unwrap();
407 assert_eq!(all.len(), 3);
408
409 let filtered = store.list_namespaces(&["a"]).await.unwrap();
410 assert_eq!(filtered.len(), 2);
411 }
412
413 #[tokio::test]
414 async fn upsert_preserves_created_at() {
415 let store = InMemoryStore::new();
416 store.put(&["ns"], "k", json!(1)).await.unwrap();
417 let first = store.get(&["ns"], "k").await.unwrap().unwrap();
418
419 store.put(&["ns"], "k", json!(2)).await.unwrap();
420 let second = store.get(&["ns"], "k").await.unwrap().unwrap();
421
422 assert_eq!(first.created_at, second.created_at);
423 assert_eq!(second.value, json!(2));
424 }
425
426 struct TestEmbeddings;
428
429 #[async_trait]
430 impl Embeddings for TestEmbeddings {
431 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
432 Ok(texts.iter().map(|t| text_to_vec(t)).collect())
433 }
434
435 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
436 Ok(text_to_vec(text))
437 }
438 }
439
440 fn text_to_vec(text: &str) -> Vec<f32> {
442 let bytes = text.as_bytes();
443 let mut v = vec![0.0f32; 4];
444 for (i, b) in bytes.iter().enumerate() {
445 v[i % 4] += *b as f32;
446 }
447 v
448 }
449
450 #[tokio::test]
451 async fn semantic_search_ranked_by_similarity() {
452 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
453
454 store
455 .put(&["docs"], "a", json!("rust programming"))
456 .await
457 .unwrap();
458 store
459 .put(&["docs"], "b", json!("python programming"))
460 .await
461 .unwrap();
462 store
463 .put(&["docs"], "c", json!("cooking recipes"))
464 .await
465 .unwrap();
466
467 let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
469 assert_eq!(results.len(), 3);
470
471 for item in &results {
473 assert!(item.score.is_some());
474 }
475
476 let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
478 for w in scores.windows(2) {
479 assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
480 }
481 }
482
483 #[tokio::test]
484 async fn semantic_search_respects_limit() {
485 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
486
487 for i in 0..5 {
488 store
489 .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
490 .await
491 .unwrap();
492 }
493
494 let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
495 assert_eq!(results.len(), 2);
496 }
497
498 #[tokio::test]
499 async fn delete_cleans_up_embeddings() {
500 let store = InMemoryStore::new().with_embeddings(Arc::new(TestEmbeddings));
501
502 store.put(&["ns"], "k", json!("hello")).await.unwrap();
503 assert!(!store.vectors.read().await.is_empty());
504
505 store.delete(&["ns"], "k").await.unwrap();
506 assert!(store.vectors.read().await.is_empty());
507 }
508
509 #[tokio::test]
510 async fn hybrid_search_combines_bm25_and_embeddings() {
511 let store = InMemoryStore::new().with_hybrid_search(Arc::new(TestEmbeddings));
512 store
513 .put(&["docs"], "a", json!("rust programming language guide"))
514 .await
515 .unwrap();
516 store
517 .put(&["docs"], "b", json!("python web development"))
518 .await
519 .unwrap();
520 store
521 .put(&["docs"], "c", json!("rust cargo build system"))
522 .await
523 .unwrap();
524
525 let results = store.search(&["docs"], Some("rust"), 10).await.unwrap();
526 assert_eq!(results.len(), 3);
527 for item in &results {
528 assert!(item.score.is_some());
529 }
530 let scores: Vec<f64> = results.iter().map(|i| i.score.unwrap()).collect();
532 for w in scores.windows(2) {
533 assert!(w[0] >= w[1], "scores not sorted: {:?}", scores);
534 }
535 }
536
537 #[tokio::test]
538 async fn hybrid_search_respects_limit() {
539 let store = InMemoryStore::new().with_hybrid_search(Arc::new(TestEmbeddings));
540 for i in 0..5 {
541 store
542 .put(&["ns"], &format!("k{}", i), json!(format!("item {}", i)))
543 .await
544 .unwrap();
545 }
546 let results = store.search(&["ns"], Some("item"), 2).await.unwrap();
547 assert_eq!(results.len(), 2);
548 }
549
550 #[test]
551 fn bm25_score_basic() {
552 let score = bm25_score("rust", "rust programming language", 3.0, 10);
553 assert!(score > 0.0);
554 }
555
556 #[test]
557 fn bm25_score_zero_for_no_match() {
558 let score = bm25_score("python", "rust programming language", 3.0, 10);
559 assert_eq!(score, 0.0);
560 }
561}