Skip to main content

sochdb_memory/
store.rs

1use crate::enrichment::{EnrichmentJob, EnrichmentQueue};
2use crate::episode::{Episode, EpisodeId, EpisodeWrite};
3use crate::fact::{FactEdge, FactId};
4use parking_lot::RwLock;
5use sochdb_query::{EmbeddingProvider, MockEmbeddingProvider, trigram_index::TrigramIndex};
6use sochdb_storage::hlc::HybridLogicalClock;
7use sochdb_vector::bm25::BM25Config;
8use sochdb_vector::inverted_index::InvertedIndex;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Instant;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum MemoryError {
17    #[error("namespace not found: {0}")]
18    NamespaceNotFound(String),
19    #[error("episode not found: {0}")]
20    EpisodeNotFound(u64),
21    #[error("io error: {0}")]
22    Io(#[from] std::io::Error),
23}
24
25pub type MemoryResult<T> = Result<T, MemoryError>;
26
27#[derive(Debug, Clone)]
28pub struct MemoryStoreConfig {
29    pub max_enrichment_queue: usize,
30    /// Run embedding + HNSW insert synchronously on write (bench/tests).
31    pub enrich_on_write: bool,
32}
33
34impl Default for MemoryStoreConfig {
35    fn default() -> Self {
36        Self {
37            max_enrichment_queue: 10_000,
38            enrich_on_write: false,
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct WriteResult {
45    pub episode_id: EpisodeId,
46    pub t_created: u64,
47    pub lexical_indexed: bool,
48    pub ingestion_lag_us: u64,
49    pub enrichment_queued: bool,
50}
51
52pub(crate) struct NamespaceIndexes {
53    pub(crate) bm25: InvertedIndex,
54    pub(crate) trigram: TrigramIndex,
55    pub(crate) vectors: HashMap<u64, Vec<f32>>,
56    pub(crate) episodes: HashMap<u64, Episode>,
57    facts: Vec<FactEdge>,
58    next_episode_id: u64,
59    next_fact_id: u64,
60}
61
62impl NamespaceIndexes {
63    fn new() -> Self {
64        Self {
65            bm25: InvertedIndex::new(BM25Config::default()),
66            trigram: TrigramIndex::new(),
67            vectors: HashMap::new(),
68            episodes: HashMap::new(),
69            facts: Vec::new(),
70            next_episode_id: 1,
71            next_fact_id: 1,
72        }
73    }
74}
75
76/// Agent memory store: write-time lexical recall + async enrichment queue.
77pub struct MemoryStore {
78    hlc: HybridLogicalClock,
79    pub(crate) namespaces: RwLock<HashMap<String, NamespaceIndexes>>,
80    pub(crate) enrichment: EnrichmentQueue,
81    pub(crate) embedder: Arc<dyn EmbeddingProvider>,
82    config: MemoryStoreConfig,
83}
84
85fn default_embedder() -> Arc<dyn EmbeddingProvider> {
86    Arc::new(MockEmbeddingProvider::new(384))
87}
88
89impl MemoryStore {
90    pub fn new(_data_dir: Option<&Path>, config: MemoryStoreConfig) -> Self {
91        Self::with_embedder(_data_dir, config, default_embedder())
92    }
93
94    pub fn with_embedder(
95        _data_dir: Option<&Path>,
96        config: MemoryStoreConfig,
97        embedder: Arc<dyn EmbeddingProvider>,
98    ) -> Self {
99        Self {
100            hlc: HybridLogicalClock::new(),
101            namespaces: RwLock::new(HashMap::new()),
102            enrichment: EnrichmentQueue::new(config.max_enrichment_queue),
103            embedder,
104            config,
105        }
106    }
107
108    pub fn with_defaults() -> Self {
109        Self::new(None, MemoryStoreConfig::default())
110    }
111
112    pub fn enrichment_queue(&self) -> &EnrichmentQueue {
113        &self.enrichment
114    }
115
116    /// Write episode: lexical lanes indexed synchronously; enrichment queued async.
117    pub fn write_episode(&self, write: EpisodeWrite) -> MemoryResult<WriteResult> {
118        let start = Instant::now();
119        let t_created = self.hlc.next();
120        // Default validity-start to wall-clock unix milliseconds, matching the
121        // `as_of=<unix_ms>` query contract. `t_created` is a raw HLC tick
122        // (`physical_micros << 16 | logical`, ~1e21), so defaulting to it made
123        // the bi-temporal filter `t_valid_from <= as_of` always false for any
124        // realistic `as_of` — silently returning zero results for every
125        // episode written without an explicit validity time (the common case).
126        // Callers that pass `t_valid_from` keep their own time domain.
127        let t_valid = write
128            .t_valid_from
129            .unwrap_or_else(|| HybridLogicalClock::physical_time(t_created) / 1000);
130
131        let mut namespaces = self.namespaces.write();
132        let ns = namespaces
133            .entry(write.namespace.clone())
134            .or_insert_with(NamespaceIndexes::new);
135
136        let episode_id = EpisodeId(ns.next_episode_id);
137        ns.next_episode_id += 1;
138
139        let doc_id = episode_id.0;
140        ns.bm25.add_document_with_id(doc_id, &write.text);
141        ns.trigram.insert(doc_id, &write.text);
142
143        let episode = Episode {
144            id: episode_id,
145            namespace: write.namespace.clone(),
146            text: write.text.clone(),
147            t_created,
148            t_valid_from: t_valid,
149            enriched: false,
150            metadata: write.metadata.clone(),
151        };
152        ns.episodes.insert(doc_id, episode);
153
154        let job = EnrichmentJob {
155            namespace: write.namespace.clone(),
156            episode_id: doc_id,
157            text: write.text.clone(),
158        };
159
160        let enrichment_queued = self.enrichment.try_enqueue(job.clone()).is_ok();
161        let ingestion_lag_us = start.elapsed().as_micros() as u64;
162
163        let result = WriteResult {
164            episode_id,
165            t_created,
166            lexical_indexed: true,
167            ingestion_lag_us,
168            enrichment_queued,
169        };
170
171        // Release namespace lock before enrichment (embed + vector insert re-lock).
172        drop(namespaces);
173
174        if self.config.enrich_on_write {
175            let _ = self.enrich_episode(&job);
176        }
177
178        Ok(result)
179    }
180
181    pub fn get_episode(&self, namespace: &str, id: EpisodeId) -> MemoryResult<Episode> {
182        let namespaces = self.namespaces.read();
183        let ns = namespaces
184            .get(namespace)
185            .ok_or_else(|| MemoryError::NamespaceNotFound(namespace.to_string()))?;
186        ns.episodes
187            .get(&id.0)
188            .cloned()
189            .ok_or_else(|| MemoryError::EpisodeNotFound(id.0))
190    }
191
192    pub fn namespace_bm25(&self, namespace: &str) -> Option<Arc<InvertedIndex>> {
193        // BM25 index is behind RwLock in namespace — expose search via store methods instead
194        let _ = namespace;
195        None
196    }
197
198    pub fn search_bm25(&self, namespace: &str, query: &str, k: usize) -> Vec<(u64, f32)> {
199        let namespaces = self.namespaces.read();
200        namespaces
201            .get(namespace)
202            .map(|ns| ns.bm25.search(query, k))
203            .unwrap_or_default()
204    }
205
206    pub fn search_trigram_literal(
207        &self,
208        namespace: &str,
209        literal: &str,
210        k: usize,
211    ) -> Vec<(u64, f32)> {
212        let namespaces = self.namespaces.read();
213        let Some(ns) = namespaces.get(namespace) else {
214            return Vec::new();
215        };
216        let trigrams = sochdb_query::trigram_index::trigrams_of(literal);
217        if trigrams.is_empty() {
218            return Vec::new();
219        }
220        let candidates = ns.trigram.candidates(&trigrams);
221        candidates
222            .into_iter()
223            .take(k)
224            .map(|doc_id| (doc_id, 1.0))
225            .collect()
226    }
227
228    pub fn episode_text(&self, namespace: &str, doc_id: u64) -> Option<String> {
229        let namespaces = self.namespaces.read();
230        namespaces
231            .get(namespace)?
232            .episodes
233            .get(&doc_id)
234            .map(|e| e.text.clone())
235    }
236
237    pub fn add_fact(&self, namespace: &str, mut fact: FactEdge) -> MemoryResult<FactId> {
238        let mut namespaces = self.namespaces.write();
239        let ns = namespaces
240            .entry(namespace.to_string())
241            .or_insert_with(NamespaceIndexes::new);
242        let id = FactId(ns.next_fact_id);
243        ns.next_fact_id += 1;
244        fact.id = id;
245        ns.facts.push(fact);
246        Ok(id)
247    }
248
249    pub fn facts_valid_at(&self, namespace: &str, tau: u64) -> Vec<FactEdge> {
250        let namespaces = self.namespaces.read();
251        namespaces
252            .get(namespace)
253            .map(|ns| {
254                ns.facts
255                    .iter()
256                    .filter(|f| f.is_valid_at(tau))
257                    .cloned()
258                    .collect()
259            })
260            .unwrap_or_default()
261    }
262
263    pub fn invalidate_fact(&self, namespace: &str, fact_id: FactId, t_invalid: u64) -> bool {
264        let mut namespaces = self.namespaces.write();
265        let Some(ns) = namespaces.get_mut(namespace) else {
266            return false;
267        };
268        if let Some(fact) = ns.facts.iter_mut().find(|f| f.id == fact_id) {
269            fact.invalidate(t_invalid);
270            return true;
271        }
272        false
273    }
274
275    pub fn episode_count(&self, namespace: &str) -> usize {
276        self.namespaces
277            .read()
278            .get(namespace)
279            .map(|ns| ns.episodes.len())
280            .unwrap_or(0)
281    }
282}