sochdb_memory/
embedding.rs1use crate::enrichment::EnrichmentJob;
4use crate::store::MemoryStore;
5use sochdb_query::EmbeddingProvider;
6use std::sync::Arc;
7
8fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
10 if a.len() != b.len() || a.is_empty() {
11 return 0.0;
12 }
13 let mut dot = 0.0f32;
14 for (x, y) in a.iter().zip(b.iter()) {
15 dot += x * y;
16 }
17 dot
18}
19
20impl MemoryStore {
21 pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
22 &self.embedder
23 }
24
25 pub fn enrich_episode(&self, job: &EnrichmentJob) -> Result<(), String> {
27 let mut embedding = self.embedder.embed(&job.text).map_err(|e| e.to_string())?;
28 self.embedder.normalize(&mut embedding);
29
30 let mut namespaces = self.namespaces.write();
31 let ns = namespaces
32 .get_mut(&job.namespace)
33 .ok_or_else(|| format!("namespace not found: {}", job.namespace))?;
34
35 ns.vectors.insert(job.episode_id, embedding);
36
37 if let Some(episode) = ns.episodes.get_mut(&job.episode_id) {
38 episode.enriched = true;
39 }
40
41 Ok(())
42 }
43
44 pub fn drain_enrichment_queue(&self) -> usize {
46 let mut processed = 0usize;
47 while let Some(job) = self.enrichment.pop() {
48 if self.enrich_episode(&job).is_ok() {
49 processed += 1;
50 }
51 self.enrichment.mark_processed();
52 }
53 processed
54 }
55
56 pub fn search_vector(&self, namespace: &str, query: &str, k: usize) -> Vec<(u64, f32)> {
58 let mut query_emb = match self.embedder.embed(query) {
59 Ok(v) => v,
60 Err(_) => return Vec::new(),
61 };
62 self.embedder.normalize(&mut query_emb);
63
64 let namespaces = self.namespaces.read();
65 let Some(ns) = namespaces.get(namespace) else {
66 return Vec::new();
67 };
68 if ns.vectors.is_empty() {
69 return Vec::new();
70 }
71
72 let mut ranked: Vec<(u64, f32)> = ns
73 .vectors
74 .iter()
75 .map(|(id, vec)| (*id, cosine_similarity(&query_emb, vec)))
76 .collect();
77 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
78 ranked.truncate(k);
79 ranked
80 }
81
82 pub fn enriched_episode_count(&self, namespace: &str) -> usize {
83 self.namespaces
84 .read()
85 .get(namespace)
86 .map(|ns| ns.vectors.len())
87 .unwrap_or(0)
88 }
89}