Skip to main content

zagens_runtime_adapters/persist/
session_store_sqlite.rs

1#![allow(dead_code)]
2/// SQLite-backed session store. Provides the same semantics as the
3/// JSON-per-file SessionManager but with far better I/O performance:
4/// a single WAL sync per transaction instead of one fsync per file.
5///
6/// ## Async / blocking policy (A1.3)
7///
8/// - **TUI interactive path:** [`crate::tui::persistence_actor`] coalesces
9///   checkpoint/session snapshots off the event-loop worker (latest-wins).
10/// - **HTTP runtime path:** session persist runs inside `spawn_blocking`
11///   (see `runtime_api::threads`).
12/// - **Direct callers:** treat `save_session_sqlite` as blocking I/O; do not
13///   call from async contexts without `spawn_blocking`.
14use std::path::PathBuf;
15
16use anyhow::{Context, bail};
17use chrono::{DateTime, Utc};
18use rusqlite::{Connection, params};
19
20use crate::persist::session_manager::{SavedSession, SessionContextReference, SessionMetadata};
21
22const CURRENT_META_VERSION: u32 = 1;
23
24fn ensure_sessions_runtime_thread_id_column(db: &Connection) -> anyhow::Result<()> {
25    let mut stmt = db.prepare("PRAGMA table_info(sessions)")?;
26    let has_col = stmt
27        .query_map([], |row| row.get::<_, String>(1))?
28        .filter_map(|r| r.ok())
29        .any(|name| name == "runtime_thread_id");
30    if !has_col {
31        db.execute("ALTER TABLE sessions ADD COLUMN runtime_thread_id TEXT", [])?;
32    }
33    Ok(())
34}
35
36fn runtime_thread_id_sql(id: &Option<String>) -> &str {
37    id.as_deref().unwrap_or("")
38}
39
40fn runtime_thread_id_from_sql(raw: &str) -> Option<String> {
41    if raw.is_empty() {
42        None
43    } else {
44        Some(raw.to_string())
45    }
46}
47
48/// Opens (or creates) the SQLite DB at `db_path`.
49/// If JSON files exist in `sessions_dir` and the DB is empty, auto-migrates.
50pub fn open_sqlite_session_db(
51    db_path: &std::path::Path,
52    sessions_dir: &std::path::Path,
53) -> anyhow::Result<Connection> {
54    let db = Connection::open(db_path).context("Failed to open SQLite session DB")?;
55
56    db.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;")
57        .context("Failed to set SQLite pragmas")?;
58
59    db.execute_batch(
60        "CREATE TABLE IF NOT EXISTS _meta (
61            key TEXT PRIMARY KEY,
62            value TEXT NOT NULL
63        );
64        CREATE TABLE IF NOT EXISTS sessions (
65            id TEXT PRIMARY KEY,
66            title TEXT NOT NULL DEFAULT '',
67            created_at TEXT NOT NULL,
68            updated_at TEXT NOT NULL,
69            message_count INTEGER NOT NULL DEFAULT 0,
70            total_tokens INTEGER NOT NULL DEFAULT 0,
71            model TEXT NOT NULL DEFAULT '',
72            workspace TEXT NOT NULL DEFAULT '.',
73            mode TEXT,
74            system_prompt TEXT,
75            messages_json TEXT NOT NULL DEFAULT '[]',
76            context_refs_json TEXT NOT NULL DEFAULT '[]',
77            runtime_thread_id TEXT
78        );
79        CREATE INDEX IF NOT EXISTS idx_sessions_updated ON sessions(updated_at);
80        CREATE INDEX IF NOT EXISTS idx_sessions_workspace ON sessions(workspace);",
81    )
82    .context("Failed to create session tables")?;
83
84    ensure_sessions_runtime_thread_id_column(&db)?;
85
86    // Check if migration is needed
87    let needs_migration: bool = db
88        .query_row("SELECT value FROM _meta WHERE key = 'version'", [], |row| {
89            row.get::<_, String>(0)
90        })
91        .ok()
92        .is_none();
93
94    if needs_migration {
95        migrate_json_sessions(&db, sessions_dir)?;
96        db.execute(
97            "INSERT OR REPLACE INTO _meta (key, value) VALUES ('version', ?1)",
98            params![CURRENT_META_VERSION.to_string()],
99        )?;
100    }
101
102    Ok(db)
103}
104
105fn migrate_json_sessions(db: &Connection, sessions_dir: &std::path::Path) -> anyhow::Result<()> {
106    let dir = std::fs::read_dir(sessions_dir);
107    let dir = match dir {
108        Ok(d) => d,
109        Err(_) => return Ok(()),
110    };
111
112    let tx = db.unchecked_transaction()?;
113
114    for entry in dir.filter_map(|e| e.ok()) {
115        let path = entry.path();
116        if path.extension().and_then(|e| e.to_str()) != Some("json") {
117            continue;
118        }
119        let content = match std::fs::read_to_string(&path) {
120            Ok(c) => c,
121            Err(_) => continue,
122        };
123        let session: SavedSession = match serde_json::from_str(&content) {
124            Ok(s) => s,
125            Err(_) => continue,
126        };
127
128        let messages_json = serde_json::to_string(&session.messages).unwrap_or_default();
129        let context_refs_json =
130            serde_json::to_string(&session.context_references).unwrap_or_default();
131        let created_at = session.metadata.created_at.to_rfc3339();
132        let updated_at = session.metadata.updated_at.to_rfc3339();
133        let mode = session.metadata.mode.as_deref().unwrap_or("");
134        let workspace = session.metadata.workspace.display().to_string();
135        let system_prompt = session.system_prompt.as_deref().unwrap_or("");
136        let runtime_thread_id = runtime_thread_id_sql(&session.metadata.runtime_thread_id);
137
138        tx.execute(
139            "INSERT OR REPLACE INTO sessions
140             (id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id)
141             VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13)",
142            params![
143                session.metadata.id,
144                session.metadata.title,
145                created_at,
146                updated_at,
147                session.metadata.message_count as i64,
148                session.metadata.total_tokens as i64,
149                session.metadata.model,
150                workspace,
151                mode,
152                system_prompt,
153                messages_json,
154                context_refs_json,
155                runtime_thread_id,
156            ],
157        )?;
158    }
159
160    tx.commit()?;
161    eprintln!(
162        "[session-store] migrated {} sessions to SQLite",
163        db.query_row("SELECT COUNT(*) FROM sessions", [], |r| r.get::<_, i64>(0))
164            .unwrap_or(0)
165    );
166    Ok(())
167}
168
169pub fn save_session_sqlite(db: &Connection, session: &SavedSession) -> anyhow::Result<()> {
170    let messages_json = serde_json::to_string(&session.messages).unwrap_or_default();
171    let context_refs_json = serde_json::to_string(&session.context_references).unwrap_or_default();
172    let created_at = session.metadata.created_at.to_rfc3339();
173    let updated_at = session.metadata.updated_at.to_rfc3339();
174    let mode = session.metadata.mode.as_deref().unwrap_or("");
175    let workspace = session.metadata.workspace.display().to_string();
176    let system_prompt = session.system_prompt.as_deref().unwrap_or("");
177    let runtime_thread_id = runtime_thread_id_sql(&session.metadata.runtime_thread_id);
178
179    db.execute(
180        "INSERT OR REPLACE INTO sessions
181         (id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id)
182         VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13)",
183        params![
184            session.metadata.id,
185            session.metadata.title,
186            created_at,
187            updated_at,
188            session.metadata.message_count as i64,
189            session.metadata.total_tokens as i64,
190            session.metadata.model,
191            workspace,
192            mode,
193            system_prompt,
194            messages_json,
195            context_refs_json,
196            runtime_thread_id,
197        ],
198    ).context("save_session_sqlite")?;
199
200    // Enforce MAX_SESSIONS via LRU deletion
201    cleanup_old_sqlite(db, 50)?;
202
203    Ok(())
204}
205
206pub fn load_session_sqlite(db: &Connection, id: &str) -> anyhow::Result<SavedSession> {
207    let id = id.trim();
208    if id.is_empty() {
209        bail!("Session id cannot be empty");
210    }
211    if !id
212        .chars()
213        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
214    {
215        bail!("Invalid session id '{id}'");
216    }
217
218    let mut stmt = db.prepare(
219        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, system_prompt, messages_json, context_refs_json, runtime_thread_id
220         FROM sessions WHERE id = ?1",
221    )?;
222
223    stmt.query_row(params![id], |row| {
224        let id: String = row.get(0)?;
225        let title: String = row.get(1)?;
226        let created_at: String = row.get(2)?;
227        let updated_at: String = row.get(3)?;
228        let message_count: i64 = row.get(4)?;
229        let total_tokens: i64 = row.get(5)?;
230        let model: String = row.get(6)?;
231        let workspace: String = row.get(7)?;
232        let mode: String = row.get(8)?;
233        let system_prompt: String = row.get(9)?;
234        let messages_json: String = row.get(10)?;
235        let context_refs_json: String = row.get(11)?;
236        let runtime_thread_id_raw: String = row.get(12)?;
237
238        let metadata = SessionMetadata {
239            id,
240            title,
241            created_at: DateTime::parse_from_rfc3339(&created_at)
242                .map(|d| d.with_timezone(&Utc))
243                .unwrap_or_default(),
244            updated_at: DateTime::parse_from_rfc3339(&updated_at)
245                .map(|d| d.with_timezone(&Utc))
246                .unwrap_or_default(),
247            message_count: message_count as usize,
248            total_tokens: total_tokens as u64,
249            model,
250            workspace: PathBuf::from(workspace),
251            mode: if mode.is_empty() { None } else { Some(mode) },
252            runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
253        };
254        let messages: Vec<crate::models::Message> =
255            serde_json::from_str(&messages_json).unwrap_or_default();
256        let context_references: Vec<SessionContextReference> =
257            serde_json::from_str(&context_refs_json).unwrap_or_default();
258
259        Ok(SavedSession {
260            schema_version: 1,
261            metadata,
262            messages,
263            system_prompt: if system_prompt.is_empty() {
264                None
265            } else {
266                Some(system_prompt)
267            },
268            context_references,
269        })
270    })
271    .map_err(|e| {
272        if matches!(e, rusqlite::Error::QueryReturnedNoRows) {
273            anyhow::anyhow!("session {id} not found")
274        } else {
275            anyhow::Error::from(e).context("load_session_sqlite query")
276        }
277    })
278}
279
280pub fn list_sessions_sqlite(db: &Connection) -> anyhow::Result<Vec<SessionMetadata>> {
281    let mut stmt = db.prepare(
282        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, runtime_thread_id
283         FROM sessions ORDER BY updated_at DESC",
284    )?;
285
286    let sessions = stmt
287        .query_map([], |row| {
288            let id: String = row.get(0)?;
289            let title: String = row.get(1)?;
290            let created_at: String = row.get(2)?;
291            let updated_at: String = row.get(3)?;
292            let message_count: i64 = row.get(4)?;
293            let total_tokens: i64 = row.get(5)?;
294            let model: String = row.get(6)?;
295            let workspace: String = row.get(7)?;
296            let mode: String = row.get(8)?;
297            let runtime_thread_id_raw: String = row.get(9)?;
298
299            Ok(SessionMetadata {
300                id,
301                title,
302                created_at: DateTime::parse_from_rfc3339(&created_at)
303                    .map(|d| d.with_timezone(&Utc))
304                    .unwrap_or_default(),
305                updated_at: DateTime::parse_from_rfc3339(&updated_at)
306                    .map(|d| d.with_timezone(&Utc))
307                    .unwrap_or_default(),
308                message_count: message_count as usize,
309                total_tokens: total_tokens as u64,
310                model,
311                workspace: PathBuf::from(workspace),
312                mode: if mode.is_empty() { None } else { Some(mode) },
313                runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
314            })
315        })?
316        .filter_map(|r| r.ok())
317        .collect();
318
319    Ok(sessions)
320}
321
322pub fn search_sessions_sqlite(
323    db: &Connection,
324    query: &str,
325) -> anyhow::Result<Vec<SessionMetadata>> {
326    let all = list_sessions_sqlite(db)?;
327    let query_lower = query.to_lowercase();
328    Ok(all
329        .into_iter()
330        .filter(|s| s.title.to_lowercase().contains(&query_lower))
331        .collect())
332}
333
334pub fn delete_session_sqlite(db: &Connection, id: &str) -> anyhow::Result<()> {
335    let id = id.trim();
336    if id.is_empty()
337        || !id
338            .chars()
339            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
340    {
341        bail!("Invalid session id '{id}'");
342    }
343    let affected = db
344        .execute("DELETE FROM sessions WHERE id = ?1", params![id])
345        .context("delete_session_sqlite")?;
346    if affected == 0 {
347        bail!("session {id} not found");
348    }
349    Ok(())
350}
351
352pub fn get_latest_session_for_workspace_sqlite(
353    db: &Connection,
354    workspace: &std::path::Path,
355) -> anyhow::Result<Option<SessionMetadata>> {
356    let workspace_str = workspace.display().to_string();
357    // Match by path prefix equality (same as JSON version's workspace_scope_matches)
358    let mut stmt = db.prepare(
359        "SELECT id, title, created_at, updated_at, message_count, total_tokens, model, workspace, mode, runtime_thread_id
360         FROM sessions WHERE workspace = ?1
361         ORDER BY updated_at DESC LIMIT 1",
362    )?;
363
364    let result = stmt.query_row(params![workspace_str], |row| {
365        let id: String = row.get(0)?;
366        let title: String = row.get(1)?;
367        let created_at: String = row.get(2)?;
368        let updated_at: String = row.get(3)?;
369        let message_count: i64 = row.get(4)?;
370        let total_tokens: i64 = row.get(5)?;
371        let model: String = row.get(6)?;
372        let workspace: String = row.get(7)?;
373        let mode: String = row.get(8)?;
374        let runtime_thread_id_raw: String = row.get(9)?;
375
376        Ok(SessionMetadata {
377            id,
378            title,
379            created_at: DateTime::parse_from_rfc3339(&created_at)
380                .map(|d| d.with_timezone(&Utc))
381                .unwrap_or_default(),
382            updated_at: DateTime::parse_from_rfc3339(&updated_at)
383                .map(|d| d.with_timezone(&Utc))
384                .unwrap_or_default(),
385            message_count: message_count as usize,
386            total_tokens: total_tokens as u64,
387            model,
388            workspace: PathBuf::from(workspace),
389            mode: if mode.is_empty() { None } else { Some(mode) },
390            runtime_thread_id: runtime_thread_id_from_sql(&runtime_thread_id_raw),
391        })
392    });
393
394    match result {
395        Ok(meta) => Ok(Some(meta)),
396        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
397        Err(e) => Err(anyhow::anyhow!("query error: {e}")),
398    }
399}
400
401fn cleanup_old_sqlite(db: &Connection, max_sessions: usize) -> anyhow::Result<()> {
402    // Delete oldest sessions beyond the limit
403    db.execute(
404        "DELETE FROM sessions WHERE id NOT IN (
405            SELECT id FROM sessions ORDER BY updated_at DESC LIMIT ?1
406        )",
407        params![max_sessions as i64],
408    )?;
409    Ok(())
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use chrono::Utc;
416    use tempfile::tempdir;
417
418    #[test]
419    fn sqlite_session_runtime_thread_id_round_trip() {
420        let dir = tempdir().unwrap();
421        let sessions_dir = dir.path().join("sessions");
422        std::fs::create_dir_all(&sessions_dir).unwrap();
423        let db_path = sessions_dir.join("sessions.db");
424        let db = open_sqlite_session_db(&db_path, &sessions_dir).unwrap();
425
426        let now = Utc::now();
427        let session = SavedSession {
428            schema_version: 1,
429            metadata: SessionMetadata {
430                id: "sess-1".to_string(),
431                title: "test".to_string(),
432                created_at: now,
433                updated_at: now,
434                message_count: 0,
435                total_tokens: 0,
436                model: "m".to_string(),
437                workspace: PathBuf::from("."),
438                mode: None,
439                runtime_thread_id: Some("thr_abc".to_string()),
440            },
441            messages: vec![],
442            system_prompt: None,
443            context_references: vec![],
444        };
445
446        save_session_sqlite(&db, &session).unwrap();
447        let loaded = load_session_sqlite(&db, "sess-1").unwrap();
448        assert_eq!(
449            loaded.metadata.runtime_thread_id.as_deref(),
450            Some("thr_abc")
451        );
452
453        let listed = list_sessions_sqlite(&db).unwrap();
454        assert_eq!(listed[0].runtime_thread_id.as_deref(), Some("thr_abc"));
455    }
456
457    #[test]
458    fn sqlite_alter_adds_runtime_thread_id_column() {
459        let dir = tempdir().unwrap();
460        let db_path = dir.path().join("legacy.db");
461        {
462            let db = Connection::open(&db_path).unwrap();
463            db.execute_batch(
464                "CREATE TABLE sessions (
465                    id TEXT PRIMARY KEY,
466                    title TEXT NOT NULL DEFAULT '',
467                    created_at TEXT NOT NULL,
468                    updated_at TEXT NOT NULL,
469                    message_count INTEGER NOT NULL DEFAULT 0,
470                    total_tokens INTEGER NOT NULL DEFAULT 0,
471                    model TEXT NOT NULL DEFAULT '',
472                    workspace TEXT NOT NULL DEFAULT '.',
473                    mode TEXT,
474                    system_prompt TEXT,
475                    messages_json TEXT NOT NULL DEFAULT '[]',
476                    context_refs_json TEXT NOT NULL DEFAULT '[]'
477                );",
478            )
479            .unwrap();
480        }
481        let db = Connection::open(&db_path).unwrap();
482        ensure_sessions_runtime_thread_id_column(&db).unwrap();
483        ensure_sessions_runtime_thread_id_column(&db).unwrap();
484    }
485}