Skip to main content

starpod_db/
lib.rs

1//! Unified SQLite database for Starpod's transactional data.
2//!
3//! Consolidates sessions, cron scheduling, and authentication into a single
4//! `core.db` file, replacing the previous `session.db` + `cron.db` + `users.db`
5//! setup. A single shared connection pool (WAL mode, foreign keys enabled)
6//! serves all three domains.
7//!
8//! # Architecture
9//!
10//! ```text
11//! ┌──────────┐
12//! │  CoreDb   │  owns SqlitePool (max 10 conns, WAL, FK ON)
13//! └────┬─────┘
14//!      │ pool.clone()
15//!      ├──────────────► SessionManager::from_pool(pool)
16//!      ├──────────────► CronStore::from_pool(pool)
17//!      └──────────────► AuthStore::from_pool(pool)
18//! ```
19//!
20//! # Databases kept separate
21//!
22//! - **memory.db** — FTS5 + vector blobs, bulk reindex I/O, different access pattern
23//! - **vault.db** — AES-256-GCM encrypted, optional (needs `.vault_key`), isolated security boundary
24//!
25//! # Legacy migration
26//!
27//! On first open, [`CoreDb::new`] detects old `session.db` / `cron.db` / `users.db`
28//! in the same directory and migrates their data into `core.db`, then renames the
29//! old files to `*.db.migrated`.
30//!
31//! # Usage
32//!
33//! ```no_run
34//! # async fn example() -> starpod_core::Result<()> {
35//! use starpod_db::CoreDb;
36//!
37//! let db = CoreDb::new(std::path::Path::new(".starpod/db")).await?;
38//! // Pass db.pool().clone() to SessionManager, CronStore, AuthStore
39//! # Ok(())
40//! # }
41//! ```
42
43mod migrate;
44
45use std::path::Path;
46use std::str::FromStr;
47
48use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
49use sqlx::SqlitePool;
50use tracing::{debug, info};
51
52use starpod_core::{StarpodError, Result};
53
54/// Unified database for sessions, cron, and auth.
55///
56/// Owns a single `SqlitePool` backed by `core.db`. Individual stores
57/// (`SessionManager`, `CronStore`, `AuthStore`) receive a clone of the
58/// pool via `from_pool()` instead of opening their own connections.
59///
60/// The pool is configured with:
61/// - **WAL journal mode** — concurrent readers don't block writers
62/// - **Foreign keys ON** — referential integrity across all tables
63/// - **10 max connections** — shared across all stores
64pub struct CoreDb {
65    pool: SqlitePool,
66}
67
68impl CoreDb {
69    /// Open (or create) `core.db` inside `db_dir`.
70    ///
71    /// Runs all migrations, then checks for legacy database files
72    /// (`session.db`, `cron.db`, `users.db`) and migrates their data
73    /// into the unified database if found.
74    pub async fn new(db_dir: &Path) -> Result<Self> {
75        std::fs::create_dir_all(db_dir)?;
76
77        let db_path = db_dir.join("core.db");
78        let opts = SqliteConnectOptions::from_str(
79            &format!("sqlite://{}?mode=rwc", db_path.display()),
80        )
81        .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?
82        .pragma("journal_mode", "WAL")
83        .pragma("foreign_keys", "ON");
84
85        let pool = SqlitePoolOptions::new()
86            .max_connections(10)
87            .connect_with(opts)
88            .await
89            .map_err(|e| StarpodError::Database(format!("Failed to open core db: {}", e)))?;
90
91        sqlx::migrate!("./migrations")
92            .run(&pool)
93            .await
94            .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
95
96        debug!("core.db ready at {}", db_path.display());
97
98        // Migrate legacy databases if present
99        if migrate::has_legacy_dbs(db_dir) {
100            info!("Legacy database files detected — migrating to core.db");
101            migrate::migrate_legacy_dbs(&pool, db_dir).await?;
102        }
103
104        Ok(Self { pool })
105    }
106
107    /// Create an in-memory `CoreDb` for testing.
108    ///
109    /// Runs all migrations on a shared in-memory database. Each call
110    /// returns a fresh, empty database.
111    pub async fn in_memory() -> Result<Self> {
112        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
113            .map_err(|e| StarpodError::Database(format!("Invalid memory DB: {}", e)))?
114            .pragma("foreign_keys", "ON");
115
116        let pool = SqlitePoolOptions::new()
117            .max_connections(1)
118            .connect_with(opts)
119            .await
120            .map_err(|e| StarpodError::Database(format!("Failed to open in-memory db: {}", e)))?;
121
122        sqlx::migrate!("./migrations")
123            .run(&pool)
124            .await
125            .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
126
127        Ok(Self { pool })
128    }
129
130    /// Get a reference to the shared connection pool.
131    pub fn pool(&self) -> &SqlitePool {
132        &self.pool
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    // ── Basic lifecycle ─────────────────────────────────────────────
141
142    #[tokio::test]
143    async fn in_memory_creates_all_tables() {
144        let db = CoreDb::in_memory().await.unwrap();
145        let pool = db.pool();
146
147        // Auth tables
148        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
149            .fetch_one(pool).await.unwrap();
150        assert_eq!(row.0, 0);
151
152        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
153            .fetch_one(pool).await.unwrap();
154        assert_eq!(row.0, 0);
155
156        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
157            .fetch_one(pool).await.unwrap();
158        assert_eq!(row.0, 0);
159
160        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM auth_audit_log")
161            .fetch_one(pool).await.unwrap();
162        assert_eq!(row.0, 0);
163
164        // Session tables
165        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_metadata")
166            .fetch_one(pool).await.unwrap();
167        assert_eq!(row.0, 0);
168
169        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
170            .fetch_one(pool).await.unwrap();
171        assert_eq!(row.0, 0);
172
173        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM usage_stats")
174            .fetch_one(pool).await.unwrap();
175        assert_eq!(row.0, 0);
176
177        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
178            .fetch_one(pool).await.unwrap();
179        assert_eq!(row.0, 0);
180
181        // Cron tables
182        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_jobs")
183            .fetch_one(pool).await.unwrap();
184        assert_eq!(row.0, 0);
185
186        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
187            .fetch_one(pool).await.unwrap();
188        assert_eq!(row.0, 0);
189    }
190
191    #[tokio::test]
192    async fn on_disk_creates_core_db() {
193        let tmp = tempfile::tempdir().unwrap();
194        let db = CoreDb::new(tmp.path()).await.unwrap();
195
196        assert!(tmp.path().join("core.db").exists());
197
198        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
199            .fetch_one(db.pool()).await.unwrap();
200        assert_eq!(row.0, 0);
201    }
202
203    #[tokio::test]
204    async fn on_disk_creates_parent_dirs() {
205        let tmp = tempfile::tempdir().unwrap();
206        let nested = tmp.path().join("deep").join("nested").join("db");
207        let db = CoreDb::new(&nested).await.unwrap();
208
209        assert!(nested.join("core.db").exists());
210        drop(db);
211    }
212
213    #[tokio::test]
214    async fn reopen_is_idempotent() {
215        let tmp = tempfile::tempdir().unwrap();
216
217        // First open — creates the DB
218        let db1 = CoreDb::new(tmp.path()).await.unwrap();
219        sqlx::query(
220            "INSERT INTO users (id, email, display_name, role, is_active, created_at, updated_at) \
221             VALUES ('u1', 'a@b.com', 'A', 'admin', 1, '2024-01-01', '2024-01-01')"
222        ).execute(db1.pool()).await.unwrap();
223        drop(db1);
224
225        // Second open — should find existing data, not recreate
226        let db2 = CoreDb::new(tmp.path()).await.unwrap();
227        let row: (String,) = sqlx::query_as("SELECT email FROM users WHERE id = 'u1'")
228            .fetch_one(db2.pool()).await.unwrap();
229        assert_eq!(row.0, "a@b.com");
230    }
231
232    // ── Foreign key enforcement ─────────────────────────────────────
233
234    #[tokio::test]
235    async fn fk_rejects_invalid_api_key_user() {
236        let db = CoreDb::in_memory().await.unwrap();
237
238        let result = sqlx::query(
239            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
240             VALUES ('k1', 'nonexistent', 'sp_', 'hash', '2024-01-01')"
241        ).execute(db.pool()).await;
242
243        assert!(result.is_err(), "FK should reject api_key with invalid user_id");
244    }
245
246    #[tokio::test]
247    async fn fk_rejects_invalid_telegram_link_user() {
248        let db = CoreDb::in_memory().await.unwrap();
249
250        let result = sqlx::query(
251            "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
252             VALUES (123, 'nonexistent', 'bob', '2024-01-01')"
253        ).execute(db.pool()).await;
254
255        assert!(result.is_err(), "FK should reject telegram_link with invalid user_id");
256    }
257
258    #[tokio::test]
259    async fn fk_rejects_invalid_session_message() {
260        let db = CoreDb::in_memory().await.unwrap();
261
262        let result = sqlx::query(
263            "INSERT INTO session_messages (session_id, role, content, timestamp) \
264             VALUES ('nonexistent', 'user', 'hello', '2024-01-01')"
265        ).execute(db.pool()).await;
266
267        assert!(result.is_err(), "FK should reject message with invalid session_id");
268    }
269
270    #[tokio::test]
271    async fn fk_rejects_invalid_cron_run_job() {
272        let db = CoreDb::in_memory().await.unwrap();
273
274        let result = sqlx::query(
275            "INSERT INTO cron_runs (id, job_id, started_at, status) \
276             VALUES ('r1', 'nonexistent', 1000, 'pending')"
277        ).execute(db.pool()).await;
278
279        assert!(result.is_err(), "FK should reject cron_run with invalid job_id");
280    }
281
282    // ── CASCADE deletes ─────────────────────────────────────────────
283
284    #[tokio::test]
285    async fn cascade_delete_user_removes_api_keys() {
286        let db = CoreDb::in_memory().await.unwrap();
287        let pool = db.pool();
288
289        sqlx::query(
290            "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
291             VALUES ('u1', 'a@b.com', 'admin', 1, '2024-01-01', '2024-01-01')"
292        ).execute(pool).await.unwrap();
293
294        sqlx::query(
295            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
296             VALUES ('k1', 'u1', 'sp_', 'hash1', '2024-01-01')"
297        ).execute(pool).await.unwrap();
298
299        sqlx::query(
300            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
301             VALUES ('k2', 'u1', 'sp_', 'hash2', '2024-01-01')"
302        ).execute(pool).await.unwrap();
303
304        // Delete user
305        sqlx::query("DELETE FROM users WHERE id = 'u1'")
306            .execute(pool).await.unwrap();
307
308        // API keys should be gone
309        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
310            .fetch_one(pool).await.unwrap();
311        assert_eq!(row.0, 0);
312    }
313
314    #[tokio::test]
315    async fn cascade_delete_user_removes_telegram_links() {
316        let db = CoreDb::in_memory().await.unwrap();
317        let pool = db.pool();
318
319        sqlx::query(
320            "INSERT INTO users (id, role, is_active, created_at, updated_at) \
321             VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')"
322        ).execute(pool).await.unwrap();
323
324        sqlx::query(
325            "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
326             VALUES (999, 'u1', 'bob', '2024-01-01')"
327        ).execute(pool).await.unwrap();
328
329        sqlx::query("DELETE FROM users WHERE id = 'u1'")
330            .execute(pool).await.unwrap();
331
332        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
333            .fetch_one(pool).await.unwrap();
334        assert_eq!(row.0, 0);
335    }
336
337    #[tokio::test]
338    async fn cascade_delete_session_removes_messages_and_compaction() {
339        let db = CoreDb::in_memory().await.unwrap();
340        let pool = db.pool();
341
342        sqlx::query(
343            "INSERT INTO session_metadata (id, created_at, last_message_at) \
344             VALUES ('s1', '2024-01-01', '2024-01-01')"
345        ).execute(pool).await.unwrap();
346
347        sqlx::query(
348            "INSERT INTO session_messages (session_id, role, content, timestamp) \
349             VALUES ('s1', 'user', 'hi', '2024-01-01')"
350        ).execute(pool).await.unwrap();
351
352        sqlx::query(
353            "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary) \
354             VALUES ('s1', '2024-01-01', 'auto', 1000, 'summary')"
355        ).execute(pool).await.unwrap();
356
357        sqlx::query("DELETE FROM session_metadata WHERE id = 's1'")
358            .execute(pool).await.unwrap();
359
360        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
361            .fetch_one(pool).await.unwrap();
362        assert_eq!(row.0, 0);
363
364        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
365            .fetch_one(pool).await.unwrap();
366        assert_eq!(row.0, 0);
367    }
368
369    #[tokio::test]
370    async fn cascade_delete_cron_job_removes_runs() {
371        let db = CoreDb::in_memory().await.unwrap();
372        let pool = db.pool();
373
374        sqlx::query(
375            "INSERT INTO cron_jobs (id, name, prompt, schedule_type, schedule_value, created_at) \
376             VALUES ('j1', 'test', 'do stuff', 'interval', '60000', 1000)"
377        ).execute(pool).await.unwrap();
378
379        sqlx::query(
380            "INSERT INTO cron_runs (id, job_id, started_at, status) \
381             VALUES ('r1', 'j1', 2000, 'success')"
382        ).execute(pool).await.unwrap();
383
384        sqlx::query("DELETE FROM cron_jobs WHERE id = 'j1'")
385            .execute(pool).await.unwrap();
386
387        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
388            .fetch_one(pool).await.unwrap();
389        assert_eq!(row.0, 0);
390    }
391
392    // ── Cross-domain queries (the whole point of consolidation) ─────
393
394    #[tokio::test]
395    async fn cross_domain_join_sessions_with_usage_by_user() {
396        let db = CoreDb::in_memory().await.unwrap();
397        let pool = db.pool();
398
399        // Create a user
400        sqlx::query(
401            "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
402             VALUES ('u1', 'alice@test.com', 'admin', 1, '2024-01-01', '2024-01-01')"
403        ).execute(pool).await.unwrap();
404
405        // Create sessions for this user
406        sqlx::query(
407            "INSERT INTO session_metadata (id, created_at, last_message_at, user_id) \
408             VALUES ('s1', '2024-01-01', '2024-01-01', 'u1')"
409        ).execute(pool).await.unwrap();
410
411        // Record usage
412        sqlx::query(
413            "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cost_usd, timestamp, user_id) \
414             VALUES ('s1', 1, 100, 200, 0.01, '2024-01-01', 'u1')"
415        ).execute(pool).await.unwrap();
416
417        // Cross-domain query: total cost per user (joins users + usage_stats)
418        let row: (String, f64) = sqlx::query_as(
419            "SELECT u.email, SUM(us.cost_usd) as total_cost \
420             FROM users u \
421             JOIN usage_stats us ON us.user_id = u.id \
422             GROUP BY u.id"
423        ).fetch_one(pool).await.unwrap();
424
425        assert_eq!(row.0, "alice@test.com");
426        assert!((row.1 - 0.01).abs() < 0.001);
427    }
428
429    #[tokio::test]
430    async fn pool_clone_shares_state() {
431        let db = CoreDb::in_memory().await.unwrap();
432
433        // Insert on original pool
434        sqlx::query(
435            "INSERT INTO users (id, role, is_active, created_at, updated_at) \
436             VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')"
437        ).execute(db.pool()).await.unwrap();
438
439        // Read from cloned pool (simulates what stores do)
440        let pool2 = db.pool().clone();
441        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
442            .fetch_one(&pool2).await.unwrap();
443        assert_eq!(row.0, 1);
444    }
445}