Skip to main content

second_brain_api/eval/
caching_store.rs

1use std::collections::HashMap;
2use std::sync::RwLock;
3
4use anyhow::Result;
5use uuid::Uuid;
6
7use second_brain_core::schema::{
8    Conversation, Entity, Memory, MemoryType, Relation, RelationType,
9};
10use second_brain_core::store::Store;
11
12type RelationsCache = HashMap<(Uuid, Option<RelationType>), Vec<Relation>>;
13
14// get_relations(memory_id, rel_type) is deterministic and query-independent, so
15// caching it is sound: the eval replays the same snapshot read-only and recall
16// asks for the same (id, rel_type) pairs across all 836 queries x 2 arms.
17// Kuzu serializes these reads, so memoizing turns repeated graph hits into
18// HashMap lookups without changing any recall output.
19pub struct CachingStore<'a> {
20    inner: &'a dyn Store,
21    relations_cache: RwLock<RelationsCache>,
22}
23
24impl<'a> CachingStore<'a> {
25    pub fn new(inner: &'a dyn Store) -> Self {
26        Self {
27            inner,
28            relations_cache: RwLock::new(HashMap::new()),
29        }
30    }
31
32    // Populate the relation cache in two bulk scans (ids + all edges) instead of
33    // letting recall trigger ~corpus_size x 7 serialized Kuzu point reads. Every
34    // (memory_id, Some(rt)) the recall path queries is seeded, empty included, so
35    // the parallel arms run entirely from the in-memory map. get_relations is
36    // query-independent, so a prewarmed cache is interchangeable with live reads.
37    pub fn prewarm(&self) -> Result<()> {
38        let scored_types = [
39            RelationType::Reinforces,
40            RelationType::RelatesTo,
41            RelationType::DistilledFrom,
42            RelationType::Mentions,
43            RelationType::DerivedFrom,
44            RelationType::Contradicts,
45            RelationType::Supersedes,
46        ];
47
48        let ids = self.inner.all_memory_ids()?;
49        let mut cache = self.relations_cache.write().unwrap();
50        cache.reserve(ids.len() * scored_types.len());
51        for id in &ids {
52            for rt in &scored_types {
53                cache.entry((*id, Some(*rt))).or_default();
54            }
55        }
56
57        for rel in self.inner.all_relations()? {
58            // get_relations only matches edges whose source is the queried node,
59            // so key on from_id; an unseeded type (none of the scored set) is
60            // ignored here and would fall through to a live read if ever asked.
61            if let Some(bucket) = cache.get_mut(&(rel.from_id, Some(rel.relation_type))) {
62                bucket.push(rel);
63            }
64        }
65
66        Ok(())
67    }
68}
69
70impl Store for CachingStore<'_> {
71    fn get_relations(
72        &self,
73        node_id: Uuid,
74        relation_type: Option<RelationType>,
75    ) -> Result<Vec<Relation>> {
76        let key = (node_id, relation_type);
77        if let Some(hit) = self.relations_cache.read().unwrap().get(&key) {
78            return Ok(hit.clone());
79        }
80        let fetched = self.inner.get_relations(node_id, relation_type)?;
81        self.relations_cache
82            .write()
83            .unwrap()
84            .insert(key, fetched.clone());
85        Ok(fetched)
86    }
87
88    fn store_memory(&self, memory: &Memory) -> Result<()> {
89        self.inner.store_memory(memory)
90    }
91
92    fn get_memory(&self, id: Uuid) -> Result<Option<Memory>> {
93        self.inner.get_memory(id)
94    }
95
96    fn delete_memory(&self, id: Uuid) -> Result<()> {
97        self.inner.delete_memory(id)
98    }
99
100    fn store_entity(&self, entity: &Entity) -> Result<()> {
101        self.inner.store_entity(entity)
102    }
103
104    fn get_entity(&self, id: Uuid) -> Result<Option<Entity>> {
105        self.inner.get_entity(id)
106    }
107
108    fn find_entity_by_name(&self, name: &str) -> Result<Option<Entity>> {
109        self.inner.find_entity_by_name(name)
110    }
111
112    fn store_conversation(&self, conversation: &Conversation) -> Result<()> {
113        self.inner.store_conversation(conversation)
114    }
115
116    fn store_relation(&self, relation: &Relation) -> Result<()> {
117        self.inner.store_relation(relation)
118    }
119
120    fn vector_search(&self, embedding: &[f32], limit: usize) -> Result<Vec<(Memory, f32)>> {
121        self.inner.vector_search(embedding, limit)
122    }
123
124    fn traverse(&self, start_id: Uuid, depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
125        self.inner.traverse(start_id, depth)
126    }
127
128    fn memories_by_source(&self, source: &str) -> Result<Vec<Memory>> {
129        self.inner.memories_by_source(source)
130    }
131
132    fn memories_by_type(&self, memory_type: MemoryType) -> Result<Vec<Memory>> {
133        self.inner.memories_by_type(memory_type)
134    }
135
136    fn memories_needing_decay(&self, threshold_days: u32) -> Result<Vec<Memory>> {
137        self.inner.memories_needing_decay(threshold_days)
138    }
139
140    fn update_memory(&self, memory: &Memory) -> Result<()> {
141        self.inner.update_memory(memory)
142    }
143
144    fn record_access(&self, memory: &Memory) -> Result<()> {
145        self.inner.record_access(memory)
146    }
147
148    fn text_search(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
149        self.inner.text_search(query, limit)
150    }
151
152    fn memory_count(&self) -> Result<usize> {
153        self.inner.memory_count()
154    }
155
156    fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
157        self.inner.all_memory_ids()
158    }
159
160    fn all_relations(&self) -> Result<Vec<Relation>> {
161        self.inner.all_relations()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use chrono::{Duration, Utc};
169    use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
170    use second_brain_core::schema::MemoryType;
171
172    struct InMemoryStore {
173        vector_results: Vec<(Memory, f32)>,
174        relations: Vec<Relation>,
175        get_relations_calls: std::sync::atomic::AtomicUsize,
176    }
177
178    impl Store for InMemoryStore {
179        fn vector_search(&self, _embedding: &[f32], _limit: usize) -> Result<Vec<(Memory, f32)>> {
180            Ok(self.vector_results.clone())
181        }
182
183        fn get_relations(
184            &self,
185            node_id: Uuid,
186            relation_type: Option<RelationType>,
187        ) -> Result<Vec<Relation>> {
188            self.get_relations_calls
189                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
190            Ok(self
191                .relations
192                .iter()
193                .filter(|r| r.from_id == node_id)
194                .filter(|r| relation_type.map(|rt| rt == r.relation_type).unwrap_or(true))
195                .cloned()
196                .collect())
197        }
198
199        fn store_memory(&self, _m: &Memory) -> Result<()> {
200            unimplemented!()
201        }
202        fn get_memory(&self, _id: Uuid) -> Result<Option<Memory>> {
203            unimplemented!()
204        }
205        fn delete_memory(&self, _id: Uuid) -> Result<()> {
206            unimplemented!()
207        }
208        fn store_entity(&self, _e: &Entity) -> Result<()> {
209            unimplemented!()
210        }
211        fn get_entity(&self, _id: Uuid) -> Result<Option<Entity>> {
212            unimplemented!()
213        }
214        fn find_entity_by_name(&self, _name: &str) -> Result<Option<Entity>> {
215            unimplemented!()
216        }
217        fn store_conversation(&self, _c: &Conversation) -> Result<()> {
218            unimplemented!()
219        }
220        fn store_relation(&self, _r: &Relation) -> Result<()> {
221            unimplemented!()
222        }
223        fn traverse(&self, _id: Uuid, _depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
224            unimplemented!()
225        }
226        fn memories_by_source(&self, _s: &str) -> Result<Vec<Memory>> {
227            unimplemented!()
228        }
229        fn memories_by_type(&self, _mt: MemoryType) -> Result<Vec<Memory>> {
230            unimplemented!()
231        }
232        fn memories_needing_decay(&self, _days: u32) -> Result<Vec<Memory>> {
233            unimplemented!()
234        }
235        fn update_memory(&self, _m: &Memory) -> Result<()> {
236            unimplemented!()
237        }
238        fn record_access(&self, _memory: &Memory) -> Result<()> {
239            unimplemented!()
240        }
241        fn text_search(&self, _q: &str, _limit: usize) -> Result<Vec<Memory>> {
242            unimplemented!()
243        }
244        fn memory_count(&self) -> Result<usize> {
245            unimplemented!()
246        }
247        fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
248            Ok(self.vector_results.iter().map(|(m, _)| m.id).collect())
249        }
250        fn all_relations(&self) -> Result<Vec<Relation>> {
251            Ok(self.relations.clone())
252        }
253    }
254
255    fn memory(content: &str, days_old: i64) -> Memory {
256        let when = Utc::now() - Duration::days(days_old);
257        let mut m = Memory::new(
258            content.to_string(),
259            MemoryType::Semantic,
260            "test".to_string(),
261            String::new(),
262        );
263        m.created_at = when;
264        m.last_accessed = when;
265        m
266    }
267
268    fn fixture() -> (Vec<(Memory, f32)>, Vec<Relation>) {
269        let a = memory("kuzu was chosen as the embedded graph store", 10);
270        let b = memory("sync runs bidirectionally over ssh", 40);
271        let c = memory("embeddings use the bge model", 5);
272        let rel = |from: Uuid, rt: RelationType, strength: f32| Relation {
273            from_id: from,
274            to_id: Uuid::new_v4(),
275            relation_type: rt,
276            strength,
277            context: None,
278        };
279        let relations = vec![
280            rel(a.id, RelationType::Reinforces, 1.0),
281            rel(a.id, RelationType::RelatesTo, 0.7),
282            rel(b.id, RelationType::Mentions, 1.0),
283            rel(c.id, RelationType::RelatesTo, 0.4),
284            rel(c.id, RelationType::Supersedes, 1.0),
285        ];
286        let vector_results = vec![(a, 0.91), (b, 0.78), (c, 0.66)];
287        (vector_results, relations)
288    }
289
290    fn request() -> QueryRequest {
291        QueryRequest {
292            text: "graph store choice".to_string(),
293            embedding: vec![0.1_f32; 384],
294            limit: 10,
295            filters: QueryFilters::default(),
296        }
297    }
298
299    #[test]
300    fn caching_store_recall_matches_raw_store() {
301        let (vector_results, relations) = fixture();
302        let raw = InMemoryStore {
303            vector_results,
304            relations,
305            get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
306        };
307
308        let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
309
310        let cached = CachingStore::new(&raw);
311        let first = QueryEngine::new(&cached).recall(&request()).unwrap();
312        let second = QueryEngine::new(&cached).recall(&request()).unwrap();
313
314        assert_eq!(baseline.len(), first.len());
315        assert_eq!(first.len(), second.len());
316        for (b, c) in baseline.iter().zip(first.iter()) {
317            assert_eq!(b.memory.id, c.memory.id, "result order must match");
318            assert!(
319                (b.score - c.score).abs() < 1e-6,
320                "scores must match: {} vs {}",
321                b.score,
322                c.score
323            );
324        }
325        // The second recall through the cache must produce the identical
326        // ranking, proving the memoized path is faithful across repeats.
327        for (a, c) in first.iter().zip(second.iter()) {
328            assert_eq!(a.memory.id, c.memory.id);
329            assert!((a.score - c.score).abs() < 1e-6);
330        }
331    }
332
333    #[test]
334    fn prewarmed_store_recall_matches_raw_and_skips_live_reads() {
335        use std::sync::atomic::Ordering;
336        let (vector_results, relations) = fixture();
337        let raw = InMemoryStore {
338            vector_results,
339            relations,
340            get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
341        };
342
343        let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
344        let calls_after_baseline = raw.get_relations_calls.load(Ordering::Relaxed);
345        assert!(calls_after_baseline > 0, "baseline must hit the live store");
346
347        let cached = CachingStore::new(&raw);
348        cached.prewarm().unwrap();
349        let prewarm_calls = raw.get_relations_calls.load(Ordering::Relaxed);
350
351        let recalled = QueryEngine::new(&cached).recall(&request()).unwrap();
352
353        // prewarm uses all_relations, never get_relations, and seeds every
354        // (id, scored_type) pair, so recall must trigger zero live reads.
355        assert_eq!(
356            raw.get_relations_calls.load(Ordering::Relaxed),
357            prewarm_calls,
358            "prewarmed recall must not call inner.get_relations"
359        );
360        assert_eq!(prewarm_calls, calls_after_baseline, "prewarm must not read via get_relations");
361
362        assert_eq!(baseline.len(), recalled.len());
363        for (b, c) in baseline.iter().zip(recalled.iter()) {
364            assert_eq!(b.memory.id, c.memory.id, "result order must match");
365            assert!(
366                (b.score - c.score).abs() < 1e-6,
367                "scores must match: {} vs {}",
368                b.score,
369                c.score
370            );
371        }
372    }
373}