1use crate::enrichment::{EnrichmentJob, EnrichmentQueue};
2use crate::episode::{Episode, EpisodeId, EpisodeWrite};
3use crate::fact::{FactEdge, FactId};
4use parking_lot::RwLock;
5use sochdb_query::{EmbeddingProvider, MockEmbeddingProvider, trigram_index::TrigramIndex};
6use sochdb_storage::hlc::HybridLogicalClock;
7use sochdb_vector::bm25::BM25Config;
8use sochdb_vector::inverted_index::InvertedIndex;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Instant;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum MemoryError {
17 #[error("namespace not found: {0}")]
18 NamespaceNotFound(String),
19 #[error("episode not found: {0}")]
20 EpisodeNotFound(u64),
21 #[error("io error: {0}")]
22 Io(#[from] std::io::Error),
23}
24
25pub type MemoryResult<T> = Result<T, MemoryError>;
26
27#[derive(Debug, Clone)]
28pub struct MemoryStoreConfig {
29 pub max_enrichment_queue: usize,
30 pub enrich_on_write: bool,
32}
33
34impl Default for MemoryStoreConfig {
35 fn default() -> Self {
36 Self {
37 max_enrichment_queue: 10_000,
38 enrich_on_write: false,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct WriteResult {
45 pub episode_id: EpisodeId,
46 pub t_created: u64,
47 pub lexical_indexed: bool,
48 pub ingestion_lag_us: u64,
49 pub enrichment_queued: bool,
50}
51
52pub(crate) struct NamespaceIndexes {
53 pub(crate) bm25: InvertedIndex,
54 pub(crate) trigram: TrigramIndex,
55 pub(crate) vectors: HashMap<u64, Vec<f32>>,
56 pub(crate) episodes: HashMap<u64, Episode>,
57 facts: Vec<FactEdge>,
58 next_episode_id: u64,
59 next_fact_id: u64,
60}
61
62impl NamespaceIndexes {
63 fn new() -> Self {
64 Self {
65 bm25: InvertedIndex::new(BM25Config::default()),
66 trigram: TrigramIndex::new(),
67 vectors: HashMap::new(),
68 episodes: HashMap::new(),
69 facts: Vec::new(),
70 next_episode_id: 1,
71 next_fact_id: 1,
72 }
73 }
74}
75
76pub struct MemoryStore {
78 hlc: HybridLogicalClock,
79 pub(crate) namespaces: RwLock<HashMap<String, NamespaceIndexes>>,
80 pub(crate) enrichment: EnrichmentQueue,
81 pub(crate) embedder: Arc<dyn EmbeddingProvider>,
82 config: MemoryStoreConfig,
83}
84
85fn default_embedder() -> Arc<dyn EmbeddingProvider> {
86 Arc::new(MockEmbeddingProvider::new(384))
87}
88
89impl MemoryStore {
90 pub fn new(_data_dir: Option<&Path>, config: MemoryStoreConfig) -> Self {
91 Self::with_embedder(_data_dir, config, default_embedder())
92 }
93
94 pub fn with_embedder(
95 _data_dir: Option<&Path>,
96 config: MemoryStoreConfig,
97 embedder: Arc<dyn EmbeddingProvider>,
98 ) -> Self {
99 Self {
100 hlc: HybridLogicalClock::new(),
101 namespaces: RwLock::new(HashMap::new()),
102 enrichment: EnrichmentQueue::new(config.max_enrichment_queue),
103 embedder,
104 config,
105 }
106 }
107
108 pub fn with_defaults() -> Self {
109 Self::new(None, MemoryStoreConfig::default())
110 }
111
112 pub fn enrichment_queue(&self) -> &EnrichmentQueue {
113 &self.enrichment
114 }
115
116 pub fn write_episode(&self, write: EpisodeWrite) -> MemoryResult<WriteResult> {
118 let start = Instant::now();
119 let t_created = self.hlc.next();
120 let t_valid = write
128 .t_valid_from
129 .unwrap_or_else(|| HybridLogicalClock::physical_time(t_created) / 1000);
130
131 let mut namespaces = self.namespaces.write();
132 let ns = namespaces
133 .entry(write.namespace.clone())
134 .or_insert_with(NamespaceIndexes::new);
135
136 let episode_id = EpisodeId(ns.next_episode_id);
137 ns.next_episode_id += 1;
138
139 let doc_id = episode_id.0;
140 ns.bm25.add_document_with_id(doc_id, &write.text);
141 ns.trigram.insert(doc_id, &write.text);
142
143 let episode = Episode {
144 id: episode_id,
145 namespace: write.namespace.clone(),
146 text: write.text.clone(),
147 t_created,
148 t_valid_from: t_valid,
149 enriched: false,
150 metadata: write.metadata.clone(),
151 };
152 ns.episodes.insert(doc_id, episode);
153
154 let job = EnrichmentJob {
155 namespace: write.namespace.clone(),
156 episode_id: doc_id,
157 text: write.text.clone(),
158 };
159
160 let enrichment_queued = self.enrichment.try_enqueue(job.clone()).is_ok();
161 let ingestion_lag_us = start.elapsed().as_micros() as u64;
162
163 let result = WriteResult {
164 episode_id,
165 t_created,
166 lexical_indexed: true,
167 ingestion_lag_us,
168 enrichment_queued,
169 };
170
171 drop(namespaces);
173
174 if self.config.enrich_on_write {
175 let _ = self.enrich_episode(&job);
176 }
177
178 Ok(result)
179 }
180
181 pub fn get_episode(&self, namespace: &str, id: EpisodeId) -> MemoryResult<Episode> {
182 let namespaces = self.namespaces.read();
183 let ns = namespaces
184 .get(namespace)
185 .ok_or_else(|| MemoryError::NamespaceNotFound(namespace.to_string()))?;
186 ns.episodes
187 .get(&id.0)
188 .cloned()
189 .ok_or_else(|| MemoryError::EpisodeNotFound(id.0))
190 }
191
192 pub fn namespace_bm25(&self, namespace: &str) -> Option<Arc<InvertedIndex>> {
193 let _ = namespace;
195 None
196 }
197
198 pub fn search_bm25(&self, namespace: &str, query: &str, k: usize) -> Vec<(u64, f32)> {
199 let namespaces = self.namespaces.read();
200 namespaces
201 .get(namespace)
202 .map(|ns| ns.bm25.search(query, k))
203 .unwrap_or_default()
204 }
205
206 pub fn search_trigram_literal(
207 &self,
208 namespace: &str,
209 literal: &str,
210 k: usize,
211 ) -> Vec<(u64, f32)> {
212 let namespaces = self.namespaces.read();
213 let Some(ns) = namespaces.get(namespace) else {
214 return Vec::new();
215 };
216 let trigrams = sochdb_query::trigram_index::trigrams_of(literal);
217 if trigrams.is_empty() {
218 return Vec::new();
219 }
220 let candidates = ns.trigram.candidates(&trigrams);
221 candidates
222 .into_iter()
223 .take(k)
224 .map(|doc_id| (doc_id, 1.0))
225 .collect()
226 }
227
228 pub fn episode_text(&self, namespace: &str, doc_id: u64) -> Option<String> {
229 let namespaces = self.namespaces.read();
230 namespaces
231 .get(namespace)?
232 .episodes
233 .get(&doc_id)
234 .map(|e| e.text.clone())
235 }
236
237 pub fn add_fact(&self, namespace: &str, mut fact: FactEdge) -> MemoryResult<FactId> {
238 let mut namespaces = self.namespaces.write();
239 let ns = namespaces
240 .entry(namespace.to_string())
241 .or_insert_with(NamespaceIndexes::new);
242 let id = FactId(ns.next_fact_id);
243 ns.next_fact_id += 1;
244 fact.id = id;
245 ns.facts.push(fact);
246 Ok(id)
247 }
248
249 pub fn facts_valid_at(&self, namespace: &str, tau: u64) -> Vec<FactEdge> {
250 let namespaces = self.namespaces.read();
251 namespaces
252 .get(namespace)
253 .map(|ns| {
254 ns.facts
255 .iter()
256 .filter(|f| f.is_valid_at(tau))
257 .cloned()
258 .collect()
259 })
260 .unwrap_or_default()
261 }
262
263 pub fn invalidate_fact(&self, namespace: &str, fact_id: FactId, t_invalid: u64) -> bool {
264 let mut namespaces = self.namespaces.write();
265 let Some(ns) = namespaces.get_mut(namespace) else {
266 return false;
267 };
268 if let Some(fact) = ns.facts.iter_mut().find(|f| f.id == fact_id) {
269 fact.invalidate(t_invalid);
270 return true;
271 }
272 false
273 }
274
275 pub fn episode_count(&self, namespace: &str) -> usize {
276 self.namespaces
277 .read()
278 .get(namespace)
279 .map(|ns| ns.episodes.len())
280 .unwrap_or(0)
281 }
282}