1use std::collections::HashMap;
2use std::sync::RwLock;
3
4use anyhow::Result;
5use uuid::Uuid;
6
7use second_brain_core::schema::{
8 Conversation, Entity, Memory, MemoryType, Relation, RelationType,
9};
10use second_brain_core::store::Store;
11
12type RelationsCache = HashMap<(Uuid, Option<RelationType>), Vec<Relation>>;
13
14pub struct CachingStore<'a> {
20 inner: &'a dyn Store,
21 relations_cache: RwLock<RelationsCache>,
22}
23
24impl<'a> CachingStore<'a> {
25 pub fn new(inner: &'a dyn Store) -> Self {
26 Self {
27 inner,
28 relations_cache: RwLock::new(HashMap::new()),
29 }
30 }
31
32 pub fn prewarm(&self) -> Result<()> {
38 let scored_types = [
39 RelationType::Reinforces,
40 RelationType::RelatesTo,
41 RelationType::DistilledFrom,
42 RelationType::Mentions,
43 RelationType::DerivedFrom,
44 RelationType::Contradicts,
45 RelationType::Supersedes,
46 ];
47
48 let ids = self.inner.all_memory_ids()?;
49 let mut cache = self.relations_cache.write().unwrap();
50 cache.reserve(ids.len() * scored_types.len());
51 for id in &ids {
52 for rt in &scored_types {
53 cache.entry((*id, Some(*rt))).or_default();
54 }
55 }
56
57 for rel in self.inner.all_relations()? {
58 if let Some(bucket) = cache.get_mut(&(rel.from_id, Some(rel.relation_type))) {
62 bucket.push(rel);
63 }
64 }
65
66 Ok(())
67 }
68}
69
70impl Store for CachingStore<'_> {
71 fn get_relations(
72 &self,
73 node_id: Uuid,
74 relation_type: Option<RelationType>,
75 ) -> Result<Vec<Relation>> {
76 let key = (node_id, relation_type);
77 if let Some(hit) = self.relations_cache.read().unwrap().get(&key) {
78 return Ok(hit.clone());
79 }
80 let fetched = self.inner.get_relations(node_id, relation_type)?;
81 self.relations_cache
82 .write()
83 .unwrap()
84 .insert(key, fetched.clone());
85 Ok(fetched)
86 }
87
88 fn store_memory(&self, memory: &Memory) -> Result<()> {
89 self.inner.store_memory(memory)
90 }
91
92 fn get_memory(&self, id: Uuid) -> Result<Option<Memory>> {
93 self.inner.get_memory(id)
94 }
95
96 fn delete_memory(&self, id: Uuid) -> Result<()> {
97 self.inner.delete_memory(id)
98 }
99
100 fn store_entity(&self, entity: &Entity) -> Result<()> {
101 self.inner.store_entity(entity)
102 }
103
104 fn get_entity(&self, id: Uuid) -> Result<Option<Entity>> {
105 self.inner.get_entity(id)
106 }
107
108 fn find_entity_by_name(&self, name: &str) -> Result<Option<Entity>> {
109 self.inner.find_entity_by_name(name)
110 }
111
112 fn store_conversation(&self, conversation: &Conversation) -> Result<()> {
113 self.inner.store_conversation(conversation)
114 }
115
116 fn store_relation(&self, relation: &Relation) -> Result<()> {
117 self.inner.store_relation(relation)
118 }
119
120 fn vector_search(&self, embedding: &[f32], limit: usize) -> Result<Vec<(Memory, f32)>> {
121 self.inner.vector_search(embedding, limit)
122 }
123
124 fn traverse(&self, start_id: Uuid, depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
125 self.inner.traverse(start_id, depth)
126 }
127
128 fn memories_by_source(&self, source: &str) -> Result<Vec<Memory>> {
129 self.inner.memories_by_source(source)
130 }
131
132 fn memories_by_type(&self, memory_type: MemoryType) -> Result<Vec<Memory>> {
133 self.inner.memories_by_type(memory_type)
134 }
135
136 fn memories_needing_decay(&self, threshold_days: u32) -> Result<Vec<Memory>> {
137 self.inner.memories_needing_decay(threshold_days)
138 }
139
140 fn update_memory(&self, memory: &Memory) -> Result<()> {
141 self.inner.update_memory(memory)
142 }
143
144 fn record_access(&self, memory: &Memory) -> Result<()> {
145 self.inner.record_access(memory)
146 }
147
148 fn text_search(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
149 self.inner.text_search(query, limit)
150 }
151
152 fn memory_count(&self) -> Result<usize> {
153 self.inner.memory_count()
154 }
155
156 fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
157 self.inner.all_memory_ids()
158 }
159
160 fn all_relations(&self) -> Result<Vec<Relation>> {
161 self.inner.all_relations()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use chrono::{Duration, Utc};
169 use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
170 use second_brain_core::schema::MemoryType;
171
172 struct InMemoryStore {
173 vector_results: Vec<(Memory, f32)>,
174 relations: Vec<Relation>,
175 get_relations_calls: std::sync::atomic::AtomicUsize,
176 }
177
178 impl Store for InMemoryStore {
179 fn vector_search(&self, _embedding: &[f32], _limit: usize) -> Result<Vec<(Memory, f32)>> {
180 Ok(self.vector_results.clone())
181 }
182
183 fn get_relations(
184 &self,
185 node_id: Uuid,
186 relation_type: Option<RelationType>,
187 ) -> Result<Vec<Relation>> {
188 self.get_relations_calls
189 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
190 Ok(self
191 .relations
192 .iter()
193 .filter(|r| r.from_id == node_id)
194 .filter(|r| relation_type.map(|rt| rt == r.relation_type).unwrap_or(true))
195 .cloned()
196 .collect())
197 }
198
199 fn store_memory(&self, _m: &Memory) -> Result<()> {
200 unimplemented!()
201 }
202 fn get_memory(&self, _id: Uuid) -> Result<Option<Memory>> {
203 unimplemented!()
204 }
205 fn delete_memory(&self, _id: Uuid) -> Result<()> {
206 unimplemented!()
207 }
208 fn store_entity(&self, _e: &Entity) -> Result<()> {
209 unimplemented!()
210 }
211 fn get_entity(&self, _id: Uuid) -> Result<Option<Entity>> {
212 unimplemented!()
213 }
214 fn find_entity_by_name(&self, _name: &str) -> Result<Option<Entity>> {
215 unimplemented!()
216 }
217 fn store_conversation(&self, _c: &Conversation) -> Result<()> {
218 unimplemented!()
219 }
220 fn store_relation(&self, _r: &Relation) -> Result<()> {
221 unimplemented!()
222 }
223 fn traverse(&self, _id: Uuid, _depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
224 unimplemented!()
225 }
226 fn memories_by_source(&self, _s: &str) -> Result<Vec<Memory>> {
227 unimplemented!()
228 }
229 fn memories_by_type(&self, _mt: MemoryType) -> Result<Vec<Memory>> {
230 unimplemented!()
231 }
232 fn memories_needing_decay(&self, _days: u32) -> Result<Vec<Memory>> {
233 unimplemented!()
234 }
235 fn update_memory(&self, _m: &Memory) -> Result<()> {
236 unimplemented!()
237 }
238 fn record_access(&self, _memory: &Memory) -> Result<()> {
239 unimplemented!()
240 }
241 fn text_search(&self, _q: &str, _limit: usize) -> Result<Vec<Memory>> {
242 unimplemented!()
243 }
244 fn memory_count(&self) -> Result<usize> {
245 unimplemented!()
246 }
247 fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
248 Ok(self.vector_results.iter().map(|(m, _)| m.id).collect())
249 }
250 fn all_relations(&self) -> Result<Vec<Relation>> {
251 Ok(self.relations.clone())
252 }
253 }
254
255 fn memory(content: &str, days_old: i64) -> Memory {
256 let when = Utc::now() - Duration::days(days_old);
257 let mut m = Memory::new(
258 content.to_string(),
259 MemoryType::Semantic,
260 "test".to_string(),
261 String::new(),
262 );
263 m.created_at = when;
264 m.last_accessed = when;
265 m
266 }
267
268 fn fixture() -> (Vec<(Memory, f32)>, Vec<Relation>) {
269 let a = memory("kuzu was chosen as the embedded graph store", 10);
270 let b = memory("sync runs bidirectionally over ssh", 40);
271 let c = memory("embeddings use the bge model", 5);
272 let rel = |from: Uuid, rt: RelationType, strength: f32| Relation {
273 from_id: from,
274 to_id: Uuid::new_v4(),
275 relation_type: rt,
276 strength,
277 context: None,
278 };
279 let relations = vec![
280 rel(a.id, RelationType::Reinforces, 1.0),
281 rel(a.id, RelationType::RelatesTo, 0.7),
282 rel(b.id, RelationType::Mentions, 1.0),
283 rel(c.id, RelationType::RelatesTo, 0.4),
284 rel(c.id, RelationType::Supersedes, 1.0),
285 ];
286 let vector_results = vec![(a, 0.91), (b, 0.78), (c, 0.66)];
287 (vector_results, relations)
288 }
289
290 fn request() -> QueryRequest {
291 QueryRequest {
292 text: "graph store choice".to_string(),
293 embedding: vec![0.1_f32; 384],
294 limit: 10,
295 filters: QueryFilters::default(),
296 }
297 }
298
299 #[test]
300 fn caching_store_recall_matches_raw_store() {
301 let (vector_results, relations) = fixture();
302 let raw = InMemoryStore {
303 vector_results,
304 relations,
305 get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
306 };
307
308 let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
309
310 let cached = CachingStore::new(&raw);
311 let first = QueryEngine::new(&cached).recall(&request()).unwrap();
312 let second = QueryEngine::new(&cached).recall(&request()).unwrap();
313
314 assert_eq!(baseline.len(), first.len());
315 assert_eq!(first.len(), second.len());
316 for (b, c) in baseline.iter().zip(first.iter()) {
317 assert_eq!(b.memory.id, c.memory.id, "result order must match");
318 assert!(
319 (b.score - c.score).abs() < 1e-6,
320 "scores must match: {} vs {}",
321 b.score,
322 c.score
323 );
324 }
325 for (a, c) in first.iter().zip(second.iter()) {
328 assert_eq!(a.memory.id, c.memory.id);
329 assert!((a.score - c.score).abs() < 1e-6);
330 }
331 }
332
333 #[test]
334 fn prewarmed_store_recall_matches_raw_and_skips_live_reads() {
335 use std::sync::atomic::Ordering;
336 let (vector_results, relations) = fixture();
337 let raw = InMemoryStore {
338 vector_results,
339 relations,
340 get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
341 };
342
343 let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
344 let calls_after_baseline = raw.get_relations_calls.load(Ordering::Relaxed);
345 assert!(calls_after_baseline > 0, "baseline must hit the live store");
346
347 let cached = CachingStore::new(&raw);
348 cached.prewarm().unwrap();
349 let prewarm_calls = raw.get_relations_calls.load(Ordering::Relaxed);
350
351 let recalled = QueryEngine::new(&cached).recall(&request()).unwrap();
352
353 assert_eq!(
356 raw.get_relations_calls.load(Ordering::Relaxed),
357 prewarm_calls,
358 "prewarmed recall must not call inner.get_relations"
359 );
360 assert_eq!(prewarm_calls, calls_after_baseline, "prewarm must not read via get_relations");
361
362 assert_eq!(baseline.len(), recalled.len());
363 for (b, c) in baseline.iter().zip(recalled.iter()) {
364 assert_eq!(b.memory.id, c.memory.id, "result order must match");
365 assert!(
366 (b.score - c.score).abs() < 1e-6,
367 "scores must match: {} vs {}",
368 b.score,
369 c.score
370 );
371 }
372 }
373}