Skip to main content

sochdb_memory/
lifecycle.rs

1use crate::fact::FactEdge;
2use crate::store::MemoryStore;
3use parking_lot::Mutex;
4use sochdb_query::memory_compaction::{
5    ExtractiveSummarizer, HierarchicalMemory, MemoryCompactionConfig,
6};
7use sochdb_query::semantic_triggers::{SemanticTrigger, TriggerIndex};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::thread;
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
14pub struct LifecycleConfig {
15    pub enrichment_poll_ms: u64,
16    pub contradiction_bm25_threshold: f32,
17    pub compaction: MemoryCompactionConfig,
18}
19
20impl Default for LifecycleConfig {
21    fn default() -> Self {
22        Self {
23            enrichment_poll_ms: 100,
24            contradiction_bm25_threshold: 0.3,
25            compaction: MemoryCompactionConfig::default(),
26        }
27    }
28}
29
30/// Background daemon: enrichment drain, contradiction pre-filter, compaction.
31pub struct MemoryLifecycleDaemon {
32    store: Arc<MemoryStore>,
33    triggers: Arc<TriggerIndex>,
34    compaction: Arc<Mutex<HierarchicalMemory<ExtractiveSummarizer>>>,
35    running: Arc<AtomicBool>,
36    handle: Mutex<Option<thread::JoinHandle<()>>>,
37}
38
39impl MemoryLifecycleDaemon {
40    pub fn new(store: Arc<MemoryStore>, config: LifecycleConfig) -> Self {
41        Self {
42            store,
43            triggers: Arc::new(TriggerIndex::new()),
44            compaction: Arc::new(Mutex::new(HierarchicalMemory::new(
45                config.compaction.clone(),
46                Arc::new(ExtractiveSummarizer::default()),
47            ))),
48            running: Arc::new(AtomicBool::new(false)),
49            handle: Mutex::new(None),
50        }
51    }
52
53    pub fn register_trigger(&self, trigger: SemanticTrigger) {
54        let _ = self.triggers.register_trigger(trigger);
55    }
56
57    pub fn start(&self, config: &LifecycleConfig) {
58        if self.running.swap(true, Ordering::SeqCst) {
59            return;
60        }
61        let store = Arc::clone(&self.store);
62        let running = Arc::clone(&self.running);
63        let poll = Duration::from_millis(config.enrichment_poll_ms);
64        let threshold = config.contradiction_bm25_threshold;
65
66        let handle = thread::spawn(move || {
67            while running.load(Ordering::SeqCst) {
68                if let Some(job) = store.enrichment_queue().pop() {
69                    // Stage 1: embed episode + index in per-namespace HNSW
70                    if let Err(e) = store.enrich_episode(&job) {
71                        tracing::warn!(
72                            namespace = %job.namespace,
73                            episode_id = job.episode_id,
74                            "enrichment failed: {e}"
75                        );
76                    }
77                    // Stage 2: cheap lexical overlap pre-filter for contradiction candidates
78                    let candidates = store.search_bm25(&job.namespace, &job.text, 8);
79                    let _adjacent: Vec<_> = candidates
80                        .into_iter()
81                        .filter(|(_, score)| *score >= threshold)
82                        .collect();
83                    // Stage 3: LLM judge would run on |C| candidates only (not wired here)
84                    store.enrichment_queue().mark_processed();
85                }
86                thread::sleep(poll);
87            }
88        });
89        *self.handle.lock() = Some(handle);
90    }
91
92    pub fn stop(&self) {
93        self.running.store(false, Ordering::SeqCst);
94        if let Some(h) = self.handle.lock().take() {
95            let _ = h.join();
96        }
97    }
98
99    pub fn check_contradiction_candidates(
100        &self,
101        namespace: &str,
102        new_fact_text: &str,
103        threshold: f32,
104    ) -> Vec<FactEdge> {
105        let tau = u64::MAX;
106        let facts = self.store.facts_valid_at(namespace, tau);
107        let hits = self.store.search_bm25(namespace, new_fact_text, 16);
108        let hit_ids: std::collections::HashSet<u64> = hits
109            .into_iter()
110            .filter(|(_, s)| *s >= threshold)
111            .map(|(id, _)| id)
112            .collect();
113        facts
114            .into_iter()
115            .filter(|f| {
116                hit_ids.contains(&f.episode_id)
117                    || f.subject.contains(new_fact_text)
118                    || f.object.contains(new_fact_text)
119            })
120            .collect()
121    }
122}