Skip to main content

zagens_runtime_adapters/persist/
session_store_sqlite.rs

1#![allow(dead_code)]
2/// SQLite-backed session store. Provides the same semantics as the
3/// JSON-per-file SessionManager but with far better I/O performance:
4/// a single WAL sync per transaction instead of one fsync per file.
5///
6/// ## Async / blocking policy (A1.3)
7///
8/// - **TUI interactive path:** [`crate::tui::persistence_actor`] coalesces
9///   checkpoint/session snapshots off the event-loop worker (latest-wins).
10/// - **HTTP runtime path:** session persist runs inside `spawn_blocking`
11///   (see `runtime_api::threads`).
12/// - **Direct callers:** treat `save_session_sqlite` as blocking I/O; do not
13///   call from async contexts without `spawn_blocking`.
14use std::path::PathBuf;
15
16use anyhow::{Context, bail};
17use chrono::{DateTime, Utc};
18use rusqlite::{Connection, params};
19
20use crate::persist::session_manager::{SavedSession, SessionContextReference, SessionMetadata};
21
22const CURRENT_META_VERSION: u32 = 1;
23
24fn ensure_sessions_runtime_thread_id_column(db: &Connection) -> anyhow::Result<()> {
25    let mut stmt = db.prepare("PRAGMA table_info(sessions)")?;
26    let has_col = stmt
27        .query_map([], |row| row.get::<_, String>(1))?
28        .filter_map(|r| r.ok())
29        .any(|name| name == "runtime_thread_id");
30    if !has_col {
31        db.execute("ALTER TABLE sessions ADD COLUMN runtime_thread_id TEXT", [])?;
32    }
33    Ok(())
34}
35
36fn runtime_thread_id_sql(id: &Option<String>) -> &str {
37    id.as_deref().unwrap_or("")
38}
39
40fn runtime_thread_id_from_sql(raw: &str) -> Option<String> {
41    if raw.is_empty() {
42        None
43    } else {
44        Some(raw.to_string())
45    }
46}
47
48/// Opens (or creates) the SQLite DB at `db_path`.
49/// If JSON files exist in `sessions_dir` and the DB is empty, auto-migrates.
50pub fn open_sqlite_session_db(
51    db_path: &std::path::Path,
52    sessions_dir: &std::path::Path,
53) -> anyhow::Result<Connection> {
54    let db = Connection::open(db_path).context("Failed to open SQLite session DB")?;
55
56    db.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;")
57        .context("Failed to set SQLite pragmas")?;
58
59    db.execute_batch(
60        "CREATE TABLE IF NOT EXISTS _meta (
61            key TEXT PRIMARY KEY,
62            value TEXT NOT NULL
63        );
64        CREATE TABLE IF NOT EXISTS sessions (
65            id TEXT PRIMARY KEY,
66            title TEXT NOT NULL DEFAULT '',
67            created_at TEXT NOT NULL,
68            updated_at TEXT NOT NULL,
69            message_count INTEGER NOT NULL DEFAULT 0,
70            total_tokens INTEGER NOT NULL DEFAULT 0,
71            model TEXT NOT NULL DEFAULT '',
72            workspace TEXT NOT NULL DEFAULT '.',
73            mode TEXT,
74            system_prompt TEXT,
75            messages_json TEXT NOT NULL DEFAULT '[]',
76            context_refs_json TEXT NOT NULL DEFAULT '[]',
77            runtime_thread_id TEXT
78        );
79        CREATE INDEX IF NOT EXISTS idx_sessions_updated ON sessions(updated_at);
80        CREATE INDEX IF NOT EXISTS idx_sessions_workspace ON sessions(workspace);",
81    )
82    .context("Failed to create session tables")?;
83
84    ensure_sessions_runtime_thread_id_column(&db)?;
85
86    // P2-C: additive migration — ensure compaction artifacts table exists.
87    // Older binaries that don't call this function simply ignore the new table.
88    crate::persist::compaction_artifact_store::ensure_compaction_artifacts_table(&db)?;
89
90    // Check if migration is needed
91    let needs_migration: bool = db
92        .query_row("SELECT value FROM _meta WHERE key = 'version'", [], |row| {
93            row.get::<_, String>(0)
94        })
95        .ok()
96        .is_none();
97
98    if needs_migration {
99        migrate_json_sessions(&db, sessions_dir)?;
100        db.execute(
101            "INSERT OR REPLACE INTO _meta (key, value) VALUES ('version', ?1)",
102            params![CURRENT_META_VERSION.to_string()],
103        )?;
104    }
105
106    Ok(db)
107}
108
109fn migrate_json_sessions(db: &Connection, sessions_dir: &std::path::Path) -> anyhow::Result<()> {
110    let dir = std::fs::read_dir(sessions_dir);
111    let dir = match dir {
112        Ok(d) => d,
113        Err(_) => return Ok(()),
114    };
115
116    let tx = db.unchecked_transaction()?;
117
118    for entry in dir.filter_map(|e| e.ok()) {
119        let path = entry.path();
120        if path.extension().and_then(|e| e.to_str()) != Some("json") {
121            continue;
122        }
123        let content = match std::fs::read_to_string(&path) {
124            Ok(c) => c,
125            Err(_) => continue,
126        };
127        let session: SavedSession = match serde_json::from_str(&content) {
128            Ok(s) => s,
129            Err(_) => continue,
130        };
131
132        let messages_json = serde_json::to_string(&session.messages).unwrap_or_default();
133        let context_refs_json =
134            serde_json::to_string(&session.context_references).unwrap_or_default();
135        let created_at = session.metadata.created_at.to_rfc3339();
136        let updated_at = session.metadata.updated_at.to_rfc3339();
137        let mode = session.metadata.mode.as_deref().unwrap_or("");
138        let workspace = session.metadata.workspace.display().to_string();
139        let system_prompt = session.system_prompt.as_deref().unwrap_or("");
140        let runtime_thread_id = runtime_thread_id_sql(&session.metadata.runtime_thread_id);
141
142        tx.execute(
143            "INSERT OR REPLACE INTO sessions
144             (id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id)
145             VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13)",
146            params![
147                session.metadata.id,
148                session.metadata.title,
149                created_at,
150                updated_at,
151                session.metadata.message_count as i64,
152                session.metadata.total_tokens as i64,
153                session.metadata.model,
154                workspace,
155                mode,
156                system_prompt,
157                messages_json,
158                context_refs_json,
159                runtime_thread_id,
160            ],
161        )?;
162    }
163
164    tx.commit()?;
165    eprintln!(
166        "[session-store] migrated {} sessions to SQLite",
167        db.query_row("SELECT COUNT(*) FROM sessions", [], |r| r.get::<_, i64>(0))
168            .unwrap_or(0)
169    );
170    Ok(())
171}
172
173pub fn save_session_sqlite(db: &Connection, session: &SavedSession) -> anyhow::Result<()> {
174    let messages_json = serde_json::to_string(&session.messages).unwrap_or_default();
175    let context_refs_json = serde_json::to_string(&session.context_references).unwrap_or_default();
176    let created_at = session.metadata.created_at.to_rfc3339();
177    let updated_at = session.metadata.updated_at.to_rfc3339();
178    let mode = session.metadata.mode.as_deref().unwrap_or("");
179    let workspace = session.metadata.workspace.display().to_string();
180    let system_prompt = session.system_prompt.as_deref().unwrap_or("");
181    let runtime_thread_id = runtime_thread_id_sql(&session.metadata.runtime_thread_id);
182
183    db.execute(
184        "INSERT OR REPLACE INTO sessions
185         (id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id)
186         VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13)",
187        params![
188            session.metadata.id,
189            session.metadata.title,
190            created_at,
191            updated_at,
192            session.metadata.message_count as i64,
193            session.metadata.total_tokens as i64,
194            session.metadata.model,
195            workspace,
196            mode,
197            system_prompt,
198            messages_json,
199            context_refs_json,
200            runtime_thread_id,
201        ],
202    ).context("save_session_sqlite")?;
203
204    // Enforce MAX_SESSIONS via LRU deletion
205    cleanup_old_sqlite(db, 50)?;
206
207    Ok(())
208}
209
210pub fn load_session_sqlite(db: &Connection, id: &str) -> anyhow::Result<SavedSession> {
211    let id = id.trim();
212    if id.is_empty() {
213        bail!("Session id cannot be empty");
214    }
215    if !id
216        .chars()
217        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
218    {
219        bail!("Invalid session id '{id}'");
220    }
221
222    let mut stmt = db.prepare(
223        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id
224         FROM sessions WHERE id = ?1",
225    )?;
226
227    stmt.query_row(params![id], |row| {
228        let id: String = row.get(0)?;
229        let title: String = row.get(1)?;
230        let created_at: String = row.get(2)?;
231        let updated_at: String = row.get(3)?;
232        let message_count: i64 = row.get(4)?;
233        let total_tokens: i64 = row.get(5)?;
234        let model: String = row.get(6)?;
235        let workspace: String = row.get(7)?;
236        let mode: String = row.get(8)?;
237        let system_prompt: String = row.get(9)?;
238        let messages_json: String = row.get(10)?;
239        let context_refs_json: String = row.get(11)?;
240        let runtime_thread_id_raw: String = row.get(12)?;
241
242        let metadata = SessionMetadata {
243            id,
244            title,
245            created_at: DateTime::parse_from_rfc3339(&created_at)
246                .map(|d| d.with_timezone(&Utc))
247                .unwrap_or_default(),
248            updated_at: DateTime::parse_from_rfc3339(&updated_at)
249                .map(|d| d.with_timezone(&Utc))
250                .unwrap_or_default(),
251            message_count: message_count as usize,
252            total_tokens: total_tokens as u64,
253            model,
254            workspace: PathBuf::from(workspace),
255            mode: if mode.is_empty() { None } else { Some(mode) },
256            runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
257        };
258        let messages: Vec<crate::models::Message> =
259            serde_json::from_str(&messages_json).unwrap_or_default();
260        let context_references: Vec<SessionContextReference> =
261            serde_json::from_str(&context_refs_json).unwrap_or_default();
262
263        Ok(SavedSession {
264            schema_version: 1,
265            metadata,
266            messages,
267            system_prompt: if system_prompt.is_empty() {
268                None
269            } else {
270                Some(system_prompt)
271            },
272            context_references,
273        })
274    })
275    .map_err(|e| {
276        if matches!(e, rusqlite::Error::QueryReturnedNoRows) {
277            anyhow::anyhow!("session {id} not found")
278        } else {
279            anyhow::Error::from(e).context("load_session_sqlite query")
280        }
281    })
282}
283
284pub fn list_sessions_sqlite(db: &Connection) -> anyhow::Result<Vec<SessionMetadata>> {
285    let mut stmt = db.prepare(
286        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, runtime_thread_id
287         FROM sessions ORDER BY updated_at DESC",
288    )?;
289
290    let sessions = stmt
291        .query_map([], |row| {
292            let id: String = row.get(0)?;
293            let title: String = row.get(1)?;
294            let created_at: String = row.get(2)?;
295            let updated_at: String = row.get(3)?;
296            let message_count: i64 = row.get(4)?;
297            let total_tokens: i64 = row.get(5)?;
298            let model: String = row.get(6)?;
299            let workspace: String = row.get(7)?;
300            let mode: String = row.get(8)?;
301            let runtime_thread_id_raw: String = row.get(9)?;
302
303            Ok(SessionMetadata {
304                id,
305                title,
306                created_at: DateTime::parse_from_rfc3339(&created_at)
307                    .map(|d| d.with_timezone(&Utc))
308                    .unwrap_or_default(),
309                updated_at: DateTime::parse_from_rfc3339(&updated_at)
310                    .map(|d| d.with_timezone(&Utc))
311                    .unwrap_or_default(),
312                message_count: message_count as usize,
313                total_tokens: total_tokens as u64,
314                model,
315                workspace: PathBuf::from(workspace),
316                mode: if mode.is_empty() { None } else { Some(mode) },
317                runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
318            })
319        })?
320        .filter_map(|r| r.ok())
321        .collect();
322
323    Ok(sessions)
324}
325
326pub fn search_sessions_sqlite(
327    db: &Connection,
328    query: &str,
329) -> anyhow::Result<Vec<SessionMetadata>> {
330    let all = list_sessions_sqlite(db)?;
331    let query_lower = query.to_lowercase();
332    Ok(all
333        .into_iter()
334        .filter(|s| s.title.to_lowercase().contains(&query_lower))
335        .collect())
336}
337
338pub fn delete_session_sqlite(db: &Connection, id: &str) -> anyhow::Result<()> {
339    let id = id.trim();
340    if id.is_empty()
341        || !id
342            .chars()
343            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
344    {
345        bail!("Invalid session id '{id}'");
346    }
347    let affected = db
348        .execute("DELETE FROM sessions WHERE id = ?1", params![id])
349        .context("delete_session_sqlite")?;
350    if affected == 0 {
351        bail!("session {id} not found");
352    }
353    Ok(())
354}
355
356/// Resolve the most recently updated session id linked to a runtime thread.
357pub fn find_session_id_by_runtime_thread_id_sqlite(
358    db: &Connection,
359    runtime_thread_id: &str,
360) -> anyhow::Result<Option<String>> {
361    let mut stmt = db.prepare(
362        "SELECT id FROM sessions WHERE runtime_thread_id = ?1 ORDER BY updated_at DESC LIMIT 1",
363    )?;
364    let result = stmt.query_row(params![runtime_thread_id], |row| row.get(0));
365    match result {
366        Ok(id) => Ok(Some(id)),
367        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
368        Err(e) => Err(anyhow::anyhow!("query error: {e}")),
369    }
370}
371
372pub fn get_latest_session_for_workspace_sqlite(
373    db: &Connection,
374    workspace: &std::path::Path,
375) -> anyhow::Result<Option<SessionMetadata>> {
376    let workspace_str = workspace.display().to_string();
377    // Match by path prefix equality (same as JSON version's workspace_scope_matches)
378    let mut stmt = db.prepare(
379        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, runtime_thread_id
380         FROM sessions WHERE workspace = ?1
381         ORDER BY updated_at DESC LIMIT 1",
382    )?;
383
384    let result = stmt.query_row(params![workspace_str], |row| {
385        let id: String = row.get(0)?;
386        let title: String = row.get(1)?;
387        let created_at: String = row.get(2)?;
388        let updated_at: String = row.get(3)?;
389        let message_count: i64 = row.get(4)?;
390        let total_tokens: i64 = row.get(5)?;
391        let model: String = row.get(6)?;
392        let workspace: String = row.get(7)?;
393        let mode: String = row.get(8)?;
394        let runtime_thread_id_raw: String = row.get(9)?;
395
396        Ok(SessionMetadata {
397            id,
398            title,
399            created_at: DateTime::parse_from_rfc3339(&created_at)
400                .map(|d| d.with_timezone(&Utc))
401                .unwrap_or_default(),
402            updated_at: DateTime::parse_from_rfc3339(&updated_at)
403                .map(|d| d.with_timezone(&Utc))
404                .unwrap_or_default(),
405            message_count: message_count as usize,
406            total_tokens: total_tokens as u64,
407            model,
408            workspace: PathBuf::from(workspace),
409            mode: if mode.is_empty() { None } else { Some(mode) },
410            runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
411        })
412    });
413
414    match result {
415        Ok(meta) => Ok(Some(meta)),
416        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
417        Err(e) => Err(anyhow::anyhow!("query error: {e}")),
418    }
419}
420
421fn cleanup_old_sqlite(db: &Connection, max_sessions: usize) -> anyhow::Result<()> {
422    // Delete oldest sessions beyond the limit
423    db.execute(
424        "DELETE FROM sessions WHERE id NOT IN (
425            SELECT id FROM sessions ORDER BY updated_at DESC LIMIT ?1
426        )",
427        params![max_sessions as i64],
428    )?;
429    Ok(())
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use chrono::Utc;
436    use tempfile::tempdir;
437
438    #[test]
439    fn sqlite_session_runtime_thread_id_round_trip() {
440        let dir = tempdir().unwrap();
441        let sessions_dir = dir.path().join("sessions");
442        std::fs::create_dir_all(&sessions_dir).unwrap();
443        let db_path = sessions_dir.join("sessions.db");
444        let db = open_sqlite_session_db(&db_path, &sessions_dir).unwrap();
445
446        let now = Utc::now();
447        let session = SavedSession {
448            schema_version: 1,
449            metadata: SessionMetadata {
450                id: "sess-1".to_string(),
451                title: "test".to_string(),
452                created_at: now,
453                updated_at: now,
454                message_count: 0,
455                total_tokens: 0,
456                model: "m".to_string(),
457                workspace: PathBuf::from("."),
458                mode: None,
459                runtime_thread_id: Some("thr_abc".to_string()),
460            },
461            messages: vec![],
462            system_prompt: None,
463            context_references: vec![],
464        };
465
466        save_session_sqlite(&db, &session).unwrap();
467        let loaded = load_session_sqlite(&db, "sess-1").unwrap();
468        assert_eq!(
469            loaded.metadata.runtime_thread_id.as_deref(),
470            Some("thr_abc")
471        );
472
473        let listed = list_sessions_sqlite(&db).unwrap();
474        assert_eq!(listed[0].runtime_thread_id.as_deref(), Some("thr_abc"));
475    }
476
477    #[test]
478    fn sqlite_find_session_id_by_runtime_thread_id() {
479        let dir = tempdir().unwrap();
480        let sessions_dir = dir.path().join("sessions");
481        std::fs::create_dir_all(&sessions_dir).unwrap();
482        let db_path = sessions_dir.join("sessions.db");
483        let db = open_sqlite_session_db(&db_path, &sessions_dir).unwrap();
484
485        let now = Utc::now();
486        for (id, thread_id) in [("sess-a", "thr_1"), ("sess-b", "thr_2")] {
487            let session = SavedSession {
488                schema_version: 1,
489                metadata: SessionMetadata {
490                    id: id.to_string(),
491                    title: id.to_string(),
492                    created_at: now,
493                    updated_at: now,
494                    message_count: 0,
495                    total_tokens: 0,
496                    model: "m".to_string(),
497                    workspace: PathBuf::from("."),
498                    mode: None,
499                    runtime_thread_id: Some(thread_id.to_string()),
500                },
501                messages: vec![],
502                system_prompt: None,
503                context_references: vec![],
504            };
505            save_session_sqlite(&db, &session).unwrap();
506        }
507
508        assert_eq!(
509            find_session_id_by_runtime_thread_id_sqlite(&db, "thr_2").unwrap(),
510            Some("sess-b".to_string())
511        );
512        assert_eq!(
513            find_session_id_by_runtime_thread_id_sqlite(&db, "missing").unwrap(),
514            None
515        );
516    }
517
518    #[test]
519    fn sqlite_alter_adds_runtime_thread_id_column() {
520        let dir = tempdir().unwrap();
521        let db_path = dir.path().join("legacy.db");
522        {
523            let db = Connection::open(&db_path).unwrap();
524            db.execute_batch(
525                "CREATE TABLE sessions (
526                    id TEXT PRIMARY KEY,
527                    title TEXT NOT NULL DEFAULT '',
528                    created_at TEXT NOT NULL,
529                    updated_at TEXT NOT NULL,
530                    message_count INTEGER NOT NULL DEFAULT 0,
531                    total_tokens INTEGER NOT NULL DEFAULT 0,
532                    model TEXT NOT NULL DEFAULT '',
533                    workspace TEXT NOT NULL DEFAULT '.',
534                    mode TEXT,
535                    system_prompt TEXT,
536                    messages_json TEXT NOT NULL DEFAULT '[]',
537                    context_refs_json TEXT NOT NULL DEFAULT '[]'
538                );",
539            )
540            .unwrap();
541        }
542        let db = Connection::open(&db_path).unwrap();
543        ensure_sessions_runtime_thread_id_column(&db).unwrap();
544        ensure_sessions_runtime_thread_id_column(&db).unwrap();
545    }
546}