Skip to main content

sochdb_memory/
embedding.rs

1//! Episode embedding + per-namespace vector store for the vector retrieval lane.
2
3use crate::enrichment::EnrichmentJob;
4use crate::store::MemoryStore;
5use sochdb_query::EmbeddingProvider;
6use std::sync::Arc;
7
8/// Cosine similarity for L2-normalized embeddings.
9fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
10    if a.len() != b.len() || a.is_empty() {
11        return 0.0;
12    }
13    let mut dot = 0.0f32;
14    for (x, y) in a.iter().zip(b.iter()) {
15        dot += x * y;
16    }
17    dot
18}
19
20impl MemoryStore {
21    pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
22        &self.embedder
23    }
24
25    /// Embed an episode and store its vector for semantic retrieval.
26    pub fn enrich_episode(&self, job: &EnrichmentJob) -> Result<(), String> {
27        let mut embedding = self.embedder.embed(&job.text).map_err(|e| e.to_string())?;
28        self.embedder.normalize(&mut embedding);
29
30        let mut namespaces = self.namespaces.write();
31        let ns = namespaces
32            .get_mut(&job.namespace)
33            .ok_or_else(|| format!("namespace not found: {}", job.namespace))?;
34
35        ns.vectors.insert(job.episode_id, embedding);
36
37        if let Some(episode) = ns.episodes.get_mut(&job.episode_id) {
38            episode.enriched = true;
39        }
40
41        Ok(())
42    }
43
44    /// Drain all pending enrichment jobs synchronously (tests / bench warmup).
45    pub fn drain_enrichment_queue(&self) -> usize {
46        let mut processed = 0usize;
47        while let Some(job) = self.enrichment.pop() {
48            if self.enrich_episode(&job).is_ok() {
49                processed += 1;
50            }
51            self.enrichment.mark_processed();
52        }
53        processed
54    }
55
56    /// Vector lane search over enriched episodes (brute-force; tuned for agent-memory scale).
57    pub fn search_vector(&self, namespace: &str, query: &str, k: usize) -> Vec<(u64, f32)> {
58        let mut query_emb = match self.embedder.embed(query) {
59            Ok(v) => v,
60            Err(_) => return Vec::new(),
61        };
62        self.embedder.normalize(&mut query_emb);
63
64        let namespaces = self.namespaces.read();
65        let Some(ns) = namespaces.get(namespace) else {
66            return Vec::new();
67        };
68        if ns.vectors.is_empty() {
69            return Vec::new();
70        }
71
72        let mut ranked: Vec<(u64, f32)> = ns
73            .vectors
74            .iter()
75            .map(|(id, vec)| (*id, cosine_similarity(&query_emb, vec)))
76            .collect();
77        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
78        ranked.truncate(k);
79        ranked
80    }
81
82    pub fn enriched_episode_count(&self, namespace: &str) -> usize {
83        self.namespaces
84            .read()
85            .get(namespace)
86            .map(|ns| ns.vectors.len())
87            .unwrap_or(0)
88    }
89}