Skip to main content

tj_core/
memory.rs

1//! Global cross-project memory index (Pillar B).
2//!
3//! A single SQLite file (`data_dir/memory.sqlite`) mirrors the *high-signal*
4//! events — decisions, rejections, constraints (and, later, consolidated
5//! semantic/procedural/preference facts) — from every project, together with
6//! their embeddings. This is what lets the agent recall relevant prior
7//! reasoning across its whole history, not just the current repo — the thing
8//! single-project memory tools can't do.
9//!
10//! The index is a denormalised cache: the per-project JSONL logs remain the
11//! source of truth. It is rebuilt idempotently by [`sync_from_project`] and
12//! queried by [`search`].
13
14use rusqlite::Connection;
15
16/// Event types worth surfacing proactively: a committed choice, a ruled-out
17/// path, or an external limit. These are the reasoning the agent most wants
18/// before repeating itself.
19pub const HIGH_SIGNAL_TYPES: [&str; 3] = ["decision", "rejection", "constraint"];
20
21const SCHEMA: &str = r#"
22CREATE TABLE IF NOT EXISTS global_memory (
23  event_id     TEXT PRIMARY KEY,
24  project_hash TEXT NOT NULL,
25  task_id      TEXT NOT NULL,
26  type         TEXT NOT NULL,
27  tier         TEXT NOT NULL DEFAULT 'episodic',
28  text         TEXT NOT NULL,
29  model        TEXT NOT NULL,
30  dim          INTEGER NOT NULL,
31  vec          BLOB NOT NULL,
32  created_at   TEXT NOT NULL,
33  superseded   INTEGER NOT NULL DEFAULT 0
34);
35CREATE INDEX IF NOT EXISTS idx_gm_type ON global_memory(type);
36CREATE INDEX IF NOT EXISTS idx_gm_model ON global_memory(model);
37CREATE VIRTUAL TABLE IF NOT EXISTS global_fts USING fts5(event_id UNINDEXED, text);
38"#;
39
40/// Open (creating + migrating) the global memory database at `path`.
41pub fn open(path: impl AsRef<std::path::Path>) -> anyhow::Result<Connection> {
42    if let Some(parent) = path.as_ref().parent() {
43        std::fs::create_dir_all(parent)?;
44    }
45    let conn = Connection::open(path)?;
46    conn.execute_batch(SCHEMA)?;
47    Ok(conn)
48}
49
50/// A cross-project recall hit.
51pub struct GlobalHit {
52    pub event_id: String,
53    pub project_hash: String,
54    pub task_id: String,
55    pub event_type: String,
56    pub tier: String,
57    pub text: String,
58    pub score: f32,
59}
60
61/// Copy this project's high-signal embedded events into the global index.
62/// Idempotent (`INSERT OR REPLACE` on `event_id`); call after embedding a
63/// project. Returns how many rows were synced. `superseded` is flagged from the
64/// `decisions.superseded_by` projection so contradicted decisions can be
65/// down-ranked at query time.
66pub fn sync_from_project(
67    global: &Connection,
68    project: &Connection,
69    project_hash: &str,
70) -> anyhow::Result<usize> {
71    let placeholders = HIGH_SIGNAL_TYPES
72        .iter()
73        .map(|_| "?")
74        .collect::<Vec<_>>()
75        .join(",");
76    let sql = format!(
77        "SELECT e.event_id, e.task_id, f.type, e.tier, f.text, e.model, e.dim, e.vec, e.created_at,
78                CASE WHEN d.superseded_by IS NOT NULL THEN 1 ELSE 0 END
79           FROM embeddings e
80           JOIN search_fts f ON f.event_id = e.event_id
81           LEFT JOIN decisions d ON d.decision_id = e.event_id
82          WHERE f.type IN ({placeholders})"
83    );
84    let mut stmt = project.prepare(&sql)?;
85    let rows = stmt.query_map(rusqlite::params_from_iter(HIGH_SIGNAL_TYPES.iter()), |r| {
86        Ok((
87            r.get::<_, String>(0)?,  // event_id
88            r.get::<_, String>(1)?,  // task_id
89            r.get::<_, String>(2)?,  // type
90            r.get::<_, String>(3)?,  // tier
91            r.get::<_, String>(4)?,  // text
92            r.get::<_, String>(5)?,  // model
93            r.get::<_, i64>(6)?,     // dim
94            r.get::<_, Vec<u8>>(7)?, // vec
95            r.get::<_, String>(8)?,  // created_at
96            r.get::<_, i64>(9)?,     // superseded
97        ))
98    })?;
99
100    let mut n = 0usize;
101    for row in rows {
102        let (event_id, task_id, ty, tier, text, model, dim, vec, created_at, superseded) = row?;
103        global.execute(
104            "INSERT OR REPLACE INTO global_memory
105               (event_id, project_hash, task_id, type, tier, text, model, dim, vec, created_at, superseded)
106             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
107            rusqlite::params![
108                event_id, project_hash, task_id, ty, tier, text, model, dim, vec, created_at, superseded
109            ],
110        )?;
111        // Mirror into FTS5 for the fast keyword path (proactive hook).
112        global.execute(
113            "DELETE FROM global_fts WHERE event_id = ?1",
114            rusqlite::params![event_id],
115        )?;
116        global.execute(
117            "INSERT INTO global_fts(event_id, text) VALUES (?1, ?2)",
118            rusqlite::params![event_id, text],
119        )?;
120        n += 1;
121    }
122    Ok(n)
123}
124
125/// Fast keyword (FTS5) search over the global index — no embedding, so it's
126/// cheap enough to run on every prompt in the proactive hook (loading a model
127/// per prompt would be too slow). Builds an OR query from the prompt's
128/// alphanumeric tokens (≥4 chars) and ranks by BM25.
129pub fn keyword_search(conn: &Connection, prompt: &str, k: usize) -> anyhow::Result<Vec<GlobalHit>> {
130    let mut seen = std::collections::HashSet::new();
131    let terms: Vec<String> = prompt
132        .split(|c: char| !c.is_alphanumeric())
133        .filter(|t| t.chars().count() >= 4)
134        .map(|t| t.to_lowercase())
135        .filter(|t| seen.insert(t.clone()))
136        .collect();
137    if terms.is_empty() {
138        return Ok(Vec::new());
139    }
140    let query = terms.join(" OR ");
141    let mut stmt = conn.prepare(
142        "SELECT g.event_id, g.project_hash, g.task_id, g.type, g.tier, g.text, g.superseded,
143                bm25(global_fts)
144           FROM global_fts
145           JOIN global_memory g ON g.event_id = global_fts.event_id
146          WHERE global_fts MATCH ?1
147          ORDER BY bm25(global_fts)
148          LIMIT ?2",
149    )?;
150    let rows = stmt.query_map(rusqlite::params![query, k as i64], |r| {
151        let bm: f64 = r.get(7)?;
152        let superseded: i64 = r.get(6)?;
153        // BM25 is lower-is-better; negate so higher == more relevant, then
154        // nudge contradicted reasoning down.
155        let score = (-bm) as f32 - if superseded != 0 { 0.5 } else { 0.0 };
156        Ok(GlobalHit {
157            event_id: r.get(0)?,
158            project_hash: r.get(1)?,
159            task_id: r.get(2)?,
160            event_type: r.get(3)?,
161            tier: r.get(4)?,
162            text: r.get(5)?,
163            score,
164        })
165    })?;
166    let mut out = Vec::new();
167    for row in rows {
168        out.push(row?);
169    }
170    Ok(out)
171}
172
173/// Semantic search across the whole global index for the embedder's `model`.
174/// Returns the top `k` hits by cosine, with a small penalty applied to
175/// superseded/contradicted entries so live reasoning ranks above stale.
176pub fn search(
177    conn: &Connection,
178    query_vec: &[f32],
179    model: &str,
180    k: usize,
181) -> anyhow::Result<Vec<GlobalHit>> {
182    let mut stmt = conn.prepare(
183        "SELECT event_id, project_hash, task_id, type, tier, text, vec, superseded
184           FROM global_memory WHERE model = ?1",
185    )?;
186    let rows = stmt.query_map(rusqlite::params![model], |r| {
187        Ok((
188            r.get::<_, String>(0)?,
189            r.get::<_, String>(1)?,
190            r.get::<_, String>(2)?,
191            r.get::<_, String>(3)?,
192            r.get::<_, String>(4)?,
193            r.get::<_, String>(5)?,
194            r.get::<_, Vec<u8>>(6)?,
195            r.get::<_, i64>(7)?,
196        ))
197    })?;
198
199    let mut hits = Vec::new();
200    for row in rows {
201        let (event_id, project_hash, task_id, event_type, tier, text, blob, superseded) = row?;
202        let mut score = crate::embed::cosine(query_vec, &crate::embed::from_blob(&blob));
203        if superseded != 0 {
204            score -= 0.1; // down-rank contradicted reasoning
205        }
206        hits.push(GlobalHit {
207            event_id,
208            project_hash,
209            task_id,
210            event_type,
211            tier,
212            text,
213            score,
214        });
215    }
216    hits.sort_by(|a, b| {
217        b.score
218            .partial_cmp(&a.score)
219            .unwrap_or(std::cmp::Ordering::Equal)
220    });
221    hits.truncate(k);
222    Ok(hits)
223}
224
225/// Count of indexed entries (test/stats helper).
226pub fn count(conn: &Connection) -> anyhow::Result<usize> {
227    let n: i64 = conn.query_row("SELECT COUNT(*) FROM global_memory", [], |r| r.get(0))?;
228    Ok(n as usize)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::embed::Embedder;
235
236    fn finding(text: &str) -> crate::event::Event {
237        // A decision event so it passes the HIGH_SIGNAL_TYPES filter.
238        crate::event::Event::new(
239            "tj-x",
240            crate::event::EventType::Decision,
241            crate::event::Author::User,
242            crate::event::Source::Cli,
243            text.into(),
244        )
245    }
246
247    #[test]
248    fn sync_then_search_finds_cross_project_decision() {
249        let d = tempfile::TempDir::new().unwrap();
250        let proj = crate::db::open(d.path().join("p.sqlite")).unwrap();
251        let global = open(d.path().join("memory.sqlite")).unwrap();
252        let emb = crate::embed::HashEmbedder::new(256);
253
254        for text in [
255            "chose to route refunds through the idempotent payment ledger",
256            "use postgres advisory locks for the cron job leader election",
257        ] {
258            crate::db::index_event(&proj, &finding(text)).unwrap();
259        }
260        crate::db::embed_pending(&proj, "projhash", &emb, "t", 100).unwrap();
261
262        let synced = sync_from_project(&global, &proj, "projhash").unwrap();
263        assert_eq!(synced, 2);
264        assert_eq!(count(&global).unwrap(), 2);
265
266        let q = emb.embed_one("refund ledger idempotent").unwrap();
267        let hits = search(&global, &q, emb.model_id(), 5).unwrap();
268        assert!(!hits.is_empty());
269        assert!(
270            hits[0].text.contains("refund"),
271            "the refund decision must rank first across the global index, got: {}",
272            hits[0].text
273        );
274        assert_eq!(hits[0].project_hash, "projhash");
275    }
276
277    #[test]
278    fn keyword_search_matches_prompt_terms() {
279        let d = tempfile::TempDir::new().unwrap();
280        let proj = crate::db::open(d.path().join("p.sqlite")).unwrap();
281        let global = open(d.path().join("memory.sqlite")).unwrap();
282        let emb = crate::embed::HashEmbedder::new(64);
283        crate::db::index_event(
284            &proj,
285            &finding("chose the idempotent payment ledger for refunds"),
286        )
287        .unwrap();
288        crate::db::index_event(
289            &proj,
290            &finding("rejected kafka for the audit log; too heavy"),
291        )
292        .unwrap();
293        crate::db::embed_pending(&proj, "ph", &emb, "t", 100).unwrap();
294        sync_from_project(&global, &proj, "ph").unwrap();
295
296        let hits = keyword_search(&global, "should we add a refund ledger here?", 5).unwrap();
297        assert!(
298            !hits.is_empty(),
299            "prompt terms must match the ledger decision"
300        );
301        assert!(hits[0].text.contains("ledger"));
302
303        // No overlapping ≥4-char term => no hit.
304        assert!(keyword_search(&global, "tiny ui css fix", 5)
305            .unwrap()
306            .is_empty());
307    }
308
309    #[test]
310    fn search_filters_by_model() {
311        let d = tempfile::TempDir::new().unwrap();
312        let proj = crate::db::open(d.path().join("p.sqlite")).unwrap();
313        let global = open(d.path().join("memory.sqlite")).unwrap();
314        let emb = crate::embed::HashEmbedder::new(64);
315        crate::db::index_event(&proj, &finding("decided to adopt the outbox pattern")).unwrap();
316        crate::db::embed_pending(&proj, "ph", &emb, "t", 100).unwrap();
317        sync_from_project(&global, &proj, "ph").unwrap();
318
319        let q = emb.embed_one("outbox").unwrap();
320        assert_eq!(search(&global, &q, "other-model", 5).unwrap().len(), 0);
321        assert_eq!(search(&global, &q, emb.model_id(), 5).unwrap().len(), 1);
322    }
323}